diff --git a/kernel/libs/aster-bigtcp/src/boolean_value.rs b/kernel/libs/aster-bigtcp/src/boolean_value.rs new file mode 100644 index 00000000..eae5687e --- /dev/null +++ b/kernel/libs/aster-bigtcp/src/boolean_value.rs @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MPL-2.0 + +/// Defines a struct representing a boolean value. +/// +/// In some cases, it is beneficial to use a struct instead of +/// a plain boolean value to clarify the semantics. +/// This macro provides a convenient way to define a struct +/// that represents a boolean value. +#[macro_export] +macro_rules! define_boolean_value { + ( + $(#[$attr:meta])* + $name: ident + ) => { + $(#[$attr])* + #[derive(Debug, Clone, Copy)] + pub struct $name(bool); + + impl $name { + pub const TRUE: Self = Self(true); + pub const FALSE: Self = Self(false); + } + + impl core::ops::Deref for $name { + type Target = bool; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + }; +} diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index 28a2d664..e0f52538 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -23,7 +23,7 @@ use super::{ use crate::{ errors::BindError, ext::Ext, - socket::{TcpListenerBg, UdpSocketBg}, + socket::{TcpConnectionBg, TcpListenerBg, UdpSocketBg}, socket_table::SocketTable, }; @@ -152,10 +152,15 @@ impl IfaceCommon { pub(crate) fn remove_tcp_listener(&self, socket: &Arc>) { let mut sockets = self.sockets.lock(); - let removed = sockets.remove_listener(socket); + let removed = sockets.remove_listener(socket.listener_key()); debug_assert!(removed.is_some()); } + pub(crate) fn remove_dead_tcp_connection(&self, socket: &Arc>) { + let mut sockets = self.sockets.lock(); + sockets.remove_dead_tcp_connection(socket.connection_key()); + } + pub(crate) fn remove_udp_socket(&self, socket: &Arc>) { let mut sockets = self.sockets.lock(); let removed = sockets.remove_udp_socket(socket); @@ -184,11 +189,17 @@ impl IfaceCommon { interface.context().now = get_network_timestamp(); let mut sockets = self.sockets.lock(); + let mut dead_tcp_conns = Vec::new(); loop { let mut new_tcp_conns = Vec::new(); - let mut context = PollContext::new(interface.context(), &sockets, &mut new_tcp_conns); + let mut context = PollContext::new( + interface.context(), + &sockets, + &mut new_tcp_conns, + &mut dead_tcp_conns, + ); context.poll_ingress(device, &mut process_phy, &mut dispatch_phy); context.poll_egress(device, &mut dispatch_phy); @@ -204,7 +215,9 @@ impl IfaceCommon { } } - sockets.remove_dead_tcp_connections(); + for dead_conn_key in dead_tcp_conns.into_iter() { + sockets.remove_dead_tcp_connection(&dead_conn_key); + } for socket in sockets.tcp_listener_iter() { if socket.has_events() { diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs index 9c7921a6..d5439a55 100644 --- a/kernel/libs/aster-bigtcp/src/iface/poll.rs +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -25,6 +25,7 @@ pub(super) struct PollContext<'a, E: Ext> { iface_cx: &'a mut Context, sockets: &'a SocketTable, new_tcp_conns: &'a mut Vec>>, + dead_tcp_conns: &'a mut Vec, } impl<'a, E: Ext> PollContext<'a, E> { @@ -32,11 +33,13 @@ impl<'a, E: Ext> PollContext<'a, E> { iface_cx: &'a mut Context, sockets: &'a SocketTable, new_tcp_conns: &'a mut Vec>>, + dead_tcp_conns: &'a mut Vec, ) -> Self { Self { iface_cx, sockets, new_tcp_conns, + dead_tcp_conns, } } } @@ -193,7 +196,12 @@ impl PollContext<'_, E> { }; if let Some(connection) = connection { - match connection.process(self.iface_cx, ip_repr, tcp_repr) { + let (process_result, became_dead) = + connection.process(self.iface_cx, ip_repr, tcp_repr); + if *became_dead { + self.dead_tcp_conns.push(*connection.connection_key()); + } + match process_result { TcpProcessResult::NotProcessed => {} TcpProcessResult::Processed => return None, TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { @@ -349,6 +357,7 @@ impl PollContext<'_, E> { { let mut tx_token = Some(tx_token); let mut did_something = false; + let mut dead_conns = Vec::new(); // We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable // reference at this point. Instead, we will retry after the entire poll is complete. @@ -363,9 +372,10 @@ impl PollContext<'_, E> { let mut deferred = None; - let reply = + let (reply, became_dead) = TcpConnectionBg::dispatch(socket, self.iface_cx, |cx, ip_repr, tcp_repr| { - let mut this = PollContext::new(cx, self.sockets, self.new_tcp_conns); + let mut this = + PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns); if !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( @@ -396,6 +406,10 @@ impl PollContext<'_, E> { None }); + if *became_dead { + self.dead_tcp_conns.push(*socket.connection_key()); + } + match (deferred, reply) { (None, None) => (), (Some((ip_repr, ip_payload)), None) => { @@ -433,6 +447,8 @@ impl PollContext<'_, E> { } } + self.dead_tcp_conns.append(&mut dead_conns); + (did_something, tx_token) } @@ -443,6 +459,7 @@ impl PollContext<'_, E> { { let mut tx_token = Some(tx_token); let mut did_something = false; + let mut dead_conns = Vec::new(); for socket in self.sockets.udp_socket_iter() { if !socket.need_dispatch(self.iface_cx.now()) { @@ -456,7 +473,8 @@ impl PollContext<'_, E> { let mut deferred = None; socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| { - let mut this = PollContext::new(cx, self.sockets, self.new_tcp_conns); + let mut this = + PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns); if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( @@ -507,6 +525,11 @@ impl PollContext<'_, E> { } } + // `dead_conns` should be empty, + // because we are using UDP sockets, + // and the `dead_conns` contains only dead TCP connections. + debug_assert!(dead_conns.is_empty()); + (did_something, tx_token) } } diff --git a/kernel/libs/aster-bigtcp/src/lib.rs b/kernel/libs/aster-bigtcp/src/lib.rs index 67aef5bf..a0aadb7d 100644 --- a/kernel/libs/aster-bigtcp/src/lib.rs +++ b/kernel/libs/aster-bigtcp/src/lib.rs @@ -14,6 +14,7 @@ #![deny(unsafe_code)] #![feature(extract_if)] +pub mod boolean_value; pub mod device; pub mod errors; pub mod ext; diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 6edf991b..2351c29c 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -3,7 +3,7 @@ use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; use core::{ ops::{Deref, DerefMut}, - sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, + sync::atomic::{AtomicU64, AtomicU8, Ordering}, }; use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard}; @@ -23,6 +23,7 @@ use super::{ RawTcpSocket, RawUdpSocket, TcpStateCheck, }; use crate::{ + define_boolean_value, errors::{ tcp::{ConnectError, ListenError}, udp::SendError, @@ -65,7 +66,6 @@ pub struct SocketBg, E: Ext> { /// States needed by [`TcpConnectionBg`]. pub struct TcpConnectionInner { socket: SpinLock, LocalIrqDisabled>, - is_dead: AtomicBool, connection_key: ConnectionKey, } @@ -89,8 +89,16 @@ impl DerefMut for RawTcpSocketExt { } } +define_boolean_value!( + /// Whether the TCP connection became dead. + TcpConnBecameDead +); + impl RawTcpSocketExt { - fn on_new_state(&mut self, this: &Arc>) -> SocketEvents { + fn on_new_state( + &mut self, + this: &Arc>, + ) -> (SocketEvents, TcpConnBecameDead) { if self.may_send() && !self.has_connected { self.has_connected = true; @@ -103,43 +111,49 @@ impl RawTcpSocketExt { } } - self.update_dead(this); + let became_dead = self.check_dead(this); - if self.is_peer_closed() { + let events = if self.is_peer_closed() { SocketEvents::PEER_CLOSED } else if self.is_closed() { SocketEvents::CLOSED } else { SocketEvents::empty() - } + }; + + (events, became_dead) } - /// Updates whether the TCP connection is dead. + /// Checks whether the TCP connection becomes dead. /// - /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections. + /// A TCP connection is considered dead when and only when the TCP socket is in the closed + /// state, meaning it's no longer accepting packets from the network. This is different from + /// the socket file being closed, which only initiates the socket close process. /// /// This method must be called after handling network events. However, it is not necessary to /// call this method after handling non-closing user events, because the socket can never be /// dead if it is not closed. - fn update_dead(&self, this: &Arc>) { + fn check_dead(&self, this: &Arc>) -> TcpConnBecameDead { // FIXME: This is a temporary workaround to mark TimeWait socket as dead. if self.state() == smoltcp::socket::tcp::State::Closed || self.state() == smoltcp::socket::tcp::State::TimeWait { - this.inner.is_dead.store(true, Ordering::Relaxed); + return TcpConnBecameDead::TRUE; } // According to the current smoltcp implementation, a backlog socket will return back to // the `Listen` state if the connection is RSTed before its establishment. if self.state() == smoltcp::socket::tcp::State::Listen { - this.inner.is_dead.store(true, Ordering::Relaxed); - if let Some(ref listener) = self.listener { let mut backlog = listener.inner.backlog.lock(); // This may fail due to race conditions, but it's fine. let _ = backlog.connecting.remove(&this.inner.connection_key); } + + return TcpConnBecameDead::TRUE; } + + TcpConnBecameDead::FALSE } } @@ -160,7 +174,6 @@ impl TcpConnectionInner { TcpConnectionInner { socket: SpinLock::new(socket_ext), - is_dead: AtomicBool::new(false), connection_key, } } @@ -168,38 +181,31 @@ impl TcpConnectionInner { fn lock(&self) -> SpinLockGuard, LocalIrqDisabled> { self.socket.lock() } - - /// Returns whether the TCP connection is dead. - /// - /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections. - fn is_dead(&self) -> bool { - self.is_dead.load(Ordering::Relaxed) - } - - /// Sets the TCP connection in [`TimeWait`] state as dead. - /// - /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections. - /// - /// [`TimeWait`]: smoltcp::socket::tcp::State::TimeWait - fn set_dead_timewait(&self, socket: &RawTcpSocketExt) { - debug_assert!(socket.state() == smoltcp::socket::tcp::State::TimeWait); - self.is_dead.store(true, Ordering::Relaxed); - } } impl Inner for TcpConnectionInner { type Observer = E::TcpEventObserver; fn on_drop(this: &Arc>) { - let mut socket = this.inner.lock(); + let became_dead = { + let mut socket = this.inner.lock(); - // FIXME: Send RSTs when there is unread data. - socket.close(); + // FIXME: Send RSTs when there is unread data. + socket.close(); - // A TCP connection may not be appropriate for immediate removal. We leave the removal - // decision to the polling logic. - this.update_next_poll_at_ms(PollAt::Now); - socket.update_dead(this); + if *socket.check_dead(this) { + true + } else { + // A TCP connection may not be appropriate for immediate removal. We leave the removal + // decision to the polling logic. + this.update_next_poll_at_ms(PollAt::Now); + false + } + }; + + if became_dead { + this.bound.iface().common().remove_dead_tcp_connection(this); + } } } @@ -318,21 +324,10 @@ pub enum ConnectState { Refused, } -#[derive(Debug, Clone, Copy)] -pub struct NeedIfacePoll(bool); - -impl NeedIfacePoll { - pub const TRUE: Self = Self(true); - pub const FALSE: Self = Self(false); -} - -impl Deref for NeedIfacePoll { - type Target = bool; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} +define_boolean_value!( + /// Whether the iface needs to be polled + NeedIfacePoll +); impl TcpConnection { /// Connects to a remote endpoint. @@ -746,7 +741,7 @@ impl, E: Ext> SocketBg { match poll_at { PollAt::Now => { self.next_poll_at_ms.store(0, Ordering::Relaxed); - NeedIfacePoll(true) + NeedIfacePoll::TRUE } PollAt::Time(instant) => { let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed); @@ -759,22 +754,13 @@ impl, E: Ext> SocketBg { } PollAt::Ingress => { self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed); - NeedIfacePoll(false) + NeedIfacePoll::FALSE } } } } impl TcpConnectionBg { - /// Returns whether the TCP connection is dead. - /// - /// A TCP connection is considered dead when and only when the TCP socket is in the closed - /// state, meaning it's no longer accepting packets from the network. This is different from - /// the socket file being closed, which only initiates the socket close process. - pub(crate) fn is_dead(&self) -> bool { - self.inner.is_dead() - } - pub(crate) const fn connection_key(&self) -> &ConnectionKey { &self.inner.connection_key } @@ -816,11 +802,11 @@ impl TcpConnectionBg { cx: &mut Context, ip_repr: &IpRepr, tcp_repr: &TcpRepr, - ) -> TcpProcessResult { + ) -> (TcpProcessResult, TcpConnBecameDead) { let mut socket = self.inner.lock(); if !socket.accepts(cx, ip_repr, tcp_repr) { - return TcpProcessResult::NotProcessed; + return (TcpProcessResult::NotProcessed, TcpConnBecameDead::FALSE); } // If the socket is in the TimeWait state and a new packet arrives that is a SYN packet @@ -840,8 +826,7 @@ impl TcpConnectionBg { && tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { - self.inner.set_dead_timewait(&socket); - return TcpProcessResult::NotProcessed; + return (TcpProcessResult::NotProcessed, TcpConnBecameDead::TRUE); } let old_state = socket.state(); @@ -854,14 +839,18 @@ impl TcpConnectionBg { Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; - if socket.state() != old_state { - events |= socket.on_new_state(self); - } + let became_dead = if socket.state() != old_state { + let (new_events, became_dead) = socket.on_new_state(self); + events |= new_events; + became_dead + } else { + TcpConnBecameDead::FALSE + }; self.add_events(events); self.update_next_poll_at_ms(socket.poll_at(cx)); - result + (result, became_dead) } /// Tries to generate an outgoing packet and dispatches the generated packet. @@ -869,7 +858,7 @@ impl TcpConnectionBg { this: &Arc, cx: &mut Context, dispatch: D, - ) -> Option<(IpRepr, TcpRepr<'static>)> + ) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead) where D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, { @@ -896,14 +885,18 @@ impl TcpConnectionBg { events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; } - if socket.state() != old_state { - events |= socket.on_new_state(this); - } + let became_dead = if socket.state() != old_state { + let (new_events, became_dead) = socket.on_new_state(this); + events |= new_events; + became_dead + } else { + TcpConnBecameDead::FALSE + }; this.add_events(events); this.update_next_poll_at_ms(socket.poll_at(cx)); - reply + (reply, became_dead) } } diff --git a/kernel/libs/aster-bigtcp/src/socket_table.rs b/kernel/libs/aster-bigtcp/src/socket_table.rs index 4b25550b..52316b77 100644 --- a/kernel/libs/aster-bigtcp/src/socket_table.rs +++ b/kernel/libs/aster-bigtcp/src/socket_table.rs @@ -265,12 +265,7 @@ impl SocketTable { .find(|connection| connection.connection_key() == key) } - pub(crate) fn remove_listener( - &mut self, - listener: &TcpListenerBg, - ) -> Option>> { - let key = listener.listener_key(); - + pub(crate) fn remove_listener(&mut self, key: &ListenerKey) -> Option>> { let bucket = { let hash = key.hash(); let bucket_index = hash & LISTENER_BUCKET_MASK; @@ -280,10 +275,27 @@ impl SocketTable { let index = bucket .listeners .iter() - .position(|tcp_listener| tcp_listener.listener_key() == listener.listener_key())?; + .position(|tcp_listener| tcp_listener.listener_key() == key)?; Some(bucket.listeners.swap_remove(index)) } + pub(crate) fn remove_dead_tcp_connection(&mut self, key: &ConnectionKey) { + let bucket = { + let hash = key.hash(); + let bucket_index = hash & CONNECTION_BUCKET_MASK; + &mut self.connection_buckets[bucket_index as usize] + }; + + if let Some(index) = bucket + .connections + .iter() + .position(|tcp_connection| tcp_connection.connection_key() == key) + { + let connection = bucket.connections.swap_remove(index); + connection.on_dead_events(); + } + } + pub(crate) fn remove_udp_socket( &mut self, socket: &Arc>, @@ -295,17 +307,6 @@ impl SocketTable { Some(self.udp_sockets.swap_remove(index)) } - pub(crate) fn remove_dead_tcp_connections(&mut self) { - for connection_bucket in self.connection_buckets.iter_mut() { - for tcp_conn in connection_bucket - .connections - .extract_if(|connection| connection.is_dead()) - { - tcp_conn.on_dead_events(); - } - } - } - pub(crate) fn tcp_listener_iter(&self) -> impl Iterator>> { self.listener_buckets .iter()