diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index afe94033d..61d385fdd 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -16,6 +16,7 @@ use smoltcp::{ use super::{ poll::{FnHelper, PollContext}, + poll_iface::PollableIface, port::BindPortConfig, time::get_network_timestamp, Iface, @@ -29,7 +30,7 @@ use crate::{ pub struct IfaceCommon { name: String, - interface: SpinLock, + interface: SpinLock, LocalIrqDisabled>, used_ports: SpinLock, LocalIrqDisabled>, sockets: SpinLock, LocalIrqDisabled>, sched_poll: E::ScheduleNextPoll, @@ -41,13 +42,11 @@ impl IfaceCommon { interface: smoltcp::iface::Interface, sched_poll: E::ScheduleNextPoll, ) -> Self { - let sockets = SocketTable::new(); - Self { name, - interface: SpinLock::new(interface), + interface: SpinLock::new(PollableIface::new(interface)), used_ports: SpinLock::new(BTreeMap::new()), - sockets: SpinLock::new(sockets), + sockets: SpinLock::new(SocketTable::new()), sched_poll, } } @@ -65,10 +64,10 @@ impl IfaceCommon { } } -// Lock order: interface -> sockets +// Lock order: `interface` -> `sockets` impl IfaceCommon { /// Acquires the lock to the interface. - pub(crate) fn interface(&self) -> SpinLockGuard { + pub(crate) fn interface(&self) -> SpinLockGuard<'_, PollableIface, LocalIrqDisabled> { self.interface.lock() } @@ -181,51 +180,42 @@ impl IfaceCommon { Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), { let mut interface = self.interface(); - interface.context().now = get_network_timestamp(); + interface.context_mut().now = get_network_timestamp(); let mut sockets = self.sockets.lock(); let mut dead_tcp_conns = Vec::new(); - loop { - let mut new_tcp_conns = Vec::new(); + let mut new_tcp_conns = Vec::new(); - let mut context = PollContext::new( - interface.context(), - &sockets, - &mut new_tcp_conns, - &mut dead_tcp_conns, - ); - context.poll_ingress(device, &mut process_phy, &mut dispatch_phy); - context.poll_egress(device, &mut dispatch_phy); + let mut context = PollContext::new( + interface.as_mut(), + &sockets, + &mut new_tcp_conns, + &mut dead_tcp_conns, + ); + context.poll_ingress(device, &mut process_phy, &mut dispatch_phy); + context.poll_egress(device, &mut dispatch_phy); - // New packets sent by new connections are not handled. So if there are new - // connections, try again. - if new_tcp_conns.is_empty() { - break; - } else { - new_tcp_conns.into_iter().for_each(|tcp_conn| { - let res = sockets.insert_connection(tcp_conn); - debug_assert!(res.is_ok()); - }); - } + // Insert new connections and remove dead connections. + for new_tcp_conn in new_tcp_conns.into_iter() { + let res = sockets.insert_connection(new_tcp_conn); + debug_assert!(res.is_ok()); } - for dead_conn_key in dead_tcp_conns.into_iter() { sockets.remove_dead_tcp_connection(&dead_conn_key); } + // Notify all socket events. for socket in sockets.tcp_listener_iter() { if socket.has_events() { socket.on_events(); } } - for socket in sockets.tcp_conn_iter() { if socket.has_events() { socket.on_events(); } } - for socket in sockets.udp_socket_iter() { if socket.has_events() { socket.on_events(); @@ -234,10 +224,7 @@ impl IfaceCommon { // Note that only TCP connections can have timers set, so as far as the time to poll is // concerned, we only need to consider TCP connections. - sockets - .tcp_conn_iter() - .map(|socket| socket.next_poll_at_ms()) - .min() + interface.next_poll_at_ms() } } diff --git a/kernel/libs/aster-bigtcp/src/iface/mod.rs b/kernel/libs/aster-bigtcp/src/iface/mod.rs index e5f8c8ed8..9cff7ef7e 100644 --- a/kernel/libs/aster-bigtcp/src/iface/mod.rs +++ b/kernel/libs/aster-bigtcp/src/iface/mod.rs @@ -5,6 +5,7 @@ mod common; mod iface; mod phy; mod poll; +mod poll_iface; mod port; mod sched; mod time; @@ -12,5 +13,6 @@ mod time; pub use common::BoundPort; pub use iface::Iface; pub use phy::{EtherIface, IpIface}; +pub(crate) use poll_iface::{PollKey, PollableIfaceMut}; pub use port::BindPortConfig; pub use sched::ScheduleNextPoll; diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs index d5439a552..853a1f95b 100644 --- a/kernel/libs/aster-bigtcp/src/iface/poll.rs +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -15,6 +15,7 @@ use smoltcp::{ }, }; +use super::poll_iface::PollableIfaceMut; use crate::{ ext::Ext, socket::{TcpConnectionBg, TcpProcessResult}, @@ -22,7 +23,7 @@ use crate::{ }; pub(super) struct PollContext<'a, E: Ext> { - iface_cx: &'a mut Context, + iface: PollableIfaceMut<'a, E>, sockets: &'a SocketTable, new_tcp_conns: &'a mut Vec>>, dead_tcp_conns: &'a mut Vec, @@ -30,13 +31,13 @@ pub(super) struct PollContext<'a, E: Ext> { impl<'a, E: Ext> PollContext<'a, E> { pub(super) fn new( - iface_cx: &'a mut Context, + iface: PollableIfaceMut<'a, E>, sockets: &'a SocketTable, new_tcp_conns: &'a mut Vec>>, dead_tcp_conns: &'a mut Vec, ) -> Self { Self { - iface_cx, + iface, sockets, new_tcp_conns, dead_tcp_conns, @@ -65,9 +66,10 @@ impl PollContext<'_, E> { >, Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), { - while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) { + while let Some((rx_token, tx_token)) = device.receive(self.iface.context().now()) { rx_token.consume(|data| { - let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else { + let Some((pkt, tx_token)) = process_phy(data, self.iface.context_mut(), tx_token) + else { return; }; @@ -75,7 +77,7 @@ impl PollContext<'_, E> { return; }; - dispatch_phy(&reply, self.iface_cx, tx_token); + dispatch_phy(&reply, self.iface.context_mut(), tx_token); }); } } @@ -85,7 +87,7 @@ impl PollContext<'_, E> { pkt: Ipv4Packet<&'pkt [u8]>, ) -> Option> { // Parse the IP header. Ignore the packet if the header is ill-formed. - let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?; + let repr = Ipv4Repr::parse(&pkt, &self.iface.context().checksum_caps()).ok()?; if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) { return self.generate_icmp_unreachable( @@ -95,17 +97,14 @@ impl PollContext<'_, E> { ); } + let checksum_caps = self.iface.context().checksum_caps(); match repr.next_header { - IpProtocol::Tcp => self.parse_and_process_tcp( - &IpRepr::Ipv4(repr), - pkt.payload(), - &self.iface_cx.checksum_caps(), - ), - IpProtocol::Udp => self.parse_and_process_udp( - &IpRepr::Ipv4(repr), - pkt.payload(), - &self.iface_cx.checksum_caps(), - ), + IpProtocol::Tcp => { + self.parse_and_process_tcp(&IpRepr::Ipv4(repr), pkt.payload(), &checksum_caps) + } + IpProtocol::Udp => { + self.parse_and_process_udp(&IpRepr::Ipv4(repr), pkt.payload(), &checksum_caps) + } _ => None, } } @@ -164,7 +163,8 @@ impl PollContext<'_, E> { if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { let listener_key = ListenerKey::new(ip_repr.dst_addr(), tcp_repr.dst_port); if let Some(listener) = self.sockets.lookup_listener(&listener_key) { - let (processed, new_tcp_conn) = listener.process(self.iface_cx, ip_repr, tcp_repr); + let (processed, new_tcp_conn) = + listener.process(&mut self.iface, ip_repr, tcp_repr); if let Some(tcp_conn) = new_tcp_conn { self.new_tcp_conns.push(tcp_conn); @@ -197,7 +197,7 @@ impl PollContext<'_, E> { if let Some(connection) = connection { let (process_result, became_dead) = - connection.process(self.iface_cx, ip_repr, tcp_repr); + connection.process(&mut self.iface, ip_repr, tcp_repr); if *became_dead { self.dead_tcp_conns.push(*connection.connection_key()); } @@ -254,7 +254,7 @@ impl PollContext<'_, E> { continue; } - processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload); + processed |= socket.process(self.iface.context_mut(), ip_repr, udp_repr, udp_payload); if processed && ip_repr.dst_addr().is_unicast() { break; } @@ -295,7 +295,8 @@ impl PollContext<'_, E> { Some(Packet::new_ipv4( Ipv4Repr { src_addr: self - .iface_cx + .iface + .context() .ipv4_addr() .unwrap_or(Ipv4Address::UNSPECIFIED), dst_addr: ipv4_repr.src_addr, @@ -314,7 +315,8 @@ impl PollContext<'_, E> { fn is_unicast_local(&self, dst_addr: IpAddress) -> bool { match dst_addr { IpAddress::Ipv4(dst_addr) => self - .iface_cx + .iface + .context() .ipv4_addr() .is_some_and(|addr| addr == dst_addr), } @@ -327,7 +329,7 @@ impl PollContext<'_, E> { D: Device + ?Sized, Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), { - while let Some(tx_token) = device.transmit(self.iface_cx.now()) { + while let Some(tx_token) = device.transmit(self.iface.context().now()) { if !self.dispatch_ipv4(tx_token, dispatch_phy) { break; } @@ -359,12 +361,10 @@ impl PollContext<'_, E> { let mut did_something = false; let mut dead_conns = Vec::new(); - // We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable - // reference at this point. Instead, we will retry after the entire poll is complete. - for socket in self.sockets.tcp_conn_iter() { - if !socket.need_dispatch(self.iface_cx.now()) { - continue; - } + loop { + let Some(socket) = self.iface.pop_pending_tcp() else { + break; + }; // We set `did_something` even if no packets are actually generated. This is because a // timer can expire, but no packets are actually generated. @@ -373,14 +373,14 @@ impl PollContext<'_, E> { let mut deferred = None; let (reply, became_dead) = - TcpConnectionBg::dispatch(socket, self.iface_cx, |cx, ip_repr, tcp_repr| { + TcpConnectionBg::dispatch(&socket, &mut self.iface, |iface, ip_repr, tcp_repr| { let mut this = - PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns); + PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns); if !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( &Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)), - this.iface_cx, + this.iface.context_mut(), tx_token.take().unwrap(), ); return None; @@ -418,13 +418,13 @@ impl PollContext<'_, E> { &ip_payload, &ChecksumCapabilities::ignored(), ) { - dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap()); + dispatch_phy(&reply, self.iface.context_mut(), tx_token.take().unwrap()); } } (None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => { dispatch_phy( &Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)), - self.iface_cx, + self.iface.context_mut(), tx_token.take().unwrap(), ); } @@ -434,7 +434,7 @@ impl PollContext<'_, E> { { dispatch_phy( &Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)), - self.iface_cx, + self.iface.context_mut(), tx_token.take().unwrap(), ); } @@ -462,7 +462,7 @@ impl PollContext<'_, E> { let mut dead_conns = Vec::new(); for socket in self.sockets.udp_socket_iter() { - if !socket.need_dispatch(self.iface_cx.now()) { + if !socket.need_dispatch() { continue; } @@ -472,14 +472,16 @@ impl PollContext<'_, E> { let mut deferred = None; - socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| { + let (cx, pending) = self.iface.inner_mut(); + socket.dispatch(cx, |cx, ip_repr, udp_repr, udp_payload| { + let iface = PollableIfaceMut::new(cx, pending); let mut this = - PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns); + PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns); if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( &Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)), - this.iface_cx, + this.iface.context_mut(), tx_token.take().unwrap(), ); if !ip_repr.dst_addr().is_broadcast() { @@ -516,7 +518,7 @@ impl PollContext<'_, E> { &ip_payload, &ChecksumCapabilities::ignored(), ) { - dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap()); + dispatch_phy(&reply, self.iface.context_mut(), tx_token.take().unwrap()); } } diff --git a/kernel/libs/aster-bigtcp/src/iface/poll_iface.rs b/kernel/libs/aster-bigtcp/src/iface/poll_iface.rs new file mode 100644 index 000000000..b762a6704 --- /dev/null +++ b/kernel/libs/aster-bigtcp/src/iface/poll_iface.rs @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::{collections::btree_set::BTreeSet, sync::Arc}; +use core::{ + borrow::Borrow, + sync::atomic::{AtomicU64, Ordering}, +}; + +use crate::{ + ext::Ext, + socket::{NeedIfacePoll, TcpConnectionBg}, +}; + +/// An interface with auxiliary data that makes it pollable. +/// +/// This is used, for example, when updating a socket's next poll time and finding a socket to +/// poll. +pub(crate) struct PollableIface { + interface: smoltcp::iface::Interface, + pending_conns: PendingConnSet, +} + +impl PollableIface { + pub(super) fn new(interface: smoltcp::iface::Interface) -> Self { + Self { + interface, + pending_conns: PendingConnSet::new(), + } + } + + pub(super) fn as_mut(&mut self) -> PollableIfaceMut { + PollableIfaceMut { + context: self.interface.context(), + pending_conns: &mut self.pending_conns, + } + } + + pub(super) fn ipv4_addr(&self) -> Option { + self.interface.ipv4_addr() + } + + /// Returns the next poll time. + pub(super) fn next_poll_at_ms(&self) -> Option { + self.pending_conns.next_poll_at_ms() + } +} + +impl PollableIface { + /// Returns the `smoltcp` context for passing to the `smoltcp` APIs. + pub(crate) fn context_mut(&mut self) -> &mut smoltcp::iface::Context { + self.interface.context() + } + + /// Updates the next poll time of `socket` to `poll_at`. + /// + /// This method (or [`PollableIfaceMut::update_next_poll_at_ms`]) should be called after network or + /// user events that change the poll time occur. + pub(crate) fn update_next_poll_at_ms( + &mut self, + socket: &Arc>, + poll_at: smoltcp::socket::PollAt, + ) -> NeedIfacePoll { + self.pending_conns.update_next_poll_at_ms(socket, poll_at) + } +} + +/// A mutable reference to a [`PollableIface`]. +/// +/// This type is reconstructed from mutable references to fields in [`PollableIface`], since the fields +/// must be broken into individual fields during interface polling due to limitations of the +/// [`smoltcp`] APIs. +pub(crate) struct PollableIfaceMut<'a, E: Ext> { + context: &'a mut smoltcp::iface::Context, + pending_conns: &'a mut PendingConnSet, +} + +// FIXME: We provide `new()` and `inner_mut()` as `pub(crate)` methods because it's necessary to +// allow the Rust compiler to check the lifetime for separate fields. We should find better ways to +// avoid these `pub(crate)` methods in the future. +impl<'a, E: Ext> PollableIfaceMut<'a, E> { + pub(crate) fn new( + context: &'a mut smoltcp::iface::Context, + pending_conns: &'a mut PendingConnSet, + ) -> Self { + Self { + context, + pending_conns, + } + } + + pub(crate) fn inner_mut(&mut self) -> (&mut smoltcp::iface::Context, &mut PendingConnSet) { + (self.context, self.pending_conns) + } +} + +impl PollableIfaceMut<'_, E> { + pub(super) fn pop_pending_tcp(&mut self) -> Option>> { + let now = self.context.now.total_millis() as u64; + self.pending_conns.pop_tcp_before_now(now) + } +} + +impl PollableIfaceMut<'_, E> { + /// Returns an immutable reference to the `smoltcp` context. + pub(crate) fn context(&self) -> &smoltcp::iface::Context { + self.context + } + + /// Returns the `smoltcp` context for passing to the `smoltcp` APIs. + pub(crate) fn context_mut(&mut self) -> &mut smoltcp::iface::Context { + self.context + } + + /// Updates the next poll time of `socket` to `poll_at`. + /// + /// This method (or [`PollableIface::update_next_poll_at_ms`]) should be called after network + /// or user events that change the poll time occur. + pub(crate) fn update_next_poll_at_ms( + &mut self, + socket: &Arc>, + poll_at: smoltcp::socket::PollAt, + ) -> NeedIfacePoll { + self.pending_conns.update_next_poll_at_ms(socket, poll_at) + } +} + +/// A key to sort sockets by their next poll time. +pub(crate) struct PollKey { + next_poll_at_ms: AtomicU64, + id: usize, +} + +impl PartialEq for PollKey { + fn eq(&self, other: &Self) -> bool { + self.next_poll_at_ms.load(Ordering::Relaxed) + == other.next_poll_at_ms.load(Ordering::Relaxed) + && self.id == other.id + } +} +impl Eq for PollKey {} +impl PartialOrd for PollKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for PollKey { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.next_poll_at_ms + .load(Ordering::Relaxed) + .cmp(&other.next_poll_at_ms.load(Ordering::Relaxed)) + .then_with(|| self.id.cmp(&other.id)) + } +} + +impl PollKey { + /// A value indicating that an immediate poll is required. + const IMMEDIATE_VAL: u64 = 0; + /// A value indicating that no poll is required. + const INACTIVE_VAL: u64 = u64::MAX; + + /// Creates a new [`PollKey`]. + /// + /// `id` must be a unique identifier for the associated socket, as it will be used to locate + /// the socket to update its next poll time. This is usually done using the address of the + /// [`Arc`] socket (see [`Arc::as_ptr`]). + /// + /// [`Arc`]: alloc::sync::Arc + /// [`Arc::as_ptr`]: alloc::sync::Arc::as_ptr + pub(crate) fn new(id: usize) -> Self { + Self { + next_poll_at_ms: AtomicU64::new(Self::INACTIVE_VAL), + id, + } + } +} + +/// Sockets to poll in the future, sorted by poll time. +pub(crate) struct PendingConnSet(BTreeSet>); + +/// A TCP socket to poll in the future. +/// +/// Note that currently only TCP sockets can set a timer to fire in the future, so a +/// [`PendingConnSet`] contains only [`PendingTcpConn`]s. +struct PendingTcpConn(Arc>); + +impl PartialEq for PendingTcpConn { + fn eq(&self, other: &Self) -> bool { + self.0.poll_key() == other.0.poll_key() + } +} +impl Eq for PendingTcpConn {} +impl PartialOrd for PendingTcpConn { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for PendingTcpConn { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.0.poll_key().cmp(other.0.poll_key()) + } +} + +impl Borrow for PendingTcpConn { + fn borrow(&self) -> &PollKey { + self.0.poll_key() + } +} + +impl PendingConnSet { + fn new() -> Self { + Self(BTreeSet::new()) + } + + fn update_next_poll_at_ms( + &mut self, + socket: &Arc>, + poll_at: smoltcp::socket::PollAt, + ) -> NeedIfacePoll { + let key = socket.poll_key(); + let old_poll_at_ms = key.next_poll_at_ms.load(Ordering::Relaxed); + + let new_poll_at_ms = match poll_at { + smoltcp::socket::PollAt::Now => PollKey::IMMEDIATE_VAL, + smoltcp::socket::PollAt::Time(instant) => instant.total_millis() as u64, + smoltcp::socket::PollAt::Ingress => PollKey::INACTIVE_VAL, + }; + + // Fast path: There is nothing to update. + if old_poll_at_ms == new_poll_at_ms { + return NeedIfacePoll::FALSE; + } + + // Remove the socket from the pending queue if it is in the queue. + let owned_socket = if old_poll_at_ms != PollKey::INACTIVE_VAL { + self.0.take(key).unwrap() + } else { + PendingTcpConn(socket.clone()) + }; + + // Update the poll time _after_ it is removed from the queue. + key.next_poll_at_ms.store(new_poll_at_ms, Ordering::Relaxed); + + // If no new poll is required, do not add the socket to the pending queue. + if new_poll_at_ms == PollKey::INACTIVE_VAL { + return NeedIfacePoll::FALSE; + } + + // Add the socket back to the queue. + let inserted = self.0.insert(owned_socket); + debug_assert!(inserted); + + if new_poll_at_ms < old_poll_at_ms { + NeedIfacePoll::TRUE + } else { + NeedIfacePoll::FALSE + } + } + + fn pop_tcp_before_now(&mut self, now_at_ms: u64) -> Option>> { + if self.0.first().is_some_and(|first| { + first.0.poll_key().next_poll_at_ms.load(Ordering::Relaxed) <= now_at_ms + }) { + self.0.pop_first().map(|first| { + // Reset `next_poll_at_ms` since the socket is no longer in the queue. + first + .0 + .poll_key() + .next_poll_at_ms + .store(PollKey::INACTIVE_VAL, Ordering::Relaxed); + first.0 + }) + } else { + None + } + } + + fn next_poll_at_ms(&self) -> Option { + self.0 + .first() + .map(|first| first.0.poll_key().next_poll_at_ms.load(Ordering::Relaxed)) + } +} diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/common.rs b/kernel/libs/aster-bigtcp/src/socket/bound/common.rs index 42dc35d32..3965ed53a 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/common.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/common.rs @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::Arc; -use core::sync::atomic::{AtomicU64, AtomicU8, Ordering}; +use alloc::sync::{Arc, Weak}; +use core::sync::atomic::{AtomicU8, Ordering}; -use smoltcp::{socket::PollAt, time::Instant, wire::IpEndpoint}; +use smoltcp::wire::IpEndpoint; use spin::once::Once; use takeable::Takeable; @@ -46,7 +46,6 @@ pub struct SocketBg, E: Ext> { pub(super) inner: T, observer: Once, events: AtomicU8, - next_poll_at_ms: AtomicU64, } impl, E: Ext> Drop for Socket { @@ -58,13 +57,24 @@ impl, E: Ext> Drop for Socket { } impl, E: Ext> Socket { - pub(crate) fn new(bound: BoundPort, inner: T) -> Self { + pub(super) fn new(bound: BoundPort, inner: T) -> Self { Self(Takeable::new(Arc::new(SocketBg { bound, inner, observer: Once::new(), events: AtomicU8::new(0), - next_poll_at_ms: AtomicU64::new(u64::MAX), + }))) + } + + pub(super) fn new_cyclic(bound: BoundPort, inner_fn: F) -> Self + where + F: FnOnce(&Weak>) -> T, + { + Self(Takeable::new(Arc::new_cyclic(|weak| SocketBg { + bound, + inner: inner_fn(weak), + observer: Once::new(), + events: AtomicU8::new(0), }))) } @@ -119,10 +129,8 @@ impl, E: Ext> SocketBg { where T::Observer: Clone, { - // This method can only be called to process network events, so we assume we are holding the - // poll lock and no race conditions can occur. + // There is no need to clear the events because the socket is dead. let events = self.events.load(Ordering::Relaxed); - self.events.store(0, Ordering::Relaxed); let observer = self.observer.get().cloned(); drop(self); @@ -141,41 +149,6 @@ impl, E: Ext> SocketBg { self.events .store(events | new_events.bits(), Ordering::Relaxed); } - - /// Returns the next polling time. - /// - /// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required - /// before new network or user events. - pub(crate) fn next_poll_at_ms(&self) -> u64 { - self.next_poll_at_ms.load(Ordering::Relaxed) - } - - /// Updates the next polling time according to `poll_at`. - /// - /// The update is typically needed after new network or user events have been handled, so this - /// method also marks that there may be new events, so that the event observer provided by - /// [`Socket::init_observer`] can be notified later. - pub(super) fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll { - match poll_at { - PollAt::Now => { - self.next_poll_at_ms.store(0, Ordering::Relaxed); - NeedIfacePoll::TRUE - } - PollAt::Time(instant) => { - let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed); - let new_total_millis = instant.total_millis() as u64; - - self.next_poll_at_ms - .store(new_total_millis, Ordering::Relaxed); - - NeedIfacePoll(new_total_millis < old_total_millis) - } - PollAt::Ingress => { - self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed); - NeedIfacePoll::FALSE - } - } - } } impl, E: Ext> SocketBg { @@ -185,11 +158,4 @@ impl, E: Ext> SocketBg { pub(crate) fn can_process(&self, dst_port: u16) -> bool { self.bound.port() == dst_port } - - /// Returns whether the socket _may_ generate an outgoing packet. - /// - /// The check is intended to be lock-free and fast, but may have false positives. - pub(crate) fn need_dispatch(&self, now: Instant) -> bool { - now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed) - } } diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs index aedf27006..00e155002 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs @@ -1,11 +1,13 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, sync::Arc}; +use alloc::{ + boxed::Box, + sync::{Arc, Weak}, +}; use core::ops::{Deref, DerefMut}; use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard}; use smoltcp::{ - iface::Context, socket::{tcp::State, PollAt}, time::Duration, wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr}, @@ -19,7 +21,7 @@ use crate::{ define_boolean_value, errors::tcp::{ConnectError, RecvError, SendError}, ext::Ext, - iface::BoundPort, + iface::{BoundPort, PollKey, PollableIfaceMut}, socket::{ event::SocketEvents, option::{RawTcpOption, RawTcpSetOption}, @@ -33,6 +35,7 @@ pub type TcpConnection = Socket, E>; /// States needed by [`TcpConnectionBg`]. pub struct TcpConnectionInner { socket: SpinLock, LocalIrqDisabled>, + poll_key: PollKey, connection_key: ConnectionKey, } @@ -216,7 +219,11 @@ impl RawTcpSocketExt { } impl TcpConnectionInner { - pub(super) fn new(socket: Box, listener: Option>>) -> Self { + pub(super) fn new( + socket: Box, + listener: Option>>, + weak_self: &Weak>, + ) -> Self { let connection_key = { // Since the socket is connected, the following unwrap can never fail let local_endpoint = socket.local_endpoint().unwrap(); @@ -224,6 +231,8 @@ impl TcpConnectionInner { ConnectionKey::from((local_endpoint, remote_endpoint)) }; + let poll_key = PollKey::new(Weak::as_ptr(weak_self).addr()); + let socket_ext = RawTcpSocketExt { socket, listener, @@ -234,6 +243,7 @@ impl TcpConnectionInner { TcpConnectionInner { socket: SpinLock::new(socket_ext), + poll_key, connection_key, } } @@ -293,7 +303,7 @@ impl TcpConnection { }; let iface = bound.iface().clone(); - // We have to lock interface before locking interface + // We have to lock `interface` before locking `sockets` // to avoid dead lock due to inconsistent lock orders. let mut interface = iface.common().interface(); let mut sockets = iface.common().sockets(); @@ -309,18 +319,19 @@ impl TcpConnection { option.apply(&mut socket); - if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) { + if let Err(err) = socket.connect(interface.context_mut(), remote_endpoint, bound.port()) + { return Err((bound, err.into())); } socket }; - let inner = TcpConnectionInner::new(socket, None); - - let connection = Self::new(bound, inner); - connection.0.update_next_poll_at_ms(PollAt::Now); + let connection = + Self::new_cyclic(bound, |weak| TcpConnectionInner::new(socket, None, weak)); + interface.update_next_poll_at_ms(&connection.0, PollAt::Now); connection.init_observer(observer); + let res = sockets.insert_connection(connection.inner().clone()); debug_assert!(res.is_ok()); @@ -378,9 +389,8 @@ impl TcpConnection { } let result = socket.send(f)?; - let need_poll = self - .0 - .update_next_poll_at_ms(socket.poll_at(iface.context())); + let poll_at = socket.poll_at(iface.context_mut()); + let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at); Ok((result, need_poll)) } @@ -408,9 +418,8 @@ impl TcpConnection { res => res, }?; - let need_poll = self - .0 - .update_next_poll_at_ms(socket.poll_at(iface.context())); + let poll_at = socket.poll_at(iface.context_mut()); + let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at); Ok((result, need_poll)) } @@ -433,6 +442,7 @@ impl TcpConnection { /// /// Polling the iface is _always_ required after this method succeeds. pub fn shut_send(&self) -> bool { + let mut iface = self.iface().common().interface(); let mut socket = self.0.inner.lock(); if matches!(socket.state(), State::Closed | State::TimeWait) { @@ -440,7 +450,9 @@ impl TcpConnection { } socket.close(); - self.0.update_next_poll_at_ms(PollAt::Now); + + let poll_at = socket.poll_at(iface.context_mut()); + iface.update_next_poll_at_ms(&self.0, poll_at); true } @@ -469,6 +481,7 @@ impl TcpConnection { /// Note that either this method or [`Self::reset`] must be called before dropping the TCP /// connection to avoid resource leakage. pub fn close(&self) { + let mut iface = self.iface().common().interface(); let mut socket = self.0.inner.lock(); socket.is_recv_shut = true; @@ -479,7 +492,9 @@ impl TcpConnection { } else { socket.close(); } - self.0.update_next_poll_at_ms(PollAt::Now); + + let poll_at = socket.poll_at(iface.context_mut()); + iface.update_next_poll_at_ms(&self.0, poll_at); } /// Resets the connection. @@ -489,10 +504,13 @@ impl TcpConnection { /// Note that either this method or [`Self::close`] must be called before dropping the TCP /// connection to avoid resource leakage. pub fn reset(&self) { + let mut iface = self.iface().common().interface(); let mut socket = self.0.inner.lock(); socket.abort(); - self.0.update_next_poll_at_ms(PollAt::Now); + + let poll_at = socket.poll_at(iface.context_mut()); + iface.update_next_poll_at_ms(&self.0, poll_at); } /// Calls `f` with an immutable reference to the associated [`RawTcpSocket`]. @@ -510,15 +528,13 @@ impl TcpConnection { impl RawTcpSetOption for TcpConnection { fn set_keep_alive(&self, interval: Option) -> NeedIfacePoll { + let mut iface = self.iface().common().interface(); let mut socket = self.0.inner.lock(); + socket.set_keep_alive(interval); - if interval.is_some() { - self.0.update_next_poll_at_ms(PollAt::Now); - NeedIfacePoll::TRUE - } else { - NeedIfacePoll::FALSE - } + let poll_at = socket.poll_at(iface.context_mut()); + iface.update_next_poll_at_ms(&self.0, poll_at) } fn set_nagle_enabled(&self, enabled: bool) { @@ -528,6 +544,10 @@ impl RawTcpSetOption for TcpConnection { } impl TcpConnectionBg { + pub(crate) const fn poll_key(&self) -> &PollKey { + &self.inner.poll_key + } + pub(crate) const fn connection_key(&self) -> &ConnectionKey { &self.inner.connection_key } @@ -544,13 +564,13 @@ impl TcpConnectionBg { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( self: &Arc, - cx: &mut Context, + iface: &mut PollableIfaceMut, ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> (TcpProcessResult, TcpConnBecameDead) { let mut socket = self.inner.lock(); - if !socket.accepts(cx, ip_repr, tcp_repr) { + if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) { return (TcpProcessResult::NotProcessed, TcpConnBecameDead::FALSE); } @@ -581,7 +601,7 @@ impl TcpConnectionBg { // to be queued. let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; - let result = match socket.process(cx, ip_repr, tcp_repr) { + let result = match socket.process(iface.context_mut(), ip_repr, tcp_repr) { None => TcpProcessResult::Processed, Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; @@ -591,7 +611,9 @@ impl TcpConnectionBg { events |= state_events; self.add_events(events); - self.update_next_poll_at_ms(socket.poll_at(cx)); + + let poll_at = socket.poll_at(iface.context_mut()); + iface.update_next_poll_at_ms(self, poll_at); (result, became_dead) } @@ -599,11 +621,11 @@ impl TcpConnectionBg { /// Tries to generate an outgoing packet and dispatches the generated packet. pub(crate) fn dispatch( self: &Arc, - cx: &mut Context, + iface: &mut PollableIfaceMut, dispatch: D, ) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead) where - D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, + D: FnOnce(PollableIfaceMut, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, { let mut socket = self.inner.lock(); @@ -613,9 +635,10 @@ impl TcpConnectionBg { let mut events = SocketEvents::empty(); let mut reply = None; + let (cx, pending) = iface.inner_mut(); socket .dispatch(cx, |cx, (ip_repr, tcp_repr)| { - reply = dispatch(cx, &ip_repr, &tcp_repr); + reply = dispatch(PollableIfaceMut::new(cx, pending), &ip_repr, &tcp_repr); Ok::<(), ()>(()) }) .unwrap(); @@ -623,12 +646,12 @@ impl TcpConnectionBg { // `dispatch` can return a packet in response to the generated packet. If the socket // accepts the packet, we can process it directly. while let Some((ref ip_repr, ref tcp_repr)) = reply { - if !socket.accepts(cx, ip_repr, tcp_repr) { + if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) { break; } is_rst |= tcp_repr.control == TcpControl::Rst; events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; - reply = socket.process(cx, ip_repr, tcp_repr); + reply = socket.process(iface.context_mut(), ip_repr, tcp_repr); } let (state_events, became_dead) = @@ -636,7 +659,9 @@ impl TcpConnectionBg { events |= state_events; self.add_events(events); - self.update_next_poll_at_ms(socket.poll_at(cx)); + + let poll_at = socket.poll_at(iface.context_mut()); + iface.update_next_poll_at_ms(self, poll_at); (reply, became_dead) } diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs index e766d24ec..0d3a51867 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs @@ -4,7 +4,6 @@ use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; use ostd::sync::{LocalIrqDisabled, SpinLock}; use smoltcp::{ - iface::Context, socket::PollAt, time::Duration, wire::{IpEndpoint, IpRepr, TcpRepr}, @@ -17,7 +16,7 @@ use super::{ use crate::{ errors::tcp::ListenError, ext::Ext, - iface::{BindPortConfig, BoundPort}, + iface::{BindPortConfig, BoundPort, PollableIfaceMut}, socket::{ option::{RawTcpOption, RawTcpSetOption}, unbound::{new_tcp_socket, RawTcpSocket}, @@ -194,13 +193,16 @@ impl TcpListenerBg { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( self: &Arc, - cx: &mut Context, + iface: &mut PollableIfaceMut, ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> (TcpProcessResult, Option>>) { let mut backlog = self.inner.backlog.lock(); - if !backlog.socket.accepts(cx, ip_repr, tcp_repr) { + if !backlog + .socket + .accepts(iface.context_mut(), ip_repr, tcp_repr) + { return (TcpProcessResult::NotProcessed, None); } @@ -211,7 +213,10 @@ impl TcpListenerBg { return (TcpProcessResult::Processed, None); } - let result = match backlog.socket.process(cx, ip_repr, tcp_repr) { + let result = match backlog + .socket + .process(iface.context_mut(), ip_repr, tcp_repr) + { None => TcpProcessResult::Processed, Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; @@ -227,23 +232,25 @@ impl TcpListenerBg { socket }; - let inner = TcpConnectionInner::new( - core::mem::replace(&mut backlog.socket, new_socket), - Some(self.clone()), - ); - let conn = TcpConnection::new( + let conn = TcpConnection::new_cyclic( self.bound .iface() .bind(BindPortConfig::CanReuse(self.bound.port())) .unwrap(), - inner, + |weak| { + TcpConnectionInner::new( + core::mem::replace(&mut backlog.socket, new_socket), + Some(self.clone()), + weak, + ) + }, ); let conn_bg = conn.inner().clone(); let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn); debug_assert!(old_conn.is_none()); - conn_bg.update_next_poll_at_ms(PollAt::Now); + iface.update_next_poll_at_ms(&conn_bg, PollAt::Now); (result, Some(conn_bg)) } diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/udp.rs b/kernel/libs/aster-bigtcp/src/socket/bound/udp.rs index eda71a04e..61b67d8c8 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/udp.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/udp.rs @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::{boxed::Box, sync::Arc}; +use core::sync::atomic::{AtomicBool, Ordering}; use ostd::sync::{LocalIrqDisabled, SpinLock}; use smoltcp::{ iface::Context, - socket::{udp::UdpMetadata, PollAt}, + socket::udp::UdpMetadata, wire::{IpRepr, UdpRepr}, }; @@ -20,13 +21,16 @@ use crate::{ pub type UdpSocket = Socket; /// States needed by [`UdpSocketBg`]. -type UdpSocketInner = SpinLock, LocalIrqDisabled>; +pub struct UdpSocketInner { + socket: SpinLock, LocalIrqDisabled>, + need_dispatch: AtomicBool, +} impl Inner for UdpSocketInner { type Observer = E::UdpEventObserver; fn on_drop(this: &Arc>) { - this.inner.lock().close(); + this.inner.socket.lock().close(); // A UDP socket can be removed immediately. this.bound.iface().common().remove_udp_socket(this); @@ -44,7 +48,7 @@ impl UdpSocketBg { udp_repr: &UdpRepr, udp_payload: &[u8], ) -> bool { - let mut socket = self.inner.lock(); + let mut socket = self.inner.socket.lock(); if !socket.accepts(cx, ip_repr, udp_repr) { return false; @@ -59,7 +63,6 @@ impl UdpSocketBg { ); self.add_events(SocketEvents::CAN_RECV); - self.update_next_poll_at_ms(socket.poll_at(cx)); true } @@ -69,7 +72,7 @@ impl UdpSocketBg { where D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]), { - let mut socket = self.inner.lock(); + let mut socket = self.inner.socket.lock(); socket .dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| { @@ -80,7 +83,17 @@ impl UdpSocketBg { // For UDP, dequeuing a packet means that we can queue more packets. self.add_events(SocketEvents::CAN_SEND); - self.update_next_poll_at_ms(socket.poll_at(cx)); + + self.inner + .need_dispatch + .store(socket.send_queue() > 0, Ordering::Relaxed); + } + + /// Returns whether the socket _may_ generate an outgoing packet. + /// + /// The check is intended to be lock-free and fast, but may have false positives. + pub(crate) fn need_dispatch(&self) -> bool { + self.inner.need_dispatch.load(Ordering::Relaxed) } } @@ -106,7 +119,10 @@ impl UdpSocket { socket }; - let inner = UdpSocketInner::new(socket); + let inner = UdpSocketInner { + socket: SpinLock::new(socket), + need_dispatch: AtomicBool::new(false), + }; let socket = Self::new(bound, inner); socket.init_observer(observer); @@ -130,7 +146,7 @@ impl UdpSocket { where F: FnOnce(&mut [u8]) -> R, { - let mut socket = self.0.inner.lock(); + let mut socket = self.0.inner.socket.lock(); if size > socket.packet_send_capacity() { return Err(SendError::TooLarge); @@ -141,7 +157,11 @@ impl UdpSocket { Err(err) => return Err(err.into()), }; let result = f(buffer); - self.0.update_next_poll_at_ms(PollAt::Now); + + self.0 + .inner + .need_dispatch + .store(socket.send_queue() > 0, Ordering::Relaxed); Ok(result) } @@ -153,7 +173,7 @@ impl UdpSocket { where F: FnOnce(&[u8], UdpMetadata) -> R, { - let mut socket = self.0.inner.lock(); + let mut socket = self.0.inner.socket.lock(); let (data, meta) = socket.recv()?; let result = f(data, meta); @@ -169,7 +189,7 @@ impl UdpSocket { where F: FnOnce(&RawUdpSocket) -> R, { - let socket = self.0.inner.lock(); + let socket = self.0.inner.socket.lock(); f(&socket) } }