diff --git a/services/libs/jinux-std/src/events/events.rs b/services/libs/jinux-std/src/events/events.rs index 9488b0033..1c4ff51ab 100644 --- a/services/libs/jinux-std/src/events/events.rs +++ b/services/libs/jinux-std/src/events/events.rs @@ -1,6 +1,14 @@ /// A trait to represent any events. +/// +/// # The unit event +/// +/// The unit type `()` can serve as a unit event. +/// It can be used if there is only one kind of event +/// and the event carries no additional information. pub trait Events: Copy + Clone + Send + Sync + 'static {} +impl Events for () {} + /// A trait to filter events. /// /// # The no-op event filter diff --git a/services/libs/jinux-std/src/events/observer.rs b/services/libs/jinux-std/src/events/observer.rs index 315abdd13..6f2052b49 100644 --- a/services/libs/jinux-std/src/events/observer.rs +++ b/services/libs/jinux-std/src/events/observer.rs @@ -5,7 +5,31 @@ use super::Events; /// In a sense, event observers are just a fancy form of callback functions. /// An observer's `on_events` methods are supposed to be called when /// some events that are interesting to the observer happen. +/// +/// # The no-op observer +/// +/// The unit type `()` can serve as a no-op observer. +/// It implements `Observer` for any events type `E`, +/// with an `on_events` method that simply does nothing. +/// +/// It can be used to create an empty `Weak`, as shown in the example below. +/// Using the unit type is necessary, as creating an empty `Weak` needs to +/// have a sized type (e.g. the unit type). +/// +/// # Examples +/// +/// ``` +/// use alloc::sync::Weak; +/// use crate::events::Observer; +/// +/// let empty: Weak> = Weak::<()>::new(); +/// assert!(empty.upgrade().is_empty()); +/// ``` pub trait Observer: Send + Sync { /// Notify the observer that some interesting events happen. fn on_events(&self, events: &E); } + +impl Observer for () { + fn on_events(&self, events: &E) {} +} diff --git a/services/libs/jinux-std/src/net/iface/any_socket.rs b/services/libs/jinux-std/src/net/iface/any_socket.rs index f32c2e09f..3fa4fcc57 100644 --- a/services/libs/jinux-std/src/net/iface/any_socket.rs +++ b/services/libs/jinux-std/src/net/iface/any_socket.rs @@ -1,6 +1,5 @@ -use crate::events::IoEvents; +use crate::events::Observer; use crate::prelude::*; -use crate::process::signal::{Pollee, Poller}; use super::Iface; use super::{IpAddress, IpEndpoint}; @@ -11,7 +10,6 @@ pub type RawSocketHandle = smoltcp::iface::SocketHandle; pub struct AnyUnboundSocket { socket_family: AnyRawSocket, - pollee: Pollee, } #[allow(clippy::large_enum_variant)] @@ -32,10 +30,8 @@ impl AnyUnboundSocket { let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]); RawTcpSocket::new(rx_buffer, tx_buffer) }; - let pollee = Pollee::new(IoEvents::empty()); AnyUnboundSocket { socket_family: AnyRawSocket::Tcp(raw_tcp_socket), - pollee, } } @@ -54,7 +50,6 @@ impl AnyUnboundSocket { }; AnyUnboundSocket { socket_family: AnyRawSocket::Udp(raw_udp_socket), - pollee: Pollee::new(IoEvents::empty()), } } @@ -68,22 +63,14 @@ impl AnyUnboundSocket { AnyRawSocket::Udp(_) => SocketFamily::Udp, } } - - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) - } - - pub(super) fn pollee(&self) -> Pollee { - self.pollee.clone() - } } pub struct AnyBoundSocket { iface: Arc, handle: smoltcp::iface::SocketHandle, port: u16, - pollee: Pollee, socket_family: SocketFamily, + observer: RwLock>>, weak_self: Weak, } @@ -92,19 +79,36 @@ impl AnyBoundSocket { iface: Arc, handle: smoltcp::iface::SocketHandle, port: u16, - pollee: Pollee, socket_family: SocketFamily, ) -> Arc { Arc::new_cyclic(|weak_self| Self { iface, handle, port, - pollee, socket_family, + observer: RwLock::new(Weak::<()>::new()), weak_self: weak_self.clone(), }) } + pub(super) fn on_iface_events(&self) { + if let Some(observer) = Weak::upgrade(&*self.observer.read()) { + observer.on_events(&()) + } + } + + /// Set the observer whose `on_events` will be called when certain iface events happen. After + /// setting, the new observer will fire once immediately to avoid missing any events. + /// + /// If there is an existing observer, due to race conditions, this function does not guarentee + /// that the old observer will never be called after the setting. Users should be aware of this + /// and proactively handle the race conditions if necessary. + pub fn set_observer(&self, handler: Weak>) { + *self.observer.write() = handler; + + self.on_iface_events(); + } + pub fn local_endpoint(&self) -> Option { let ip_addr = { let ipv4_addr = self.iface.ipv4_addr()?; @@ -135,30 +139,10 @@ impl AnyBoundSocket { Ok(()) } - pub fn update_socket_state(&self) { - let handle = &self.handle; - let pollee = &self.pollee; - let sockets = self.iface().sockets(); - match self.socket_family { - SocketFamily::Tcp => { - let socket = sockets.get::(*handle); - update_tcp_socket_state(socket, pollee); - } - SocketFamily::Udp => { - let udp_socket = sockets.get::(*handle); - update_udp_socket_state(udp_socket, pollee); - } - } - } - pub fn iface(&self) -> &Arc { &self.iface } - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) - } - pub(super) fn weak_ref(&self) -> Weak { self.weak_self.clone() } @@ -181,34 +165,6 @@ impl Drop for AnyBoundSocket { } } -fn update_tcp_socket_state(socket: &RawTcpSocket, pollee: &Pollee) { - if socket.can_recv() { - pollee.add_events(IoEvents::IN); - } else { - pollee.del_events(IoEvents::IN); - } - - if socket.can_send() { - pollee.add_events(IoEvents::OUT); - } else { - pollee.del_events(IoEvents::OUT); - } -} - -fn update_udp_socket_state(socket: &RawUdpSocket, pollee: &Pollee) { - if socket.can_recv() { - pollee.add_events(IoEvents::IN); - } else { - pollee.del_events(IoEvents::IN); - } - - if socket.can_send() { - pollee.add_events(IoEvents::OUT); - } else { - pollee.del_events(IoEvents::OUT); - } -} - // For TCP const RECV_BUF_LEN: usize = 65536; const SEND_BUF_LEN: usize = 65536; diff --git a/services/libs/jinux-std/src/net/iface/common.rs b/services/libs/jinux-std/src/net/iface/common.rs index 7123f429b..01890a446 100644 --- a/services/libs/jinux-std/src/net/iface/common.rs +++ b/services/libs/jinux-std/src/net/iface/common.rs @@ -113,13 +113,12 @@ impl IfaceCommon { return Err((e, socket)); } let socket_family = socket.socket_family(); - let pollee = socket.pollee(); let mut sockets = self.sockets.lock_irq_disabled(); let handle = match socket.raw_socket_family() { AnyRawSocket::Tcp(tcp_socket) => sockets.add(tcp_socket), AnyRawSocket::Udp(udp_socket) => sockets.add(udp_socket), }; - let bound_socket = AnyBoundSocket::new(iface, handle, port, pollee, socket_family); + let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family); self.insert_bound_socket(&bound_socket).unwrap(); Ok(bound_socket) } @@ -140,7 +139,7 @@ impl IfaceCommon { if has_events { self.bound_sockets.read().iter().for_each(|bound_socket| { if let Some(bound_socket) = bound_socket.upgrade() { - bound_socket.update_socket_state(); + bound_socket.on_iface_events(); } }); } diff --git a/services/libs/jinux-std/src/net/socket/ip/datagram.rs b/services/libs/jinux-std/src/net/socket/ip/datagram.rs index 3e423798d..120a008cf 100644 --- a/services/libs/jinux-std/src/net/socket/ip/datagram.rs +++ b/services/libs/jinux-std/src/net/socket/ip/datagram.rs @@ -1,10 +1,10 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use crate::events::IoEvents; +use crate::events::{IoEvents, Observer}; use crate::fs::utils::StatusFlags; use crate::net::iface::IpEndpoint; -use crate::process::signal::Poller; +use crate::process::signal::{Pollee, Poller}; use crate::{ fs::file_handle::FileLike, net::{ @@ -27,21 +27,63 @@ pub struct DatagramSocket { } enum Inner { - Unbound(AlwaysSome>), + Unbound(AlwaysSome), Bound(Arc), } +struct UnboundDatagram { + unbound_socket: Box, + pollee: Pollee, +} + +impl UnboundDatagram { + fn new() -> Self { + Self { + unbound_socket: Box::new(AnyUnboundSocket::new_udp()), + pollee: Pollee::new(IoEvents::empty()), + } + } + + fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.pollee.poll(mask, poller) + } + + fn bind(self, endpoint: IpEndpoint) -> core::result::Result, (Error, Self)> { + let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) { + Ok(bound_socket) => bound_socket, + Err((err, unbound_socket)) => { + return Err(( + err, + Self { + unbound_socket, + pollee: self.pollee, + }, + )) + } + }; + let bound_endpoint = bound_socket.local_endpoint().unwrap(); + bound_socket.raw_with(|socket: &mut RawUdpSocket| { + socket.bind(bound_endpoint).unwrap(); + }); + Ok(BoundDatagram::new(bound_socket, self.pollee)) + } +} + struct BoundDatagram { bound_socket: Arc, remote_endpoint: RwLock>, + pollee: Pollee, } impl BoundDatagram { - fn new(bound_socket: Arc) -> Arc { - Arc::new(Self { + fn new(bound_socket: Arc, pollee: Pollee) -> Arc { + let bound = Arc::new(Self { bound_socket, remote_endpoint: RwLock::new(None), - }) + pollee, + }); + bound.bound_socket.set_observer(Arc::downgrade(&bound) as _); + bound } fn remote_endpoint(&self) -> Result { @@ -94,11 +136,31 @@ impl BoundDatagram { } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.bound_socket.poll(mask, poller) + self.pollee.poll(mask, poller) } - fn update_socket_state(&self) { - self.bound_socket.update_socket_state(); + fn update_io_events(&self) { + self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { + let pollee = &self.pollee; + + if socket.can_recv() { + pollee.add_events(IoEvents::IN); + } else { + pollee.del_events(IoEvents::IN); + } + + if socket.can_send() { + pollee.add_events(IoEvents::OUT); + } else { + pollee.del_events(IoEvents::OUT); + } + }); + } +} + +impl Observer<()> for BoundDatagram { + fn on_events(&self, _: &()) { + self.update_io_events(); } } @@ -108,25 +170,15 @@ impl Inner { } fn bind(&mut self, endpoint: IpEndpoint) -> Result> { - if self.is_bound() { - return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); - } - let unbound_socket = match self { - Inner::Unbound(unbound_socket) => unbound_socket, - _ => unreachable!(), + let unbound = match self { + Inner::Unbound(unbound) => unbound, + Inner::Bound(..) => return_errno_with_message!( + Errno::EINVAL, + "the socket is already bound to an address" + ), }; - let bound_socket = - unbound_socket.try_take_with(|socket| bind_socket(socket, endpoint, false))?; - let bound_endpoint = bound_socket.local_endpoint().unwrap(); - bound_socket.raw_with(|socket: &mut RawUdpSocket| { - socket - .bind(bound_endpoint) - .map_err(|_| Error::with_message(Errno::EINVAL, "cannot bind socket")) - })?; - let bound = BoundDatagram::new(bound_socket); + let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?; *self = Inner::Bound(bound.clone()); - // Once the socket is bound, we should update the socket state at once. - bound.update_socket_state(); Ok(bound) } @@ -140,7 +192,7 @@ impl Inner { fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { match self { - Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller), + Inner::Unbound(unbound) => unbound.poll(mask, poller), Inner::Bound(bound) => bound.poll(mask, poller), } } @@ -148,9 +200,9 @@ impl Inner { impl DatagramSocket { pub fn new(nonblocking: bool) -> Self { - let udp_socket = Box::new(AnyUnboundSocket::new_udp()); + let unbound = UnboundDatagram::new(); Self { - inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))), + inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))), nonblocking: AtomicBool::new(nonblocking), } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs index 55b8aceeb..a5d5d299e 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs @@ -1,8 +1,8 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use crate::events::IoEvents; +use crate::events::{IoEvents, Observer}; use crate::net::iface::IpEndpoint; -use crate::process::signal::Poller; +use crate::process::signal::{Pollee, Poller}; use crate::{ net::{ iface::{AnyBoundSocket, RawTcpSocket}, @@ -16,6 +16,7 @@ pub struct ConnectedStream { nonblocking: AtomicBool, bound_socket: Arc, remote_endpoint: IpEndpoint, + pollee: Pollee, } impl ConnectedStream { @@ -23,12 +24,18 @@ impl ConnectedStream { is_nonblocking: bool, bound_socket: Arc, remote_endpoint: IpEndpoint, - ) -> Self { - Self { + pollee: Pollee, + ) -> Arc { + let connected = Arc::new(Self { nonblocking: AtomicBool::new(is_nonblocking), bound_socket, remote_endpoint, - } + pollee, + }); + connected + .bound_socket + .set_observer(Arc::downgrade(&connected) as _); + connected } pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -50,7 +57,7 @@ impl ConnectedStream { let remote_endpoint = self.remote_endpoint()?; return Ok((recv_len, remote_endpoint)); } - let events = self.bound_socket.poll(IoEvents::IN, Some(&poller)); + let events = self.poll(IoEvents::IN, Some(&poller)); if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) { return_errno_with_message!(Errno::ENOTCONN, "recv packet fails"); } @@ -71,7 +78,7 @@ impl ConnectedStream { .recv_slice(buf) .map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet")) }); - self.bound_socket.update_socket_state(); + self.update_io_events(); res } @@ -84,7 +91,7 @@ impl ConnectedStream { if sent_len > 0 { return Ok(sent_len); } - let events = self.bound_socket.poll(IoEvents::OUT, Some(&poller)); + let events = self.poll(IoEvents::OUT, Some(&poller)); if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) { return_errno_with_message!(Errno::ENOBUFS, "fail to send packets"); } @@ -104,10 +111,10 @@ impl ConnectedStream { .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf)) .map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet")); match res { - // We have to explicitly invoke `update_socket_state` when the send buffer becomes + // We have to explicitly invoke `update_io_events` when the send buffer becomes // full. Note that smoltcp does not think it is an interface event, so calling // `poll_ifaces` alone is not enough. - Ok(0) => self.bound_socket.update_socket_state(), + Ok(0) => self.update_io_events(), Ok(_) => poll_ifaces(), _ => (), }; @@ -125,7 +132,25 @@ impl ConnectedStream { } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.bound_socket.poll(mask, poller) + self.pollee.poll(mask, poller) + } + + fn update_io_events(&self) { + self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + let pollee = &self.pollee; + + if socket.can_recv() { + pollee.add_events(IoEvents::IN); + } else { + pollee.del_events(IoEvents::IN); + } + + if socket.can_send() { + pollee.add_events(IoEvents::OUT); + } else { + pollee.del_events(IoEvents::OUT); + } + }); } pub fn is_nonblocking(&self) -> bool { @@ -136,3 +161,9 @@ impl ConnectedStream { self.nonblocking.store(nonblocking, Ordering::Relaxed); } } + +impl Observer<()> for ConnectedStream { + fn on_events(&self, _: &()) { + self.update_io_events(); + } +} diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs index fc98515a4..d97e3ff47 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs @@ -2,12 +2,13 @@ use core::sync::atomic::{AtomicBool, Ordering}; use alloc::sync::Arc; -use crate::events::IoEvents; +use crate::events::{IoEvents, Observer}; +use crate::net::iface::RawTcpSocket; use crate::net::poll_ifaces; use crate::prelude::*; use crate::net::iface::{AnyBoundSocket, IpEndpoint}; -use crate::process::signal::Poller; +use crate::process::signal::{Pollee, Poller}; use super::connected::ConnectedStream; use super::init::InitStream; @@ -16,6 +17,13 @@ pub struct ConnectingStream { nonblocking: AtomicBool, bound_socket: Arc, remote_endpoint: IpEndpoint, + conn_result: RwLock>, + pollee: Pollee, +} + +enum ConnResult { + Connected, + Refused, } impl ConnectingStream { @@ -23,36 +31,57 @@ impl ConnectingStream { nonblocking: bool, bound_socket: Arc, remote_endpoint: IpEndpoint, - ) -> Result { + pollee: Pollee, + ) -> Result> { bound_socket.do_connect(remote_endpoint)?; - Ok(Self { + let connecting = Arc::new(Self { nonblocking: AtomicBool::new(nonblocking), bound_socket, remote_endpoint, - }) + conn_result: RwLock::new(None), + pollee, + }); + connecting.pollee.reset_events(); + connecting + .bound_socket + .set_observer(Arc::downgrade(&connecting) as _); + Ok(connecting) } - pub fn wait_conn(&self) -> core::result::Result { + pub fn wait_conn( + &self, + ) -> core::result::Result, (Error, Arc)> { debug_assert!(!self.is_nonblocking()); let poller = Poller::new(); loop { poll_ifaces(); - let events = self.poll(IoEvents::OUT | IoEvents::IN, Some(&poller)); - if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { - return Ok(ConnectedStream::new( - self.is_nonblocking(), - self.bound_socket.clone(), - self.remote_endpoint, - )); - } else if !events.is_empty() { - return Err(( - Error::with_message(Errno::ECONNREFUSED, "connection refused"), - InitStream::new_bound(self.is_nonblocking(), self.bound_socket.clone()), - )); - } else { + match *self.conn_result.read() { + Some(ConnResult::Connected) => { + return Ok(ConnectedStream::new( + self.is_nonblocking(), + self.bound_socket.clone(), + self.remote_endpoint, + self.pollee.clone(), + )); + } + Some(ConnResult::Refused) => { + return Err(( + Error::with_message(Errno::ECONNREFUSED, "connection refused"), + InitStream::new_bound( + self.is_nonblocking(), + self.bound_socket.clone(), + self.pollee.clone(), + ), + )); + } + None => (), + }; + + let events = self.poll(IoEvents::OUT, Some(&poller)); + if !events.contains(IoEvents::OUT) { // FIXME: deal with nonblocking mode & connecting timeout poller.wait().expect("async connect() not implemented"); } @@ -70,7 +99,7 @@ impl ConnectingStream { } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.bound_socket.poll(mask, poller) + self.pollee.poll(mask, poller) } pub fn is_nonblocking(&self) -> bool { @@ -80,4 +109,47 @@ impl ConnectingStream { pub fn set_nonblocking(&self, nonblocking: bool) { self.nonblocking.store(nonblocking, Ordering::Relaxed); } + + fn update_io_events(&self) { + if self.conn_result.read().is_some() { + return; + } + + let became_writable = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + let mut result = self.conn_result.write(); + if result.is_some() { + return false; + } + + // Connected + if socket.can_send() { + *result = Some(ConnResult::Connected); + return true; + } + // Connecting + if socket.is_open() { + return false; + } + // Refused + *result = Some(ConnResult::Refused); + true + }); + + // Either when the connection is established, or when the connection fails, the socket + // shall indicate that it is writable. + // + // TODO: Find a way to turn `ConnectingStream` into `ConnectedStream` or `InitStream` + // here, so non-blocking `connect()` can work correctly. Meanwhile, the latter should + // be responsible to initialize all the I/O events including `IoEvents::OUT`, so the + // following hard-coded event addition can be removed. + if became_writable { + self.pollee.add_events(IoEvents::OUT); + } + } +} + +impl Observer<()> for ConnectingStream { + fn on_events(&self, _: &()) { + self.update_io_events(); + } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs index d82c16982..c2d33b4d1 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs @@ -7,6 +7,7 @@ use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket}; use crate::net::socket::ip::always_some::AlwaysSome; use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint}; use crate::prelude::*; +use crate::process::signal::Pollee; use crate::process::signal::Poller; use super::connecting::ConnectingStream; @@ -15,6 +16,7 @@ use super::listen::ListenStream; pub struct InitStream { inner: RwLock, is_nonblocking: AtomicBool, + pollee: Pollee, } enum Inner { @@ -23,6 +25,11 @@ enum Inner { } impl Inner { + fn new() -> Inner { + let unbound_socket = Box::new(AnyUnboundSocket::new_tcp()); + Inner::Unbound(AlwaysSome::new(unbound_socket)) + } + fn is_bound(&self) -> bool { match self { Self::Unbound(_) => false, @@ -38,7 +45,6 @@ impl Inner { }; let bound_socket = unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?; - bound_socket.update_socket_state(); *self = Inner::Bound(AlwaysSome::new(bound_socket)); Ok(()) } @@ -55,13 +61,6 @@ impl Inner { } } - fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - match self { - Inner::Bound(bound_socket) => bound_socket.poll(mask, poller), - Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller), - } - } - fn iface(&self) -> Option> { match self { Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()), @@ -76,28 +75,36 @@ impl Inner { } impl InitStream { - pub fn new(nonblocking: bool) -> Self { - let socket = Box::new(AnyUnboundSocket::new_tcp()); - let inner = Inner::Unbound(AlwaysSome::new(socket)); - Self { + // FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling + // `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by + // experimentation and Linux source code. + pub fn new(nonblocking: bool) -> Arc { + Arc::new(Self { + inner: RwLock::new(Inner::new()), is_nonblocking: AtomicBool::new(nonblocking), - inner: RwLock::new(inner), - } + pollee: Pollee::new(IoEvents::empty()), + }) } - pub fn new_bound(nonblocking: bool, bound_socket: Arc) -> Self { + pub fn new_bound( + nonblocking: bool, + bound_socket: Arc, + pollee: Pollee, + ) -> Arc { + bound_socket.set_observer(Weak::<()>::new()); let inner = Inner::Bound(AlwaysSome::new(bound_socket)); - Self { + Arc::new(Self { is_nonblocking: AtomicBool::new(nonblocking), inner: RwLock::new(inner), - } + pollee, + }) } pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> { self.inner.write().bind(endpoint) } - pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result { + pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result> { if !self.inner.read().is_bound() { self.inner .write() @@ -107,16 +114,22 @@ impl InitStream { self.is_nonblocking(), self.inner.read().bound_socket().unwrap().clone(), *remote_endpoint, + self.pollee.clone(), ) } - pub fn listen(&self, backlog: usize) -> Result { + pub fn listen(&self, backlog: usize) -> Result> { let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() { bound_socket.clone() } else { return_errno_with_message!(Errno::EINVAL, "cannot listen without bound") }; - ListenStream::new(self.is_nonblocking(), bound_socket, backlog) + ListenStream::new( + self.is_nonblocking(), + bound_socket, + backlog, + self.pollee.clone(), + ) } pub fn local_endpoint(&self) -> Result { @@ -127,7 +140,7 @@ impl InitStream { } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.inner.read().poll(mask, poller) + self.pollee.poll(mask, poller) } pub fn is_nonblocking(&self) -> bool { diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs b/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs index 06cce98e9..20b3f114f 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs @@ -1,10 +1,10 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use crate::events::IoEvents; +use crate::events::{IoEvents, Observer}; use crate::net::iface::{AnyUnboundSocket, BindPortConfig, IpEndpoint}; use crate::net::iface::{AnyBoundSocket, RawTcpSocket}; -use crate::process::signal::Poller; +use crate::process::signal::{Pollee, Poller}; use crate::{net::poll_ifaces, prelude::*}; use super::connected::ConnectedStream; @@ -16,6 +16,7 @@ pub struct ListenStream { bound_socket: Arc, /// Backlog sockets listening at the local endpoint backlog_sockets: RwLock>, + pollee: Pollee, } impl ListenStream { @@ -23,18 +24,24 @@ impl ListenStream { nonblocking: bool, bound_socket: Arc, backlog: usize, - ) -> Result { - let listen_stream = Self { + pollee: Pollee, + ) -> Result> { + let listen_stream = Arc::new(Self { is_nonblocking: AtomicBool::new(nonblocking), backlog, bound_socket, backlog_sockets: RwLock::new(Vec::new()), - }; + pollee, + }); listen_stream.fill_backlog_sockets()?; + listen_stream.pollee.reset_events(); + listen_stream + .bound_socket + .set_observer(Arc::downgrade(&listen_stream) as _); Ok(listen_stream) } - pub fn accept(&self) -> Result<(ConnectedStream, IpEndpoint)> { + pub fn accept(&self) -> Result<(Arc, IpEndpoint)> { // wait to accept let poller = Poller::new(); loop { @@ -42,8 +49,8 @@ impl ListenStream { let accepted_socket = if let Some(accepted_socket) = self.try_accept() { accepted_socket } else { - let events = self.poll(IoEvents::IN | IoEvents::OUT, Some(&poller)); - if !events.contains(IoEvents::IN) && !events.contains(IoEvents::OUT) { + let events = self.poll(IoEvents::IN, Some(&poller)); + if !events.contains(IoEvents::IN) { if self.is_nonblocking() { return_errno_with_message!(Errno::EAGAIN, "try accept again"); } @@ -57,7 +64,12 @@ impl ListenStream { let BacklogSocket { bound_socket: backlog_socket, } = accepted_socket; - ConnectedStream::new(false, backlog_socket, remote_endpoint) + ConnectedStream::new( + false, + backlog_socket, + remote_endpoint, + Pollee::new(IoEvents::empty()), + ) }; return Ok((connected_stream, remote_endpoint)); } @@ -88,6 +100,7 @@ impl ListenStream { backlog_sockets.remove(index) }; self.fill_backlog_sockets().unwrap(); + self.update_io_events(); Some(backlog_socket) } @@ -98,22 +111,25 @@ impl ListenStream { } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - let backlog_sockets = self.backlog_sockets.read(); - for backlog_socket in backlog_sockets.iter() { - if backlog_socket.is_active() { - return IoEvents::IN; - } else { - // regiser poller to the backlog socket - backlog_socket.poll(mask, poller); - } - } - IoEvents::empty() + self.pollee.poll(mask, poller) } fn bound_socket(&self) -> Arc { self.backlog_sockets.read()[0].bound_socket.clone() } + fn update_io_events(&self) { + // The lock should be held to avoid data races + let backlog_sockets = self.backlog_sockets.read(); + + let can_accept = backlog_sockets.iter().any(|socket| socket.is_active()); + if can_accept { + self.pollee.add_events(IoEvents::IN); + } else { + self.pollee.del_events(IoEvents::IN); + } + } + pub fn is_nonblocking(&self) -> bool { self.is_nonblocking.load(Ordering::Relaxed) } @@ -123,6 +139,12 @@ impl ListenStream { } } +impl Observer<()> for ListenStream { + fn on_events(&self, _: &()) { + self.update_io_events(); + } +} + struct BacklogSocket { bound_socket: Arc, } @@ -146,7 +168,6 @@ impl BacklogSocket { .listen(local_endpoint) .map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen")) })?; - bound_socket.update_socket_state(); Ok(Self { bound_socket }) } @@ -159,8 +180,4 @@ impl BacklogSocket { self.bound_socket .raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint()) } - - fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.bound_socket.poll(mask, poller) - } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index c364449ec..d96ce2156 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -38,7 +38,7 @@ enum State { impl StreamSocket { pub fn new(nonblocking: bool) -> Self { - let state = State::Init(Arc::new(InitStream::new(nonblocking))); + let state = State::Init(InitStream::new(nonblocking)); Self { state: RwLock::new(state), } @@ -71,7 +71,7 @@ impl StreamSocket { } }; - let connecting = Arc::new(init_stream.connect(remote_endpoint)?); + let connecting = init_stream.connect(remote_endpoint)?; *state = State::Connecting(connecting.clone()); Ok(connecting) } @@ -139,12 +139,10 @@ impl Socket for StreamSocket { let connecting_stream = self.do_connect(&remote_endpoint)?; match connecting_stream.wait_conn() { Ok(connected_stream) => { - let connected_stream = Arc::new(connected_stream); *self.state.write() = State::Connected(connected_stream); Ok(()) } Err((err, init_stream)) => { - let init_stream = Arc::new(init_stream); *self.state.write() = State::Init(init_stream); Err(err) } @@ -164,7 +162,7 @@ impl Socket for StreamSocket { State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"), }; - let listener = Arc::new(init_stream.listen(backlog)?); + let listener = init_stream.listen(backlog)?; *state = State::Listen(listener); Ok(()) } @@ -181,7 +179,7 @@ impl Socket for StreamSocket { }; let accepted_socket = { - let state = RwLock::new(State::Connected(Arc::new(connected_stream))); + let state = RwLock::new(State::Connected(connected_stream)); Arc::new(StreamSocket { state }) };