Add socket hash table

This commit is contained in:
jiangjianfeng
2024-12-27 08:28:11 +00:00
committed by Tate, Hongliang Tian
parent 783345b90b
commit 39cc0dca26
11 changed files with 684 additions and 202 deletions

View File

@ -1,13 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, collections::btree_set::BTreeSet, sync::Arc, vec::Vec};
use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
use core::{
borrow::Borrow,
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
};
use keyable_arc::KeyableArc;
use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
use smoltcp::{
iface::Context,
@ -15,7 +13,7 @@ use smoltcp::{
time::{Duration, Instant},
wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr},
};
use spin::Once;
use spin::once::Once;
use takeable::Takeable;
use super::{
@ -25,40 +23,20 @@ use super::{
RawTcpSocket, RawUdpSocket, TcpStateCheck,
};
use crate::{
errors::tcp::{ConnectError, ListenError},
ext::Ext,
iface::{BindPortConfig, BoundPort, Iface},
socket_table::{ConnectionKey, ListenerKey},
};
pub struct Socket<T: Inner<E>, E: Ext>(Takeable<KeyableArc<SocketBg<T, E>>>);
impl<T: Inner<E>, E: Ext> PartialEq for Socket<T, E> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl<T: Inner<E>, E: Ext> Eq for Socket<T, E> {}
impl<T: Inner<E>, E: Ext> PartialOrd for Socket<T, E> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: Inner<E>, E: Ext> Ord for Socket<T, E> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl<T: Inner<E>, E: Ext> Borrow<KeyableArc<SocketBg<T, E>>> for Socket<T, E> {
fn borrow(&self) -> &KeyableArc<SocketBg<T, E>> {
self.0.as_ref()
}
}
pub struct Socket<T: Inner<E>, E: Ext>(Takeable<Arc<SocketBg<T, E>>>);
/// [`TcpConnectionInner`] or [`UdpSocketInner`].
pub trait Inner<E: Ext> {
type Observer: SocketEventObserver;
/// Called by [`Socket::drop`].
fn on_drop(this: &KeyableArc<SocketBg<Self, E>>)
fn on_drop(this: &Arc<SocketBg<Self, E>>)
where
E: Ext,
Self: Sized;
@ -85,6 +63,7 @@ pub struct SocketBg<T: Inner<E>, E: Ext> {
pub struct TcpConnectionInner<E: Ext> {
socket: SpinLock<RawTcpSocketExt<E>, LocalIrqDisabled>,
is_dead: AtomicBool,
connection_key: ConnectionKey,
}
struct RawTcpSocketExt<E: Ext> {
@ -108,13 +87,13 @@ impl<E: Ext> DerefMut for RawTcpSocketExt<E> {
}
impl<E: Ext> RawTcpSocketExt<E> {
fn on_new_state(&mut self, this: &KeyableArc<TcpConnectionBg<E>>) -> SocketEvents {
fn on_new_state(&mut self, this: &Arc<TcpConnectionBg<E>>) -> SocketEvents {
if self.may_send() && !self.has_connected {
self.has_connected = true;
if let Some(ref listener) = self.listener {
let mut backlog = listener.inner.lock();
if let Some(value) = backlog.connecting.take(this) {
let mut backlog = listener.inner.backlog.lock();
if let Some(value) = backlog.connecting.remove(this.connection_key()) {
backlog.connected.push(value);
}
listener.add_events(SocketEvents::CAN_RECV);
@ -139,7 +118,7 @@ impl<E: Ext> RawTcpSocketExt<E> {
/// This method must be called after handling network events. However, it is not necessary to
/// call this method after handling non-closing user events, because the socket can never be
/// dead if it is not closed.
fn update_dead(&self, this: &KeyableArc<TcpConnectionBg<E>>) {
fn update_dead(&self, this: &Arc<TcpConnectionBg<E>>) {
if self.state() == smoltcp::socket::tcp::State::Closed {
this.inner.is_dead.store(true, Ordering::Relaxed);
}
@ -150,9 +129,9 @@ impl<E: Ext> RawTcpSocketExt<E> {
this.inner.is_dead.store(true, Ordering::Relaxed);
if let Some(ref listener) = self.listener {
let mut backlog = listener.inner.lock();
let mut backlog = listener.inner.backlog.lock();
// This may fail due to race conditions, but it's fine.
let _ = backlog.connecting.remove(this);
let _ = backlog.connecting.remove(&this.inner.connection_key);
}
}
}
@ -160,6 +139,13 @@ impl<E: Ext> RawTcpSocketExt<E> {
impl<E: Ext> TcpConnectionInner<E> {
fn new(socket: Box<RawTcpSocket>, listener: Option<Arc<TcpListenerBg<E>>>) -> Self {
let connection_key = {
// Since the socket is connected, the following unwrap can never fail
let local_endpoint = socket.local_endpoint().unwrap();
let remote_endpoint = socket.remote_endpoint().unwrap();
ConnectionKey::from((local_endpoint, remote_endpoint))
};
let socket_ext = RawTcpSocketExt {
socket,
listener,
@ -169,6 +155,7 @@ impl<E: Ext> TcpConnectionInner<E> {
TcpConnectionInner {
socket: SpinLock::new(socket_ext),
is_dead: AtomicBool::new(false),
connection_key,
}
}
@ -197,7 +184,7 @@ impl<E: Ext> TcpConnectionInner<E> {
impl<E: Ext> Inner<E> for TcpConnectionInner<E> {
type Observer = E::TcpEventObserver;
fn on_drop(this: &KeyableArc<SocketBg<Self, E>>) {
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
let mut socket = this.inner.lock();
// FIXME: Send RSTs when there is unread data.
@ -213,21 +200,33 @@ impl<E: Ext> Inner<E> for TcpConnectionInner<E> {
pub struct TcpBacklog<E: Ext> {
socket: Box<RawTcpSocket>,
max_conn: usize,
connecting: BTreeSet<TcpConnection<E>>,
connecting: BTreeMap<ConnectionKey, TcpConnection<E>>,
connected: Vec<TcpConnection<E>>,
}
pub type TcpListenerInner<E> = SpinLock<TcpBacklog<E>, LocalIrqDisabled>;
pub struct TcpListenerInner<E: Ext> {
backlog: SpinLock<TcpBacklog<E>, LocalIrqDisabled>,
listener_key: ListenerKey,
}
impl<E: Ext> TcpListenerInner<E> {
fn new(backlog: TcpBacklog<E>, listener_key: ListenerKey) -> Self {
Self {
backlog: SpinLock::new(backlog),
listener_key,
}
}
}
impl<E: Ext> Inner<E> for TcpListenerInner<E> {
type Observer = E::TcpEventObserver;
fn on_drop(this: &KeyableArc<SocketBg<Self, E>>) {
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
// A TCP listener can be removed immediately.
this.bound.iface().common().remove_tcp_listener(this);
let (connecting, connected) = {
let mut socket = this.inner.lock();
let mut socket = this.inner.backlog.lock();
(
core::mem::take(&mut socket.connecting),
core::mem::take(&mut socket.connected),
@ -249,7 +248,7 @@ type UdpSocketInner = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
impl<E: Ext> Inner<E> for UdpSocketInner {
type Observer = E::UdpEventObserver;
fn on_drop(this: &KeyableArc<SocketBg<Self, E>>) {
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
this.inner.lock().close();
// A UDP socket can be removed immediately.
@ -271,7 +270,7 @@ pub(crate) type UdpSocketBg<E> = SocketBg<UdpSocketInner, E>;
impl<T: Inner<E>, E: Ext> Socket<T, E> {
pub(crate) fn new(bound: BoundPort<E>, inner: T) -> Self {
Self(Takeable::new(KeyableArc::new(SocketBg {
Self(Takeable::new(Arc::new(SocketBg {
bound,
inner,
observer: Once::new(),
@ -280,7 +279,7 @@ impl<T: Inner<E>, E: Ext> Socket<T, E> {
})))
}
pub(crate) fn inner(&self) -> &KeyableArc<SocketBg<T, E>> {
pub(crate) fn inner(&self) -> &Arc<SocketBg<T, E>> {
&self.0
}
}
@ -337,18 +336,30 @@ impl<E: Ext> TcpConnection<E> {
remote_endpoint: IpEndpoint,
option: &RawTcpOption,
observer: E::TcpEventObserver,
) -> Result<Self, (BoundPort<E>, smoltcp::socket::tcp::ConnectError)> {
) -> Result<Self, (BoundPort<E>, ConnectError)> {
let Some(local_endpoint) = bound.endpoint() else {
return Err((bound, ConnectError::Unaddressable));
};
let iface = bound.iface().clone();
// We have to lock interface before locking interface
// to avoid dead lock due to inconsistent lock orders.
let mut interface = iface.common().interface();
let mut sockets = iface.common().sockets();
let connection_key = ConnectionKey::from((local_endpoint, remote_endpoint));
if sockets.lookup_connection(&connection_key).is_some() {
return Err((bound, ConnectError::AddressInUse));
}
let socket = {
let mut socket = new_tcp_socket();
option.apply(&mut socket);
let common = bound.iface().common();
let mut iface = common.interface();
if let Err(err) = socket.connect(iface.context(), remote_endpoint, bound.port()) {
drop(iface);
return Err((bound, err));
if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) {
return Err((bound, err.into()));
}
socket
@ -359,10 +370,8 @@ impl<E: Ext> TcpConnection<E> {
let connection = Self::new(bound, inner);
connection.0.update_next_poll_at_ms(PollAt::Now);
connection.init_observer(observer);
connection
.iface()
.common()
.register_tcp_connection(connection.inner().clone());
let res = sockets.insert_connection(connection.inner().clone());
debug_assert!(res.is_ok());
Ok(connection)
}
@ -375,7 +384,7 @@ impl<E: Ext> TcpConnection<E> {
ConnectState::Connecting
} else if socket.has_connected {
ConnectState::Connected
} else if KeyableArc::strong_count(self.0.as_ref()) > 1 {
} else if Arc::strong_count(self.0.as_ref()) > 1 {
// Now we should return `ConnectState::Refused`. However, when we do this, we must
// guarantee that `into_bound_port` can succeed (see the method's doc comments). We can
// only guarantee this after we have removed all `Arc<TcpConnectionBg>` in the iface's
@ -396,7 +405,7 @@ impl<E: Ext> TcpConnection<E> {
/// this connection. We guarantee that this method will always succeed if
/// [`Self::connect_state`] returns [`ConnectState::Refused`].
pub fn into_bound_port(mut self) -> Option<BoundPort<E>> {
let this: TcpConnectionBg<E> = Arc::into_inner(self.0.take().into())?;
let this: TcpConnectionBg<E> = Arc::into_inner(self.0.take())?;
Some(this.bound)
}
@ -492,36 +501,47 @@ impl<E: Ext> TcpListener<E> {
max_conn: usize,
option: &RawTcpOption,
observer: E::TcpEventObserver,
) -> Result<Self, (BoundPort<E>, smoltcp::socket::tcp::ListenError)> {
) -> Result<Self, (BoundPort<E>, ListenError)> {
let Some(local_endpoint) = bound.endpoint() else {
return Err((bound, smoltcp::socket::tcp::ListenError::Unaddressable));
return Err((bound, ListenError::Unaddressable));
};
let iface = bound.iface().clone();
let mut sockets = iface.common().sockets();
let listener_key = ListenerKey::new(local_endpoint.addr, local_endpoint.port);
if sockets.lookup_listener(&listener_key).is_some() {
return Err((bound, ListenError::AddressInUse));
}
let socket = {
let mut socket = new_tcp_socket();
option.apply(&mut socket);
if let Err(err) = socket.listen(local_endpoint) {
return Err((bound, err));
return Err((bound, err.into()));
}
socket
};
let inner = TcpListenerInner::new(TcpBacklog {
socket,
max_conn,
connecting: BTreeSet::new(),
connected: Vec::new(),
});
let inner = {
let backlog = TcpBacklog {
socket,
max_conn,
connecting: BTreeMap::new(),
connected: Vec::new(),
};
TcpListenerInner::new(backlog, listener_key)
};
let listener = Self::new(bound, inner);
listener.init_observer(observer);
listener
.iface()
.common()
.register_tcp_listener(listener.inner().clone());
let res = sockets.insert_listener(listener.inner().clone());
debug_assert!(res.is_ok());
Ok(listener)
}
@ -531,7 +551,7 @@ impl<E: Ext> TcpListener<E> {
/// Polling the iface is _not_ required after this method succeeds.
pub fn accept(&self) -> Option<(TcpConnection<E>, IpEndpoint)> {
let accepted = {
let mut backlog = self.0.inner.lock();
let mut backlog = self.0.inner.backlog.lock();
backlog.connected.pop()?
};
@ -551,20 +571,20 @@ impl<E: Ext> TcpListener<E> {
///
/// It's the caller's responsibility to deal with race conditions when using this method.
pub fn can_accept(&self) -> bool {
!self.0.inner.lock().connected.is_empty()
!self.0.inner.backlog.lock().connected.is_empty()
}
}
impl<E: Ext> RawTcpSetOption for TcpListener<E> {
fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
let mut backlog = self.0.inner.lock();
let mut backlog = self.0.inner.backlog.lock();
backlog.socket.set_keep_alive(interval);
NeedIfacePoll::FALSE
}
fn set_nagle_enabled(&self, enabled: bool) {
let mut backlog = self.0.inner.lock();
let mut backlog = self.0.inner.backlog.lock();
backlog.socket.set_nagle_enabled(enabled);
}
}
@ -680,17 +700,17 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
}
}
pub(crate) fn on_dead_events(this: KeyableArc<Self>)
pub(crate) fn on_dead_events(self: Arc<Self>)
where
T::Observer: Clone,
{
// This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = this.events.load(Ordering::Relaxed);
this.events.store(0, Ordering::Relaxed);
let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed);
let observer = this.observer.get().cloned();
drop(this);
let observer = self.observer.get().cloned();
drop(self);
// Notify dead events after the `Arc` is dropped to ensure the observer sees this event
// with the expected reference count. See `TcpConnection::connect_state` for an example.
@ -752,6 +772,16 @@ impl<E: Ext> TcpConnectionBg<E> {
pub(crate) fn is_dead(&self) -> bool {
self.inner.is_dead()
}
pub(crate) const fn connection_key(&self) -> &ConnectionKey {
&self.inner.connection_key
}
}
impl<E: Ext> TcpListenerBg<E> {
pub(crate) const fn listener_key(&self) -> &ListenerKey {
&self.inner.listener_key
}
}
impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
@ -780,12 +810,12 @@ pub(crate) enum TcpProcessResult {
impl<E: Ext> TcpConnectionBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
this: &KeyableArc<Self>,
self: &Arc<Self>,
cx: &mut Context,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> TcpProcessResult {
let mut socket = this.inner.lock();
let mut socket = self.inner.lock();
if !socket.accepts(cx, ip_repr, tcp_repr) {
return TcpProcessResult::NotProcessed;
@ -808,7 +838,7 @@ impl<E: Ext> TcpConnectionBg<E> {
&& tcp_repr.control == TcpControl::Syn
&& tcp_repr.ack_number.is_none()
{
this.inner.set_dead_timewait(&socket);
self.inner.set_dead_timewait(&socket);
return TcpProcessResult::NotProcessed;
}
@ -823,18 +853,18 @@ impl<E: Ext> TcpConnectionBg<E> {
};
if socket.state() != old_state {
events |= socket.on_new_state(this);
events |= socket.on_new_state(self);
}
this.add_events(events);
this.update_next_poll_at_ms(socket.poll_at(cx));
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
result
}
/// Tries to generate an outgoing packet and dispatches the generated packet.
pub(crate) fn dispatch<D>(
this: &KeyableArc<Self>,
this: &Arc<Self>,
cx: &mut Context,
dispatch: D,
) -> Option<(IpRepr, TcpRepr<'static>)>
@ -878,12 +908,12 @@ impl<E: Ext> TcpConnectionBg<E> {
impl<E: Ext> TcpListenerBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
this: &KeyableArc<Self>,
self: &Arc<Self>,
cx: &mut Context,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> (TcpProcessResult, Option<KeyableArc<TcpConnectionBg<E>>>) {
let mut backlog = this.inner.lock();
) -> (TcpProcessResult, Option<Arc<TcpConnectionBg<E>>>) {
let mut backlog = self.inner.backlog.lock();
if !backlog.socket.accepts(cx, ip_repr, tcp_repr) {
return (TcpProcessResult::NotProcessed, None);
@ -914,19 +944,19 @@ impl<E: Ext> TcpListenerBg<E> {
let inner = TcpConnectionInner::new(
core::mem::replace(&mut backlog.socket, new_socket),
Some(this.clone().into()),
Some(self.clone()),
);
let conn = TcpConnection::new(
this.bound
self.bound
.iface()
.bind(BindPortConfig::CanReuse(this.bound.port()))
.bind(BindPortConfig::CanReuse(self.bound.port()))
.unwrap(),
inner,
);
let conn_bg = conn.inner().clone();
let inserted = backlog.connecting.insert(conn);
assert!(inserted);
let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn);
debug_assert!(old_conn.is_none());
conn_bg.update_next_poll_at_ms(PollAt::Now);