diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index 61d385fd..b6f1e668 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -15,7 +15,7 @@ use smoltcp::{ }; use super::{ - poll::{FnHelper, PollContext}, + poll::{FnHelper, PollContext, SocketTableAction}, poll_iface::PollableIface, port::BindPortConfig, time::get_network_timestamp, @@ -183,26 +183,23 @@ impl IfaceCommon { interface.context_mut().now = get_network_timestamp(); let mut sockets = self.sockets.lock(); - let mut dead_tcp_conns = Vec::new(); + let mut socket_actions = Vec::new(); - let mut new_tcp_conns = Vec::new(); - - let mut context = PollContext::new( - interface.as_mut(), - &sockets, - &mut new_tcp_conns, - &mut dead_tcp_conns, - ); + let mut context = PollContext::new(interface.as_mut(), &sockets, &mut socket_actions); context.poll_ingress(device, &mut process_phy, &mut dispatch_phy); context.poll_egress(device, &mut dispatch_phy); // Insert new connections and remove dead connections. - for new_tcp_conn in new_tcp_conns.into_iter() { - let res = sockets.insert_connection(new_tcp_conn); - debug_assert!(res.is_ok()); - } - for dead_conn_key in dead_tcp_conns.into_iter() { - sockets.remove_dead_tcp_connection(&dead_conn_key); + for action in socket_actions.into_iter() { + match action { + SocketTableAction::AddTcpConn(new_tcp_conn) => { + let res = sockets.insert_connection(new_tcp_conn); + debug_assert!(res.is_ok()); + } + SocketTableAction::DelTcpConn(dead_conn_key) => { + sockets.remove_dead_tcp_connection(&dead_conn_key); + } + } } // Notify all socket events. diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs index 853a1f95..6c6aa805 100644 --- a/kernel/libs/aster-bigtcp/src/iface/poll.rs +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -25,22 +25,28 @@ use crate::{ pub(super) struct PollContext<'a, E: Ext> { iface: PollableIfaceMut<'a, E>, sockets: &'a SocketTable, - new_tcp_conns: &'a mut Vec>>, - dead_tcp_conns: &'a mut Vec, + actions: &'a mut Vec>, +} + +/// Socket table actions such as adding or removing TCP connections. +/// +/// Note that they must be performed in order. This is because the same connection key can occur +/// multiple times, but with different types of operations (e.g., add or remove). +pub(super) enum SocketTableAction { + AddTcpConn(Arc>), + DelTcpConn(ConnectionKey), } impl<'a, E: Ext> PollContext<'a, E> { pub(super) fn new( iface: PollableIfaceMut<'a, E>, sockets: &'a SocketTable, - new_tcp_conns: &'a mut Vec>>, - dead_tcp_conns: &'a mut Vec, + actions: &'a mut Vec>, ) -> Self { Self { iface, sockets, - new_tcp_conns, - dead_tcp_conns, + actions, } } } @@ -159,7 +165,60 @@ impl PollContext<'_, E> { ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> Option<(IpRepr, TcpRepr<'static>)> { - // Process packets that request to create new connections first. + // Process packets belonging to existing connections first. + // Note that we must do this first because SYN packets may match existing TIME-WAIT + // sockets. See comments in `TcpConnectionBg::process` for details. + let connection_key = ConnectionKey::new( + ip_repr.dst_addr(), + tcp_repr.dst_port, + ip_repr.src_addr(), + tcp_repr.src_port, + ); + let mut connection_in_table = self.sockets.lookup_connection(&connection_key); + + loop { + // First try the connection in the socket table, as this is the most common. If it + // fails, it might mean that the connection is dead, the next step is to try the new + // connections instead. + let (should_break, connection) = if let Some(conn) = connection_in_table.take() { + (false, Some(conn)) + } else { + // Find in reverse order because old connections must have been dead. + ( + true, + self.actions + .iter() + .rev() + .flat_map(|action| match action { + SocketTableAction::AddTcpConn(conn) => Some(conn), + SocketTableAction::DelTcpConn(_) => None, + }) + .find(|conn| conn.connection_key() == &connection_key), + ) + }; + + if let Some(connection) = connection { + let (process_result, became_dead) = + connection.process(&mut self.iface, ip_repr, tcp_repr); + if *became_dead { + self.actions + .push(SocketTableAction::DelTcpConn(*connection.connection_key())); + } + match process_result { + TcpProcessResult::NotProcessed => {} + TcpProcessResult::Processed => return None, + TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { + return Some((ip_repr, tcp_repr)) + } + } + } + + if should_break { + break; + } + } + + // Process packets that request to create new connections second. if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { let listener_key = ListenerKey::new(ip_repr.dst_addr(), tcp_repr.dst_port); if let Some(listener) = self.sockets.lookup_listener(&listener_key) { @@ -167,7 +226,7 @@ impl PollContext<'_, E> { listener.process(&mut self.iface, ip_repr, tcp_repr); if let Some(tcp_conn) = new_tcp_conn { - self.new_tcp_conns.push(tcp_conn); + self.actions.push(SocketTableAction::AddTcpConn(tcp_conn)); } match processed { @@ -180,36 +239,6 @@ impl PollContext<'_, E> { } } - // Process packets belonging to existing connections second. - let connection_key = ConnectionKey::new( - ip_repr.dst_addr(), - tcp_repr.dst_port, - ip_repr.src_addr(), - tcp_repr.src_port, - ); - let connection = if let Some(connection) = self.sockets.lookup_connection(&connection_key) { - Some(connection) - } else { - self.new_tcp_conns - .iter() - .find(|tcp_conn| tcp_conn.connection_key() == &connection_key) - }; - - if let Some(connection) = connection { - let (process_result, became_dead) = - connection.process(&mut self.iface, ip_repr, tcp_repr); - if *became_dead { - self.dead_tcp_conns.push(*connection.connection_key()); - } - match process_result { - TcpProcessResult::NotProcessed => {} - TcpProcessResult::Processed => return None, - TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { - return Some((ip_repr, tcp_repr)) - } - } - } - // "In no case does receipt of a segment containing RST give rise to a RST in response." // See . if tcp_repr.control == TcpControl::Rst { @@ -359,7 +388,6 @@ impl PollContext<'_, E> { { let mut tx_token = Some(tx_token); let mut did_something = false; - let mut dead_conns = Vec::new(); loop { let Some(socket) = self.iface.pop_pending_tcp() else { @@ -374,8 +402,7 @@ impl PollContext<'_, E> { let (reply, became_dead) = TcpConnectionBg::dispatch(&socket, &mut self.iface, |iface, ip_repr, tcp_repr| { - let mut this = - PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns); + let mut this = PollContext::new(iface, self.sockets, self.actions); if !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( @@ -407,7 +434,8 @@ impl PollContext<'_, E> { }); if *became_dead { - self.dead_tcp_conns.push(*socket.connection_key()); + self.actions + .push(SocketTableAction::DelTcpConn(*socket.connection_key())); } match (deferred, reply) { @@ -447,8 +475,6 @@ impl PollContext<'_, E> { } } - self.dead_tcp_conns.append(&mut dead_conns); - (did_something, tx_token) } @@ -459,7 +485,8 @@ impl PollContext<'_, E> { { let mut tx_token = Some(tx_token); let mut did_something = false; - let mut dead_conns = Vec::new(); + + let mut actions = Vec::new(); for socket in self.sockets.udp_socket_iter() { if !socket.need_dispatch() { @@ -475,8 +502,7 @@ impl PollContext<'_, E> { let (cx, pending) = self.iface.inner_mut(); socket.dispatch(cx, |cx, ip_repr, udp_repr, udp_payload| { let iface = PollableIfaceMut::new(cx, pending); - let mut this = - PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns); + let mut this = PollContext::new(iface, self.sockets, &mut actions); if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( @@ -527,10 +553,10 @@ impl PollContext<'_, E> { } } - // `dead_conns` should be empty, - // because we are using UDP sockets, - // and the `dead_conns` contains only dead TCP connections. - debug_assert!(dead_conns.is_empty()); + // `actions` should be empty, + // because we are dealing with UDP sockets, + // and the `actions` contains only TCP actions. + debug_assert!(actions.is_empty()); (did_something, tx_token) } diff --git a/kernel/libs/aster-bigtcp/src/socket_table.rs b/kernel/libs/aster-bigtcp/src/socket_table.rs index 4244b099..3e681c42 100644 --- a/kernel/libs/aster-bigtcp/src/socket_table.rs +++ b/kernel/libs/aster-bigtcp/src/socket_table.rs @@ -286,14 +286,13 @@ impl SocketTable { &mut self.connection_buckets[bucket_index as usize] }; - if let Some(index) = bucket + let index = bucket .connections .iter() .position(|tcp_connection| tcp_connection.connection_key() == key) - { - let connection = bucket.connections.swap_remove(index); - connection.on_dead_events(); - } + .unwrap(); + let connection = bucket.connections.swap_remove(index); + connection.on_dead_events(); } pub(crate) fn remove_udp_socket(