Avoid O(n) iteration when sending TCP packets

This commit is contained in:
Ruihan Li
2025-03-18 09:34:15 +08:00
committed by Tate, Hongliang Tian
parent 58ad43b0a9
commit a7e718e812
8 changed files with 476 additions and 185 deletions

View File

@ -1,9 +1,9 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::sync::Arc;
use core::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use alloc::sync::{Arc, Weak};
use core::sync::atomic::{AtomicU8, Ordering};
use smoltcp::{socket::PollAt, time::Instant, wire::IpEndpoint};
use smoltcp::wire::IpEndpoint;
use spin::once::Once;
use takeable::Takeable;
@ -46,7 +46,6 @@ pub struct SocketBg<T: Inner<E>, E: Ext> {
pub(super) inner: T,
observer: Once<T::Observer>,
events: AtomicU8,
next_poll_at_ms: AtomicU64,
}
impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
@ -58,13 +57,24 @@ impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
}
impl<T: Inner<E>, E: Ext> Socket<T, E> {
pub(crate) fn new(bound: BoundPort<E>, inner: T) -> Self {
pub(super) fn new(bound: BoundPort<E>, inner: T) -> Self {
Self(Takeable::new(Arc::new(SocketBg {
bound,
inner,
observer: Once::new(),
events: AtomicU8::new(0),
next_poll_at_ms: AtomicU64::new(u64::MAX),
})))
}
pub(super) fn new_cyclic<F>(bound: BoundPort<E>, inner_fn: F) -> Self
where
F: FnOnce(&Weak<SocketBg<T, E>>) -> T,
{
Self(Takeable::new(Arc::new_cyclic(|weak| SocketBg {
bound,
inner: inner_fn(weak),
observer: Once::new(),
events: AtomicU8::new(0),
})))
}
@ -119,10 +129,8 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
where
T::Observer: Clone,
{
// This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
// There is no need to clear the events because the socket is dead.
let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed);
let observer = self.observer.get().cloned();
drop(self);
@ -141,41 +149,6 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
self.events
.store(events | new_events.bits(), Ordering::Relaxed);
}
/// Returns the next polling time.
///
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
/// before new network or user events.
pub(crate) fn next_poll_at_ms(&self) -> u64 {
self.next_poll_at_ms.load(Ordering::Relaxed)
}
/// Updates the next polling time according to `poll_at`.
///
/// The update is typically needed after new network or user events have been handled, so this
/// method also marks that there may be new events, so that the event observer provided by
/// [`Socket::init_observer`] can be notified later.
pub(super) fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll {
match poll_at {
PollAt::Now => {
self.next_poll_at_ms.store(0, Ordering::Relaxed);
NeedIfacePoll::TRUE
}
PollAt::Time(instant) => {
let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed);
let new_total_millis = instant.total_millis() as u64;
self.next_poll_at_ms
.store(new_total_millis, Ordering::Relaxed);
NeedIfacePoll(new_total_millis < old_total_millis)
}
PollAt::Ingress => {
self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed);
NeedIfacePoll::FALSE
}
}
}
}
impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
@ -185,11 +158,4 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
self.bound.port() == dst_port
}
/// Returns whether the socket _may_ generate an outgoing packet.
///
/// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
}
}

View File

