From 1716f4f3243d3bc9b514829537c737d8938fec07 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Mon, 2 Dec 2024 23:53:19 +0800 Subject: [PATCH] Use `Pollee` as the socket observer --- kernel/libs/aster-bigtcp/src/ext.rs | 8 +- kernel/libs/aster-bigtcp/src/iface/common.rs | 6 +- kernel/libs/aster-bigtcp/src/iface/iface.rs | 6 +- kernel/libs/aster-bigtcp/src/iface/poll.rs | 13 +-- kernel/libs/aster-bigtcp/src/socket/bound.rs | 78 ++++++++---------- .../libs/aster-bigtcp/src/socket/unbound.rs | 27 ++++-- kernel/src/net/iface/ext.rs | 4 + kernel/src/net/socket/ip/datagram/mod.rs | 55 ++++++------- kernel/src/net/socket/ip/datagram/observer.rs | 29 +++++++ kernel/src/net/socket/ip/datagram/unbound.rs | 16 ++-- kernel/src/net/socket/ip/mod.rs | 4 +- kernel/src/net/socket/ip/stream/connected.rs | 6 +- kernel/src/net/socket/ip/stream/init.rs | 32 ++++---- kernel/src/net/socket/ip/stream/listen.rs | 24 ++++-- kernel/src/net/socket/ip/stream/mod.rs | 82 +++++++------------ kernel/src/net/socket/ip/stream/observer.rs | 37 +++++++++ kernel/src/process/signal/poll.rs | 1 + kernel/src/syscall/socket.rs | 2 +- 18 files changed, 242 insertions(+), 188 deletions(-) create mode 100644 kernel/src/net/socket/ip/datagram/observer.rs create mode 100644 kernel/src/net/socket/ip/stream/observer.rs diff --git a/kernel/libs/aster-bigtcp/src/ext.rs b/kernel/libs/aster-bigtcp/src/ext.rs index 336889d6..ec39f68e 100644 --- a/kernel/libs/aster-bigtcp/src/ext.rs +++ b/kernel/libs/aster-bigtcp/src/ext.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use crate::iface::ScheduleNextPoll; +use crate::{iface::ScheduleNextPoll, socket::SocketEventObserver}; /// Extension to be implemented by users of this crate. /// @@ -13,4 +13,10 @@ use crate::iface::ScheduleNextPoll; pub trait Ext { /// The type for ifaces to schedule the next poll. type ScheduleNextPoll: ScheduleNextPoll; + + /// The type for TCP sockets to observe events. + type TcpEventObserver: SocketEventObserver; + + /// The type for UDP sockets to observe events. + type UdpEventObserver: SocketEventObserver; } diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index 4422edca..21b53cc3 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -86,6 +86,7 @@ impl IfaceCommon { &self, iface: Arc>, socket: Box, + observer: E::TcpEventObserver, config: BindPortConfig, ) -> core::result::Result, (BindError, Box)> { let port = match self.bind_port(config) { @@ -93,7 +94,7 @@ impl IfaceCommon { Err(err) => return Err((err, socket)), }; - let (raw_socket, observer) = socket.into_raw(); + let raw_socket = socket.into_raw(); let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer); let inserted = self @@ -109,6 +110,7 @@ impl IfaceCommon { &self, iface: Arc>, socket: Box, + observer: E::UdpEventObserver, config: BindPortConfig, ) -> core::result::Result, (BindError, Box)> { let port = match self.bind_port(config) { @@ -116,7 +118,7 @@ impl IfaceCommon { Err(err) => return Err((err, socket)), }; - let (raw_socket, observer) = socket.into_raw(); + let raw_socket = socket.into_raw(); let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer); let inserted = self diff --git a/kernel/libs/aster-bigtcp/src/iface/iface.rs b/kernel/libs/aster-bigtcp/src/iface/iface.rs index 738f8f4f..6aa7865b 100644 --- a/kernel/libs/aster-bigtcp/src/iface/iface.rs +++ b/kernel/libs/aster-bigtcp/src/iface/iface.rs @@ -37,19 +37,21 @@ impl dyn Iface { pub fn bind_tcp( self: &Arc, socket: Box, + observer: E::TcpEventObserver, config: BindPortConfig, ) -> core::result::Result, (BindError, Box)> { let common = self.common(); - common.bind_tcp(self.clone(), socket, config) + common.bind_tcp(self.clone(), socket, observer, config) } pub fn bind_udp( self: &Arc, socket: Box, + observer: E::UdpEventObserver, config: BindPortConfig, ) -> core::result::Result, (BindError, Box)> { let common = self.common(); - common.bind_udp(self.clone(), socket, config) + common.bind_udp(self.clone(), socket, observer, config) } /// Gets the name of the iface. diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs index 33673d31..729522f4 100644 --- a/kernel/libs/aster-bigtcp/src/iface/poll.rs +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -16,15 +16,18 @@ use smoltcp::{ }, }; -use crate::socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; +use crate::{ + ext::Ext, + socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}, +}; -pub(super) struct PollContext<'a, E> { +pub(super) struct PollContext<'a, E: Ext> { iface_cx: &'a mut Context, tcp_sockets: &'a BTreeSet>>, udp_sockets: &'a BTreeSet>>, } -impl<'a, E> PollContext<'a, E> { +impl<'a, E: Ext> PollContext<'a, E> { #[allow(clippy::mutable_key_type)] pub(super) fn new( iface_cx: &'a mut Context, @@ -44,7 +47,7 @@ impl<'a, E> PollContext<'a, E> { pub(super) trait FnHelper: FnMut(A, B, C) -> O {} impl FnHelper for F where F: FnMut(A, B, C) -> O {} -impl PollContext<'_, E> { +impl PollContext<'_, E> { pub(super) fn poll_ingress( &mut self, device: &mut D, @@ -280,7 +283,7 @@ impl PollContext<'_, E> { } } -impl PollContext<'_, E> { +impl PollContext<'_, E> { pub(super) fn poll_egress(&mut self, device: &mut D, mut dispatch_phy: Q) where D: Device + ?Sized, diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 74ed97a2..375d364f 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -1,9 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{ - boxed::Box, - sync::{Arc, Weak}, -}; +use alloc::{boxed::Box, sync::Arc}; use core::{ ops::{Deref, DerefMut}, sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, @@ -23,18 +20,20 @@ use super::{ }; use crate::{ext::Ext, iface::Iface}; -pub struct BoundSocket(Arc>); +pub struct BoundSocket, E: Ext>(Arc>); /// [`TcpSocket`] or [`UdpSocket`]. -pub trait AnySocket { +pub trait AnySocket { type RawSocket; + type Observer: SocketEventObserver; /// Called by [`BoundSocket::new`]. fn new(socket: Box) -> Self; /// Called by [`BoundSocket::drop`]. - fn on_drop(this: &Arc>) + fn on_drop(this: &Arc>) where + E: Ext, Self: Sized; } @@ -42,11 +41,11 @@ pub type BoundTcpSocket = BoundSocket; pub type BoundUdpSocket = BoundSocket; /// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`]. -pub struct BoundSocketInner { +pub struct BoundSocketInner, E> { iface: Arc>, port: u16, socket: T, - observer: RwLock, WriteIrqDisabled>, + observer: RwLock, events: AtomicU8, next_poll_at_ms: AtomicU64, } @@ -137,8 +136,9 @@ impl TcpSocket { } } -impl AnySocket for TcpSocket { +impl AnySocket for TcpSocket { type RawSocket = RawTcpSocket; + type Observer = E::TcpEventObserver; fn new(socket: Box) -> Self { let socket_ext = RawTcpSocketExt { @@ -153,7 +153,7 @@ impl AnySocket for TcpSocket { } } - fn on_drop(this: &Arc>) { + fn on_drop(this: &Arc>) { let mut socket = this.socket.lock(); socket.in_background = true; @@ -169,14 +169,18 @@ impl AnySocket for TcpSocket { /// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`]. type UdpSocket = SpinLock, LocalIrqDisabled>; -impl AnySocket for UdpSocket { +impl AnySocket for UdpSocket { type RawSocket = RawUdpSocket; + type Observer = E::UdpEventObserver; fn new(socket: Box) -> Self { Self::new(socket) } - fn on_drop(this: &Arc>) { + fn on_drop(this: &Arc>) + where + E: Ext, + { this.socket.lock().close(); // A UDP socket can be removed immediately. @@ -184,7 +188,7 @@ impl AnySocket for UdpSocket { } } -impl Drop for BoundSocket { +impl, E: Ext> Drop for BoundSocket { fn drop(&mut self) { T::on_drop(&self.0); } @@ -193,12 +197,12 @@ impl Drop for BoundSocket { pub(crate) type BoundTcpSocketInner = BoundSocketInner; pub(crate) type BoundUdpSocketInner = BoundSocketInner; -impl BoundSocket { +impl, E: Ext> BoundSocket { pub(crate) fn new( iface: Arc>, port: u16, socket: Box, - observer: Weak, + observer: T::Observer, ) -> Self { Self(Arc::new(BoundSocketInner { iface, @@ -215,24 +219,13 @@ impl BoundSocket { } } -impl BoundSocket { - /// Sets the observer whose `on_events` will be called when certain iface events happen. After - /// setting, the new observer will fire once immediately to avoid missing any events. +impl, E: Ext> BoundSocket { + /// Sets the observer whose `on_events` will be called when certain iface events happen. /// - /// If there is an existing observer, due to race conditions, this function does not guarantee - /// that the old observer will never be called after the setting. Users should be aware of this - /// and proactively handle the race conditions if necessary. - pub fn set_observer(&self, new_observer: Weak) { + /// The caller needs to be responsible for race conditions if network events can occur + /// simultaneously. + pub fn set_observer(&self, new_observer: T::Observer) { *self.0.observer.write() = new_observer; - - self.0.on_events(); - } - - /// Returns the observer. - /// - /// See also [`Self::set_observer`]. - pub fn observer(&self) -> Weak { - self.0.observer.read().clone() } pub fn local_endpoint(&self) -> Option { @@ -449,7 +442,7 @@ impl BoundUdpSocket { } } -impl BoundSocketInner { +impl, E> BoundSocketInner { pub(crate) fn has_events(&self) -> bool { self.events.load(Ordering::Relaxed) != 0 } @@ -460,13 +453,8 @@ impl BoundSocketInner { let events = self.events.load(Ordering::Relaxed); self.events.store(0, Ordering::Relaxed); - // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we - // get the read lock. - let observer = Weak::upgrade(&*self.observer.read()); - - if let Some(inner) = observer { - inner.on_events(SocketEvents::from_bits_truncate(events)); - } + let observer = self.observer.read(); + observer.on_events(SocketEvents::from_bits_truncate(events)); } fn add_events(&self, new_events: SocketEvents) { @@ -513,13 +501,13 @@ impl BoundSocketInner { } } -impl BoundSocketInner { +impl, E> BoundSocketInner { pub(crate) fn port(&self) -> u16 { self.port } } -impl BoundTcpSocketInner { +impl BoundTcpSocketInner { /// Returns whether the TCP socket is dead. /// /// A TCP socket is considered dead if and only if the following two conditions are met: @@ -531,7 +519,7 @@ impl BoundTcpSocketInner { } } -impl BoundSocketInner { +impl, E> BoundSocketInner { /// Returns whether an incoming packet _may_ be processed by the socket. /// /// The check is intended to be lock-free and fast, but may have false positives. @@ -554,7 +542,7 @@ pub(crate) enum TcpProcessResult { ProcessedWithReply(IpRepr, TcpRepr<'static>), } -impl BoundTcpSocketInner { +impl BoundTcpSocketInner { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( &self, @@ -654,7 +642,7 @@ impl BoundTcpSocketInner { } } -impl BoundUdpSocketInner { +impl BoundUdpSocketInner { /// Tries to process an incoming packet and returns whether the packet is processed. pub(crate) fn process( &self, diff --git a/kernel/libs/aster-bigtcp/src/socket/unbound.rs b/kernel/libs/aster-bigtcp/src/socket/unbound.rs index b79f60bb..f68804e0 100644 --- a/kernel/libs/aster-bigtcp/src/socket/unbound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/unbound.rs @@ -1,19 +1,18 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, sync::Weak, vec}; +use alloc::{boxed::Box, vec}; -use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; +use super::{RawTcpSocket, RawUdpSocket}; pub struct UnboundSocket { socket: Box, - observer: Weak, } pub type UnboundTcpSocket = UnboundSocket; pub type UnboundUdpSocket = UnboundSocket; impl UnboundTcpSocket { - pub fn new(observer: Weak) -> Self { + pub fn new() -> Self { let raw_tcp_socket = { let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]); @@ -21,13 +20,18 @@ impl UnboundTcpSocket { }; Self { socket: Box::new(raw_tcp_socket), - observer, } } } +impl Default for UnboundTcpSocket { + fn default() -> Self { + Self::new() + } +} + impl UnboundUdpSocket { - pub fn new(observer: Weak) -> Self { + pub fn new() -> Self { let raw_udp_socket = { let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( @@ -42,14 +46,19 @@ impl UnboundUdpSocket { }; Self { socket: Box::new(raw_udp_socket), - observer, } } } +impl Default for UnboundUdpSocket { + fn default() -> Self { + Self::new() + } +} + impl UnboundSocket { - pub(crate) fn into_raw(self) -> (Box, Weak) { - (self.socket, self.observer) + pub(crate) fn into_raw(self) -> Box { + self.socket } } diff --git a/kernel/src/net/iface/ext.rs b/kernel/src/net/iface/ext.rs index 64846394..b6537833 100644 --- a/kernel/src/net/iface/ext.rs +++ b/kernel/src/net/iface/ext.rs @@ -1,9 +1,13 @@ // SPDX-License-Identifier: MPL-2.0 use super::sched::PollScheduler; +use crate::net::socket::ip::{datagram::DatagramObserver, stream::StreamObserver}; pub struct BigtcpExt; impl aster_bigtcp::ext::Ext for BigtcpExt { type ScheduleNextPoll = PollScheduler; + + type TcpEventObserver = StreamObserver; + type UdpEventObserver = DatagramObserver; } diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 6cdb2e48..7d5e9243 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -2,10 +2,7 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use aster_bigtcp::{ - socket::{SocketEventObserver, SocketEvents}, - wire::IpEndpoint, -}; +use aster_bigtcp::wire::IpEndpoint; use ostd::sync::WriteIrqDisabled; use takeable::Takeable; @@ -32,8 +29,11 @@ use crate::{ }; mod bound; +mod observer; mod unbound; +pub(in crate::net) use self::observer::DatagramObserver; + #[derive(Debug, Clone)] struct OptionSet { socket: SocketOptionSet, @@ -64,6 +64,7 @@ impl Inner { self, endpoint: &IpEndpoint, can_reuse: bool, + observer: DatagramObserver, ) -> core::result::Result { let unbound_datagram = match self { Inner::Unbound(unbound_datagram) => unbound_datagram, @@ -75,7 +76,7 @@ impl Inner { } }; - let bound_datagram = match unbound_datagram.bind(endpoint, can_reuse) { + let bound_datagram = match unbound_datagram.bind(endpoint, can_reuse, observer) { Ok(bound_datagram) => bound_datagram, Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))), }; @@ -85,26 +86,25 @@ impl Inner { fn bind_to_ephemeral_endpoint( self, remote_endpoint: &IpEndpoint, + observer: DatagramObserver, ) -> core::result::Result { if let Inner::Bound(bound_datagram) = self { return Ok(bound_datagram); } let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(&endpoint, false) + self.bind(&endpoint, false, observer) } } impl DatagramSocket { pub fn new(nonblocking: bool) -> Arc { - Arc::new_cyclic(|me| { - let unbound_datagram = UnboundDatagram::new(me.clone() as _); - Self { - inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), - nonblocking: AtomicBool::new(nonblocking), - pollee: Pollee::new(), - options: RwLock::new(OptionSet::new()), - } + let unbound_datagram = UnboundDatagram::new(); + Arc::new(Self { + inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), + nonblocking: AtomicBool::new(nonblocking), + pollee: Pollee::new(), + options: RwLock::new(OptionSet::new()), }) } @@ -134,7 +134,10 @@ impl DatagramSocket { // Slow path let mut inner = self.inner.write(); inner.borrow_result(|owned_inner| { - let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { + let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint( + remote_endpoint, + DatagramObserver::new(self.pollee.clone()), + ) { Ok(bound_datagram) => bound_datagram, Err((err, err_inner)) => { return (err_inner, Err(err)); @@ -277,7 +280,11 @@ impl Socket for DatagramSocket { let can_reuse = self.options.read().socket.reuse_addr(); let mut inner = self.inner.write(); inner.borrow_result(|owned_inner| { - let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) { + let bound_datagram = match owned_inner.bind( + &endpoint, + can_reuse, + DatagramObserver::new(self.pollee.clone()), + ) { Ok(bound_datagram) => bound_datagram, Err((err, err_inner)) => { return (err_inner, Err(err)); @@ -389,19 +396,3 @@ impl Socket for DatagramSocket { self.options.write().socket.set_option(option) } } - -impl SocketEventObserver for DatagramSocket { - fn on_events(&self, events: SocketEvents) { - let mut io_events = IoEvents::empty(); - - if events.contains(SocketEvents::CAN_RECV) { - io_events |= IoEvents::IN; - } - - if events.contains(SocketEvents::CAN_SEND) { - io_events |= IoEvents::OUT; - } - - self.pollee.notify(io_events); - } -} diff --git a/kernel/src/net/socket/ip/datagram/observer.rs b/kernel/src/net/socket/ip/datagram/observer.rs new file mode 100644 index 00000000..d52cc009 --- /dev/null +++ b/kernel/src/net/socket/ip/datagram/observer.rs @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MPL-2.0 + +use aster_bigtcp::socket::{SocketEventObserver, SocketEvents}; + +use crate::{events::IoEvents, process::signal::Pollee}; + +pub struct DatagramObserver(Pollee); + +impl DatagramObserver { + pub(super) fn new(pollee: Pollee) -> Self { + Self(pollee) + } +} + +impl SocketEventObserver for DatagramObserver { + fn on_events(&self, events: SocketEvents) { + let mut io_events = IoEvents::empty(); + + if events.contains(SocketEvents::CAN_RECV) { + io_events |= IoEvents::IN; + } + + if events.contains(SocketEvents::CAN_SEND) { + io_events |= IoEvents::OUT; + } + + self.0.notify(io_events); + } +} diff --git a/kernel/src/net/socket/ip/datagram/unbound.rs b/kernel/src/net/socket/ip/datagram/unbound.rs index 5eb51f94..29d56c4b 100644 --- a/kernel/src/net/socket/ip/datagram/unbound.rs +++ b/kernel/src/net/socket/ip/datagram/unbound.rs @@ -1,13 +1,8 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::Weak; +use aster_bigtcp::{socket::UnboundUdpSocket, wire::IpEndpoint}; -use aster_bigtcp::{ - socket::{SocketEventObserver, UnboundUdpSocket}, - wire::IpEndpoint, -}; - -use super::bound::BoundDatagram; +use super::{bound::BoundDatagram, DatagramObserver}; use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*}; pub struct UnboundDatagram { @@ -15,9 +10,9 @@ pub struct UnboundDatagram { } impl UnboundDatagram { - pub fn new(observer: Weak) -> Self { + pub fn new() -> Self { Self { - unbound_socket: Box::new(UnboundUdpSocket::new(observer)), + unbound_socket: Box::new(UnboundUdpSocket::new()), } } @@ -25,12 +20,13 @@ impl UnboundDatagram { self, endpoint: &IpEndpoint, can_reuse: bool, + observer: DatagramObserver, ) -> core::result::Result { let bound_socket = match bind_socket( self.unbound_socket, endpoint, can_reuse, - |iface, socket, config| iface.bind_udp(socket, config), + |iface, socket, config| iface.bind_udp(socket, observer, config), ) { Ok(bound_socket) => bound_socket, Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })), diff --git a/kernel/src/net/socket/ip/mod.rs b/kernel/src/net/socket/ip/mod.rs index 31326060..7e4c8c98 100644 --- a/kernel/src/net/socket/ip/mod.rs +++ b/kernel/src/net/socket/ip/mod.rs @@ -2,9 +2,7 @@ mod addr; mod common; -mod datagram; +pub mod datagram; pub mod stream; use addr::UNSPECIFIED_LOCAL_ENDPOINT; -pub use datagram::DatagramSocket; -pub use stream::StreamSocket; diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index c6cd05fc..12395b1d 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -1,14 +1,14 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::Weak; use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, - socket::{NeedIfacePoll, SocketEventObserver, TcpStateCheck}, + socket::{NeedIfacePoll, TcpStateCheck}, wire::IpEndpoint, }; +use super::StreamObserver; use crate::{ events::IoEvents, net::{ @@ -202,7 +202,7 @@ impl ConnectedStream { }) } - pub(super) fn set_observer(&self, observer: Weak) { + pub(super) fn set_observer(&self, observer: StreamObserver) { self.bound_socket.set_observer(observer) } } diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index 5c28824f..0597dc98 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -1,13 +1,8 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::Weak; +use aster_bigtcp::{socket::UnboundTcpSocket, wire::IpEndpoint}; -use aster_bigtcp::{ - socket::{SocketEventObserver, UnboundTcpSocket}, - wire::IpEndpoint, -}; - -use super::{connecting::ConnectingStream, listen::ListenStream}; +use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver}; use crate::{ events::IoEvents, net::{ @@ -15,6 +10,7 @@ use crate::{ socket::ip::common::{bind_socket, get_ephemeral_endpoint}, }, prelude::*, + process::signal::Pollee, }; pub enum InitStream { @@ -23,8 +19,8 @@ pub enum InitStream { } impl InitStream { - pub fn new(observer: Weak) -> Self { - InitStream::Unbound(Box::new(UnboundTcpSocket::new(observer))) + pub fn new() -> Self { + InitStream::Unbound(Box::new(UnboundTcpSocket::new())) } pub fn new_bound(bound_socket: BoundTcpSocket) -> Self { @@ -35,6 +31,7 @@ impl InitStream { self, endpoint: &IpEndpoint, can_reuse: bool, + observer: StreamObserver, ) -> core::result::Result { let unbound_socket = match self { InitStream::Unbound(unbound_socket) => unbound_socket, @@ -49,7 +46,7 @@ impl InitStream { unbound_socket, endpoint, can_reuse, - |iface, socket, config| iface.bind_tcp(socket, config), + |iface, socket, config| iface.bind_tcp(socket, observer, config), ) { Ok(bound_socket) => bound_socket, Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))), @@ -60,25 +57,32 @@ impl InitStream { fn bind_to_ephemeral_endpoint( self, remote_endpoint: &IpEndpoint, + observer: StreamObserver, ) -> core::result::Result { let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(&endpoint, false) + self.bind(&endpoint, false, observer) } pub fn connect( self, remote_endpoint: &IpEndpoint, + pollee: &Pollee, ) -> core::result::Result { let bound_socket = match self { InitStream::Bound(bound_socket) => bound_socket, - InitStream::Unbound(_) => self.bind_to_ephemeral_endpoint(remote_endpoint)?, + InitStream::Unbound(_) => self + .bind_to_ephemeral_endpoint(remote_endpoint, StreamObserver::new(pollee.clone()))?, }; ConnectingStream::new(bound_socket, *remote_endpoint) .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) } - pub fn listen(self, backlog: usize) -> core::result::Result { + pub fn listen( + self, + backlog: usize, + pollee: &Pollee, + ) -> core::result::Result { let InitStream::Bound(bound_socket) = self else { // FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral // port. However, INADDR_ANY is not yet supported, so we need to return an error first. @@ -89,7 +93,7 @@ impl InitStream { )); }; - ListenStream::new(bound_socket, backlog) + ListenStream::new(bound_socket, backlog, pollee) .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 613f5338..ba503eb3 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -5,11 +5,12 @@ use aster_bigtcp::{ }; use ostd::sync::WriteIrqDisabled; -use super::connected::ConnectedStream; +use super::{connected::ConnectedStream, StreamObserver}; use crate::{ events::IoEvents, net::iface::{BoundTcpSocket, Iface}, prelude::*, + process::signal::Pollee, }; pub struct ListenStream { @@ -24,6 +25,7 @@ impl ListenStream { pub fn new( bound_socket: BoundTcpSocket, backlog: usize, + pollee: &Pollee, ) -> core::result::Result { const SOMAXCONN: usize = 4096; let somaxconn = SOMAXCONN.min(backlog); @@ -33,14 +35,14 @@ impl ListenStream { bound_socket, backlog_sockets: RwLock::new(Vec::new()), }; - if let Err(err) = listen_stream.fill_backlog_sockets() { + if let Err(err) = listen_stream.fill_backlog_sockets(pollee) { return Err((err, listen_stream.bound_socket)); } Ok(listen_stream) } /// Append sockets listening at LocalEndPoint to support backlog - fn fill_backlog_sockets(&self) -> Result<()> { + fn fill_backlog_sockets(&self, pollee: &Pollee) -> Result<()> { let mut backlog_sockets = self.backlog_sockets.write(); let backlog = self.backlog; @@ -51,14 +53,14 @@ impl ListenStream { } for _ in current_backlog_len..backlog { - let backlog_socket = BacklogSocket::new(&self.bound_socket)?; + let backlog_socket = BacklogSocket::new(&self.bound_socket, pollee)?; backlog_sockets.push(backlog_socket); } Ok(()) } - pub fn try_accept(&self) -> Result { + pub fn try_accept(&self, pollee: &Pollee) -> Result { let mut backlog_sockets = self.backlog_sockets.write(); let index = backlog_sockets @@ -69,7 +71,7 @@ impl ListenStream { })?; let active_backlog_socket = backlog_sockets.remove(index); - if let Ok(backlog_socket) = BacklogSocket::new(&self.bound_socket) { + if let Ok(backlog_socket) = BacklogSocket::new(&self.bound_socket, pollee) { backlog_sockets.push(backlog_socket); } @@ -111,18 +113,22 @@ struct BacklogSocket { impl BacklogSocket { // FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason // why the error may occur. Perhaps it is better to call `unwrap()` directly? - fn new(bound_socket: &BoundTcpSocket) -> Result { + fn new(bound_socket: &BoundTcpSocket, pollee: &Pollee) -> Result { let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message( Errno::EINVAL, "the socket is not bound", ))?; - let unbound_socket = Box::new(UnboundTcpSocket::new(bound_socket.observer())); + let unbound_socket = Box::new(UnboundTcpSocket::new()); let bound_socket = { let iface = bound_socket.iface(); let bind_port_config = BindPortConfig::new(local_endpoint.port, true); iface - .bind_tcp(unbound_socket, bind_port_config) + .bind_tcp( + unbound_socket, + StreamObserver::new(pollee.clone()), + bind_port_config, + ) .map_err(|(err, _)| err)? }; diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 4a6d8a41..1723bf71 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -2,10 +2,7 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use aster_bigtcp::{ - socket::{SocketEventObserver, SocketEvents}, - wire::IpEndpoint, -}; +use aster_bigtcp::wire::IpEndpoint; use connected::ConnectedStream; use connecting::{ConnResult, ConnectingStream}; use init::InitStream; @@ -40,9 +37,11 @@ mod connected; mod connecting; mod init; mod listen; +mod observer; pub mod options; mod util; +pub(in crate::net) use self::observer::StreamObserver; pub use self::util::CongestionControl; pub struct StreamSocket { @@ -79,26 +78,23 @@ impl OptionSet { impl StreamSocket { pub fn new(nonblocking: bool) -> Arc { - Arc::new_cyclic(|me| { - let init_stream = InitStream::new(me.clone() as _); - Self { - options: RwLock::new(OptionSet::new()), - state: RwLock::new(Takeable::new(State::Init(init_stream))), - is_nonblocking: AtomicBool::new(nonblocking), - pollee: Pollee::new(), - } + let init_stream = InitStream::new(); + Arc::new(Self { + options: RwLock::new(OptionSet::new()), + state: RwLock::new(Takeable::new(State::Init(init_stream))), + is_nonblocking: AtomicBool::new(nonblocking), + pollee: Pollee::new(), }) } fn new_connected(connected_stream: ConnectedStream) -> Arc { - Arc::new_cyclic(move |me| { - connected_stream.set_observer(me.clone() as _); - Self { - options: RwLock::new(OptionSet::new()), - state: RwLock::new(Takeable::new(State::Connected(connected_stream))), - is_nonblocking: AtomicBool::new(false), - pollee: Pollee::new(), - } + let pollee = Pollee::new(); + connected_stream.set_observer(StreamObserver::new(pollee.clone())); + Arc::new(Self { + options: RwLock::new(OptionSet::new()), + state: RwLock::new(Takeable::new(State::Connected(connected_stream))), + is_nonblocking: AtomicBool::new(false), + pollee, }) } @@ -221,7 +217,7 @@ impl StreamSocket { } }; - let connecting_stream = match init_stream.connect(remote_endpoint) { + let connecting_stream = match init_stream.connect(remote_endpoint, &self.pollee) { Ok(connecting_stream) => connecting_stream, Err((err, init_stream)) => { return (State::Init(init_stream), (Some(Err(err)), None)); @@ -276,11 +272,13 @@ impl StreamSocket { return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); }; - let accepted = listen_stream.try_accept().map(|connected_stream| { - let remote_endpoint = connected_stream.remote_endpoint(); - let accepted_socket = Self::new_connected(connected_stream); - (accepted_socket as _, remote_endpoint.into()) - }); + let accepted = listen_stream + .try_accept(&self.pollee) + .map(|connected_stream| { + let remote_endpoint = connected_stream.remote_endpoint(); + let accepted_socket = Self::new_connected(connected_stream); + (accepted_socket as _, remote_endpoint.into()) + }); let iface_to_poll = listen_stream.iface().clone(); drop(state); @@ -451,7 +449,11 @@ impl Socket for StreamSocket { ); }; - let bound_socket = match init_stream.bind(&endpoint, can_reuse) { + let bound_socket = match init_stream.bind( + &endpoint, + can_reuse, + StreamObserver::new(self.pollee.clone()), + ) { Ok(bound_socket) => bound_socket, Err((err, init_stream)) => { return (State::Init(init_stream), Err(err)); @@ -492,7 +494,7 @@ impl Socket for StreamSocket { } }; - let listen_stream = match init_stream.listen(backlog) { + let listen_stream = match init_stream.listen(backlog, &self.pollee) { Ok(listen_stream) => listen_stream, Err((err, init_stream)) => { return (State::Init(init_stream), Err(err)); @@ -692,30 +694,6 @@ impl Socket for StreamSocket { } } -impl SocketEventObserver for StreamSocket { - fn on_events(&self, events: SocketEvents) { - let mut io_events = IoEvents::empty(); - - if events.contains(SocketEvents::CAN_RECV) { - io_events |= IoEvents::IN; - } - - if events.contains(SocketEvents::CAN_SEND) { - io_events |= IoEvents::OUT; - } - - if events.contains(SocketEvents::PEER_CLOSED) { - io_events |= IoEvents::IN | IoEvents::RDHUP; - } - - if events.contains(SocketEvents::CLOSED) { - io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP; - } - - self.pollee.notify(io_events); - } -} - impl Drop for StreamSocket { fn drop(&mut self) { let state = self.state.write().take(); diff --git a/kernel/src/net/socket/ip/stream/observer.rs b/kernel/src/net/socket/ip/stream/observer.rs new file mode 100644 index 00000000..87cd4523 --- /dev/null +++ b/kernel/src/net/socket/ip/stream/observer.rs @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MPL-2.0 + +use aster_bigtcp::socket::{SocketEventObserver, SocketEvents}; + +use crate::{events::IoEvents, process::signal::Pollee}; + +pub struct StreamObserver(Pollee); + +impl StreamObserver { + pub(super) fn new(pollee: Pollee) -> Self { + Self(pollee) + } +} + +impl SocketEventObserver for StreamObserver { + fn on_events(&self, events: SocketEvents) { + let mut io_events = IoEvents::empty(); + + if events.contains(SocketEvents::CAN_RECV) { + io_events |= IoEvents::IN; + } + + if events.contains(SocketEvents::CAN_SEND) { + io_events |= IoEvents::OUT; + } + + if events.contains(SocketEvents::PEER_CLOSED) { + io_events |= IoEvents::IN | IoEvents::RDHUP; + } + + if events.contains(SocketEvents::CLOSED) { + io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP; + } + + self.0.notify(io_events); + } +} diff --git a/kernel/src/process/signal/poll.rs b/kernel/src/process/signal/poll.rs index bf3c6366..956b4f69 100644 --- a/kernel/src/process/signal/poll.rs +++ b/kernel/src/process/signal/poll.rs @@ -28,6 +28,7 @@ use crate::{ /// /// Then, [`Pollee::poll_with`] can allow you to register a [`Poller`] to wait for certain events, /// or register a [`PollAdaptor`] to be notified when certain events occur. +#[derive(Clone)] pub struct Pollee { inner: Arc, } diff --git a/kernel/src/syscall/socket.rs b/kernel/src/syscall/socket.rs index fa87294f..fdb9a813 100644 --- a/kernel/src/syscall/socket.rs +++ b/kernel/src/syscall/socket.rs @@ -4,7 +4,7 @@ use super::SyscallReturn; use crate::{ fs::{file_handle::FileLike, file_table::FdFlags}, net::socket::{ - ip::{DatagramSocket, StreamSocket}, + ip::{datagram::DatagramSocket, stream::StreamSocket}, unix::UnixStreamSocket, vsock::VsockStreamSocket, },