diff --git a/Cargo.lock b/Cargo.lock index 9fc2da63e..e308e82a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,6 +74,8 @@ dependencies = [ "keyable-arc", "ostd", "smoltcp", + "spin 0.9.8", + "takeable", ] [[package]] diff --git a/kernel/libs/aster-bigtcp/Cargo.toml b/kernel/libs/aster-bigtcp/Cargo.toml index dddbb5794..06c5bc49a 100644 --- a/kernel/libs/aster-bigtcp/Cargo.toml +++ b/kernel/libs/aster-bigtcp/Cargo.toml @@ -18,3 +18,5 @@ smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f0 "socket-udp", "socket-tcp", ] } +spin = "0.9.4" +takeable = "0.2.2" diff --git a/kernel/libs/aster-bigtcp/src/ext.rs b/kernel/libs/aster-bigtcp/src/ext.rs index ec39f68e9..13ca32108 100644 --- a/kernel/libs/aster-bigtcp/src/ext.rs +++ b/kernel/libs/aster-bigtcp/src/ext.rs @@ -15,7 +15,7 @@ pub trait Ext { type ScheduleNextPoll: ScheduleNextPoll; /// The type for TCP sockets to observe events. - type TcpEventObserver: SocketEventObserver; + type TcpEventObserver: SocketEventObserver + Clone; /// The type for UDP sockets to observe events. type UdpEventObserver: SocketEventObserver; diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index 3de05bbde..a11816be1 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::{ - boxed::Box, collections::{ btree_map::{BTreeMap, Entry}, btree_set::BTreeSet, }, string::String, sync::Arc, + vec::Vec, }; use keyable_arc::KeyableArc; @@ -15,7 +15,7 @@ use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard}; use smoltcp::{ iface::{packet::Packet, Context}, phy::Device, - wire::{Ipv4Address, Ipv4Packet}, + wire::{IpAddress, IpEndpoint, Ipv4Address, Ipv4Packet}, }; use super::{ @@ -27,33 +27,40 @@ use super::{ use crate::{ errors::BindError, ext::Ext, - socket::{ - BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket, - UnboundUdpSocket, - }, + socket::{TcpConnectionBg, TcpListenerBg, UdpSocketBg}, }; pub struct IfaceCommon { name: String, interface: SpinLock, used_ports: SpinLock, LocalIrqDisabled>, - tcp_sockets: SpinLock>>, LocalIrqDisabled>, - udp_sockets: SpinLock>>, LocalIrqDisabled>, + sockets: SpinLock, LocalIrqDisabled>, sched_poll: E::ScheduleNextPoll, } +pub(super) struct SocketSet { + pub(super) tcp_conn: BTreeSet>>, + pub(super) tcp_listen: BTreeSet>>, + pub(super) udp: BTreeSet>>, +} + impl IfaceCommon { pub(super) fn new( name: String, interface: smoltcp::iface::Interface, sched_poll: E::ScheduleNextPoll, ) -> Self { + let sockets = SocketSet { + tcp_conn: BTreeSet::new(), + tcp_listen: BTreeSet::new(), + udp: BTreeSet::new(), + }; + Self { name, interface: SpinLock::new(interface), used_ports: SpinLock::new(BTreeMap::new()), - tcp_sockets: SpinLock::new(BTreeSet::new()), - udp_sockets: SpinLock::new(BTreeSet::new()), + sockets: SpinLock::new(sockets), sched_poll, } } @@ -82,52 +89,13 @@ const IP_LOCAL_PORT_START: u16 = 32768; const IP_LOCAL_PORT_END: u16 = 60999; impl IfaceCommon { - pub(super) fn bind_tcp( + pub(super) fn bind( &self, iface: Arc>, - socket: Box, - observer: E::TcpEventObserver, config: BindPortConfig, - ) -> core::result::Result, (BindError, Box)> { - let port = match self.bind_port(config) { - Ok(port) => port, - Err(err) => return Err((err, socket)), - }; - - let raw_socket = socket.into_raw(); - let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer); - - let inserted = self - .tcp_sockets - .lock() - .insert(KeyableArc::from(bound_socket.inner().clone())); - assert!(inserted); - - Ok(bound_socket) - } - - pub(super) fn bind_udp( - &self, - iface: Arc>, - socket: Box, - observer: E::UdpEventObserver, - config: BindPortConfig, - ) -> core::result::Result, (BindError, Box)> { - let port = match self.bind_port(config) { - Ok(port) => port, - Err(err) => return Err((err, socket)), - }; - - let raw_socket = socket.into_raw(); - let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer); - - let inserted = self - .udp_sockets - .lock() - .insert(KeyableArc::from(bound_socket.inner().clone())); - assert!(inserted); - - Ok(bound_socket) + ) -> core::result::Result, BindError> { + let port = self.bind_port(config)?; + Ok(BoundPort { iface, port }) } /// Allocates an unused ephemeral port. @@ -171,29 +139,6 @@ impl IfaceCommon { Ok(port) } -} - -impl IfaceCommon { - #[allow(clippy::mutable_key_type)] - fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet>>) { - sockets.retain(|socket| { - if socket.is_dead() { - self.release_port(socket.port()); - false - } else { - true - } - }); - } - - pub(crate) fn remove_udp_socket(&self, socket: &Arc>) { - let keyable_socket = KeyableArc::from(socket.clone()); - - let removed = self.udp_sockets.lock().remove(&keyable_socket); - assert!(removed); - - self.release_port(keyable_socket.port()); - } /// Releases the port so that it can be used again (if it is not being reused). fn release_port(&self, port: u16) { @@ -206,11 +151,50 @@ impl IfaceCommon { } } +impl IfaceCommon { + pub(crate) fn register_tcp_connection(&self, socket: KeyableArc>) { + let mut sockets = self.sockets.lock(); + let inserted = sockets.tcp_conn.insert(socket); + debug_assert!(inserted); + } + + pub(crate) fn register_tcp_listener(&self, socket: KeyableArc>) { + let mut sockets = self.sockets.lock(); + let inserted = sockets.tcp_listen.insert(socket); + debug_assert!(inserted); + } + + pub(crate) fn register_udp_socket(&self, socket: KeyableArc>) { + let mut sockets = self.sockets.lock(); + let inserted = sockets.udp.insert(socket); + debug_assert!(inserted); + } + + #[allow(clippy::mutable_key_type)] + fn remove_dead_tcp_connections(sockets: &mut BTreeSet>>) { + for socket in sockets.extract_if(|socket| socket.is_dead()) { + TcpConnectionBg::on_dead_events(socket); + } + } + + pub(crate) fn remove_tcp_listener(&self, socket: &KeyableArc>) { + let mut sockets = self.sockets.lock(); + let removed = sockets.tcp_listen.remove(socket); + debug_assert!(removed); + } + + pub(crate) fn remove_udp_socket(&self, socket: &KeyableArc>) { + let mut sockets = self.sockets.lock(); + let removed = sockets.udp.remove(socket); + debug_assert!(removed); + } +} + impl IfaceCommon { pub(super) fn poll( &self, device: &mut D, - process_phy: P, + mut process_phy: P, mut dispatch_phy: Q, ) -> Option where @@ -226,41 +210,85 @@ impl IfaceCommon { let mut interface = self.interface(); interface.context().now = get_network_timestamp(); - let mut tcp_sockets = self.tcp_sockets.lock(); - let udp_sockets = self.udp_sockets.lock(); + let mut sockets = self.sockets.lock(); - let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets); - context.poll_ingress(device, process_phy, &mut dispatch_phy); - context.poll_egress(device, dispatch_phy); + loop { + let mut new_tcp_conns = Vec::new(); - tcp_sockets.iter().for_each(|socket| { - if socket.has_events() { - socket.on_events(); + let mut context = PollContext::new(interface.context(), &sockets, &mut new_tcp_conns); + context.poll_ingress(device, &mut process_phy, &mut dispatch_phy); + context.poll_egress(device, &mut dispatch_phy); + + // New packets sent by new connections are not handled. So if there are new + // connections, try again. + if new_tcp_conns.is_empty() { + break; + } else { + sockets.tcp_conn.extend(new_tcp_conns); } - }); - udp_sockets.iter().for_each(|socket| { - if socket.has_events() { - socket.on_events(); - } - }); - - self.remove_dead_tcp_sockets(&mut tcp_sockets); - - match ( - tcp_sockets - .iter() - .map(|socket| socket.next_poll_at_ms()) - .min(), - udp_sockets - .iter() - .map(|socket| socket.next_poll_at_ms()) - .min(), - ) { - (Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => { - Some(tcp_poll_at) - } - (tcp_poll_at, None) => tcp_poll_at, - (_, udp_poll_at) => udp_poll_at, } + + Self::remove_dead_tcp_connections(&mut sockets.tcp_conn); + + sockets.tcp_conn.iter().for_each(|socket| { + if socket.has_events() { + socket.on_events(); + } + }); + sockets.tcp_listen.iter().for_each(|socket| { + if socket.has_events() { + socket.on_events(); + } + }); + sockets.udp.iter().for_each(|socket| { + if socket.has_events() { + socket.on_events(); + } + }); + + // Note that only TCP connections can have timers set, so as far as the time to poll is + // concerned, we only need to consider TCP connections. + sockets + .tcp_conn + .iter() + .map(|socket| socket.next_poll_at_ms()) + .min() + } +} + +/// A port bound to an iface. +/// +/// When dropped, the port is automatically released. +// +// FIXME: TCP and UDP ports are independent. Find a way to track the protocol here. +pub struct BoundPort { + iface: Arc>, + port: u16, +} + +impl BoundPort { + /// Returns a reference to the iface. + pub fn iface(&self) -> &Arc> { + &self.iface + } + + /// Returns the port number. + pub fn port(&self) -> u16 { + self.port + } + + /// Returns the bound endpoint. + pub fn endpoint(&self) -> Option { + let ip_addr = { + let ipv4_addr = self.iface().ipv4_addr()?; + IpAddress::Ipv4(ipv4_addr) + }; + Some(IpEndpoint::new(ip_addr, self.port)) + } +} + +impl Drop for BoundPort { + fn drop(&mut self) { + self.iface.common().release_port(self.port); } } diff --git a/kernel/libs/aster-bigtcp/src/iface/iface.rs b/kernel/libs/aster-bigtcp/src/iface/iface.rs index 6aa7865b7..eedf2c75e 100644 --- a/kernel/libs/aster-bigtcp/src/iface/iface.rs +++ b/kernel/libs/aster-bigtcp/src/iface/iface.rs @@ -1,15 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, sync::Arc}; +use alloc::sync::Arc; use smoltcp::wire::Ipv4Address; -use super::port::BindPortConfig; -use crate::{ - errors::BindError, - ext::Ext, - socket::{BoundTcpSocket, BoundUdpSocket, UnboundTcpSocket, UnboundUdpSocket}, -}; +use super::{port::BindPortConfig, BoundPort}; +use crate::{errors::BindError, ext::Ext}; /// A network interface. /// @@ -34,24 +30,12 @@ impl dyn Iface { /// FIXME: The reason for binding the socket and the iface together is because there are /// limitations inside smoltcp. See discussion at /// . - pub fn bind_tcp( + pub fn bind( self: &Arc, - socket: Box, - observer: E::TcpEventObserver, config: BindPortConfig, - ) -> core::result::Result, (BindError, Box)> { + ) -> core::result::Result, BindError> { let common = self.common(); - common.bind_tcp(self.clone(), socket, observer, config) - } - - pub fn bind_udp( - self: &Arc, - socket: Box, - observer: E::UdpEventObserver, - config: BindPortConfig, - ) -> core::result::Result, (BindError, Box)> { - let common = self.common(); - common.bind_udp(self.clone(), socket, observer, config) + common.bind(self.clone(), config) } /// Gets the name of the iface. diff --git a/kernel/libs/aster-bigtcp/src/iface/mod.rs b/kernel/libs/aster-bigtcp/src/iface/mod.rs index 7c0608542..428d7a438 100644 --- a/kernel/libs/aster-bigtcp/src/iface/mod.rs +++ b/kernel/libs/aster-bigtcp/src/iface/mod.rs @@ -9,6 +9,7 @@ mod port; mod sched; mod time; +pub use common::BoundPort; pub use iface::Iface; pub use phy::{EtherIface, IpIface}; pub use port::BindPortConfig; diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs index 729522f4b..97f6b5072 100644 --- a/kernel/libs/aster-bigtcp/src/iface/poll.rs +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{collections::btree_set::BTreeSet, vec}; +use alloc::{vec, vec::Vec}; use keyable_arc::KeyableArc; use smoltcp::{ @@ -16,28 +16,28 @@ use smoltcp::{ }, }; +use super::common::SocketSet; use crate::{ ext::Ext, - socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}, + socket::{TcpConnectionBg, TcpListenerBg, TcpProcessResult}, }; pub(super) struct PollContext<'a, E: Ext> { iface_cx: &'a mut Context, - tcp_sockets: &'a BTreeSet>>, - udp_sockets: &'a BTreeSet>>, + sockets: &'a SocketSet, + new_tcp_conns: &'a mut Vec>>, } impl<'a, E: Ext> PollContext<'a, E> { - #[allow(clippy::mutable_key_type)] pub(super) fn new( iface_cx: &'a mut Context, - tcp_sockets: &'a BTreeSet>>, - udp_sockets: &'a BTreeSet>>, + sockets: &'a SocketSet, + new_tcp_conns: &'a mut Vec>>, ) -> Self { Self { iface_cx, - tcp_sockets, - udp_sockets, + sockets, + new_tcp_conns, } } } @@ -51,7 +51,7 @@ impl PollContext<'_, E> { pub(super) fn poll_ingress( &mut self, device: &mut D, - mut process_phy: P, + process_phy: &mut P, dispatch_phy: &mut Q, ) where D: Device + ?Sized, @@ -158,12 +158,17 @@ impl PollContext<'_, E> { ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> Option<(IpRepr, TcpRepr<'static>)> { - for socket in self.tcp_sockets.iter() { + for socket in self + .sockets + .tcp_conn + .iter() + .chain(self.new_tcp_conns.iter()) + { if !socket.can_process(tcp_repr.dst_port) { continue; } - match socket.process(self.iface_cx, ip_repr, tcp_repr) { + match TcpConnectionBg::process(socket, self.iface_cx, ip_repr, tcp_repr) { TcpProcessResult::NotProcessed => continue, TcpProcessResult::Processed => return None, TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { @@ -172,6 +177,29 @@ impl PollContext<'_, E> { } } + if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { + for socket in self.sockets.tcp_listen.iter() { + if !socket.can_process(tcp_repr.dst_port) { + continue; + } + + let (processed, new_tcp_conn) = + TcpListenerBg::process(socket, self.iface_cx, ip_repr, tcp_repr); + + if let Some(tcp_conn) = new_tcp_conn { + self.new_tcp_conns.push(tcp_conn); + } + + match processed { + TcpProcessResult::NotProcessed => continue, + TcpProcessResult::Processed => return None, + TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { + return Some((ip_repr, tcp_repr)) + } + } + } + } + // "In no case does receipt of a segment containing RST give rise to a RST in response." // See . if tcp_repr.control == TcpControl::Rst { @@ -211,7 +239,7 @@ impl PollContext<'_, E> { fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool { let mut processed = false; - for socket in self.udp_sockets.iter() { + for socket in self.sockets.udp.iter() { if !socket.can_process(udp_repr.dst_port) { continue; } @@ -284,13 +312,13 @@ impl PollContext<'_, E> { } impl PollContext<'_, E> { - pub(super) fn poll_egress(&mut self, device: &mut D, mut dispatch_phy: Q) + pub(super) fn poll_egress(&mut self, device: &mut D, dispatch_phy: &mut Q) where D: Device + ?Sized, Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), { while let Some(tx_token) = device.transmit(self.iface_cx.now()) { - if !self.dispatch_ipv4(tx_token, &mut dispatch_phy) { + if !self.dispatch_ipv4(tx_token, dispatch_phy) { break; } } @@ -320,7 +348,9 @@ impl PollContext<'_, E> { let mut tx_token = Some(tx_token); let mut did_something = false; - for socket in self.tcp_sockets.iter() { + // We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable + // reference at this point. Instead, we will retry after the entire poll is complete. + for socket in self.sockets.tcp_conn.iter() { if !socket.need_dispatch(self.iface_cx.now()) { continue; } @@ -331,37 +361,38 @@ impl PollContext<'_, E> { let mut deferred = None; - let reply = socket.dispatch(self.iface_cx, |cx, ip_repr, tcp_repr| { - let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets); + let reply = + TcpConnectionBg::dispatch(socket, self.iface_cx, |cx, ip_repr, tcp_repr| { + let mut this = PollContext::new(cx, self.sockets, self.new_tcp_conns); - if !this.is_unicast_local(ip_repr.dst_addr()) { - dispatch_phy( - &Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)), - this.iface_cx, - tx_token.take().unwrap(), - ); - return None; - } + if !this.is_unicast_local(ip_repr.dst_addr()) { + dispatch_phy( + &Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)), + this.iface_cx, + tx_token.take().unwrap(), + ); + return None; + } - if !socket.can_process(tcp_repr.dst_port) { - return this.process_tcp(ip_repr, tcp_repr); - } + if !socket.can_process(tcp_repr.dst_port) { + return this.process_tcp(ip_repr, tcp_repr); + } - // We cannot call `process_tcp` now because it may cause deadlocks. We will copy - // the packet and call `process_tcp` after releasing the socket lock. - deferred = Some((ip_repr.clone(), { - let mut data = vec![0; tcp_repr.buffer_len()]; - tcp_repr.emit( - &mut TcpPacket::new_unchecked(data.as_mut_slice()), - &ip_repr.src_addr(), - &ip_repr.dst_addr(), - &ChecksumCapabilities::ignored(), - ); - data - })); + // We cannot call `process_tcp` now because it may cause deadlocks. We will copy + // the packet and call `process_tcp` after releasing the socket lock. + deferred = Some((ip_repr.clone(), { + let mut data = vec![0; tcp_repr.buffer_len()]; + tcp_repr.emit( + &mut TcpPacket::new_unchecked(data.as_mut_slice()), + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + &ChecksumCapabilities::ignored(), + ); + data + })); - None - }); + None + }); match (deferred, reply) { (None, None) => (), @@ -411,7 +442,7 @@ impl PollContext<'_, E> { let mut tx_token = Some(tx_token); let mut did_something = false; - for socket in self.udp_sockets.iter() { + for socket in self.sockets.udp.iter() { if !socket.need_dispatch(self.iface_cx.now()) { continue; } @@ -423,7 +454,7 @@ impl PollContext<'_, E> { let mut deferred = None; socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| { - let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets); + let mut this = PollContext::new(cx, self.sockets, self.new_tcp_conns); if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) { dispatch_phy( diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index cecc737b8..f0b25ed2e 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -1,75 +1,99 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, sync::Arc}; +use alloc::{boxed::Box, collections::btree_set::BTreeSet, sync::Arc, vec::Vec}; use core::{ + borrow::Borrow, ops::{Deref, DerefMut}, sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, }; -use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard, WriteIrqDisabled}; +use keyable_arc::KeyableArc; +use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard}; use smoltcp::{ iface::Context, socket::{tcp::State, udp::UdpMetadata, PollAt}, time::{Duration, Instant}, - wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, + wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, }; +use spin::Once; +use takeable::Takeable; use super::{ event::{SocketEventObserver, SocketEvents}, - option::RawTcpSetOption, + option::{RawTcpOption, RawTcpSetOption}, + unbound::{new_tcp_socket, new_udp_socket}, RawTcpSocket, RawUdpSocket, TcpStateCheck, }; -use crate::{ext::Ext, iface::Iface}; +use crate::{ + ext::Ext, + iface::{BindPortConfig, BoundPort, Iface}, +}; -pub struct BoundSocket, E: Ext>(Arc>); +pub struct Socket, E: Ext>(Takeable>>); -/// [`TcpSocket`] or [`UdpSocket`]. -pub trait AnySocket { - type RawSocket; +impl, E: Ext> PartialEq for Socket { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} +impl, E: Ext> Eq for Socket {} +impl, E: Ext> PartialOrd for Socket { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl, E: Ext> Ord for Socket { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.0.cmp(&other.0) + } +} +impl, E: Ext> Borrow>> for Socket { + fn borrow(&self) -> &KeyableArc> { + self.0.as_ref() + } +} + +/// [`TcpConnectionInner`] or [`UdpSocketInner`]. +pub trait Inner { type Observer: SocketEventObserver; - /// Called by [`BoundSocket::new`]. - fn new(socket: Box) -> Self; - - /// Called by [`BoundSocket::drop`]. - fn on_drop(this: &Arc>) + /// Called by [`Socket::drop`]. + fn on_drop(this: &KeyableArc>) where E: Ext, Self: Sized; } -pub type BoundTcpSocket = BoundSocket; -pub type BoundUdpSocket = BoundSocket; +pub type TcpConnection = Socket, E>; +pub type TcpListener = Socket, E>; +pub type UdpSocket = Socket; -/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`]. -pub struct BoundSocketInner, E> { - iface: Arc>, - port: u16, - socket: T, - observer: RwLock, +/// Common states shared by [`TcpConnectionBg`] and [`UdpSocketBg`]. +/// +/// In the type name, `Bg` means "background". Its meaning is described below: +/// - A foreground socket (e.g., [`TcpConnection`]) handles system calls from the user program. +/// - A background socket (e.g., [`TcpConnectionBg`]) handles packets from the network. +pub struct SocketBg, E: Ext> { + bound: BoundPort, + inner: T, + observer: Once, events: AtomicU8, next_poll_at_ms: AtomicU64, } -/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`]. -pub struct TcpSocket { - socket: SpinLock, +/// States needed by [`TcpConnectionBg`] but not [`UdpSocketBg`]. +pub struct TcpConnectionInner { + socket: SpinLock, LocalIrqDisabled>, is_dead: AtomicBool, } -struct RawTcpSocketExt { +struct RawTcpSocketExt { socket: Box, + listener: Option>>, has_connected: bool, - /// Whether the socket is in the background. - /// - /// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means - /// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a - /// state of waiting for certain network events (e.g., remote FIN/ACK packets), so - /// [`BoundSocketInner`] may still be alive for a while. - in_background: bool, } -impl Deref for RawTcpSocketExt { +impl Deref for RawTcpSocketExt { type Target = RawTcpSocket; fn deref(&self) -> &Self::Target { @@ -77,18 +101,28 @@ impl Deref for RawTcpSocketExt { } } -impl DerefMut for RawTcpSocketExt { +impl DerefMut for RawTcpSocketExt { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.socket } } -impl RawTcpSocketExt { - fn on_new_state(&mut self) -> SocketEvents { - if self.may_send() { +impl RawTcpSocketExt { + fn on_new_state(&mut self, this: &KeyableArc>) -> SocketEvents { + if self.may_send() && !self.has_connected { self.has_connected = true; + + if let Some(ref listener) = self.listener { + let mut backlog = listener.inner.lock(); + if let Some(value) = backlog.connecting.take(this) { + backlog.connected.push(value); + } + listener.add_events(SocketEvents::CAN_RECV); + } } + self.update_dead(this); + if self.is_peer_closed() { SocketEvents::PEER_CLOSED } else if self.is_closed() { @@ -97,148 +131,178 @@ impl RawTcpSocketExt { SocketEvents::empty() } } -} -impl TcpSocket { - fn lock(&self) -> SpinLockGuard { - self.socket.lock() - } - - /// Returns whether the TCP socket is dead. + /// Updates whether the TCP connection is dead. /// - /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. - fn is_dead(&self) -> bool { - self.is_dead.load(Ordering::Relaxed) - } - - /// Updates whether the TCP socket is dead. - /// - /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. + /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections. /// /// This method must be called after handling network events. However, it is not necessary to /// call this method after handling non-closing user events, because the socket can never be - /// dead if user events can reach the socket. - fn update_dead(&self, socket: &RawTcpSocketExt) { - if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed { - self.is_dead.store(true, Ordering::Relaxed); + /// dead if it is not closed. + fn update_dead(&self, this: &KeyableArc>) { + if self.state() == smoltcp::socket::tcp::State::Closed { + this.inner.is_dead.store(true, Ordering::Relaxed); } - } - /// Sets the TCP socket in [`TimeWait`] state as dead. - /// - /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. - /// - /// [`TimeWait`]: smoltcp::socket::tcp::State::TimeWait - fn set_dead_timewait(&self, socket: &RawTcpSocketExt) { - debug_assert!( - socket.in_background && socket.state() == smoltcp::socket::tcp::State::TimeWait - ); - self.is_dead.store(true, Ordering::Relaxed); + // According to the current smoltcp implementation, a backlog socket will return back to + // the `Listen` state if the connection is RSTed before its establishment. + if self.state() == smoltcp::socket::tcp::State::Listen { + this.inner.is_dead.store(true, Ordering::Relaxed); + + if let Some(ref listener) = self.listener { + let mut backlog = listener.inner.lock(); + // This may fail due to race conditions, but it's fine. + let _ = backlog.connecting.remove(this); + } + } } } -impl AnySocket for TcpSocket { - type RawSocket = RawTcpSocket; - type Observer = E::TcpEventObserver; - - fn new(socket: Box) -> Self { +impl TcpConnectionInner { + fn new(socket: Box, listener: Option>>) -> Self { let socket_ext = RawTcpSocketExt { socket, + listener, has_connected: false, - in_background: false, }; - Self { + TcpConnectionInner { socket: SpinLock::new(socket_ext), is_dead: AtomicBool::new(false), } } - fn on_drop(this: &Arc>) { - let mut socket = this.socket.lock(); + fn lock(&self) -> SpinLockGuard, LocalIrqDisabled> { + self.socket.lock() + } - socket.in_background = true; + /// Returns whether the TCP connection is dead. + /// + /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections. + fn is_dead(&self) -> bool { + self.is_dead.load(Ordering::Relaxed) + } + + /// Sets the TCP connection in [`TimeWait`] state as dead. + /// + /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections. + /// + /// [`TimeWait`]: smoltcp::socket::tcp::State::TimeWait + fn set_dead_timewait(&self, socket: &RawTcpSocketExt) { + debug_assert!(socket.state() == smoltcp::socket::tcp::State::TimeWait); + self.is_dead.store(true, Ordering::Relaxed); + } +} + +impl Inner for TcpConnectionInner { + type Observer = E::TcpEventObserver; + + fn on_drop(this: &KeyableArc>) { + let mut socket = this.inner.lock(); + + // FIXME: Send RSTs when there is unread data. socket.close(); - // A TCP socket may not be appropriate for immediate removal. We leave the removal decision - // to the polling logic. + // A TCP connection may not be appropriate for immediate removal. We leave the removal + // decision to the polling logic. this.update_next_poll_at_ms(PollAt::Now); - this.socket.update_dead(&socket); + socket.update_dead(this); } } -/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`]. -type UdpSocket = SpinLock, LocalIrqDisabled>; +pub struct TcpBacklog { + socket: Box, + max_conn: usize, + connecting: BTreeSet>, + connected: Vec>, +} -impl AnySocket for UdpSocket { - type RawSocket = RawUdpSocket; +pub type TcpListenerInner = SpinLock, LocalIrqDisabled>; + +impl Inner for TcpListenerInner { + type Observer = E::TcpEventObserver; + + fn on_drop(this: &KeyableArc>) { + // A TCP listener can be removed immediately. + this.bound.iface().common().remove_tcp_listener(this); + + let (connecting, connected) = { + let mut socket = this.inner.lock(); + ( + core::mem::take(&mut socket.connecting), + core::mem::take(&mut socket.connected), + ) + }; + + // The lock on `connecting`/`connected` cannot be locked after locking `self`, otherwise we + // might get a deadlock. due to inconsistent lock order problems. + // + // FIXME: Send RSTs instead of going through the normal socket close process. + drop(connecting); + drop(connected); + } +} + +/// States needed by [`UdpSocketBg`] but not [`TcpConnectionBg`]. +type UdpSocketInner = SpinLock, LocalIrqDisabled>; + +impl Inner for UdpSocketInner { type Observer = E::UdpEventObserver; - fn new(socket: Box) -> Self { - Self::new(socket) - } - - fn on_drop(this: &Arc>) - where - E: Ext, - { - this.socket.lock().close(); + fn on_drop(this: &KeyableArc>) { + this.inner.lock().close(); // A UDP socket can be removed immediately. - this.iface.common().remove_udp_socket(this); + this.bound.iface().common().remove_udp_socket(this); } } -impl, E: Ext> Drop for BoundSocket { +impl, E: Ext> Drop for Socket { fn drop(&mut self) { - T::on_drop(&self.0); + if self.0.is_usable() { + T::on_drop(&self.0); + } } } -pub(crate) type BoundTcpSocketInner = BoundSocketInner; -pub(crate) type BoundUdpSocketInner = BoundSocketInner; +pub(crate) type TcpConnectionBg = SocketBg, E>; +pub(crate) type TcpListenerBg = SocketBg, E>; +pub(crate) type UdpSocketBg = SocketBg; -impl, E: Ext> BoundSocket { - pub(crate) fn new( - iface: Arc>, - port: u16, - socket: Box, - observer: T::Observer, - ) -> Self { - Self(Arc::new(BoundSocketInner { - iface, - port, - socket: T::new(socket), - observer: RwLock::new(observer), +impl, E: Ext> Socket { + pub(crate) fn new(bound: BoundPort, inner: T) -> Self { + Self(Takeable::new(KeyableArc::new(SocketBg { + bound, + inner, + observer: Once::new(), events: AtomicU8::new(0), next_poll_at_ms: AtomicU64::new(u64::MAX), - })) + }))) } - pub(crate) fn inner(&self) -> &Arc> { + pub(crate) fn inner(&self) -> &KeyableArc> { &self.0 } } -impl, E: Ext> BoundSocket { - /// Sets the observer whose `on_events` will be called when certain iface events happen. +impl, E: Ext> Socket { + /// Initializes the observer whose `on_events` will be called when certain iface events happen. /// /// The caller needs to be responsible for race conditions if network events can occur /// simultaneously. - pub fn set_observer(&self, new_observer: T::Observer) { - *self.0.observer.write() = new_observer; + /// + /// Calling this method on a socket whose observer has already been initialized will have no + /// effect. + pub fn init_observer(&self, new_observer: T::Observer) { + self.0.observer.call_once(|| new_observer); } pub fn local_endpoint(&self) -> Option { - let ip_addr = { - let ipv4_addr = self.0.iface.ipv4_addr()?; - IpAddress::Ipv4(ipv4_addr) - }; - Some(IpEndpoint::new(ip_addr, self.0.port)) + self.0.bound.endpoint() } pub fn iface(&self) -> &Arc> { - &self.0.iface + self.0.bound.iface() } } @@ -264,50 +328,76 @@ impl Deref for NeedIfacePoll { } } -impl BoundTcpSocket { +impl TcpConnection { /// Connects to a remote endpoint. /// /// Polling the iface is _always_ required after this method succeeds. - pub fn connect( - &self, + pub fn new_connect( + bound: BoundPort, remote_endpoint: IpEndpoint, - ) -> Result<(), smoltcp::socket::tcp::ConnectError> { - let common = self.iface().common(); - let mut iface = common.interface(); + option: &RawTcpOption, + observer: E::TcpEventObserver, + ) -> Result, smoltcp::socket::tcp::ConnectError)> { + let socket = { + let mut socket = new_tcp_socket(); - let mut socket = self.0.socket.lock(); + option.apply(&mut socket); - socket.connect(iface.context(), remote_endpoint, self.0.port)?; + let common = bound.iface().common(); + let mut iface = common.interface(); - socket.has_connected = false; - self.0.update_next_poll_at_ms(PollAt::Now); + if let Err(err) = socket.connect(iface.context(), remote_endpoint, bound.port()) { + drop(iface); + return Err((bound, err)); + } - Ok(()) + socket + }; + + let inner = TcpConnectionInner::new(socket, None); + + let connection = Self::new(bound, inner); + connection.0.update_next_poll_at_ms(PollAt::Now); + connection.init_observer(observer); + connection + .iface() + .common() + .register_tcp_connection(connection.inner().clone()); + + Ok(connection) } /// Returns the state of the connecting procedure. pub fn connect_state(&self) -> ConnectState { - let socket = self.0.socket.lock(); + let socket = self.0.inner.lock(); if socket.state() == State::SynSent || socket.state() == State::SynReceived { ConnectState::Connecting } else if socket.has_connected { ConnectState::Connected + } else if KeyableArc::strong_count(self.0.as_ref()) > 1 { + // Now we should return `ConnectState::Refused`. However, when we do this, we must + // guarantee that `into_bound_port` can succeed (see the method's doc comments). We can + // only guarantee this after we have removed all `Arc` in the iface's + // socket set. + // + // This branch serves to avoid a race condition: if the removal process hasn't + // finished, we will return `Connecting` so that the caller won't try to call + // `into_bound_port` (which may fail immediately). + ConnectState::Connecting } else { ConnectState::Refused } } - /// Listens at a specified endpoint. + /// Converts back to the [`BoundPort`]. /// - /// Polling the iface is _not_ required after this method succeeds. - pub fn listen( - &self, - local_endpoint: IpEndpoint, - ) -> Result<(), smoltcp::socket::tcp::ListenError> { - let mut socket = self.0.socket.lock(); - - socket.listen(local_endpoint) + /// This method will succeed if the connection is fully closed and no network events can reach + /// this connection. We guarantee that this method will always succeed if + /// [`Self::connect_state`] returns [`ConnectState::Refused`]. + pub fn into_bound_port(mut self) -> Option> { + let this: TcpConnectionBg = Arc::into_inner(self.0.take().into())?; + Some(this.bound) } /// Sends some data. @@ -320,7 +410,7 @@ impl BoundTcpSocket { let common = self.iface().common(); let mut iface = common.interface(); - let mut socket = self.0.socket.lock(); + let mut socket = self.0.inner.lock(); let result = socket.send(f)?; let need_poll = self @@ -340,7 +430,7 @@ impl BoundTcpSocket { let common = self.iface().common(); let mut iface = common.interface(); - let mut socket = self.0.socket.lock(); + let mut socket = self.0.inner.lock(); let result = socket.recv(f)?; let need_poll = self @@ -354,8 +444,9 @@ impl BoundTcpSocket { /// /// Polling the iface is _always_ required after this method succeeds. pub fn close(&self) { - let mut socket = self.0.socket.lock(); + let mut socket = self.0.inner.lock(); + socket.listener = None; socket.close(); self.0.update_next_poll_at_ms(PollAt::Now); } @@ -368,14 +459,14 @@ impl BoundTcpSocket { where F: FnOnce(&RawTcpSocket) -> R, { - let socket = self.0.socket.lock(); + let socket = self.0.inner.lock(); f(&socket) } } -impl RawTcpSetOption for BoundTcpSocket { - fn set_keep_alive(&mut self, interval: Option) -> NeedIfacePoll { - let mut socket = self.0.socket.lock(); +impl RawTcpSetOption for TcpConnection { + fn set_keep_alive(&self, interval: Option) -> NeedIfacePoll { + let mut socket = self.0.inner.lock(); socket.set_keep_alive(interval); if interval.is_some() { @@ -386,20 +477,130 @@ impl RawTcpSetOption for BoundTcpSocket { } } - fn set_nagle_enabled(&mut self, enabled: bool) { - let mut socket = self.0.socket.lock(); + fn set_nagle_enabled(&self, enabled: bool) { + let mut socket = self.0.inner.lock(); socket.set_nagle_enabled(enabled); } } -impl BoundUdpSocket { +impl TcpListener { + /// Listens at a specified endpoint. + /// + /// Polling the iface is _not_ required after this method succeeds. + pub fn new_listen( + bound: BoundPort, + max_conn: usize, + option: &RawTcpOption, + observer: E::TcpEventObserver, + ) -> Result, smoltcp::socket::tcp::ListenError)> { + let Some(local_endpoint) = bound.endpoint() else { + return Err((bound, smoltcp::socket::tcp::ListenError::Unaddressable)); + }; + + let socket = { + let mut socket = new_tcp_socket(); + + option.apply(&mut socket); + + if let Err(err) = socket.listen(local_endpoint) { + return Err((bound, err)); + } + + socket + }; + + let inner = TcpListenerInner::new(TcpBacklog { + socket, + max_conn, + connecting: BTreeSet::new(), + connected: Vec::new(), + }); + + let listener = Self::new(bound, inner); + listener.init_observer(observer); + listener + .iface() + .common() + .register_tcp_listener(listener.inner().clone()); + + Ok(listener) + } + + /// Accepts a TCP connection. + /// + /// Polling the iface is _not_ required after this method succeeds. + pub fn accept(&self) -> Option<(TcpConnection, IpEndpoint)> { + let accepted = { + let mut backlog = self.0.inner.lock(); + backlog.connected.pop()? + }; + + let remote_endpoint = { + // The lock on `accepted` cannot be locked after locking `self`, otherwise we might get + // a deadlock. due to inconsistent lock order problems. + let mut socket = accepted.0.inner.lock(); + + socket.listener = None; + socket.remote_endpoint() + }; + + Some((accepted, remote_endpoint.unwrap())) + } + + /// Returns whether there is a TCP connection to accept. + /// + /// It's the caller's responsibility to deal with race conditions when using this method. + pub fn can_accept(&self) -> bool { + !self.0.inner.lock().connected.is_empty() + } +} + +impl RawTcpSetOption for TcpListener { + fn set_keep_alive(&self, interval: Option) -> NeedIfacePoll { + let mut backlog = self.0.inner.lock(); + backlog.socket.set_keep_alive(interval); + + NeedIfacePoll::FALSE + } + + fn set_nagle_enabled(&self, enabled: bool) { + let mut backlog = self.0.inner.lock(); + backlog.socket.set_nagle_enabled(enabled); + } +} + +impl UdpSocket { /// Binds to a specified endpoint. /// /// Polling the iface is _not_ required after this method succeeds. - pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> { - let mut socket = self.0.socket.lock(); + pub fn new_bind( + bound: BoundPort, + observer: E::UdpEventObserver, + ) -> Result, smoltcp::socket::udp::BindError)> { + let Some(local_endpoint) = bound.endpoint() else { + return Err((bound, smoltcp::socket::udp::BindError::Unaddressable)); + }; - socket.bind(local_endpoint) + let socket = { + let mut socket = new_udp_socket(); + + if let Err(err) = socket.bind(local_endpoint) { + return Err((bound, err)); + } + + socket + }; + + let inner = UdpSocketInner::new(socket); + + let socket = Self::new(bound, inner); + socket.init_observer(observer); + socket + .iface() + .common() + .register_udp_socket(socket.inner().clone()); + + Ok(socket) } /// Sends some data. @@ -418,7 +619,7 @@ impl BoundUdpSocket { use crate::errors::udp::SendError; - let mut socket = self.0.socket.lock(); + let mut socket = self.0.inner.lock(); if size > socket.packet_send_capacity() { return Err(SendError::TooLarge); @@ -442,7 +643,7 @@ impl BoundUdpSocket { where F: FnOnce(&[u8], UdpMetadata) -> R, { - let mut socket = self.0.socket.lock(); + let mut socket = self.0.inner.lock(); let (data, meta) = socket.recv()?; let result = f(data, meta); @@ -458,12 +659,12 @@ impl BoundUdpSocket { where F: FnOnce(&RawUdpSocket) -> R, { - let socket = self.0.socket.lock(); + let socket = self.0.inner.lock(); f(&socket) } } -impl, E> BoundSocketInner { +impl, E: Ext> SocketBg { pub(crate) fn has_events(&self) -> bool { self.events.load(Ordering::Relaxed) != 0 } @@ -474,8 +675,28 @@ impl, E> BoundSocketInner { let events = self.events.load(Ordering::Relaxed); self.events.store(0, Ordering::Relaxed); - let observer = self.observer.read(); - observer.on_events(SocketEvents::from_bits_truncate(events)); + if let Some(observer) = self.observer.get() { + observer.on_events(SocketEvents::from_bits_truncate(events)); + } + } + + pub(crate) fn on_dead_events(this: KeyableArc) + where + T::Observer: Clone, + { + // This method can only be called to process network events, so we assume we are holding the + // poll lock and no race conditions can occur. + let events = this.events.load(Ordering::Relaxed); + this.events.store(0, Ordering::Relaxed); + + let observer = this.observer.get().cloned(); + drop(this); + + // Notify dead events after the `Arc` is dropped to ensure the observer sees this event + // with the expected reference count. See `TcpConnection::connect_state` for an example. + if let Some(ref observer) = observer { + observer.on_events(SocketEvents::from_bits_truncate(events)); + } } fn add_events(&self, new_events: SocketEvents) { @@ -498,7 +719,7 @@ impl, E> BoundSocketInner { /// /// The update is typically needed after new network or user events have been handled, so this /// method also marks that there may be new events, so that the event observer provided by - /// [`BoundSocket::set_observer`] can be notified later. + /// [`Socket::init_observer`] can be notified later. fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll { match poll_at { PollAt::Now => { @@ -522,30 +743,23 @@ impl, E> BoundSocketInner { } } -impl, E> BoundSocketInner { - pub(crate) fn port(&self) -> u16 { - self.port - } -} - -impl BoundTcpSocketInner { - /// Returns whether the TCP socket is dead. +impl TcpConnectionBg { + /// Returns whether the TCP connection is dead. /// - /// A TCP socket is considered dead if and only if the following two conditions are met: - /// 1. The TCP connection is closed, so this socket cannot process any network events. - /// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this - /// [`BoundSocketInner`] is in background and no more user events can reach it. + /// A TCP connection is considered dead when and only when the TCP socket is in the closed + /// state, meaning it's no longer accepting packets from the network. This is different from + /// the socket file being closed, which only initiates the socket close process. pub(crate) fn is_dead(&self) -> bool { - self.socket.is_dead() + self.inner.is_dead() } } -impl, E> BoundSocketInner { +impl, E: Ext> SocketBg { /// Returns whether an incoming packet _may_ be processed by the socket. /// /// The check is intended to be lock-free and fast, but may have false positives. pub(crate) fn can_process(&self, dst_port: u16) -> bool { - self.port == dst_port + self.bound.port() == dst_port } /// Returns whether the socket _may_ generate an outgoing packet. @@ -563,15 +777,15 @@ pub(crate) enum TcpProcessResult { ProcessedWithReply(IpRepr, TcpRepr<'static>), } -impl BoundTcpSocketInner { +impl TcpConnectionBg { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( - &self, + this: &KeyableArc, cx: &mut Context, ip_repr: &IpRepr, tcp_repr: &TcpRepr, ) -> TcpProcessResult { - let mut socket = self.socket.lock(); + let mut socket = this.inner.lock(); if !socket.accepts(cx, ip_repr, tcp_repr) { return TcpProcessResult::NotProcessed; @@ -594,7 +808,7 @@ impl BoundTcpSocketInner { && tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() { - self.socket.set_dead_timewait(&socket); + this.inner.set_dead_timewait(&socket); return TcpProcessResult::NotProcessed; } @@ -609,26 +823,25 @@ impl BoundTcpSocketInner { }; if socket.state() != old_state { - events |= socket.on_new_state(); + events |= socket.on_new_state(this); } - self.add_events(events); - self.update_next_poll_at_ms(socket.poll_at(cx)); - self.socket.update_dead(&socket); + this.add_events(events); + this.update_next_poll_at_ms(socket.poll_at(cx)); result } /// Tries to generate an outgoing packet and dispatches the generated packet. pub(crate) fn dispatch( - &self, + this: &KeyableArc, cx: &mut Context, dispatch: D, ) -> Option<(IpRepr, TcpRepr<'static>)> where D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, { - let mut socket = self.socket.lock(); + let mut socket = this.inner.lock(); let old_state = socket.state(); let mut events = SocketEvents::empty(); @@ -652,18 +865,76 @@ impl BoundTcpSocketInner { } if socket.state() != old_state { - events |= socket.on_new_state(); + events |= socket.on_new_state(this); } - self.add_events(events); - self.update_next_poll_at_ms(socket.poll_at(cx)); - self.socket.update_dead(&socket); + this.add_events(events); + this.update_next_poll_at_ms(socket.poll_at(cx)); reply } } -impl BoundUdpSocketInner { +impl TcpListenerBg { + /// Tries to process an incoming packet and returns whether the packet is processed. + pub(crate) fn process( + this: &KeyableArc, + cx: &mut Context, + ip_repr: &IpRepr, + tcp_repr: &TcpRepr, + ) -> (TcpProcessResult, Option>>) { + let mut backlog = this.inner.lock(); + + if !backlog.socket.accepts(cx, ip_repr, tcp_repr) { + return (TcpProcessResult::NotProcessed, None); + } + + // FIXME: According to the Linux implementation, `max_conn` is the upper bound of + // `connected.len()`. We currently limit it to `connected.len() + connecting.len()` for + // simplicity. + if backlog.connected.len() + backlog.connecting.len() >= backlog.max_conn { + return (TcpProcessResult::Processed, None); + } + + let result = match backlog.socket.process(cx, ip_repr, tcp_repr) { + None => TcpProcessResult::Processed, + Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), + }; + + if backlog.socket.state() == smoltcp::socket::tcp::State::Listen { + return (result, None); + } + + let new_socket = { + let mut socket = new_tcp_socket(); + RawTcpOption::inherit(&backlog.socket, &mut socket); + socket.listen(backlog.socket.listen_endpoint()).unwrap(); + socket + }; + + let inner = TcpConnectionInner::new( + core::mem::replace(&mut backlog.socket, new_socket), + Some(this.clone().into()), + ); + let conn = TcpConnection::new( + this.bound + .iface() + .bind(BindPortConfig::CanReuse(this.bound.port())) + .unwrap(), + inner, + ); + let conn_bg = conn.inner().clone(); + + let inserted = backlog.connecting.insert(conn); + assert!(inserted); + + conn_bg.update_next_poll_at_ms(PollAt::Now); + + (result, Some(conn_bg)) + } +} + +impl UdpSocketBg { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( &self, @@ -672,7 +943,7 @@ impl BoundUdpSocketInner { udp_repr: &UdpRepr, udp_payload: &[u8], ) -> bool { - let mut socket = self.socket.lock(); + let mut socket = self.inner.lock(); if !socket.accepts(cx, ip_repr, udp_repr) { return false; @@ -697,7 +968,7 @@ impl BoundUdpSocketInner { where D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]), { - let mut socket = self.socket.lock(); + let mut socket = self.inner.lock(); socket .dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| { diff --git a/kernel/libs/aster-bigtcp/src/socket/mod.rs b/kernel/libs/aster-bigtcp/src/socket/mod.rs index 293526713..136a635a9 100644 --- a/kernel/libs/aster-bigtcp/src/socket/mod.rs +++ b/kernel/libs/aster-bigtcp/src/socket/mod.rs @@ -6,15 +6,12 @@ mod option; mod state; mod unbound; -pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState, NeedIfacePoll}; -pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; +pub use bound::{ConnectState, NeedIfacePoll, TcpConnection, TcpListener, UdpSocket}; +pub(crate) use bound::{TcpConnectionBg, TcpListenerBg, TcpProcessResult, UdpSocketBg}; pub use event::{SocketEventObserver, SocketEvents}; -pub use option::RawTcpSetOption; -pub use state::{TcpState, TcpStateCheck}; -pub use unbound::{ - UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, - UDP_SEND_PAYLOAD_LEN, -}; +pub use option::{RawTcpOption, RawTcpSetOption}; +pub use state::TcpStateCheck; +pub use unbound::{TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN}; pub type RawTcpSocket = smoltcp::socket::tcp::Socket<'static>; pub type RawUdpSocket = smoltcp::socket::udp::Socket<'static>; diff --git a/kernel/libs/aster-bigtcp/src/socket/option.rs b/kernel/libs/aster-bigtcp/src/socket/option.rs index c5d79b0ef..e597a5ba8 100644 --- a/kernel/libs/aster-bigtcp/src/socket/option.rs +++ b/kernel/libs/aster-bigtcp/src/socket/option.rs @@ -2,20 +2,37 @@ use smoltcp::time::Duration; -use super::NeedIfacePoll; +use super::{NeedIfacePoll, RawTcpSocket}; /// A trait defines setting socket options on a raw socket. -/// -/// TODO: When `UnboundSocket` is removed, all methods in this trait can accept -/// `&self` instead of `&mut self` as parameter. pub trait RawTcpSetOption { /// Sets the keep alive interval. /// /// Polling the iface _may_ be required after this method succeeds. - fn set_keep_alive(&mut self, interval: Option) -> NeedIfacePoll; + fn set_keep_alive(&self, interval: Option) -> NeedIfacePoll; /// Enables or disables Nagle’s Algorithm. /// - /// Polling the iface is not required after this method succeeds. - fn set_nagle_enabled(&mut self, enabled: bool); + /// Polling the iface is _not_ required after this method succeeds. + fn set_nagle_enabled(&self, enabled: bool); +} + +/// Socket options on a raw socket. +pub struct RawTcpOption { + /// The keep alive interval. + pub keep_alive: Option, + /// Whether Nagle's algorithm is enabled. + pub is_nagle_enabled: bool, +} + +impl RawTcpOption { + pub(super) fn apply(&self, socket: &mut RawTcpSocket) { + socket.set_keep_alive(self.keep_alive); + socket.set_nagle_enabled(self.is_nagle_enabled); + } + + pub(super) fn inherit(from: &RawTcpSocket, to: &mut RawTcpSocket) { + to.set_keep_alive(from.keep_alive()); + to.set_nagle_enabled(from.nagle_enabled()); + } } diff --git a/kernel/libs/aster-bigtcp/src/socket/unbound.rs b/kernel/libs/aster-bigtcp/src/socket/unbound.rs index b46dcb0ef..724298c80 100644 --- a/kernel/libs/aster-bigtcp/src/socket/unbound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/unbound.rs @@ -2,75 +2,31 @@ use alloc::{boxed::Box, vec}; -use super::{option::RawTcpSetOption, NeedIfacePoll, RawTcpSocket, RawUdpSocket}; +use super::{RawTcpSocket, RawUdpSocket}; -pub struct UnboundSocket { - socket: Box, +pub(super) fn new_tcp_socket() -> Box { + let raw_tcp_socket = { + let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]); + let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]); + RawTcpSocket::new(rx_buffer, tx_buffer) + }; + Box::new(raw_tcp_socket) } -pub type UnboundTcpSocket = UnboundSocket; -pub type UnboundUdpSocket = UnboundSocket; - -impl UnboundTcpSocket { - pub fn new() -> Self { - let raw_tcp_socket = { - let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]); - let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]); - RawTcpSocket::new(rx_buffer, tx_buffer) - }; - Self { - socket: Box::new(raw_tcp_socket), - } - } -} - -impl Default for UnboundTcpSocket { - fn default() -> Self { - Self::new() - } -} - -impl RawTcpSetOption for UnboundTcpSocket { - fn set_keep_alive(&mut self, interval: Option) -> NeedIfacePoll { - self.socket.set_keep_alive(interval); - NeedIfacePoll::FALSE - } - - fn set_nagle_enabled(&mut self, enabled: bool) { - self.socket.set_nagle_enabled(enabled); - } -} - -impl UnboundUdpSocket { - pub fn new() -> Self { - let raw_udp_socket = { - let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; - let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( - vec![metadata; UDP_METADATA_LEN], - vec![0u8; UDP_RECV_PAYLOAD_LEN], - ); - let tx_buffer = smoltcp::socket::udp::PacketBuffer::new( - vec![metadata; UDP_METADATA_LEN], - vec![0u8; UDP_SEND_PAYLOAD_LEN], - ); - RawUdpSocket::new(rx_buffer, tx_buffer) - }; - Self { - socket: Box::new(raw_udp_socket), - } - } -} - -impl Default for UnboundUdpSocket { - fn default() -> Self { - Self::new() - } -} - -impl UnboundSocket { - pub(crate) fn into_raw(self) -> Box { - self.socket - } +pub(super) fn new_udp_socket() -> Box { + let raw_udp_socket = { + let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; + let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( + vec![metadata; UDP_METADATA_LEN], + vec![0u8; UDP_RECV_PAYLOAD_LEN], + ); + let tx_buffer = smoltcp::socket::udp::PacketBuffer::new( + vec![metadata; UDP_METADATA_LEN], + vec![0u8; UDP_SEND_PAYLOAD_LEN], + ); + RawUdpSocket::new(rx_buffer, tx_buffer) + }; + Box::new(raw_udp_socket) } // TCP socket buffer sizes: diff --git a/kernel/libs/keyable-arc/src/lib.rs b/kernel/libs/keyable-arc/src/lib.rs index b8800ceae..ea24d11b4 100644 --- a/kernel/libs/keyable-arc/src/lib.rs +++ b/kernel/libs/keyable-arc/src/lib.rs @@ -138,9 +138,22 @@ impl KeyableArc { } /// Creates a new `KeyableWeak` pointer to this allocation. + #[inline] pub fn downgrade(this: &Self) -> KeyableWeak { Arc::downgrade(&this.0).into() } + + /// Gets the number of strong pointers pointing to this allocation. + #[inline] + pub fn strong_count(this: &Self) -> usize { + Arc::strong_count(&this.0) + } + + /// Gets the number of weak pointers pointing to this allocation. + #[inline] + pub fn weak_count(this: &Self) -> usize { + Arc::weak_count(&this.0) + } } impl Deref for KeyableArc { diff --git a/kernel/src/net/iface/mod.rs b/kernel/src/net/iface/mod.rs index 76bd7d73e..6284b015f 100644 --- a/kernel/src/net/iface/mod.rs +++ b/kernel/src/net/iface/mod.rs @@ -9,5 +9,8 @@ pub use init::{init, IFACES}; pub use poll::lazy_init; pub type Iface = dyn aster_bigtcp::iface::Iface; -pub type BoundTcpSocket = aster_bigtcp::socket::BoundTcpSocket; -pub type BoundUdpSocket = aster_bigtcp::socket::BoundUdpSocket; +pub type BoundPort = aster_bigtcp::iface::BoundPort; + +pub type TcpConnection = aster_bigtcp::socket::TcpConnection; +pub type TcpListener = aster_bigtcp::socket::TcpListener; +pub type UdpSocket = aster_bigtcp::socket::UdpSocket; diff --git a/kernel/src/net/socket/ip/common.rs b/kernel/src/net/socket/ip/common.rs index bd081ad20..ce3761a3a 100644 --- a/kernel/src/net/socket/ip/common.rs +++ b/kernel/src/net/socket/ip/common.rs @@ -7,7 +7,7 @@ use aster_bigtcp::{ }; use crate::{ - net::iface::{Iface, IFACES}, + net::iface::{BoundPort, Iface, IFACES}, prelude::*, }; @@ -45,30 +45,20 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc { ifaces[0].clone() } -pub(super) fn bind_socket( - unbound_socket: Box, - endpoint: &IpEndpoint, - can_reuse: bool, - bind: impl FnOnce( - Arc, - Box, - BindPortConfig, - ) -> core::result::Result)>, -) -> core::result::Result)> { +pub(super) fn bind_port(endpoint: &IpEndpoint, can_reuse: bool) -> Result { let iface = match get_iface_to_bind(&endpoint.addr) { Some(iface) => iface, None => { - let err = Error::with_message( + return_errno_with_message!( Errno::EADDRNOTAVAIL, - "the address is not available from the local machine", + "the address is not available from the local machine" ); - return Err((err, unbound_socket)); } }; let bind_port_config = BindPortConfig::new(endpoint.port, can_reuse); - bind(iface, unbound_socket, bind_port_config).map_err(|(err, unbound)| (err.into(), unbound)) + Ok(iface.bind(bind_port_config)?) } impl From for Error { diff --git a/kernel/src/net/socket/ip/datagram/bound.rs b/kernel/src/net/socket/ip/datagram/bound.rs index 8cb82b06e..b44a2d308 100644 --- a/kernel/src/net/socket/ip/datagram/bound.rs +++ b/kernel/src/net/socket/ip/datagram/bound.rs @@ -8,7 +8,7 @@ use aster_bigtcp::{ use crate::{ events::IoEvents, net::{ - iface::{BoundUdpSocket, Iface}, + iface::{Iface, UdpSocket}, socket::util::send_recv_flags::SendRecvFlags, }, prelude::*, @@ -16,12 +16,12 @@ use crate::{ }; pub struct BoundDatagram { - bound_socket: BoundUdpSocket, + bound_socket: UdpSocket, remote_endpoint: Option, } impl BoundDatagram { - pub fn new(bound_socket: BoundUdpSocket) -> Self { + pub fn new(bound_socket: UdpSocket) -> Self { Self { bound_socket, remote_endpoint: None, diff --git a/kernel/src/net/socket/ip/datagram/unbound.rs b/kernel/src/net/socket/ip/datagram/unbound.rs index 29d56c4bb..2ba33dea7 100644 --- a/kernel/src/net/socket/ip/datagram/unbound.rs +++ b/kernel/src/net/socket/ip/datagram/unbound.rs @@ -1,19 +1,17 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_bigtcp::{socket::UnboundUdpSocket, wire::IpEndpoint}; +use aster_bigtcp::{socket::UdpSocket, wire::IpEndpoint}; use super::{bound::BoundDatagram, DatagramObserver}; -use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*}; +use crate::{events::IoEvents, net::socket::ip::common::bind_port, prelude::*}; pub struct UnboundDatagram { - unbound_socket: Box, + _private: (), } impl UnboundDatagram { pub fn new() -> Self { - Self { - unbound_socket: Box::new(UnboundUdpSocket::new()), - } + Self { _private: () } } pub fn bind( @@ -22,18 +20,17 @@ impl UnboundDatagram { can_reuse: bool, observer: DatagramObserver, ) -> core::result::Result { - let bound_socket = match bind_socket( - self.unbound_socket, - endpoint, - can_reuse, - |iface, socket, config| iface.bind_udp(socket, observer, config), - ) { - Ok(bound_socket) => bound_socket, - Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })), + let bound_port = match bind_port(endpoint, can_reuse) { + Ok(bound_port) => bound_port, + Err(err) => return Err((err, self)), }; - let bound_endpoint = bound_socket.local_endpoint().unwrap(); - bound_socket.bind(bound_endpoint).unwrap(); + let bound_socket = match UdpSocket::new_bind(bound_port, observer) { + Ok(bound_socket) => bound_socket, + Err((_, err)) => { + unreachable!("`new_bind fails with {:?}, which should not happen", err) + } + }; Ok(BoundDatagram::new(bound_socket)) } diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index bb673b2ab..71f705882 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -12,7 +12,7 @@ use super::StreamObserver; use crate::{ events::IoEvents, net::{ - iface::{BoundTcpSocket, Iface}, + iface::{Iface, TcpConnection}, socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, }, prelude::*, @@ -21,7 +21,7 @@ use crate::{ }; pub struct ConnectedStream { - bound_socket: BoundTcpSocket, + tcp_conn: TcpConnection, remote_endpoint: IpEndpoint, /// Indicates whether this connection is "new" in a `connect()` system call. /// @@ -47,12 +47,12 @@ pub struct ConnectedStream { impl ConnectedStream { pub fn new( - bound_socket: BoundTcpSocket, + tcp_conn: TcpConnection, remote_endpoint: IpEndpoint, is_new_connection: bool, ) -> Self { Self { - bound_socket, + tcp_conn, remote_endpoint, is_new_connection, is_receiving_closed: AtomicBool::new(false), @@ -70,7 +70,7 @@ impl ConnectedStream { if cmd.shut_write() { self.is_sending_closed.store(true, Ordering::Relaxed); - self.bound_socket.close(); + self.tcp_conn.close(); events |= IoEvents::OUT | IoEvents::HUP; } @@ -84,7 +84,7 @@ impl ConnectedStream { writer: &mut dyn MultiWrite, _flags: SendRecvFlags, ) -> Result<(usize, NeedIfacePoll)> { - let result = self.bound_socket.recv(|socket_buffer| { + let result = self.tcp_conn.recv(|socket_buffer| { match writer.write(&mut VmReader::from(&*socket_buffer)) { Ok(len) => (len, Ok(len)), Err(e) => (0, Err(e)), @@ -116,7 +116,7 @@ impl ConnectedStream { reader: &mut dyn MultiRead, _flags: SendRecvFlags, ) -> Result<(usize, NeedIfacePoll)> { - let result = self.bound_socket.send(|socket_buffer| { + let result = self.tcp_conn.send(|socket_buffer| { match reader.read(&mut VmWriter::from(socket_buffer)) { Ok(len) => (len, Ok(len)), Err(e) => (0, Err(e)), @@ -143,7 +143,7 @@ impl ConnectedStream { } pub fn local_endpoint(&self) -> IpEndpoint { - self.bound_socket.local_endpoint().unwrap() + self.tcp_conn.local_endpoint().unwrap() } pub fn remote_endpoint(&self) -> IpEndpoint { @@ -151,7 +151,7 @@ impl ConnectedStream { } pub fn iface(&self) -> &Arc { - self.bound_socket.iface() + self.tcp_conn.iface() } pub fn check_new(&mut self) -> Result<()> { @@ -163,8 +163,12 @@ impl ConnectedStream { Ok(()) } + pub(super) fn init_observer(&self, observer: StreamObserver) { + self.tcp_conn.init_observer(observer); + } + pub(super) fn check_io_events(&self) -> IoEvents { - self.bound_socket.raw_with(|socket| { + self.tcp_conn.raw_with(|socket| { if socket.is_peer_closed() { // Only the sending side of peer socket is closed self.is_receiving_closed.store(true, Ordering::Relaxed); @@ -202,18 +206,14 @@ impl ConnectedStream { }) } - pub(super) fn set_observer(&self, observer: StreamObserver) { - self.bound_socket.set_observer(observer) - } - pub(super) fn set_raw_option( - &mut self, - set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + &self, + set_option: impl FnOnce(&dyn RawTcpSetOption) -> R, ) -> R { - set_option(&mut self.bound_socket) + set_option(&self.tcp_conn) } pub(super) fn raw_with(&self, f: impl FnOnce(&RawTcpSocket) -> R) -> R { - self.bound_socket.raw_with(f) + self.tcp_conn.raw_with(f) } } diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index e73d1fd0f..001b8ff36 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -1,19 +1,19 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::{ - socket::{ConnectState, RawTcpSetOption}, + socket::{ConnectState, RawTcpOption, RawTcpSetOption}, wire::IpEndpoint, }; -use super::{connected::ConnectedStream, init::InitStream}; +use super::{connected::ConnectedStream, init::InitStream, StreamObserver}; use crate::{ events::IoEvents, - net::iface::{BoundTcpSocket, Iface}, + net::iface::{BoundPort, Iface, TcpConnection}, prelude::*, }; pub struct ConnectingStream { - bound_socket: BoundTcpSocket, + tcp_conn: TcpConnection, remote_endpoint: IpEndpoint, } @@ -25,32 +25,38 @@ pub enum ConnResult { impl ConnectingStream { pub fn new( - bound_socket: BoundTcpSocket, + bound_port: BoundPort, remote_endpoint: IpEndpoint, - ) -> core::result::Result { + option: &RawTcpOption, + observer: StreamObserver, + ) -> core::result::Result { // The only reason this method might fail is because we're trying to connect to an // unspecified address (i.e. 0.0.0.0). We currently have no support for binding to, // listening on, or connecting to the unspecified address. // // We assume the remote will just refuse to connect, so we return `ECONNREFUSED`. - if bound_socket.connect(remote_endpoint).is_err() { - return Err(( - Error::with_message( - Errno::ECONNREFUSED, - "connecting to an unspecified address is not supported", - ), - bound_socket, - )); - } + let tcp_conn = + match TcpConnection::new_connect(bound_port, remote_endpoint, option, observer) { + Ok(tcp_conn) => tcp_conn, + Err((bound_port, _)) => { + return Err(( + Error::with_message( + Errno::ECONNREFUSED, + "connecting to an unspecified address is not supported", + ), + bound_port, + )) + } + }; Ok(Self { - bound_socket, + tcp_conn, remote_endpoint, }) } pub fn has_result(&self) -> bool { - match self.bound_socket.connect_state() { + match self.tcp_conn.connect_state() { ConnectState::Connecting => false, ConnectState::Connected => true, ConnectState::Refused => true, @@ -58,21 +64,23 @@ impl ConnectingStream { } pub fn into_result(self) -> ConnResult { - let next_state = self.bound_socket.connect_state(); + let next_state = self.tcp_conn.connect_state(); match next_state { ConnectState::Connecting => ConnResult::Connecting(self), ConnectState::Connected => ConnResult::Connected(ConnectedStream::new( - self.bound_socket, + self.tcp_conn, self.remote_endpoint, true, )), - ConnectState::Refused => ConnResult::Refused(InitStream::new_bound(self.bound_socket)), + ConnectState::Refused => ConnResult::Refused(InitStream::new_bound( + self.tcp_conn.into_bound_port().unwrap(), + )), } } pub fn local_endpoint(&self) -> IpEndpoint { - self.bound_socket.local_endpoint().unwrap() + self.tcp_conn.local_endpoint().unwrap() } pub fn remote_endpoint(&self) -> IpEndpoint { @@ -80,7 +88,7 @@ impl ConnectingStream { } pub fn iface(&self) -> &Arc { - self.bound_socket.iface() + self.tcp_conn.iface() } pub(super) fn check_io_events(&self) -> IoEvents { @@ -88,9 +96,9 @@ impl ConnectingStream { } pub(super) fn set_raw_option( - &mut self, - set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + &self, + set_option: impl FnOnce(&dyn RawTcpSetOption) -> R, ) -> R { - set_option(&mut self.bound_socket) + set_option(&self.tcp_conn) } } diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index 17ebdeb0e..f39f9964a 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -1,43 +1,38 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_bigtcp::{ - socket::{RawTcpSetOption, UnboundTcpSocket}, - wire::IpEndpoint, -}; +use aster_bigtcp::{socket::RawTcpOption, wire::IpEndpoint}; use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver}; use crate::{ events::IoEvents, net::{ - iface::BoundTcpSocket, - socket::ip::common::{bind_socket, get_ephemeral_endpoint}, + iface::BoundPort, + socket::ip::common::{bind_port, get_ephemeral_endpoint}, }, prelude::*, - process::signal::Pollee, }; pub enum InitStream { - Unbound(Box), - Bound(BoundTcpSocket), + Unbound, + Bound(BoundPort), } impl InitStream { pub fn new() -> Self { - InitStream::Unbound(Box::new(UnboundTcpSocket::new())) + InitStream::Unbound } - pub fn new_bound(bound_socket: BoundTcpSocket) -> Self { - InitStream::Bound(bound_socket) + pub fn new_bound(bound_port: BoundPort) -> Self { + InitStream::Bound(bound_port) } pub fn bind( self, endpoint: &IpEndpoint, can_reuse: bool, - observer: StreamObserver, - ) -> core::result::Result { - let unbound_socket = match self { - InitStream::Unbound(unbound_socket) => unbound_socket, + ) -> core::result::Result { + match self { + InitStream::Unbound => (), InitStream::Bound(bound_socket) => { return Err(( Error::with_message(Errno::EINVAL, "the socket is already bound to an address"), @@ -45,48 +40,45 @@ impl InitStream { )); } }; - let bound_socket = match bind_socket( - unbound_socket, - endpoint, - can_reuse, - |iface, socket, config| iface.bind_tcp(socket, observer, config), - ) { - Ok(bound_socket) => bound_socket, - Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))), + + let bound_port = match bind_port(endpoint, can_reuse) { + Ok(bound_port) => bound_port, + Err(err) => return Err((err, Self::Unbound)), }; - Ok(bound_socket) + + Ok(bound_port) } fn bind_to_ephemeral_endpoint( self, remote_endpoint: &IpEndpoint, - observer: StreamObserver, - ) -> core::result::Result { + ) -> core::result::Result { let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(&endpoint, false, observer) + self.bind(&endpoint, false) } pub fn connect( self, remote_endpoint: &IpEndpoint, - pollee: &Pollee, + option: &RawTcpOption, + observer: StreamObserver, ) -> core::result::Result { - let bound_socket = match self { - InitStream::Bound(bound_socket) => bound_socket, - InitStream::Unbound(_) => self - .bind_to_ephemeral_endpoint(remote_endpoint, StreamObserver::new(pollee.clone()))?, + let bound_port = match self { + InitStream::Bound(bound_port) => bound_port, + InitStream::Unbound => self.bind_to_ephemeral_endpoint(remote_endpoint)?, }; - ConnectingStream::new(bound_socket, *remote_endpoint) - .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) + ConnectingStream::new(bound_port, *remote_endpoint, option, observer) + .map_err(|(err, bound_port)| (err, InitStream::Bound(bound_port))) } pub fn listen( self, backlog: usize, - pollee: &Pollee, + option: &RawTcpOption, + observer: StreamObserver, ) -> core::result::Result { - let InitStream::Bound(bound_socket) = self else { + let InitStream::Bound(bound_port) = self else { // FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral // port. However, INADDR_ANY is not yet supported, so we need to return an error first. debug_assert!(false, "listen() without bind() is not implemented"); @@ -96,14 +88,13 @@ impl InitStream { )); }; - ListenStream::new(bound_socket, backlog, pollee) - .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) + Ok(ListenStream::new(bound_port, backlog, option, observer)) } pub fn local_endpoint(&self) -> Option { match self { - InitStream::Unbound(_) => None, - InitStream::Bound(bound_socket) => Some(bound_socket.local_endpoint().unwrap()), + InitStream::Unbound => None, + InitStream::Bound(bound_port) => Some(bound_port.endpoint().unwrap()), } } @@ -111,14 +102,4 @@ impl InitStream { // Linux adds OUT and HUP events for a newly created socket IoEvents::OUT | IoEvents::HUP } - - pub(super) fn set_raw_option( - &mut self, - set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, - ) -> R { - match self { - InitStream::Unbound(unbound_socket) => set_option(unbound_socket.as_mut()), - InitStream::Bound(bound_socket) => set_option(bound_socket), - } - } } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 514b37fe0..74c9256ff 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -1,103 +1,59 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::{ - errors::tcp::ListenError, - iface::BindPortConfig, - socket::{RawTcpSetOption, TcpState, UnboundTcpSocket}, + socket::{RawTcpOption, RawTcpSetOption}, wire::IpEndpoint, }; -use ostd::sync::PreemptDisabled; use super::{connected::ConnectedStream, StreamObserver}; use crate::{ events::IoEvents, - net::iface::{BoundTcpSocket, Iface}, + net::iface::{BoundPort, Iface, TcpListener}, prelude::*, - process::signal::Pollee, }; pub struct ListenStream { - backlog: usize, - /// A bound socket held to ensure the TCP port cannot be released - bound_socket: BoundTcpSocket, - /// Backlog sockets listening at the local endpoint - backlog_sockets: RwLock, PreemptDisabled>, + tcp_listener: TcpListener, } impl ListenStream { pub fn new( - bound_socket: BoundTcpSocket, + bound_port: BoundPort, backlog: usize, - pollee: &Pollee, - ) -> core::result::Result { + option: &RawTcpOption, + observer: StreamObserver, + ) -> Self { const SOMAXCONN: usize = 4096; - let somaxconn = SOMAXCONN.min(backlog); + let max_conn = SOMAXCONN.min(backlog); - let listen_stream = Self { - backlog: somaxconn, - bound_socket, - backlog_sockets: RwLock::new(Vec::new()), + let tcp_listener = match TcpListener::new_listen(bound_port, max_conn, option, observer) { + Ok(tcp_listener) => tcp_listener, + Err((_, err)) => { + unreachable!("`new_listen` fails with {:?}, which should not happen", err) + } }; - if let Err(err) = listen_stream.fill_backlog_sockets(pollee) { - return Err((err, listen_stream.bound_socket)); - } - Ok(listen_stream) + + Self { tcp_listener } } - /// Append sockets listening at LocalEndPoint to support backlog - fn fill_backlog_sockets(&self, pollee: &Pollee) -> Result<()> { - let mut backlog_sockets = self.backlog_sockets.write(); + pub fn try_accept(&self) -> Result { + let (new_conn, remote_endpoint) = self.tcp_listener.accept().ok_or_else(|| { + Error::with_message(Errno::EAGAIN, "no pending connection is available") + })?; - let backlog = self.backlog; - let current_backlog_len = backlog_sockets.len(); - debug_assert!(backlog >= current_backlog_len); - if backlog == current_backlog_len { - return Ok(()); - } - - for _ in current_backlog_len..backlog { - let backlog_socket = BacklogSocket::new(&self.bound_socket, pollee)?; - backlog_sockets.push(backlog_socket); - } - - Ok(()) - } - - pub fn try_accept(&self, pollee: &Pollee) -> Result { - let mut backlog_sockets = self.backlog_sockets.write(); - - let index = backlog_sockets - .iter() - .position(|backlog_socket| backlog_socket.can_accept()) - .ok_or_else(|| { - Error::with_message(Errno::EAGAIN, "no pending connection is available") - })?; - let active_backlog_socket = backlog_sockets.remove(index); - - if let Ok(backlog_socket) = BacklogSocket::new(&self.bound_socket, pollee) { - backlog_sockets.push(backlog_socket); - } - - let remote_endpoint = active_backlog_socket.remote_endpoint().unwrap(); - Ok(ConnectedStream::new( - active_backlog_socket.into_bound_socket(), - remote_endpoint, - false, - )) + Ok(ConnectedStream::new(new_conn, remote_endpoint, false)) } pub fn local_endpoint(&self) -> IpEndpoint { - self.bound_socket.local_endpoint().unwrap() + self.tcp_listener.local_endpoint().unwrap() } pub fn iface(&self) -> &Arc { - self.bound_socket.iface() + self.tcp_listener.iface() } pub(super) fn check_io_events(&self) -> IoEvents { - let backlog_sockets = self.backlog_sockets.read(); - - let can_accept = backlog_sockets.iter().any(|socket| socket.can_accept()); + let can_accept = self.tcp_listener.can_accept(); // If network packets come in simultaneously, the socket state may change in the middle. // However, the current pollee implementation should be able to handle this race condition. @@ -108,97 +64,10 @@ impl ListenStream { } } - /// Calls `f` to set socket option on raw socket. - /// - /// This method will call `f` on the bound socket and each backlog socket that is in `Listen` state . pub(super) fn set_raw_option( - &mut self, - set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + &self, + set_option: impl FnOnce(&dyn RawTcpSetOption) -> R, ) -> R { - self.backlog_sockets.write().iter_mut().for_each(|socket| { - if socket - .bound_socket - .raw_with(|raw_tcp_socket| raw_tcp_socket.state() != TcpState::Listen) - { - return; - } - - // If the socket receives SYN after above check, - // we will also set keep alive on the socket that is not in `Listen` state. - // But such a race doesn't matter, we just let it happen. - set_option(&mut socket.bound_socket); - }); - - set_option(&mut self.bound_socket) - } -} - -struct BacklogSocket { - bound_socket: BoundTcpSocket, -} - -impl BacklogSocket { - // FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason - // why the error may occur. Perhaps it is better to call `unwrap()` directly? - fn new(bound_socket: &BoundTcpSocket, pollee: &Pollee) -> Result { - let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message( - Errno::EINVAL, - "the socket is not bound", - ))?; - - let unbound_socket = { - let mut unbound = UnboundTcpSocket::new(); - unbound.set_keep_alive(bound_socket.raw_with(|socket| socket.keep_alive())); - unbound.set_nagle_enabled(bound_socket.raw_with(|socket| socket.nagle_enabled())); - - // TODO: Inherit other options that can be set via `setsockopt` from bound socket - - Box::new(unbound) - }; - let bound_socket = { - let iface = bound_socket.iface(); - let bind_port_config = BindPortConfig::new(local_endpoint.port, true); - iface - .bind_tcp( - unbound_socket, - StreamObserver::new(pollee.clone()), - bind_port_config, - ) - .map_err(|(err, _)| err)? - }; - - match bound_socket.listen(local_endpoint) { - Ok(()) => Ok(Self { bound_socket }), - Err(ListenError::Unaddressable) => { - return_errno_with_message!(Errno::EINVAL, "the listening address is invalid") - } - Err(ListenError::InvalidState) => { - return_errno_with_message!(Errno::EINVAL, "the listening socket is invalid") - } - } - } - - /// Returns whether the backlog socket can be `accept`ed. - /// - /// According to the Linux implementation, assuming the TCP Fast Open mechanism is off, a - /// backlog socket becomes ready to be returned in the `accept` system call when the 3-way - /// handshake is complete (i.e., when it enters the ESTABLISHED state). - /// - /// The Linux kernel implementation can be found at - /// . - // - // FIMXE: Some sockets may be dead (e.g., RSTed), and such sockets can never become alive - // again. We need to remove them from the backlog sockets. - fn can_accept(&self) -> bool { - self.bound_socket.raw_with(|socket| socket.may_send()) - } - - fn remote_endpoint(&self) -> Option { - self.bound_socket - .raw_with(|socket| socket.remote_endpoint()) - } - - fn into_bound_socket(self) -> BoundTcpSocket { - self.bound_socket + set_option(&self.tcp_listener) } } diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 0f5ade33d..d05eb6eed 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -3,14 +3,14 @@ use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::{ - socket::{NeedIfacePoll, RawTcpSetOption}, + socket::{NeedIfacePoll, RawTcpOption, RawTcpSetOption}, wire::IpEndpoint, }; use connected::ConnectedStream; use connecting::{ConnResult, ConnectingStream}; use init::InitStream; use listen::ListenStream; -use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; +use options::{Congestion, MaxSegment, NoDelay, WindowClamp, KEEPALIVE_INTERVAL}; use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; use takeable::Takeable; use util::TcpOptionSet; @@ -83,6 +83,13 @@ impl OptionSet { let tcp = TcpOptionSet::new(); OptionSet { socket, tcp } } + + fn raw(&self) -> RawTcpOption { + RawTcpOption { + keep_alive: self.socket.keep_alive().then_some(KEEPALIVE_INTERVAL), + is_nagle_enabled: !self.tcp.no_delay(), + } + } } impl StreamSocket { @@ -114,7 +121,7 @@ impl StreamSocket { }); let pollee = Pollee::new(); - connected_stream.set_observer(StreamObserver::new(pollee.clone())); + connected_stream.init_observer(StreamObserver::new(pollee.clone())); Arc::new(Self { options: RwLock::new(options), @@ -207,7 +214,9 @@ impl StreamSocket { // `Some(_)` if blocking is not necessary or not allowed. fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option> { let is_nonblocking = self.is_nonblocking(); - let mut state = self.write_updated_state(); + let (options, mut state) = self.update_connecting(); + + let raw_option = options.raw(); let (result_or_block, iface_to_poll) = state.borrow_result(|mut owned_state| { let init_stream = match owned_state { @@ -243,7 +252,11 @@ impl StreamSocket { } }; - let connecting_stream = match init_stream.connect(remote_endpoint, &self.pollee) { + let connecting_stream = match init_stream.connect( + remote_endpoint, + &raw_option, + StreamObserver::new(self.pollee.clone()), + ) { Ok(connecting_stream) => connecting_stream, Err((err, init_stream)) => { return (State::Init(init_stream), (Some(Err(err)), None)); @@ -298,13 +311,11 @@ impl StreamSocket { return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); }; - let accepted = listen_stream - .try_accept(&self.pollee) - .map(|connected_stream| { - let remote_endpoint = connected_stream.remote_endpoint(); - let accepted_socket = Self::new_accepted(connected_stream); - (accepted_socket as _, remote_endpoint.into()) - }); + let accepted = listen_stream.try_accept().map(|connected_stream| { + let remote_endpoint = connected_stream.remote_endpoint(); + let accepted_socket = Self::new_accepted(connected_stream); + (accepted_socket as _, remote_endpoint.into()) + }); let iface_to_poll = listen_stream.iface().clone(); drop(state); @@ -475,18 +486,14 @@ impl Socket for StreamSocket { ); }; - let bound_socket = match init_stream.bind( - &endpoint, - can_reuse, - StreamObserver::new(self.pollee.clone()), - ) { - Ok(bound_socket) => bound_socket, + let bound_port = match init_stream.bind(&endpoint, can_reuse) { + Ok(bound_port) => bound_port, Err((err, init_stream)) => { return (State::Init(init_stream), Err(err)); } }; - (State::Init(InitStream::new_bound(bound_socket)), Ok(())) + (State::Init(InitStream::new_bound(bound_port)), Ok(())) }) } @@ -501,7 +508,9 @@ impl Socket for StreamSocket { } fn listen(&self, backlog: usize) -> Result<()> { - let mut state = self.write_updated_state(); + let (options, mut state) = self.update_connecting(); + + let raw_option = options.raw(); state.borrow_result(|owned_state| { let init_stream = match owned_state { @@ -520,7 +529,11 @@ impl Socket for StreamSocket { } }; - let listen_stream = match init_stream.listen(backlog, &self.pollee) { + let listen_stream = match init_stream.listen( + backlog, + &raw_option, + StreamObserver::new(self.pollee.clone()), + ) { Ok(listen_stream) => listen_stream, Err((err, init_stream)) => { return (State::Init(init_stream), Err(err)); @@ -701,7 +714,7 @@ impl Socket for StreamSocket { tcp_no_delay: NoDelay => { let no_delay = tcp_no_delay.get().unwrap(); options.tcp.set_no_delay(*no_delay); - state.set_raw_option(|raw_socket: &mut dyn RawTcpSetOption| raw_socket.set_nagle_enabled(!no_delay)); + state.set_raw_option(|raw_socket: &dyn RawTcpSetOption| raw_socket.set_nagle_enabled(!no_delay)); }, tcp_congestion: Congestion => { let congestion = tcp_congestion.get().unwrap(); @@ -736,14 +749,16 @@ impl Socket for StreamSocket { impl State { /// Calls `f` to set raw socket option. /// - /// Note that for listening socket, `f` is called on all backlog sockets in `Listen` State. - /// That is to say, `f` won't be called on backlog sockets in `SynReceived` or `Established` state. - fn set_raw_option(&mut self, set_option: impl Fn(&mut dyn RawTcpSetOption) -> R) -> R { + /// For listening sockets, socket options are inherited by new connections. However, they are + /// not updated for connections in the backlog queue. + fn set_raw_option(&self, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R) -> Option { match self { - State::Init(init_stream) => init_stream.set_raw_option(set_option), - State::Connecting(connecting_stream) => connecting_stream.set_raw_option(set_option), - State::Connected(connected_stream) => connected_stream.set_raw_option(set_option), - State::Listen(listen_stream) => listen_stream.set_raw_option(set_option), + State::Init(_) => None, + State::Connecting(connecting_stream) => { + Some(connecting_stream.set_raw_option(set_option)) + } + State::Connected(connected_stream) => Some(connected_stream.set_raw_option(set_option)), + State::Listen(listen_stream) => Some(listen_stream.set_raw_option(set_option)), } } @@ -758,24 +773,17 @@ impl State { } impl SetSocketLevelOption for State { - fn set_keep_alive(&mut self, keep_alive: bool) -> NeedIfacePoll { - /// The keepalive interval. - /// - /// The linux value can be found at `/proc/sys/net/ipv4/tcp_keepalive_intvl`, - /// which is by default 75 seconds for most Linux distributions. - const KEEPALIVE_INTERVAL: aster_bigtcp::time::Duration = - aster_bigtcp::time::Duration::from_secs(75); - + fn set_keep_alive(&self, keep_alive: bool) -> NeedIfacePoll { let interval = if keep_alive { Some(KEEPALIVE_INTERVAL) } else { None }; - let set_keepalive = - |raw_socket: &mut dyn RawTcpSetOption| raw_socket.set_keep_alive(interval); + let set_keepalive = |raw_socket: &dyn RawTcpSetOption| raw_socket.set_keep_alive(interval); self.set_raw_option(set_keepalive) + .unwrap_or(NeedIfacePoll::FALSE) } } diff --git a/kernel/src/net/socket/ip/stream/observer.rs b/kernel/src/net/socket/ip/stream/observer.rs index 87cd45231..862ae229f 100644 --- a/kernel/src/net/socket/ip/stream/observer.rs +++ b/kernel/src/net/socket/ip/stream/observer.rs @@ -4,6 +4,7 @@ use aster_bigtcp::socket::{SocketEventObserver, SocketEvents}; use crate::{events::IoEvents, process::signal::Pollee}; +#[derive(Clone)] pub struct StreamObserver(Pollee); impl StreamObserver { diff --git a/kernel/src/net/socket/ip/stream/options.rs b/kernel/src/net/socket/ip/stream/options.rs index 5b3478313..fbb1c903d 100644 --- a/kernel/src/net/socket/ip/stream/options.rs +++ b/kernel/src/net/socket/ip/stream/options.rs @@ -9,3 +9,10 @@ impl_socket_options!( pub struct MaxSegment(u32); pub struct WindowClamp(u32); ); + +/// The keepalive interval. +/// +/// The linux value can be found at `/proc/sys/net/ipv4/tcp_keepalive_intvl`, +/// which is by default 75 seconds for most Linux distributions. +pub(super) const KEEPALIVE_INTERVAL: aster_bigtcp::time::Duration = + aster_bigtcp::time::Duration::from_secs(75); diff --git a/kernel/src/net/socket/util/options.rs b/kernel/src/net/socket/util/options.rs index 2015c7a65..a2d720d01 100644 --- a/kernel/src/net/socket/util/options.rs +++ b/kernel/src/net/socket/util/options.rs @@ -173,7 +173,7 @@ impl LingerOption { /// A trait used for setting socket level options on actual sockets. pub(in crate::net) trait SetSocketLevelOption { /// Sets whether keepalive messages are enabled. - fn set_keep_alive(&mut self, _keep_alive: bool) -> NeedIfacePoll { + fn set_keep_alive(&self, _keep_alive: bool) -> NeedIfacePoll { NeedIfacePoll::FALSE } }