@ -1,11 +1,13 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc};
use alloc::{
boxed::Box,
sync::{Arc, Weak},
};
use core::ops::{Deref, DerefMut};
use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
use smoltcp::{
iface::Context,
socket::{tcp::State, PollAt},
time::Duration,
wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr},
@ -19,7 +21,7 @@ use crate::{
define_boolean_value,
errors::tcp::{ConnectError, RecvError, SendError},
ext::Ext,
iface::BoundPort,
iface::{BoundPort, PollKey, PollableIfaceMut},
socket::{
event::SocketEvents,
option::{RawTcpOption, RawTcpSetOption},
@ -33,6 +35,7 @@ pub type TcpConnection<E> = Socket<TcpConnectionInner<E>, E>;
/// States needed by [`TcpConnectionBg`].
pub struct TcpConnectionInner<E: Ext> {
socket: SpinLock<RawTcpSocketExt<E>, LocalIrqDisabled>,
poll_key: PollKey,
connection_key: ConnectionKey,
}
@ -216,7 +219,11 @@ impl<E: Ext> RawTcpSocketExt<E> {
}
impl<E: Ext> TcpConnectionInner<E> {
pub(super) fn new(socket: Box<RawTcpSocket>, listener: Option<Arc<TcpListenerBg<E>>>) -> Self {
pub(super) fn new(
socket: Box<RawTcpSocket>,
listener: Option<Arc<TcpListenerBg<E>>>,
weak_self: &Weak<TcpConnectionBg<E>>,
) -> Self {
let connection_key = {
// Since the socket is connected, the following unwrap can never fail
let local_endpoint = socket.local_endpoint().unwrap();
@ -224,6 +231,8 @@ impl<E: Ext> TcpConnectionInner<E> {
ConnectionKey::from((local_endpoint, remote_endpoint))
};
let poll_key = PollKey::new(Weak::as_ptr(weak_self).addr());
let socket_ext = RawTcpSocketExt {
socket,
listener,
@ -234,6 +243,7 @@ impl<E: Ext> TcpConnectionInner<E> {
TcpConnectionInner {
socket: SpinLock::new(socket_ext),
poll_key,
connection_key,
}
}
@ -293,7 +303,7 @@ impl<E: Ext> TcpConnection<E> {
};
let iface = bound.iface().clone();
// We have to lock interface before locking interface
// We have to lock `interface` before locking `sockets`
// to avoid dead lock due to inconsistent lock orders.
let mut interface = iface.common().interface();
let mut sockets = iface.common().sockets();
@ -309,18 +319,19 @@ impl<E: Ext> TcpConnection<E> {
option.apply(&mut socket);
if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) {
if let Err(err) = socket.connect(interface.context_mut(), remote_endpoint, bound.port())
{
return Err((bound, err.into()));
}
socket
};
let inner = TcpConnectionInner::new(socket, None);
let connection = Self::new(bound, inner);
connection.0.update_next_poll_at_ms(PollAt::Now);
let connection =
Self::new_cyclic(bound, |weak| TcpConnectionInner::new(socket, None, weak));
interface.update_next_poll_at_ms(&connection.0, PollAt::Now);
connection.init_observer(observer);
let res = sockets.insert_connection(connection.inner().clone());
debug_assert!(res.is_ok());
@ -378,9 +389,8 @@ impl<E: Ext> TcpConnection<E> {
}
let result = socket.send(f)?;
let need_poll = self
.0
.update_next_poll_at_ms(socket.poll_at(iface.context()));
let poll_at = socket.poll_at(iface.context_mut());
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
Ok((result, need_poll))
}
@ -408,9 +418,8 @@ impl<E: Ext> TcpConnection<E> {
res => res,
}?;
let need_poll = self
.0
.update_next_poll_at_ms(socket.poll_at(iface.context()));
let poll_at = socket.poll_at(iface.context_mut());
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
Ok((result, need_poll))
}
@ -433,6 +442,7 @@ impl<E: Ext> TcpConnection<E> {
///
/// Polling the iface is _always_ required after this method succeeds.
pub fn shut_send(&self) -> bool {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
if matches!(socket.state(), State::Closed | State::TimeWait) {
@ -440,7 +450,9 @@ impl<E: Ext> TcpConnection<E> {
}
socket.close();
self.0.update_next_poll_at_ms(PollAt::Now);
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at);
true
}
@ -469,6 +481,7 @@ impl<E: Ext> TcpConnection<E> {
/// Note that either this method or [`Self::reset`] must be called before dropping the TCP
/// connection to avoid resource leakage.
pub fn close(&self) {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
socket.is_recv_shut = true;
@ -479,7 +492,9 @@ impl<E: Ext> TcpConnection<E> {
} else {
socket.close();
}
self.0.update_next_poll_at_ms(PollAt::Now);
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at);
}
/// Resets the connection.
@ -489,10 +504,13 @@ impl<E: Ext> TcpConnection<E> {
/// Note that either this method or [`Self::close`] must be called before dropping the TCP
/// connection to avoid resource leakage.
pub fn reset(&self) {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
socket.abort();
self.0.update_next_poll_at_ms(PollAt::Now);
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at);
}
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
@ -510,15 +528,13 @@ impl<E: Ext> TcpConnection<E> {
impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
let mut iface = self.iface().common().interface();
let mut socket = self.0.inner.lock();
socket.set_keep_alive(interval);
if interval.is_some() {
self.0.update_next_poll_at_ms(PollAt::Now);
NeedIfacePoll::TRUE
} else {
NeedIfacePoll::FALSE
}
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(&self.0, poll_at)
}
fn set_nagle_enabled(&self, enabled: bool) {
@ -528,6 +544,10 @@ impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
}
impl<E: Ext> TcpConnectionBg<E> {
pub(crate) const fn poll_key(&self) -> &PollKey {
&self.inner.poll_key
}
pub(crate) const fn connection_key(&self) -> &ConnectionKey {
&self.inner.connection_key
}
@ -544,13 +564,13 @@ impl<E: Ext> TcpConnectionBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
self: &Arc<Self>,
cx: &mut Context,
iface: &mut PollableIfaceMut<E>,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> (TcpProcessResult, TcpConnBecameDead) {
let mut socket = self.inner.lock();
if !socket.accepts(cx, ip_repr, tcp_repr) {
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
return (TcpProcessResult::NotProcessed, TcpConnBecameDead::FALSE);
}
@ -581,7 +601,7 @@ impl<E: Ext> TcpConnectionBg<E> {
// to be queued.
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
let result = match socket.process(cx, ip_repr, tcp_repr) {
let result = match socket.process(iface.context_mut(), ip_repr, tcp_repr) {
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
@ -591,7 +611,9 @@ impl<E: Ext> TcpConnectionBg<E> {
events |= state_events;
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(self, poll_at);
(result, became_dead)
}
@ -599,11 +621,11 @@ impl<E: Ext> TcpConnectionBg<E> {
/// Tries to generate an outgoing packet and dispatches the generated packet.
pub(crate) fn dispatch<D>(
self: &Arc<Self>,
cx: &mut Context,
iface: &mut PollableIfaceMut<E>,
dispatch: D,
) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead)
where
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
D: FnOnce(PollableIfaceMut<E>, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
{
let mut socket = self.inner.lock();
@ -613,9 +635,10 @@ impl<E: Ext> TcpConnectionBg<E> {
let mut events = SocketEvents::empty();
let mut reply = None;
let (cx, pending) = iface.inner_mut();
socket
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
reply = dispatch(cx, &ip_repr, &tcp_repr);
reply = dispatch(PollableIfaceMut::new(cx, pending), &ip_repr, &tcp_repr);
Ok::<(), ()>(())
})
.unwrap();
@ -623,12 +646,12 @@ impl<E: Ext> TcpConnectionBg<E> {
// `dispatch` can return a packet in response to the generated packet. If the socket
// accepts the packet, we can process it directly.
while let Some((ref ip_repr, ref tcp_repr)) = reply {
if !socket.accepts(cx, ip_repr, tcp_repr) {
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
break;
}
is_rst |= tcp_repr.control == TcpControl::Rst;
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
reply = socket.process(cx, ip_repr, tcp_repr);
reply = socket.process(iface.context_mut(), ip_repr, tcp_repr);
}
let (state_events, became_dead) =
@ -636,7 +659,9 @@ impl<E: Ext> TcpConnectionBg<E> {
events |= state_events;
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(self, poll_at);
(reply, became_dead)
}

View File

@ -4,7 +4,6 @@ use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use ostd::sync::{LocalIrqDisabled, SpinLock};
use smoltcp::{
iface::Context,
socket::PollAt,
time::Duration,
wire::{IpEndpoint, IpRepr, TcpRepr},
@ -17,7 +16,7 @@ use super::{
use crate::{
errors::tcp::ListenError,
ext::Ext,
iface::{BindPortConfig, BoundPort},
iface::{BindPortConfig, BoundPort, PollableIfaceMut},
socket::{
option::{RawTcpOption, RawTcpSetOption},
unbound::{new_tcp_socket, RawTcpSocket},
@ -194,13 +193,16 @@ impl<E: Ext> TcpListenerBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
self: &Arc<Self>,
cx: &mut Context,
iface: &mut PollableIfaceMut<E>,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> (TcpProcessResult, Option<Arc<TcpConnectionBg<E>>>) {
let mut backlog = self.inner.backlog.lock();
if !backlog.socket.accepts(cx, ip_repr, tcp_repr) {
if !backlog
.socket
.accepts(iface.context_mut(), ip_repr, tcp_repr)
{
return (TcpProcessResult::NotProcessed, None);
}
@ -211,7 +213,10 @@ impl<E: Ext> TcpListenerBg<E> {
return (TcpProcessResult::Processed, None);
}
let result = match backlog.socket.process(cx, ip_repr, tcp_repr) {
let result = match backlog
.socket
.process(iface.context_mut(), ip_repr, tcp_repr)
{
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
@ -227,23 +232,25 @@ impl<E: Ext> TcpListenerBg<E> {
socket
};
let inner = TcpConnectionInner::new(
core::mem::replace(&mut backlog.socket, new_socket),
Some(self.clone()),
);
let conn = TcpConnection::new(
let conn = TcpConnection::new_cyclic(
self.bound
.iface()
.bind(BindPortConfig::CanReuse(self.bound.port()))
.unwrap(),
inner,
|weak| {
TcpConnectionInner::new(
core::mem::replace(&mut backlog.socket, new_socket),
Some(self.clone()),
weak,
)
},
);
let conn_bg = conn.inner().clone();
let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn);
debug_assert!(old_conn.is_none());
conn_bg.update_next_poll_at_ms(PollAt::Now);
iface.update_next_poll_at_ms(&conn_bg, PollAt::Now);
(result, Some(conn_bg))
}

View File

@ -1,11 +1,12 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc};
use core::sync::atomic::{AtomicBool, Ordering};
use ostd::sync::{LocalIrqDisabled, SpinLock};
use smoltcp::{
iface::Context,
socket::{udp::UdpMetadata, PollAt},
socket::udp::UdpMetadata,
wire::{IpRepr, UdpRepr},
};
@ -20,13 +21,16 @@ use crate::{
pub type UdpSocket<E> = Socket<UdpSocketInner, E>;
/// States needed by [`UdpSocketBg`].
type UdpSocketInner = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
pub struct UdpSocketInner {
socket: SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>,
need_dispatch: AtomicBool,
}
impl<E: Ext> Inner<E> for UdpSocketInner {
type Observer = E::UdpEventObserver;
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
this.inner.lock().close();
this.inner.socket.lock().close();
// A UDP socket can be removed immediately.
this.bound.iface().common().remove_udp_socket(this);
@ -44,7 +48,7 @@ impl<E: Ext> UdpSocketBg<E> {
udp_repr: &UdpRepr,
udp_payload: &[u8],
) -> bool {
let mut socket = self.inner.lock();
let mut socket = self.inner.socket.lock();
if !socket.accepts(cx, ip_repr, udp_repr) {
return false;
@ -59,7 +63,6 @@ impl<E: Ext> UdpSocketBg<E> {
);
self.add_events(SocketEvents::CAN_RECV);
self.update_next_poll_at_ms(socket.poll_at(cx));
true
}
@ -69,7 +72,7 @@ impl<E: Ext> UdpSocketBg<E> {
where
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
{
let mut socket = self.inner.lock();
let mut socket = self.inner.socket.lock();
socket
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
@ -80,7 +83,17 @@ impl<E: Ext> UdpSocketBg<E> {
// For UDP, dequeuing a packet means that we can queue more packets.
self.add_events(SocketEvents::CAN_SEND);
self.update_next_poll_at_ms(socket.poll_at(cx));
self.inner
.need_dispatch
.store(socket.send_queue() > 0, Ordering::Relaxed);
}
/// Returns whether the socket _may_ generate an outgoing packet.
///
/// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn need_dispatch(&self) -> bool {
self.inner.need_dispatch.load(Ordering::Relaxed)
}
}
@ -106,7 +119,10 @@ impl<E: Ext> UdpSocket<E> {
socket
};
let inner = UdpSocketInner::new(socket);
let inner = UdpSocketInner {
socket: SpinLock::new(socket),
need_dispatch: AtomicBool::new(false),
};
let socket = Self::new(bound, inner);
socket.init_observer(observer);
@ -130,7 +146,7 @@ impl<E: Ext> UdpSocket<E> {
where
F: FnOnce(&mut [u8]) -> R,
{
let mut socket = self.0.inner.lock();
let mut socket = self.0.inner.socket.lock();
if size > socket.packet_send_capacity() {
return Err(SendError::TooLarge);
@ -141,7 +157,11 @@ impl<E: Ext> UdpSocket<E> {
Err(err) => return Err(err.into()),
};
let result = f(buffer);
self.0.update_next_poll_at_ms(PollAt::Now);
self.0
.inner
.need_dispatch
.store(socket.send_queue() > 0, Ordering::Relaxed);
Ok(result)
}
@ -153,7 +173,7 @@ impl<E: Ext> UdpSocket<E> {
where
F: FnOnce(&[u8], UdpMetadata) -> R,
{
let mut socket = self.0.inner.lock();
let mut socket = self.0.inner.socket.lock();
let (data, meta) = socket.recv()?;
let result = f(data, meta);
@ -169,7 +189,7 @@ impl<E: Ext> UdpSocket<E> {
where
F: FnOnce(&RawUdpSocket) -> R,
{
let socket = self.0.inner.lock();
let socket = self.0.inner.socket.lock();
f(&socket)
}
}