Avoid O(n) iteration when sending TCP packets

This commit is contained in:
Ruihan Li 2025-03-18 09:34:15 +08:00 committed by Tate, Hongliang Tian
parent 58ad43b0a9
commit a7e718e812
8 changed files with 476 additions and 185 deletions

View File

@ -16,6 +16,7 @@ use smoltcp::{
use super::{
poll::{FnHelper, PollContext},
poll_iface::PollableIface,
port::BindPortConfig,
time::get_network_timestamp,
Iface,
@ -29,7 +30,7 @@ use crate::{
pub struct IfaceCommon<E: Ext> {
name: String,
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
interface: SpinLock<PollableIface<E>, LocalIrqDisabled>,
used_ports: SpinLock<BTreeMap<u16, usize>, LocalIrqDisabled>,
sockets: SpinLock<SocketTable<E>, LocalIrqDisabled>,
sched_poll: E::ScheduleNextPoll,
@ -41,13 +42,11 @@ impl<E: Ext> IfaceCommon<E> {
interface: smoltcp::iface::Interface,
sched_poll: E::ScheduleNextPoll,
) -> Self {
let sockets = SocketTable::new();
Self {
name,
interface: SpinLock::new(interface),
interface: SpinLock::new(PollableIface::new(interface)),
used_ports: SpinLock::new(BTreeMap::new()),
sockets: SpinLock::new(sockets),
sockets: SpinLock::new(SocketTable::new()),
sched_poll,
}
}
@ -65,10 +64,10 @@ impl<E: Ext> IfaceCommon<E> {
}
}
// Lock order: interface -> sockets
// Lock order: `interface` -> `sockets`
impl<E: Ext> IfaceCommon<E> {
/// Acquires the lock to the interface.
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
pub(crate) fn interface(&self) -> SpinLockGuard<'_, PollableIface<E>, LocalIrqDisabled> {
self.interface.lock()
}
@ -181,51 +180,42 @@ impl<E: Ext> IfaceCommon<E> {
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{
let mut interface = self.interface();
interface.context().now = get_network_timestamp();
interface.context_mut().now = get_network_timestamp();
let mut sockets = self.sockets.lock();
let mut dead_tcp_conns = Vec::new();
loop {
let mut new_tcp_conns = Vec::new();
let mut new_tcp_conns = Vec::new();
let mut context = PollContext::new(
interface.context(),
&sockets,
&mut new_tcp_conns,
&mut dead_tcp_conns,
);
context.poll_ingress(device, &mut process_phy, &mut dispatch_phy);
context.poll_egress(device, &mut dispatch_phy);
let mut context = PollContext::new(
interface.as_mut(),
&sockets,
&mut new_tcp_conns,
&mut dead_tcp_conns,
);
context.poll_ingress(device, &mut process_phy, &mut dispatch_phy);
context.poll_egress(device, &mut dispatch_phy);
// New packets sent by new connections are not handled. So if there are new
// connections, try again.
if new_tcp_conns.is_empty() {
break;
} else {
new_tcp_conns.into_iter().for_each(|tcp_conn| {
let res = sockets.insert_connection(tcp_conn);
debug_assert!(res.is_ok());
});
}
// Insert new connections and remove dead connections.
for new_tcp_conn in new_tcp_conns.into_iter() {
let res = sockets.insert_connection(new_tcp_conn);
debug_assert!(res.is_ok());
}
for dead_conn_key in dead_tcp_conns.into_iter() {
sockets.remove_dead_tcp_connection(&dead_conn_key);
}
// Notify all socket events.
for socket in sockets.tcp_listener_iter() {
if socket.has_events() {
socket.on_events();
}
}
for socket in sockets.tcp_conn_iter() {
if socket.has_events() {
socket.on_events();
}
}
for socket in sockets.udp_socket_iter() {
if socket.has_events() {
socket.on_events();
@ -234,10 +224,7 @@ impl<E: Ext> IfaceCommon<E> {
// Note that only TCP connections can have timers set, so as far as the time to poll is
// concerned, we only need to consider TCP connections.
sockets
.tcp_conn_iter()
.map(|socket| socket.next_poll_at_ms())
.min()
interface.next_poll_at_ms()
}
}

View File

@ -5,6 +5,7 @@ mod common;
mod iface;
mod phy;
mod poll;
mod poll_iface;
mod port;
mod sched;
mod time;
@ -12,5 +13,6 @@ mod time;
pub use common::BoundPort;
pub use iface::Iface;
pub use phy::{EtherIface, IpIface};
pub(crate) use poll_iface::{PollKey, PollableIfaceMut};
pub use port::BindPortConfig;
pub use sched::ScheduleNextPoll;

View File

@ -15,6 +15,7 @@ use smoltcp::{
},
};
use super::poll_iface::PollableIfaceMut;
use crate::{
ext::Ext,
socket::{TcpConnectionBg, TcpProcessResult},
@ -22,7 +23,7 @@ use crate::{
};
pub(super) struct PollContext<'a, E: Ext> {
iface_cx: &'a mut Context,
iface: PollableIfaceMut<'a, E>,
sockets: &'a SocketTable<E>,
new_tcp_conns: &'a mut Vec<Arc<TcpConnectionBg<E>>>,
dead_tcp_conns: &'a mut Vec<ConnectionKey>,
@ -30,13 +31,13 @@ pub(super) struct PollContext<'a, E: Ext> {
impl<'a, E: Ext> PollContext<'a, E> {
pub(super) fn new(
iface_cx: &'a mut Context,
iface: PollableIfaceMut<'a, E>,
sockets: &'a SocketTable<E>,
new_tcp_conns: &'a mut Vec<Arc<TcpConnectionBg<E>>>,
dead_tcp_conns: &'a mut Vec<ConnectionKey>,
) -> Self {
Self {
iface_cx,
iface,
sockets,
new_tcp_conns,
dead_tcp_conns,
@ -65,9 +66,10 @@ impl<E: Ext> PollContext<'_, E> {
>,
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{
while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) {
while let Some((rx_token, tx_token)) = device.receive(self.iface.context().now()) {
rx_token.consume(|data| {
let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else {
let Some((pkt, tx_token)) = process_phy(data, self.iface.context_mut(), tx_token)
else {
return;
};
@ -75,7 +77,7 @@ impl<E: Ext> PollContext<'_, E> {
return;
};
dispatch_phy(&reply, self.iface_cx, tx_token);
dispatch_phy(&reply, self.iface.context_mut(), tx_token);
});
}
}
@ -85,7 +87,7 @@ impl<E: Ext> PollContext<'_, E> {
pkt: Ipv4Packet<&'pkt [u8]>,
) -> Option<Packet<'pkt>> {
// Parse the IP header. Ignore the packet if the header is ill-formed.
let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?;
let repr = Ipv4Repr::parse(&pkt, &self.iface.context().checksum_caps()).ok()?;
if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) {
return self.generate_icmp_unreachable(
@ -95,17 +97,14 @@ impl<E: Ext> PollContext<'_, E> {
);
}
let checksum_caps = self.iface.context().checksum_caps();
match repr.next_header {
IpProtocol::Tcp => self.parse_and_process_tcp(
&IpRepr::Ipv4(repr),
pkt.payload(),
&self.iface_cx.checksum_caps(),
),
IpProtocol::Udp => self.parse_and_process_udp(
&IpRepr::Ipv4(repr),
pkt.payload(),
&self.iface_cx.checksum_caps(),
),
IpProtocol::Tcp => {
self.parse_and_process_tcp(&IpRepr::Ipv4(repr), pkt.payload(), &checksum_caps)
}
IpProtocol::Udp => {
self.parse_and_process_udp(&IpRepr::Ipv4(repr), pkt.payload(), &checksum_caps)
}
_ => None,
}
}
@ -164,7 +163,8 @@ impl<E: Ext> PollContext<'_, E> {
if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() {
let listener_key = ListenerKey::new(ip_repr.dst_addr(), tcp_repr.dst_port);
if let Some(listener) = self.sockets.lookup_listener(&listener_key) {
let (processed, new_tcp_conn) = listener.process(self.iface_cx, ip_repr, tcp_repr);
let (processed, new_tcp_conn) =
listener.process(&mut self.iface, ip_repr, tcp_repr);
if let Some(tcp_conn) = new_tcp_conn {
self.new_tcp_conns.push(tcp_conn);
@ -197,7 +197,7 @@ impl<E: Ext> PollContext<'_, E> {
if let Some(connection) = connection {
let (process_result, became_dead) =
connection.process(self.iface_cx, ip_repr, tcp_repr);
connection.process(&mut self.iface, ip_repr, tcp_repr);
if *became_dead {
self.dead_tcp_conns.push(*connection.connection_key());
}
@ -254,7 +254,7 @@ impl<E: Ext> PollContext<'_, E> {
continue;
}
processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload);
processed |= socket.process(self.iface.context_mut(), ip_repr, udp_repr, udp_payload);
if processed && ip_repr.dst_addr().is_unicast() {
break;
}
@ -295,7 +295,8 @@ impl<E: Ext> PollContext<'_, E> {
Some(Packet::new_ipv4(
Ipv4Repr {
src_addr: self
.iface_cx
.iface
.context()
.ipv4_addr()
.unwrap_or(Ipv4Address::UNSPECIFIED),
dst_addr: ipv4_repr.src_addr,
@ -314,7 +315,8 @@ impl<E: Ext> PollContext<'_, E> {
fn is_unicast_local(&self, dst_addr: IpAddress) -> bool {
match dst_addr {
IpAddress::Ipv4(dst_addr) => self
.iface_cx
.iface
.context()
.ipv4_addr()
.is_some_and(|addr| addr == dst_addr),
}
@ -327,7 +329,7 @@ impl<E: Ext> PollContext<'_, E> {
D: Device + ?Sized,
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{
while let Some(tx_token) = device.transmit(self.iface_cx.now()) {
while let Some(tx_token) = device.transmit(self.iface.context().now()) {
if !self.dispatch_ipv4(tx_token, dispatch_phy) {
break;
}
@ -359,12 +361,10 @@ impl<E: Ext> PollContext<'_, E> {
let mut did_something = false;
let mut dead_conns = Vec::new();
// We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable
// reference at this point. Instead, we will retry after the entire poll is complete.
for socket in self.sockets.tcp_conn_iter() {
if !socket.need_dispatch(self.iface_cx.now()) {
continue;
}
loop {
let Some(socket) = self.iface.pop_pending_tcp() else {
break;
};
// We set `did_something` even if no packets are actually generated. This is because a
// timer can expire, but no packets are actually generated.
@ -373,14 +373,14 @@ impl<E: Ext> PollContext<'_, E> {
let mut deferred = None;
let (reply, became_dead) =
TcpConnectionBg::dispatch(socket, self.iface_cx, |cx, ip_repr, tcp_repr| {
TcpConnectionBg::dispatch(&socket, &mut self.iface, |iface, ip_repr, tcp_repr| {
let mut this =
PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns);
PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns);
if !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy(
&Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)),
this.iface_cx,
this.iface.context_mut(),
tx_token.take().unwrap(),
);
return None;
@ -418,13 +418,13 @@ impl<E: Ext> PollContext<'_, E> {
&ip_payload,
&ChecksumCapabilities::ignored(),
) {
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
dispatch_phy(&reply, self.iface.context_mut(), tx_token.take().unwrap());
}
}
(None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => {
dispatch_phy(
&Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)),
self.iface_cx,
self.iface.context_mut(),
tx_token.take().unwrap(),
);
}
@ -434,7 +434,7 @@ impl<E: Ext> PollContext<'_, E> {
{
dispatch_phy(
&Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)),
self.iface_cx,
self.iface.context_mut(),
tx_token.take().unwrap(),
);
}
@ -462,7 +462,7 @@ impl<E: Ext> PollContext<'_, E> {
let mut dead_conns = Vec::new();
for socket in self.sockets.udp_socket_iter() {
if !socket.need_dispatch(self.iface_cx.now()) {
if !socket.need_dispatch() {
continue;
}
@ -472,14 +472,16 @@ impl<E: Ext> PollContext<'_, E> {
let mut deferred = None;
socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| {
let (cx, pending) = self.iface.inner_mut();
socket.dispatch(cx, |cx, ip_repr, udp_repr, udp_payload| {
let iface = PollableIfaceMut::new(cx, pending);
let mut this =
PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns);
PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns);
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy(
&Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)),
this.iface_cx,
this.iface.context_mut(),
tx_token.take().unwrap(),
);
if !ip_repr.dst_addr().is_broadcast() {
@ -516,7 +518,7 @@ impl<E: Ext> PollContext<'_, E> {
&ip_payload,
&ChecksumCapabilities::ignored(),
) {
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
dispatch_phy(&reply, self.iface.context_mut(), tx_token.take().unwrap());
}
}

View File

@ -0,0 +1,282 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{collections::btree_set::BTreeSet, sync::Arc};
use core::{
borrow::Borrow,
sync::atomic::{AtomicU64, Ordering},
};
use crate::{
ext::Ext,
socket::{NeedIfacePoll, TcpConnectionBg},
};
/// An interface with auxiliary data that makes it pollable.
///
/// This is used, for example, when updating a socket's next poll time and finding a socket to
/// poll.
pub(crate) struct PollableIface<E: Ext> {
interface: smoltcp::iface::Interface,
pending_conns: PendingConnSet<E>,
}
impl<E: Ext> PollableIface<E> {
pub(super) fn new(interface: smoltcp::iface::Interface) -> Self {
Self {
interface,
pending_conns: PendingConnSet::new(),
}
}
pub(super) fn as_mut(&mut self) -> PollableIfaceMut<E> {
PollableIfaceMut {
context: self.interface.context(),
pending_conns: &mut self.pending_conns,
}
}
pub(super) fn ipv4_addr(&self) -> Option<smoltcp::wire::Ipv4Address> {
self.interface.ipv4_addr()
}
/// Returns the next poll time.
pub(super) fn next_poll_at_ms(&self) -> Option<u64> {
self.pending_conns.next_poll_at_ms()
}
}
impl<E: Ext> PollableIface<E> {
/// Returns the `smoltcp` context for passing to the `smoltcp` APIs.
pub(crate) fn context_mut(&mut self) -> &mut smoltcp::iface::Context {
self.interface.context()
}
/// Updates the next poll time of `socket` to `poll_at`.
///
/// This method (or [`PollableIfaceMut::update_next_poll_at_ms`]) should be called after network or
/// user events that change the poll time occur.
pub(crate) fn update_next_poll_at_ms(
&mut self,
socket: &Arc<TcpConnectionBg<E>>,
poll_at: smoltcp::socket::PollAt,
) -> NeedIfacePoll {
self.pending_conns.update_next_poll_at_ms(socket, poll_at)
}
}
/// A mutable reference to a [`PollableIface`].
///
/// This type is reconstructed from mutable references to fields in [`PollableIface`], since the fields
/// must be broken into individual fields during interface polling due to limitations of the
/// [`smoltcp`] APIs.
pub(crate) struct PollableIfaceMut<'a, E: Ext> {
context: &'a mut smoltcp::iface::Context,
pending_conns: &'a mut PendingConnSet<E>,
}
// FIXME: We provide `new()` and `inner_mut()` as `pub(crate)` methods because it's necessary to
// allow the Rust compiler to check the lifetime for separate fields. We should find better ways to
// avoid these `pub(crate)` methods in the future.
impl<'a, E: Ext> PollableIfaceMut<'a, E> {
pub(crate) fn new(
context: &'a mut smoltcp::iface::Context,
pending_conns: &'a mut PendingConnSet<E>,
) -> Self {
Self {
context,
pending_conns,
}
}
pub(crate) fn inner_mut(&mut self) -> (&mut smoltcp::iface::Context, &mut PendingConnSet<E>) {
(self.context, self.pending_conns)
}
}
impl<E: Ext> PollableIfaceMut<'_, E> {
pub(super) fn pop_pending_tcp(&mut self) -> Option<Arc<TcpConnectionBg<E>>> {
let now = self.context.now.total_millis() as u64;
self.pending_conns.pop_tcp_before_now(now)
}
}
impl<E: Ext> PollableIfaceMut<'_, E> {
/// Returns an immutable reference to the `smoltcp` context.
pub(crate) fn context(&self) -> &smoltcp::iface::Context {
self.context
}
/// Returns the `smoltcp` context for passing to the `smoltcp` APIs.
pub(crate) fn context_mut(&mut self) -> &mut smoltcp::iface::Context {
self.context
}
/// Updates the next poll time of `socket` to `poll_at`.
///
/// This method (or [`PollableIface::update_next_poll_at_ms`]) should be called after network
/// or user events that change the poll time occur.
pub(crate) fn update_next_poll_at_ms(
&mut self,
socket: &Arc<TcpConnectionBg<E>>,
poll_at: smoltcp::socket::PollAt,
) -> NeedIfacePoll {
self.pending_conns.update_next_poll_at_ms(socket, poll_at)
}
}
/// A key to sort sockets by their next poll time.
pub(crate) struct PollKey {
next_poll_at_ms: AtomicU64,
id: usize,
}
impl PartialEq for PollKey {
fn eq(&self, other: &Self) -> bool {
self.next_poll_at_ms.load(Ordering::Relaxed)
== other.next_poll_at_ms.load(Ordering::Relaxed)
&& self.id == other.id
}
}
impl Eq for PollKey {}
impl PartialOrd for PollKey {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PollKey {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.next_poll_at_ms
.load(Ordering::Relaxed)
.cmp(&other.next_poll_at_ms.load(Ordering::Relaxed))
.then_with(|| self.id.cmp(&other.id))
}
}
impl PollKey {
/// A value indicating that an immediate poll is required.
const IMMEDIATE_VAL: u64 = 0;
/// A value indicating that no poll is required.
const INACTIVE_VAL: u64 = u64::MAX;
/// Creates a new [`PollKey`].
///
/// `id` must be a unique identifier for the associated socket, as it will be used to locate
/// the socket to update its next poll time. This is usually done using the address of the
/// [`Arc`] socket (see [`Arc::as_ptr`]).
///
/// [`Arc`]: alloc::sync::Arc
/// [`Arc::as_ptr`]: alloc::sync::Arc::as_ptr
pub(crate) fn new(id: usize) -> Self {
Self {
next_poll_at_ms: AtomicU64::new(Self::INACTIVE_VAL),
id,
}
}
}
/// Sockets to poll in the future, sorted by poll time.
pub(crate) struct PendingConnSet<E: Ext>(BTreeSet<PendingTcpConn<E>>);
/// A TCP socket to poll in the future.
///
/// Note that currently only TCP sockets can set a timer to fire in the future, so a
/// [`PendingConnSet`] contains only [`PendingTcpConn`]s.
struct PendingTcpConn<E: Ext>(Arc<TcpConnectionBg<E>>);
impl<E: Ext> PartialEq for PendingTcpConn<E> {
fn eq(&self, other: &Self) -> bool {
self.0.poll_key() == other.0.poll_key()
}
}
impl<E: Ext> Eq for PendingTcpConn<E> {}
impl<E: Ext> PartialOrd for PendingTcpConn<E> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<E: Ext> Ord for PendingTcpConn<E> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.poll_key().cmp(other.0.poll_key())
}
}
impl<E: Ext> Borrow<PollKey> for PendingTcpConn<E> {
fn borrow(&self) -> &PollKey {
self.0.poll_key()
}
}
impl<E: Ext> PendingConnSet<E> {
fn new() -> Self {
Self(BTreeSet::new())
}
fn update_next_poll_at_ms(
&mut self,
socket: &Arc<TcpConnectionBg<E>>,
poll_at: smoltcp::socket::PollAt,
) -> NeedIfacePoll {
let key = socket.poll_key();
let old_poll_at_ms = key.next_poll_at_ms.load(Ordering::Relaxed);
let new_poll_at_ms = match poll_at {
smoltcp::socket::PollAt::Now => PollKey::IMMEDIATE_VAL,
smoltcp::socket::PollAt::Time(instant) => instant.total_millis() as u64,
smoltcp::socket::PollAt::Ingress => PollKey::INACTIVE_VAL,
};
// Fast path: There is nothing to update.
if old_poll_at_ms == new_poll_at_ms {
return NeedIfacePoll::FALSE;
}
// Remove the socket from the pending queue if it is in the queue.
let owned_socket = if old_poll_at_ms != PollKey::INACTIVE_VAL {
self.0.take(key).unwrap()
} else {
PendingTcpConn(socket.clone())
};
// Update the poll time _after_ it is removed from the queue.
key.next_poll_at_ms.store(new_poll_at_ms, Ordering::Relaxed);
// If no new poll is required, do not add the socket to the pending queue.
if new_poll_at_ms == PollKey::INACTIVE_VAL {
return NeedIfacePoll::FALSE;
}
// Add the socket back to the queue.
let inserted = self.0.insert(owned_socket);
debug_assert!(inserted);
if new_poll_at_ms < old_poll_at_ms {
NeedIfacePoll::TRUE
} else {
NeedIfacePoll::FALSE
}
}
fn pop_tcp_before_now(&mut self, now_at_ms: u64) -> Option<Arc<TcpConnectionBg<E>>> {
if self.0.first().is_some_and(|first| {
first.0.poll_key().next_poll_at_ms.load(Ordering::Relaxed) <= now_at_ms
}) {
self.0.pop_first().map(|first| {
// Reset `next_poll_at_ms` since the socket is no longer in the queue.
first
.0
.poll_key()
.next_poll_at_ms
.store(PollKey::INACTIVE_VAL, Ordering::Relaxed);
first.0
})
} else {
None
}
}
fn next_poll_at_ms(&self) -> Option<u64> {
self.0
.first()
.map(|first| first.0.poll_key().next_poll_at_ms.load(Ordering::Relaxed))
}
}

View File

@ -1,9 +1,9 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::sync::Arc;
use core::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use alloc::sync::{Arc, Weak};
use core::sync::atomic::{AtomicU8, Ordering};
use smoltcp::{socket::PollAt, time::Instant, wire::IpEndpoint};
use smoltcp::wire::IpEndpoint;
use spin::once::Once;
use takeable::Takeable;
@ -46,7 +46,6 @@ pub struct SocketBg<T: Inner<E>, E: Ext> {
pub(super) inner: T,
observer: Once<T::Observer>,
events: AtomicU8,
next_poll_at_ms: AtomicU64,
}
impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
@ -58,13 +57,24 @@ impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
}
impl<T: Inner<E>, E: Ext> Socket<T, E> {
pub(crate) fn new(bound: BoundPort<E>, inner: T) -> Self {
pub(super) fn new(bound: BoundPort<E>, inner: T) -> Self {
Self(Takeable::new(Arc::new(SocketBg {
bound,
inner,
observer: Once::new(),
events: AtomicU8::new(0),
next_poll_at_ms: AtomicU64::new(u64::MAX),
})))
}
pub(super) fn new_cyclic<F>(bound: BoundPort<E>, inner_fn: F) -> Self
where
F: FnOnce(&Weak<SocketBg<T, E>>) -> T,
{
Self(Takeable::new(Arc::new_cyclic(|weak| SocketBg {
bound,
inner: inner_fn(weak),
observer: Once::new(),
events: AtomicU8::new(0),
})))
}
@ -119,10 +129,8 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
where
T::Observer: Clone,
{
// This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
// There is no need to clear the events because the socket is dead.
let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed);
let observer = self.observer.get().cloned();
drop(self);
@ -141,41 +149,6 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
self.events
.store(events | new_events.bits(), Ordering::Relaxed);
}
/// Returns the next polling time.
///
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
/// before new network or user events.
pub(crate) fn next_poll_at_ms(&self) -> u64 {
self.next_poll_at_ms.load(Ordering::Relaxed)
}
/// Updates the next polling time according to `poll_at`.
///
/// The update is typically needed after new network or user events have been handled, so this
/// method also marks that there may be new events, so that the event observer provided by
/// [`Socket::init_observer`] can be notified later.
pub(super) fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll {
match poll_at {
PollAt::Now => {
self.next_poll_at_ms.store(0, Ordering::Relaxed);
NeedIfacePoll::TRUE
}
PollAt::Time(instant) => {
let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed);
let new_total_millis = instant.total_millis() as u64;
self.next_poll_at_ms
.store(new_total_millis, Ordering::Relaxed);
NeedIfacePoll(new_total_millis < old_total_millis)
}
PollAt::Ingress => {
self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed);
NeedIfacePoll::FALSE
}
}
}
}
impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
@ -185,11 +158,4 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
self.bound.port() == dst_port
}
/// Returns whether the socket _may_ generate an outgoing packet.
///
/// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
}
}

View File

@ -1,11 +1,13 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc};
use alloc::{
boxed::Box,
sync::{Arc, Weak},
};
use core::ops::{Deref, DerefMut};
use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
use smoltcp::{
iface::Context,
socket::{tcp::State, PollAt},
time::Duration,
wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr},
@ -19,7 +21,7 @@ use crate::{
define_boolean_value,
errors::tcp::{ConnectError, RecvError, SendError},
ext::Ext,
iface::BoundPort,
iface::{BoundPort, PollKey, PollableIfaceMut},
socket::{
event::SocketEvents,
option::{RawTcpOption, RawTcpSetOption},
@ -33,6 +35,7 @@ pub type TcpConnection<E> = Socket<TcpConnectionInner<E>, E>;
/// States needed by [`TcpConnectionBg`].
pub struct TcpConnectionInner<E: Ext> {
socket: SpinLock<RawTcpSocketExt<E>, LocalIrqDisabled>,
poll_key: PollKey,
connection_key: ConnectionKey,
}
@ -216,7 +219,11 @@ impl<E: Ext> RawTcpSocketExt<E> {
}
impl<E: Ext> TcpConnectionInner<E> {
pub(super) fn new(socket: Box<RawTcpSocket>, listener: Option<Arc<TcpListenerBg<E>>>) -> Self {
pub(super) fn new(
socket: Box<RawTcpSocket>,
listener: Option<Arc<TcpListenerBg<E>>>,
weak_self: &Weak<TcpConnectionBg<E>>,
) -> Self {
let connection_key = {
// Since the socket is connected, the following unwrap can never fail
let local_endpoint = socket.local_endpoint().unwrap();
@ -224,6 +231,8 @@ impl<E: Ext> TcpConnectionInner<E> {
ConnectionKey::from((local_endpoint, remote_endpoint))
};
let poll_key = PollKey::new(Weak::as_ptr(weak_self).addr());
let socket_ext = RawTcpSocketExt {
socket,
listener,
@ -234,6 +243,7 @@ impl<E: Ext> TcpConnectionInner<E> {
TcpConnectionInner {
socket: SpinLock::new(socket_ext),
poll_key,
connection_key,
}
}
@ -293,7 +303,7 @@ impl<E: Ext> TcpConnection<E> {
};
let iface = bound.iface().clone();
// We have to lock interface before locking interface
// We have to lock `interface` before locking `sockets`
// to avoid dead lock due to inconsistent lock orders.
let mut interface = iface.common().interface();
let mut sockets = iface.common().sockets();
@ -309,18 +319,19 @@ impl<E: Ext> TcpConnection<E> {
option.apply(&mut socket);
if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) {
if let Err(err) = socket.connect(interface.context_mut(), remote_endpoint, bound.port())
{
return Err((bound, err.into()));
}
socket
};
let inner = TcpConnectionInner::new(socket, None);
let connection = Self::new(bound, inner);
connection.0.update_next_poll_at_ms(PollAt::Now);
let connection =
Self::new_cyclic(bound, |weak| TcpConnectionInner::new(socket, None, weak));
interface.update_next_poll_at_ms(&connection.0, PollAt::Now);
connection.init_observer(observer);
let res = sockets.insert_connection(connection.inner().clone());
debug_assert!(res.is_ok());
@ -378,9 +389,8 @@ impl<E: Ext> TcpConnection<E> {
}
let result = socket.send(f)?;
let need_poll = self
.0
.update_next_poll_at_ms(socket.poll_at(iface.context()));
let poll_at = socket.poll_at(iface.context_mut());
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
Ok((result, need_poll))
}
@ -408,9 +418,8 @@ impl<E: Ext> TcpConnection<E> {
res => res,
}?;
let need_poll = self
.0
.update_next_poll_at_ms(socket.poll_at(iface.context()));
let poll_at = socket.poll_at(iface.context_mut());
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
Ok((result, need_poll))
}
@ -433,6 +442,7 @@ impl<E: Ext> TcpConnection<E> {
///
/// Polling the iface is _always_ required after this method succeeds.
pub fn shut_send(&self) -> bool {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
if matches!(socket.state(), State::Closed | State::TimeWait) {
@ -440,7 +450,9 @@ impl<E: Ext> TcpConnection<E> {
}
socket.close();
self.0.update_next_poll_at_ms(PollAt::Now);
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at);
true
}
@ -469,6 +481,7 @@ impl<E: Ext> TcpConnection<E> {
/// Note that either this method or [`Self::reset`] must be called before dropping the TCP
/// connection to avoid resource leakage.
pub fn close(&self) {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
socket.is_recv_shut = true;
@ -479,7 +492,9 @@ impl<E: Ext> TcpConnection<E> {
} else {
socket.close();
}
self.0.update_next_poll_at_ms(PollAt::Now);
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at);
}
/// Resets the connection.
@ -489,10 +504,13 @@ impl<E: Ext> TcpConnection<E> {
/// Note that either this method or [`Self::close`] must be called before dropping the TCP
/// connection to avoid resource leakage.
pub fn reset(&self) {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
socket.abort();
self.0.update_next_poll_at_ms(PollAt::Now);
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at);
}
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
@ -510,15 +528,13 @@ impl<E: Ext> TcpConnection<E> {
impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
socket.set_keep_alive(interval);
if interval.is_some() {
self.0.update_next_poll_at_ms(PollAt::Now);
NeedIfacePoll::TRUE
} else {
NeedIfacePoll::FALSE
}
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at)
}
fn set_nagle_enabled(&self, enabled: bool) {
@ -528,6 +544,10 @@ impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
}
impl<E: Ext> TcpConnectionBg<E> {
pub(crate) const fn poll_key(&self) -> &PollKey {
&self.inner.poll_key
}
pub(crate) const fn connection_key(&self) -> &ConnectionKey {
&self.inner.connection_key
}
@ -544,13 +564,13 @@ impl<E: Ext> TcpConnectionBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
self: &Arc<Self>,
cx: &mut Context,
iface: &mut PollableIfaceMut<E>,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> (TcpProcessResult, TcpConnBecameDead) {
let mut socket = self.inner.lock();
if !socket.accepts(cx, ip_repr, tcp_repr) {
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
return (TcpProcessResult::NotProcessed, TcpConnBecameDead::FALSE);
}
@ -581,7 +601,7 @@ impl<E: Ext> TcpConnectionBg<E> {
// to be queued.
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
let result = match socket.process(cx, ip_repr, tcp_repr) {
let result = match socket.process(iface.context_mut(), ip_repr, tcp_repr) {
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
@ -591,7 +611,9 @@ impl<E: Ext> TcpConnectionBg<E> {
events |= state_events;
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(self, poll_at);
(result, became_dead)
}
@ -599,11 +621,11 @@ impl<E: Ext> TcpConnectionBg<E> {
/// Tries to generate an outgoing packet and dispatches the generated packet.
pub(crate) fn dispatch<D>(
self: &Arc<Self>,
cx: &mut Context,
iface: &mut PollableIfaceMut<E>,
dispatch: D,
) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead)
where
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
D: FnOnce(PollableIfaceMut<E>, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
{
let mut socket = self.inner.lock();
@ -613,9 +635,10 @@ impl<E: Ext> TcpConnectionBg<E> {
let mut events = SocketEvents::empty();
let mut reply = None;
let (cx, pending) = iface.inner_mut();
socket
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
reply = dispatch(cx, &ip_repr, &tcp_repr);
reply = dispatch(PollableIfaceMut::new(cx, pending), &ip_repr, &tcp_repr);
Ok::<(), ()>(())
})
.unwrap();
@ -623,12 +646,12 @@ impl<E: Ext> TcpConnectionBg<E> {
// `dispatch` can return a packet in response to the generated packet. If the socket
// accepts the packet, we can process it directly.
while let Some((ref ip_repr, ref tcp_repr)) = reply {
if !socket.accepts(cx, ip_repr, tcp_repr) {
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
break;
}
is_rst |= tcp_repr.control == TcpControl::Rst;
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
reply = socket.process(cx, ip_repr, tcp_repr);
reply = socket.process(iface.context_mut(), ip_repr, tcp_repr);
}
let (state_events, became_dead) =
@ -636,7 +659,9 @@ impl<E: Ext> TcpConnectionBg<E> {
events |= state_events;
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(self, poll_at);
(reply, became_dead)
}

View File

@ -4,7 +4,6 @@ use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use ostd::sync::{LocalIrqDisabled, SpinLock};
use smoltcp::{
iface::Context,
socket::PollAt,
time::Duration,
wire::{IpEndpoint, IpRepr, TcpRepr},
@ -17,7 +16,7 @@ use super::{
use crate::{
errors::tcp::ListenError,
ext::Ext,
iface::{BindPortConfig, BoundPort},
iface::{BindPortConfig, BoundPort, PollableIfaceMut},
socket::{
option::{RawTcpOption, RawTcpSetOption},
unbound::{new_tcp_socket, RawTcpSocket},
@ -194,13 +193,16 @@ impl<E: Ext> TcpListenerBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
self: &Arc<Self>,
cx: &mut Context,
iface: &mut PollableIfaceMut<E>,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> (TcpProcessResult, Option<Arc<TcpConnectionBg<E>>>) {
let mut backlog = self.inner.backlog.lock();
if !backlog.socket.accepts(cx, ip_repr, tcp_repr) {
if !backlog
.socket
.accepts(iface.context_mut(), ip_repr, tcp_repr)
{
return (TcpProcessResult::NotProcessed, None);
}
@ -211,7 +213,10 @@ impl<E: Ext> TcpListenerBg<E> {
return (TcpProcessResult::Processed, None);
}
let result = match backlog.socket.process(cx, ip_repr, tcp_repr) {
let result = match backlog
.socket
.process(iface.context_mut(), ip_repr, tcp_repr)
{
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
@ -227,23 +232,25 @@ impl<E: Ext> TcpListenerBg<E> {
socket
};
let inner = TcpConnectionInner::new(
core::mem::replace(&mut backlog.socket, new_socket),
Some(self.clone()),
);
let conn = TcpConnection::new(
let conn = TcpConnection::new_cyclic(
self.bound
.iface()
.bind(BindPortConfig::CanReuse(self.bound.port()))
.unwrap(),
inner,
|weak| {
TcpConnectionInner::new(
core::mem::replace(&mut backlog.socket, new_socket),
Some(self.clone()),
weak,
)
},
);
let conn_bg = conn.inner().clone();
let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn);
debug_assert!(old_conn.is_none());
conn_bg.update_next_poll_at_ms(PollAt::Now);
iface.update_next_poll_at_ms(&conn_bg, PollAt::Now);
(result, Some(conn_bg))
}

View File

@ -1,11 +1,12 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc};
use core::sync::atomic::{AtomicBool, Ordering};
use ostd::sync::{LocalIrqDisabled, SpinLock};
use smoltcp::{
iface::Context,
socket::{udp::UdpMetadata, PollAt},
socket::udp::UdpMetadata,
wire::{IpRepr, UdpRepr},
};
@ -20,13 +21,16 @@ use crate::{
pub type UdpSocket<E> = Socket<UdpSocketInner, E>;
/// States needed by [`UdpSocketBg`].
type UdpSocketInner = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
pub struct UdpSocketInner {
socket: SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>,
need_dispatch: AtomicBool,
}
impl<E: Ext> Inner<E> for UdpSocketInner {
type Observer = E::UdpEventObserver;
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
this.inner.lock().close();
this.inner.socket.lock().close();
// A UDP socket can be removed immediately.
this.bound.iface().common().remove_udp_socket(this);
@ -44,7 +48,7 @@ impl<E: Ext> UdpSocketBg<E> {
udp_repr: &UdpRepr,
udp_payload: &[u8],
) -> bool {
let mut socket = self.inner.lock();
let mut socket = self.inner.socket.lock();
if !socket.accepts(cx, ip_repr, udp_repr) {
return false;
@ -59,7 +63,6 @@ impl<E: Ext> UdpSocketBg<E> {
);
self.add_events(SocketEvents::CAN_RECV);
self.update_next_poll_at_ms(socket.poll_at(cx));
true
}
@ -69,7 +72,7 @@ impl<E: Ext> UdpSocketBg<E> {
where
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
{
let mut socket = self.inner.lock();
let mut socket = self.inner.socket.lock();
socket
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
@ -80,7 +83,17 @@ impl<E: Ext> UdpSocketBg<E> {
// For UDP, dequeuing a packet means that we can queue more packets.
self.add_events(SocketEvents::CAN_SEND);
self.update_next_poll_at_ms(socket.poll_at(cx));
self.inner
.need_dispatch
.store(socket.send_queue() > 0, Ordering::Relaxed);
}
/// Returns whether the socket _may_ generate an outgoing packet.
///
/// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn need_dispatch(&self) -> bool {
self.inner.need_dispatch.load(Ordering::Relaxed)
}
}
@ -106,7 +119,10 @@ impl<E: Ext> UdpSocket<E> {
socket
};
let inner = UdpSocketInner::new(socket);
let inner = UdpSocketInner {
socket: SpinLock::new(socket),
need_dispatch: AtomicBool::new(false),
};
let socket = Self::new(bound, inner);
socket.init_observer(observer);
@ -130,7 +146,7 @@ impl<E: Ext> UdpSocket<E> {
where
F: FnOnce(&mut [u8]) -> R,
{
let mut socket = self.0.inner.lock();
let mut socket = self.0.inner.socket.lock();
if size > socket.packet_send_capacity() {
return Err(SendError::TooLarge);
@ -141,7 +157,11 @@ impl<E: Ext> UdpSocket<E> {
Err(err) => return Err(err.into()),
};
let result = f(buffer);
self.0.update_next_poll_at_ms(PollAt::Now);
self.0
.inner
.need_dispatch
.store(socket.send_queue() > 0, Ordering::Relaxed);
Ok(result)
}
@ -153,7 +173,7 @@ impl<E: Ext> UdpSocket<E> {
where
F: FnOnce(&[u8], UdpMetadata) -> R,
{
let mut socket = self.0.inner.lock();
let mut socket = self.0.inner.socket.lock();
let (data, meta) = socket.recv()?;
let result = f(data, meta);
@ -169,7 +189,7 @@ impl<E: Ext> UdpSocket<E> {
where
F: FnOnce(&RawUdpSocket) -> R,
{
let socket = self.0.inner.lock();
let socket = self.0.inner.socket.lock();
f(&socket)
}
}