diff --git a/kernel/libs/aster-bigtcp/Cargo.toml b/kernel/libs/aster-bigtcp/Cargo.toml index 06c5bc49a..90b1e03c8 100644 --- a/kernel/libs/aster-bigtcp/Cargo.toml +++ b/kernel/libs/aster-bigtcp/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] bitflags = "1.3" -keyable-arc = { path = "../keyable-arc" } +jhash = { path = "../jhash" } ostd = { path = "../../../ostd" } smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f07e5b5", default-features = false, features = [ "alloc", @@ -19,4 +19,5 @@ smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f0 "socket-tcp", ] } spin = "0.9.4" +static_assertions = "1.1.0" takeable = "0.2.2" diff --git a/kernel/libs/aster-bigtcp/src/errors.rs b/kernel/libs/aster-bigtcp/src/errors.rs index 5f058b1cc..8e221c26c 100644 --- a/kernel/libs/aster-bigtcp/src/errors.rs +++ b/kernel/libs/aster-bigtcp/src/errors.rs @@ -10,7 +10,41 @@ pub enum BindError { } pub mod tcp { - pub use smoltcp::socket::tcp::{ConnectError, ListenError, RecvError, SendError}; + pub use smoltcp::socket::tcp::{RecvError, SendError}; + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum ListenError { + InvalidState, + Unaddressable, + /// The specified address is in use. + AddressInUse, + } + + impl From for ListenError { + fn from(value: smoltcp::socket::tcp::ListenError) -> Self { + match value { + smoltcp::socket::tcp::ListenError::InvalidState => Self::InvalidState, + smoltcp::socket::tcp::ListenError::Unaddressable => Self::Unaddressable, + } + } + } + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum ConnectError { + InvalidState, + Unaddressable, + /// The specified address is in use. + AddressInUse, + } + + impl From for ConnectError { + fn from(value: smoltcp::socket::tcp::ConnectError) -> Self { + match value { + smoltcp::socket::tcp::ConnectError::InvalidState => Self::InvalidState, + smoltcp::socket::tcp::ConnectError::Unaddressable => Self::Unaddressable, + } + } + } } pub mod udp { diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index a11816be1..28a2d6644 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -1,16 +1,12 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::{ - collections::{ - btree_map::{BTreeMap, Entry}, - btree_set::BTreeSet, - }, + collections::btree_map::{BTreeMap, Entry}, string::String, sync::Arc, vec::Vec, }; -use keyable_arc::KeyableArc; use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard}; use smoltcp::{ iface::{packet::Packet, Context}, @@ -27,34 +23,25 @@ use super::{ use crate::{ errors::BindError, ext::Ext, - socket::{TcpConnectionBg, TcpListenerBg, UdpSocketBg}, + socket::{TcpListenerBg, UdpSocketBg}, + socket_table::SocketTable, }; pub struct IfaceCommon { name: String, interface: SpinLock, used_ports: SpinLock, LocalIrqDisabled>, - sockets: SpinLock, LocalIrqDisabled>, + sockets: SpinLock, LocalIrqDisabled>, sched_poll: E::ScheduleNextPoll, } -pub(super) struct SocketSet { - pub(super) tcp_conn: BTreeSet>>, - pub(super) tcp_listen: BTreeSet>>, - pub(super) udp: BTreeSet>>, -} - impl IfaceCommon { pub(super) fn new( name: String, interface: smoltcp::iface::Interface, sched_poll: E::ScheduleNextPoll, ) -> Self { - let sockets = SocketSet { - tcp_conn: BTreeSet::new(), - tcp_listen: BTreeSet::new(), - udp: BTreeSet::new(), - }; + let sockets = SocketTable::new(); Self { name, @@ -78,11 +65,17 @@ impl IfaceCommon { } } +// Lock order: interface -> sockets impl IfaceCommon { /// Acquires the lock to the interface. pub(crate) fn interface(&self) -> SpinLockGuard { self.interface.lock() } + + /// Acquires the lock to the socket table. + pub(crate) fn sockets(&self) -> SpinLockGuard<'_, SocketTable, LocalIrqDisabled> { + self.sockets.lock() + } } const IP_LOCAL_PORT_START: u16 = 32768; @@ -152,41 +145,21 @@ impl IfaceCommon { } impl IfaceCommon { - pub(crate) fn register_tcp_connection(&self, socket: KeyableArc>) { + pub(crate) fn register_udp_socket(&self, socket: Arc>) { let mut sockets = self.sockets.lock(); - let inserted = sockets.tcp_conn.insert(socket); - debug_assert!(inserted); + sockets.insert_udp_socket(socket); } - pub(crate) fn register_tcp_listener(&self, socket: KeyableArc>) { + pub(crate) fn remove_tcp_listener(&self, socket: &Arc>) { let mut sockets = self.sockets.lock(); - let inserted = sockets.tcp_listen.insert(socket); - debug_assert!(inserted); + let removed = sockets.remove_listener(socket); + debug_assert!(removed.is_some()); } - pub(crate) fn register_udp_socket(&self, socket: KeyableArc>) { + pub(crate) fn remove_udp_socket(&self, socket: &Arc>) { let mut sockets = self.sockets.lock(); - let inserted = sockets.udp.insert(socket); - debug_assert!(inserted); - } - - #[allow(clippy::mutable_key_type)] - fn remove_dead_tcp_connections(sockets: &mut BTreeSet>>) { - for socket in sockets.extract_if(|socket| socket.is_dead()) { - TcpConnectionBg::on_dead_events(socket); - } - } - - pub(crate) fn remove_tcp_listener(&self, socket: &KeyableArc>) { - let mut sockets = self.sockets.lock(); - let removed = sockets.tcp_listen.remove(socket); - debug_assert!(removed); - } - - pub(crate) fn remove_udp_socket(&self, socket: &KeyableArc>) { - let mut sockets = self.sockets.lock(); - let removed = sockets.udp.remove(socket); - debug_assert!(removed); + let removed = sockets.remove_udp_socket(socket); + debug_assert!(removed.is_some()); } } @@ -224,33 +197,37 @@ impl IfaceCommon { if new_tcp_conns.is_empty() { break; } else { - sockets.tcp_conn.extend(new_tcp_conns); + new_tcp_conns.into_iter().for_each(|tcp_conn| { + let res = sockets.insert_connection(tcp_conn); + debug_assert!(res.is_ok()); + }); } } - Self::remove_dead_tcp_connections(&mut sockets.tcp_conn); + sockets.remove_dead_tcp_connections(); - sockets.tcp_conn.iter().for_each(|socket| { + for socket in sockets.tcp_listener_iter() { if socket.has_events() { socket.on_events(); } - }); - sockets.tcp_listen.iter().for_each(|socket| { + } + + for socket in sockets.tcp_conn_iter() { if socket.has_events() { socket.on_events(); } - }); - sockets.udp.iter().for_each(|socket| { + } + + for socket in sockets.udp_socket_iter() { if socket.has_events() { socket.on_events(); } - }); + } // Note that only TCP connections can have timers set, so as far as the time to poll is // concerned, we only need to consider TCP connections. sockets - .tcp_conn - .iter() + .tcp_conn_iter() .map(|socket| socket.next_poll_at_ms()) .min() } diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs index 97f6b5072..9c7921a6f 100644 --- a/kernel/libs/aster-bigtcp/src/iface/poll.rs +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -1,8 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{vec, vec::Vec}; +use alloc::{sync::Arc, vec, vec::Vec}; -use keyable_arc::KeyableArc; use smoltcp::{ iface::{ packet::{icmp_reply_payload_len, IpPayload, Packet}, @@ -16,23 +15,23 @@ use smoltcp::{ }, }; -use super::common::SocketSet; use crate::{ ext::Ext, - socket::{TcpConnectionBg, TcpListenerBg, TcpProcessResult}, + socket::{TcpConnectionBg, TcpProcessResult}, + socket_table::{ConnectionKey, ListenerKey, SocketTable}, }; pub(super) struct PollContext<'a, E: Ext> { iface_cx: &'a mut Context, - sockets: &'a SocketSet, - new_tcp_conns: &'a mut Vec>>, + sockets: &'a SocketTable, + new_tcp_conns: &'a mut Vec>>, } impl<'a, E: Ext> PollContext<'a, E> { pub(super) fn new( iface_cx: &'a mut Context, - sockets: &'a SocketSet, - new_tcp_conns: &'a mut Vec>>, + sockets: &'a SocketTable, + new_tcp_conns: &'a mut Vec>>, ) -> Self { Self { iface_cx, @@ -158,40 +157,18 @@ impl PollContext<'_, E> { ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> Option<(IpRepr, TcpRepr<'static>)> { - for socket in self - .sockets - .tcp_conn - .iter() - .chain(self.new_tcp_conns.iter()) - { - if !socket.can_process(tcp_repr.dst_port) { - continue; - } - - match TcpConnectionBg::process(socket, self.iface_cx, ip_repr, tcp_repr) { - TcpProcessResult::NotProcessed => continue, - TcpProcessResult::Processed => return None, - TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { - return Some((ip_repr, tcp_repr)) - } - } - } - + // Process packets that request to create new connections first. if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { - for socket in self.sockets.tcp_listen.iter() { - if !socket.can_process(tcp_repr.dst_port) { - continue; - } - - let (processed, new_tcp_conn) = - TcpListenerBg::process(socket, self.iface_cx, ip_repr, tcp_repr); + let listener_key = ListenerKey::new(ip_repr.dst_addr(), tcp_repr.dst_port); + if let Some(listener) = self.sockets.lookup_listener(&listener_key) { + let (processed, new_tcp_conn) = listener.process(self.iface_cx, ip_repr, tcp_repr); if let Some(tcp_conn) = new_tcp_conn { self.new_tcp_conns.push(tcp_conn); } match processed { - TcpProcessResult::NotProcessed => continue, + TcpProcessResult::NotProcessed => {} TcpProcessResult::Processed => return None, TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { return Some((ip_repr, tcp_repr)) @@ -200,6 +177,31 @@ impl PollContext<'_, E> { } } + // Process packets belonging to existing connections second. + let connection_key = ConnectionKey::new( + ip_repr.dst_addr(), + tcp_repr.dst_port, + ip_repr.src_addr(), + tcp_repr.src_port, + ); + let connection = if let Some(connection) = self.sockets.lookup_connection(&connection_key) { + Some(connection) + } else { + self.new_tcp_conns + .iter() + .find(|tcp_conn| tcp_conn.connection_key() == &connection_key) + }; + + if let Some(connection) = connection { + match connection.process(self.iface_cx, ip_repr, tcp_repr) { + TcpProcessResult::NotProcessed => {} + TcpProcessResult::Processed => return None, + TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { + return Some((ip_repr, tcp_repr)) + } + } + } + // "In no case does receipt of a segment containing RST give rise to a RST in response." // See . if tcp_repr.control == TcpControl::Rst { @@ -239,7 +241,7 @@ impl PollContext<'_, E> { fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool { let mut processed = false; - for socket in self.sockets.udp.iter() { + for socket in self.sockets.udp_socket_iter() { if !socket.can_process(udp_repr.dst_port) { continue; } @@ -350,7 +352,7 @@ impl PollContext<'_, E> { // We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable // reference at this point. Instead, we will retry after the entire poll is complete. - for socket in self.sockets.tcp_conn.iter() { + for socket in self.sockets.tcp_conn_iter() { if !socket.need_dispatch(self.iface_cx.now()) { continue; } @@ -442,7 +444,7 @@ impl PollContext<'_, E> { let mut tx_token = Some(tx_token); let mut did_something = false; - for socket in self.sockets.udp.iter() { + for socket in self.sockets.udp_socket_iter() { if !socket.need_dispatch(self.iface_cx.now()) { continue; } diff --git a/kernel/libs/aster-bigtcp/src/lib.rs b/kernel/libs/aster-bigtcp/src/lib.rs index 1873e1ea2..67aef5bfa 100644 --- a/kernel/libs/aster-bigtcp/src/lib.rs +++ b/kernel/libs/aster-bigtcp/src/lib.rs @@ -12,13 +12,14 @@ #![no_std] #![deny(unsafe_code)] -#![feature(btree_extract_if)] +#![feature(extract_if)] pub mod device; pub mod errors; pub mod ext; pub mod iface; pub mod socket; +pub mod socket_table; pub mod time; pub mod wire; diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index f0b25ed2e..eeb2b88ba 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -1,13 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, collections::btree_set::BTreeSet, sync::Arc, vec::Vec}; +use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec}; use core::{ - borrow::Borrow, ops::{Deref, DerefMut}, sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, }; -use keyable_arc::KeyableArc; use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard}; use smoltcp::{ iface::Context, @@ -15,7 +13,7 @@ use smoltcp::{ time::{Duration, Instant}, wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, }; -use spin::Once; +use spin::once::Once; use takeable::Takeable; use super::{ @@ -25,40 +23,20 @@ use super::{ RawTcpSocket, RawUdpSocket, TcpStateCheck, }; use crate::{ + errors::tcp::{ConnectError, ListenError}, ext::Ext, iface::{BindPortConfig, BoundPort, Iface}, + socket_table::{ConnectionKey, ListenerKey}, }; -pub struct Socket, E: Ext>(Takeable>>); - -impl, E: Ext> PartialEq for Socket { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} -impl, E: Ext> Eq for Socket {} -impl, E: Ext> PartialOrd for Socket { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl, E: Ext> Ord for Socket { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.0.cmp(&other.0) - } -} -impl, E: Ext> Borrow>> for Socket { - fn borrow(&self) -> &KeyableArc> { - self.0.as_ref() - } -} +pub struct Socket, E: Ext>(Takeable>>); /// [`TcpConnectionInner`] or [`UdpSocketInner`]. pub trait Inner { type Observer: SocketEventObserver; /// Called by [`Socket::drop`]. - fn on_drop(this: &KeyableArc>) + fn on_drop(this: &Arc>) where E: Ext, Self: Sized; @@ -85,6 +63,7 @@ pub struct SocketBg, E: Ext> { pub struct TcpConnectionInner { socket: SpinLock, LocalIrqDisabled>, is_dead: AtomicBool, + connection_key: ConnectionKey, } struct RawTcpSocketExt { @@ -108,13 +87,13 @@ impl DerefMut for RawTcpSocketExt { } impl RawTcpSocketExt { - fn on_new_state(&mut self, this: &KeyableArc>) -> SocketEvents { + fn on_new_state(&mut self, this: &Arc>) -> SocketEvents { if self.may_send() && !self.has_connected { self.has_connected = true; if let Some(ref listener) = self.listener { - let mut backlog = listener.inner.lock(); - if let Some(value) = backlog.connecting.take(this) { + let mut backlog = listener.inner.backlog.lock(); + if let Some(value) = backlog.connecting.remove(this.connection_key()) { backlog.connected.push(value); } listener.add_events(SocketEvents::CAN_RECV); @@ -139,7 +118,7 @@ impl RawTcpSocketExt { /// This method must be called after handling network events. However, it is not necessary to /// call this method after handling non-closing user events, because the socket can never be /// dead if it is not closed. - fn update_dead(&self, this: &KeyableArc>) { + fn update_dead(&self, this: &Arc>) { if self.state() == smoltcp::socket::tcp::State::Closed { this.inner.is_dead.store(true, Ordering::Relaxed); } @@ -150,9 +129,9 @@ impl RawTcpSocketExt { this.inner.is_dead.store(true, Ordering::Relaxed); if let Some(ref listener) = self.listener { - let mut backlog = listener.inner.lock(); + let mut backlog = listener.inner.backlog.lock(); // This may fail due to race conditions, but it's fine. - let _ = backlog.connecting.remove(this); + let _ = backlog.connecting.remove(&this.inner.connection_key); } } } @@ -160,6 +139,13 @@ impl RawTcpSocketExt { impl TcpConnectionInner { fn new(socket: Box, listener: Option>>) -> Self { + let connection_key = { + // Since the socket is connected, the following unwrap can never fail + let local_endpoint = socket.local_endpoint().unwrap(); + let remote_endpoint = socket.remote_endpoint().unwrap(); + ConnectionKey::from((local_endpoint, remote_endpoint)) + }; + let socket_ext = RawTcpSocketExt { socket, listener, @@ -169,6 +155,7 @@ impl TcpConnectionInner { TcpConnectionInner { socket: SpinLock::new(socket_ext), is_dead: AtomicBool::new(false), + connection_key, } } @@ -197,7 +184,7 @@ impl TcpConnectionInner { impl Inner for TcpConnectionInner { type Observer = E::TcpEventObserver; - fn on_drop(this: &KeyableArc>) { + fn on_drop(this: &Arc>) { let mut socket = this.inner.lock(); // FIXME: Send RSTs when there is unread data. @@ -213,21 +200,33 @@ impl Inner for TcpConnectionInner { pub struct TcpBacklog { socket: Box, max_conn: usize, - connecting: BTreeSet>, + connecting: BTreeMap>, connected: Vec>, } -pub type TcpListenerInner = SpinLock, LocalIrqDisabled>; +pub struct TcpListenerInner { + backlog: SpinLock, LocalIrqDisabled>, + listener_key: ListenerKey, +} + +impl TcpListenerInner { + fn new(backlog: TcpBacklog, listener_key: ListenerKey) -> Self { + Self { + backlog: SpinLock::new(backlog), + listener_key, + } + } +} impl Inner for TcpListenerInner { type Observer = E::TcpEventObserver; - fn on_drop(this: &KeyableArc>) { + fn on_drop(this: &Arc>) { // A TCP listener can be removed immediately. this.bound.iface().common().remove_tcp_listener(this); let (connecting, connected) = { - let mut socket = this.inner.lock(); + let mut socket = this.inner.backlog.lock(); ( core::mem::take(&mut socket.connecting), core::mem::take(&mut socket.connected), @@ -249,7 +248,7 @@ type UdpSocketInner = SpinLock, LocalIrqDisabled>; impl Inner for UdpSocketInner { type Observer = E::UdpEventObserver; - fn on_drop(this: &KeyableArc>) { + fn on_drop(this: &Arc>) { this.inner.lock().close(); // A UDP socket can be removed immediately. @@ -271,7 +270,7 @@ pub(crate) type UdpSocketBg = SocketBg; impl, E: Ext> Socket { pub(crate) fn new(bound: BoundPort, inner: T) -> Self { - Self(Takeable::new(KeyableArc::new(SocketBg { + Self(Takeable::new(Arc::new(SocketBg { bound, inner, observer: Once::new(), @@ -280,7 +279,7 @@ impl, E: Ext> Socket { }))) } - pub(crate) fn inner(&self) -> &KeyableArc> { + pub(crate) fn inner(&self) -> &Arc> { &self.0 } } @@ -337,18 +336,30 @@ impl TcpConnection { remote_endpoint: IpEndpoint, option: &RawTcpOption, observer: E::TcpEventObserver, - ) -> Result, smoltcp::socket::tcp::ConnectError)> { + ) -> Result, ConnectError)> { + let Some(local_endpoint) = bound.endpoint() else { + return Err((bound, ConnectError::Unaddressable)); + }; + + let iface = bound.iface().clone(); + // We have to lock interface before locking interface + // to avoid dead lock due to inconsistent lock orders. + let mut interface = iface.common().interface(); + let mut sockets = iface.common().sockets(); + + let connection_key = ConnectionKey::from((local_endpoint, remote_endpoint)); + + if sockets.lookup_connection(&connection_key).is_some() { + return Err((bound, ConnectError::AddressInUse)); + } + let socket = { let mut socket = new_tcp_socket(); option.apply(&mut socket); - let common = bound.iface().common(); - let mut iface = common.interface(); - - if let Err(err) = socket.connect(iface.context(), remote_endpoint, bound.port()) { - drop(iface); - return Err((bound, err)); + if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) { + return Err((bound, err.into())); } socket @@ -359,10 +370,8 @@ impl TcpConnection { let connection = Self::new(bound, inner); connection.0.update_next_poll_at_ms(PollAt::Now); connection.init_observer(observer); - connection - .iface() - .common() - .register_tcp_connection(connection.inner().clone()); + let res = sockets.insert_connection(connection.inner().clone()); + debug_assert!(res.is_ok()); Ok(connection) } @@ -375,7 +384,7 @@ impl TcpConnection { ConnectState::Connecting } else if socket.has_connected { ConnectState::Connected - } else if KeyableArc::strong_count(self.0.as_ref()) > 1 { + } else if Arc::strong_count(self.0.as_ref()) > 1 { // Now we should return `ConnectState::Refused`. However, when we do this, we must // guarantee that `into_bound_port` can succeed (see the method's doc comments). We can // only guarantee this after we have removed all `Arc` in the iface's @@ -396,7 +405,7 @@ impl TcpConnection { /// this connection. We guarantee that this method will always succeed if /// [`Self::connect_state`] returns [`ConnectState::Refused`]. pub fn into_bound_port(mut self) -> Option> { - let this: TcpConnectionBg = Arc::into_inner(self.0.take().into())?; + let this: TcpConnectionBg = Arc::into_inner(self.0.take())?; Some(this.bound) } @@ -492,36 +501,47 @@ impl TcpListener { max_conn: usize, option: &RawTcpOption, observer: E::TcpEventObserver, - ) -> Result, smoltcp::socket::tcp::ListenError)> { + ) -> Result, ListenError)> { let Some(local_endpoint) = bound.endpoint() else { - return Err((bound, smoltcp::socket::tcp::ListenError::Unaddressable)); + return Err((bound, ListenError::Unaddressable)); }; + let iface = bound.iface().clone(); + let mut sockets = iface.common().sockets(); + + let listener_key = ListenerKey::new(local_endpoint.addr, local_endpoint.port); + + if sockets.lookup_listener(&listener_key).is_some() { + return Err((bound, ListenError::AddressInUse)); + } + let socket = { let mut socket = new_tcp_socket(); option.apply(&mut socket); if let Err(err) = socket.listen(local_endpoint) { - return Err((bound, err)); + return Err((bound, err.into())); } socket }; - let inner = TcpListenerInner::new(TcpBacklog { - socket, - max_conn, - connecting: BTreeSet::new(), - connected: Vec::new(), - }); + let inner = { + let backlog = TcpBacklog { + socket, + max_conn, + connecting: BTreeMap::new(), + connected: Vec::new(), + }; + + TcpListenerInner::new(backlog, listener_key) + }; let listener = Self::new(bound, inner); listener.init_observer(observer); - listener - .iface() - .common() - .register_tcp_listener(listener.inner().clone()); + let res = sockets.insert_listener(listener.inner().clone()); + debug_assert!(res.is_ok()); Ok(listener) } @@ -531,7 +551,7 @@ impl TcpListener { /// Polling the iface is _not_ required after this method succeeds. pub fn accept(&self) -> Option<(TcpConnection, IpEndpoint)> { let accepted = { - let mut backlog = self.0.inner.lock(); + let mut backlog = self.0.inner.backlog.lock(); backlog.connected.pop()? }; @@ -551,20 +571,20 @@ impl TcpListener { /// /// It's the caller's responsibility to deal with race conditions when using this method. pub fn can_accept(&self) -> bool { - !self.0.inner.lock().connected.is_empty() + !self.0.inner.backlog.lock().connected.is_empty() } } impl RawTcpSetOption for TcpListener { fn set_keep_alive(&self, interval: Option) -> NeedIfacePoll { - let mut backlog = self.0.inner.lock(); + let mut backlog = self.0.inner.backlog.lock(); backlog.socket.set_keep_alive(interval); NeedIfacePoll::FALSE } fn set_nagle_enabled(&self, enabled: bool) { - let mut backlog = self.0.inner.lock(); + let mut backlog = self.0.inner.backlog.lock(); backlog.socket.set_nagle_enabled(enabled); } } @@ -680,17 +700,17 @@ impl, E: Ext> SocketBg { } } - pub(crate) fn on_dead_events(this: KeyableArc) + pub(crate) fn on_dead_events(self: Arc) where T::Observer: Clone, { // This method can only be called to process network events, so we assume we are holding the // poll lock and no race conditions can occur. - let events = this.events.load(Ordering::Relaxed); - this.events.store(0, Ordering::Relaxed); + let events = self.events.load(Ordering::Relaxed); + self.events.store(0, Ordering::Relaxed); - let observer = this.observer.get().cloned(); - drop(this); + let observer = self.observer.get().cloned(); + drop(self); // Notify dead events after the `Arc` is dropped to ensure the observer sees this event // with the expected reference count. See `TcpConnection::connect_state` for an example. @@ -752,6 +772,16 @@ impl TcpConnectionBg { pub(crate) fn is_dead(&self) -> bool { self.inner.is_dead() } + + pub(crate) const fn connection_key(&self) -> &ConnectionKey { + &self.inner.connection_key + } +} + +impl TcpListenerBg { + pub(crate) const fn listener_key(&self) -> &ListenerKey { + &self.inner.listener_key + } } impl, E: Ext> SocketBg { @@ -780,12 +810,12 @@ pub(crate) enum TcpProcessResult { impl TcpConnectionBg { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( - this: &KeyableArc, + self: &Arc, cx: &mut Context, ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> TcpProcessResult { - let mut socket = this.inner.lock(); + let mut socket = self.inner.lock(); if !socket.accepts(cx, ip_repr, tcp_repr) { return TcpProcessResult::NotProcessed; @@ -808,7 +838,7 @@ impl TcpConnectionBg { && tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { - this.inner.set_dead_timewait(&socket); + self.inner.set_dead_timewait(&socket); return TcpProcessResult::NotProcessed; } @@ -823,18 +853,18 @@ impl TcpConnectionBg { }; if socket.state() != old_state { - events |= socket.on_new_state(this); + events |= socket.on_new_state(self); } - this.add_events(events); - this.update_next_poll_at_ms(socket.poll_at(cx)); + self.add_events(events); + self.update_next_poll_at_ms(socket.poll_at(cx)); result } /// Tries to generate an outgoing packet and dispatches the generated packet. pub(crate) fn dispatch( - this: &KeyableArc, + this: &Arc, cx: &mut Context, dispatch: D, ) -> Option<(IpRepr, TcpRepr<'static>)> @@ -878,12 +908,12 @@ impl TcpConnectionBg { impl TcpListenerBg { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( - this: &KeyableArc, + self: &Arc, cx: &mut Context, ip_repr: &IpRepr, tcp_repr: &TcpRepr, - ) -> (TcpProcessResult, Option>>) { - let mut backlog = this.inner.lock(); + ) -> (TcpProcessResult, Option>>) { + let mut backlog = self.inner.backlog.lock(); if !backlog.socket.accepts(cx, ip_repr, tcp_repr) { return (TcpProcessResult::NotProcessed, None); @@ -914,19 +944,19 @@ impl TcpListenerBg { let inner = TcpConnectionInner::new( core::mem::replace(&mut backlog.socket, new_socket), - Some(this.clone().into()), + Some(self.clone()), ); let conn = TcpConnection::new( - this.bound + self.bound .iface() - .bind(BindPortConfig::CanReuse(this.bound.port())) + .bind(BindPortConfig::CanReuse(self.bound.port())) .unwrap(), inner, ); let conn_bg = conn.inner().clone(); - let inserted = backlog.connecting.insert(conn); - assert!(inserted); + let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn); + debug_assert!(old_conn.is_none()); conn_bg.update_next_poll_at_ms(PollAt::Now); diff --git a/kernel/libs/aster-bigtcp/src/socket_table.rs b/kernel/libs/aster-bigtcp/src/socket_table.rs new file mode 100644 index 000000000..cc7f615e4 --- /dev/null +++ b/kernel/libs/aster-bigtcp/src/socket_table.rs @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module defines the socket table, which manages all TCP and UDP sockets, +//! for efficiently inserting, looking up, and removing sockets. + +use alloc::{boxed::Box, sync::Arc, vec::Vec}; +use core::net::Ipv4Addr; + +use jhash::{jhash_1vals, jhash_3vals}; +use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint}; +use static_assertions::const_assert; + +use crate::{ + ext::Ext, + socket::{TcpConnectionBg, TcpListenerBg, UdpSocketBg}, + wire::PortNum, +}; + +pub type SocketHash = u32; + +/// A unique key for identifying a `TcpListener`. +/// +/// Note that two `TcpListener`s cannot listen on the same address +/// even if both sockets set SO_REUSEADDR to true, +/// so there cannot be multiple listeners with the same `ListenerKey`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct ListenerKey { + addr: IpAddress, + port: PortNum, + hash: SocketHash, +} + +impl ListenerKey { + pub(crate) const fn new(addr: IpAddress, port: PortNum) -> Self { + // FIXME: If the socket is listening on an unspecified address (0.0.0.0), + // Linux will get the hash value by port only. + let hash = hash_addr_port(addr, port); + Self { addr, port, hash } + } + + pub(crate) const fn hash(&self) -> SocketHash { + self.hash + } +} + +impl From for ListenerKey { + fn from(listen_endpoint: IpListenEndpoint) -> Self { + let addr = listen_endpoint + .addr + .unwrap_or(IpAddress::Ipv4(Ipv4Addr::UNSPECIFIED)); + let port = listen_endpoint.port; + Self::new(addr, port) + } +} + +/// A unique key for identifying a `TcpConnection`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct ConnectionKey { + local_addr: IpAddress, + local_port: PortNum, + remote_addr: IpAddress, + remote_port: PortNum, + hash: SocketHash, +} + +impl ConnectionKey { + pub(crate) const fn new( + local_addr: IpAddress, + local_port: PortNum, + remote_addr: IpAddress, + remote_port: PortNum, + ) -> Self { + let hash = hash_local_remote(local_addr, local_port, remote_addr, remote_port); + Self { + local_addr, + local_port, + remote_addr, + remote_port, + hash, + } + } + + pub(crate) const fn hash(&self) -> SocketHash { + self.hash + } +} + +impl From<(IpEndpoint, IpEndpoint)> for ConnectionKey { + fn from(value: (IpEndpoint, IpEndpoint)) -> Self { + Self::new(value.0.addr, value.0.port, value.1.addr, value.1.port) + } +} + +// FIXME: The following two constants should be randomly-generated at runtime +const HASH_SECRET: u32 = 0xdeadbeef; + +// FIXME: This constant should be a per-net-namespace value +const NET_HASHMIX: u32 = 0xbeefdead; + +const fn hash_local_remote( + local_addr: IpAddress, + local_port: PortNum, + remote_addr: IpAddress, + remote_port: PortNum, +) -> SocketHash { + // FIXME: Deal with IPv6 addresses once IPv6 is supported. + let IpAddress::Ipv4(local_ipv4) = local_addr; + let IpAddress::Ipv4(remote_ipv4) = remote_addr; + + jhash_3vals( + local_ipv4.to_bits(), + remote_ipv4.to_bits(), + (local_port as u32).wrapping_shl(16) | remote_port as u32, + HASH_SECRET.wrapping_add(NET_HASHMIX), + ) +} + +const fn hash_addr_port(addr: IpAddress, port: PortNum) -> SocketHash { + // FIXME: Deal with IPv6 addresses once IPv6 is supported. + let IpAddress::Ipv4(ipv4_addr) = addr; + + jhash_1vals(ipv4_addr.to_bits(), NET_HASHMIX) ^ (port as u32) +} + +/// The socket table manages TCP and UDP sockets. +/// +/// Unlike the Linux inet hashtable, which is shared across a single network namespace, +/// this table is currently limited to a single interface. +/// +// TODO: Modify the table to be shared across a single network namespace +// to support INADDR_ANY (0.0.0.0). +pub(crate) struct SocketTable { + // TODO: Linux has two hashtables for listeners: + // the first is hashed by local address and port, + // the second is hashed by local port only. + // The second table is the only place where sockets listening on INADDR_ANY (0.0.0.0) can exist. + // Since we do not yet support INADDR_ANY, we only have the first table here. + listener_buckets: Box<[ListenerHashBucket]>, + connection_buckets: Box<[ConnectionHashBucket]>, + // Linux does not include UDP sockets in the inet hashtable. + // Here we include UDP sockets in the socket table for simplicity. + // Note that multiple UDP sockets can be bound to the same address, + // so we cannot use (addr, port) as a _unique_ key for UDP sockets. + udp_sockets: Vec>>, +} + +// On Linux, the number of buckets is determined at runtime based on the available memory. +// For simplicity, we use fixed values here. +// The bucket count should be a power of 2 to ensure efficient modulo calculations. +const LISTENER_BUCKET_COUNT: u32 = 64; +const LISTENER_BUCKET_MASK: u32 = LISTENER_BUCKET_COUNT - 1; +const CONNECTION_BUCKET_COUNT: u32 = 8192; +const CONNECTION_BUCKET_MASK: u32 = CONNECTION_BUCKET_COUNT - 1; + +const_assert!(LISTENER_BUCKET_COUNT.is_power_of_two()); +const_assert!(CONNECTION_BUCKET_COUNT.is_power_of_two()); + +impl SocketTable { + pub(crate) fn new() -> Self { + let listener_buckets = (0..LISTENER_BUCKET_COUNT) + .map(|_| ListenerHashBucket::new()) + .collect(); + + let connection_buckets = (0..CONNECTION_BUCKET_COUNT) + .map(|_| ConnectionHashBucket::new()) + .collect(); + + let udp_sockets = Vec::new(); + + Self { + listener_buckets, + connection_buckets, + udp_sockets, + } + } + + /// Inserts a TCP listener into the table. + /// + /// If a socket with the same [`ListenerKey`] has already been inserted, + /// this method will return an error and the listener will not be inserted. + pub(crate) fn insert_listener( + &mut self, + listener: Arc>, + ) -> Result<(), Arc>> { + let key = listener.listener_key(); + + let bucket = { + let hash = key.hash(); + let bucket_index = hash & LISTENER_BUCKET_MASK; + &mut self.listener_buckets[bucket_index as usize] + }; + + if bucket + .listeners + .iter() + .any(|tcp_listener| tcp_listener.listener_key() == listener.listener_key()) + { + return Err(listener); + } + + bucket.listeners.push(listener); + Ok(()) + } + + pub(crate) fn insert_connection( + &mut self, + connection: Arc>, + ) -> Result<(), Arc>> { + let key = connection.connection_key(); + + let bucket = { + let hash = key.hash(); + let bucket_index = hash & CONNECTION_BUCKET_MASK; + &mut self.connection_buckets[bucket_index as usize] + }; + + if bucket + .connections + .iter() + .any(|tcp_connection| tcp_connection.connection_key() == connection.connection_key()) + { + return Err(connection); + } + + bucket.connections.push(connection); + Ok(()) + } + + pub(crate) fn insert_udp_socket(&mut self, udp_socket: Arc>) { + debug_assert!(!self + .udp_sockets + .iter() + .any(|socket| Arc::ptr_eq(socket, &udp_socket))); + self.udp_sockets.push(udp_socket); + } + + pub(crate) fn lookup_listener(&self, key: &ListenerKey) -> Option<&Arc>> { + let bucket = { + let hash = key.hash(); + let bucket_index = hash & LISTENER_BUCKET_MASK; + &self.listener_buckets[bucket_index as usize] + }; + + bucket + .listeners + .iter() + .find(|listener| listener.listener_key() == key) + } + + pub(crate) fn lookup_connection( + &self, + key: &ConnectionKey, + ) -> Option<&Arc>> { + let bucket = { + let hash = key.hash(); + let bucket_index = hash & CONNECTION_BUCKET_MASK; + &self.connection_buckets[bucket_index as usize] + }; + + bucket + .connections + .iter() + .find(|connection| connection.connection_key() == key) + } + + pub(crate) fn remove_listener( + &mut self, + listener: &TcpListenerBg, + ) -> Option>> { + let key = listener.listener_key(); + + let bucket = { + let hash = key.hash(); + let bucket_index = hash & LISTENER_BUCKET_MASK; + &mut self.listener_buckets[bucket_index as usize] + }; + + let index = bucket + .listeners + .iter() + .position(|tcp_listener| tcp_listener.listener_key() == listener.listener_key())?; + Some(bucket.listeners.swap_remove(index)) + } + + pub(crate) fn remove_udp_socket( + &mut self, + socket: &Arc>, + ) -> Option>> { + let index = self + .udp_sockets + .iter() + .position(|udp_socket| Arc::ptr_eq(udp_socket, socket))?; + Some(self.udp_sockets.swap_remove(index)) + } + + pub(crate) fn remove_dead_tcp_connections(&mut self) { + for connection_bucket in self.connection_buckets.iter_mut() { + for tcp_conn in connection_bucket + .connections + .extract_if(|connection| connection.is_dead()) + { + tcp_conn.on_dead_events(); + } + } + } + + pub(crate) fn tcp_listener_iter(&self) -> impl Iterator>> { + self.listener_buckets + .iter() + .flat_map(|bucket| bucket.listeners.iter()) + } + + pub(crate) fn tcp_conn_iter(&self) -> impl Iterator>> { + self.connection_buckets + .iter() + .flat_map(|bucket| bucket.connections.iter()) + } + + pub(crate) fn udp_socket_iter(&self) -> impl Iterator>> { + self.udp_sockets.iter() + } +} + +impl Default for SocketTable { + fn default() -> Self { + Self::new() + } +} + +struct ListenerHashBucket { + listeners: Vec>>, +} + +impl ListenerHashBucket { + const fn new() -> Self { + Self { + listeners: Vec::new(), + } + } +} + +struct ConnectionHashBucket { + connections: Vec>>, +} + +impl ConnectionHashBucket { + const fn new() -> Self { + Self { + connections: Vec::new(), + } + } +} diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index 001b8ff36..1efc9eaeb 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::{ + errors::tcp::ConnectError, socket::{ConnectState, RawTcpOption, RawTcpSetOption}, wire::IpEndpoint, }; @@ -30,22 +31,30 @@ impl ConnectingStream { option: &RawTcpOption, observer: StreamObserver, ) -> core::result::Result { - // The only reason this method might fail is because we're trying to connect to an - // unspecified address (i.e. 0.0.0.0). We currently have no support for binding to, - // listening on, or connecting to the unspecified address. - // - // We assume the remote will just refuse to connect, so we return `ECONNREFUSED`. let tcp_conn = match TcpConnection::new_connect(bound_port, remote_endpoint, option, observer) { Ok(tcp_conn) => tcp_conn, + Err((bound_port, ConnectError::AddressInUse)) => { + return Err(( + Error::with_message(Errno::EADDRNOTAVAIL, "connection key conflicts"), + bound_port, + )) + } Err((bound_port, _)) => { + // The only reason this method might go to this branch is because + // we're trying to connect to an unspecified address (i.e. 0.0.0.0). + // We currently have no support for binding to, + // listening on, or connecting to the unspecified address. + // + // We assume the remote will just refuse to connect, + // so we return `ECONNREFUSED`. return Err(( Error::with_message( Errno::ECONNREFUSED, "connecting to an unspecified address is not supported", ), bound_port, - )) + )); } }; diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index f39f9964a..6c37b3d28 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -88,7 +88,10 @@ impl InitStream { )); }; - Ok(ListenStream::new(bound_port, backlog, option, observer)) + match ListenStream::new(bound_port, backlog, option, observer) { + Ok(listen_stream) => Ok(listen_stream), + Err((bound_port, error)) => Err((error, Self::Bound(bound_port))), + } } pub fn local_endpoint(&self) -> Option { diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 74c9256ff..686ee3027 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::{ + errors::tcp::ListenError, socket::{RawTcpOption, RawTcpSetOption}, wire::IpEndpoint, }; @@ -22,18 +23,20 @@ impl ListenStream { backlog: usize, option: &RawTcpOption, observer: StreamObserver, - ) -> Self { + ) -> core::result::Result { const SOMAXCONN: usize = 4096; let max_conn = SOMAXCONN.min(backlog); - let tcp_listener = match TcpListener::new_listen(bound_port, max_conn, option, observer) { - Ok(tcp_listener) => tcp_listener, + match TcpListener::new_listen(bound_port, max_conn, option, observer) { + Ok(tcp_listener) => Ok(Self { tcp_listener }), + Err((bound_port, ListenError::AddressInUse)) => Err(( + bound_port, + Error::with_message(Errno::EADDRINUSE, "listener key conflicts"), + )), Err((_, err)) => { unreachable!("`new_listen` fails with {:?}, which should not happen", err) } - }; - - Self { tcp_listener } + } } pub fn try_accept(&self) -> Result { diff --git a/test/apps/network/tcp_err.c b/test/apps/network/tcp_err.c index 6e6738e5e..8c42c0e0e 100644 --- a/test/apps/network/tcp_err.c +++ b/test/apps/network/tcp_err.c @@ -478,3 +478,73 @@ FN_TEST(self_connect) TEST_SUCC(close(sk)); } END_TEST() + +FN_TEST(listen_at_the_same_address) +{ + int sk_listen1; + int sk_listen2; + + sk_listen1 = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + sk_listen2 = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + + int reuse_option = 1; + TEST_SUCC(setsockopt(sk_listen1, SOL_SOCKET, SO_REUSEADDR, + &reuse_option, sizeof(reuse_option))); + TEST_SUCC(setsockopt(sk_listen2, SOL_SOCKET, SO_REUSEADDR, + &reuse_option, sizeof(reuse_option))); + + sk_addr.sin_port = htons(8889); + TEST_SUCC( + bind(sk_listen1, (struct sockaddr *)&sk_addr, sizeof(sk_addr))); + TEST_SUCC( + bind(sk_listen2, (struct sockaddr *)&sk_addr, sizeof(sk_addr))); + + TEST_SUCC(listen(sk_listen1, 3)); + TEST_ERRNO(listen(sk_listen2, 3), EADDRINUSE); + + TEST_SUCC(close(sk_listen1)); + TEST_SUCC(close(sk_listen2)); +} +END_TEST() + +FN_TEST(bind_and_connect_same_address) +{ + int sk_listen; + int sk_connect1; + int sk_connect2; + + sk_listen = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + sk_connect1 = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + sk_connect2 = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + + int reuse_option = 1; + TEST_SUCC(setsockopt(sk_connect1, SOL_SOCKET, SO_REUSEADDR, + &reuse_option, sizeof(reuse_option))); + TEST_SUCC(setsockopt(sk_connect2, SOL_SOCKET, SO_REUSEADDR, + &reuse_option, sizeof(reuse_option))); + + int listen_port = 8890; + int connect_port = 8891; + sk_addr.sin_port = htons(listen_port); + TEST_SUCC( + bind(sk_listen, (struct sockaddr *)&sk_addr, sizeof(sk_addr))); + sk_addr.sin_port = htons(connect_port); + TEST_SUCC(bind(sk_connect1, (struct sockaddr *)&sk_addr, + sizeof(sk_addr))); + TEST_SUCC(bind(sk_connect2, (struct sockaddr *)&sk_addr, + sizeof(sk_addr))); + + TEST_SUCC(listen(sk_listen, 3)); + + sk_addr.sin_port = htons(listen_port); + TEST_SUCC(connect(sk_connect1, (struct sockaddr *)&sk_addr, + sizeof(sk_addr))); + TEST_ERRNO(connect(sk_connect2, (struct sockaddr *)&sk_addr, + sizeof(sk_addr)), + EADDRNOTAVAIL); + + TEST_SUCC(close(sk_listen)); + TEST_SUCC(close(sk_connect1)); + TEST_SUCC(close(sk_connect2)); +} +END_TEST() \ No newline at end of file