From a10d04c5f985988dc8140a0ff1b35e10fb396bff Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 7 Jan 2024 23:55:23 +0800 Subject: [PATCH] Remove `Arc`s in TCP and UDP states --- kernel/aster-nix/src/net/iface/any_socket.rs | 21 +- kernel/aster-nix/src/net/iface/common.rs | 28 +- kernel/aster-nix/src/net/iface/mod.rs | 1 - kernel/aster-nix/src/net/socket/ip/common.rs | 2 +- .../src/net/socket/ip/datagram/bound.rs | 76 ++-- .../src/net/socket/ip/datagram/mod.rs | 260 +++++++++---- .../src/net/socket/ip/datagram/unbound.rs | 34 +- kernel/aster-nix/src/net/socket/ip/mod.rs | 1 - .../src/net/socket/ip/stream/connected.rs | 140 ++----- .../src/net/socket/ip/stream/connecting.rs | 123 ++---- .../src/net/socket/ip/stream/init.rs | 197 ++++------ .../src/net/socket/ip/stream/listen.rs | 141 +++---- .../aster-nix/src/net/socket/ip/stream/mod.rs | 359 ++++++++++++------ kernel/aster-nix/src/syscall/socket.rs | 4 +- 14 files changed, 672 insertions(+), 715 deletions(-) diff --git a/kernel/aster-nix/src/net/iface/any_socket.rs b/kernel/aster-nix/src/net/iface/any_socket.rs index 3173f1d37..6c7634bdc 100644 --- a/kernel/aster-nix/src/net/iface/any_socket.rs +++ b/kernel/aster-nix/src/net/iface/any_socket.rs @@ -9,6 +9,7 @@ pub type RawSocketHandle = smoltcp::iface::SocketHandle; pub struct AnyUnboundSocket { socket_family: AnyRawSocket, + observer: Weak>, } #[allow(clippy::large_enum_variant)] @@ -23,7 +24,7 @@ pub(super) enum SocketFamily { } impl AnyUnboundSocket { - pub fn new_tcp() -> Self { + pub fn new_tcp(observer: Weak>) -> Self { let raw_tcp_socket = { let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; RECV_BUF_LEN]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]); @@ -31,10 +32,11 @@ impl AnyUnboundSocket { }; AnyUnboundSocket { socket_family: AnyRawSocket::Tcp(raw_tcp_socket), + observer, } } - pub fn new_udp() -> Self { + pub fn new_udp(observer: Weak>) -> Self { let raw_udp_socket = { let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( @@ -49,18 +51,12 @@ impl AnyUnboundSocket { }; AnyUnboundSocket { socket_family: AnyRawSocket::Udp(raw_udp_socket), + observer, } } - pub(super) fn raw_socket_family(self) -> AnyRawSocket { - self.socket_family - } - - pub(super) fn socket_family(&self) -> SocketFamily { - match &self.socket_family { - AnyRawSocket::Tcp(_) => SocketFamily::Tcp, - AnyRawSocket::Udp(_) => SocketFamily::Udp, - } + pub(super) fn into_raw(self) -> (AnyRawSocket, Weak>) { + (self.socket_family, self.observer) } } @@ -79,13 +75,14 @@ impl AnyBoundSocket { handle: smoltcp::iface::SocketHandle, port: u16, socket_family: SocketFamily, + observer: Weak>, ) -> Arc { Arc::new_cyclic(|weak_self| Self { iface, handle, port, socket_family, - observer: RwLock::new(Weak::<()>::new()), + observer: RwLock::new(observer), weak_self: weak_self.clone(), }) } diff --git a/kernel/aster-nix/src/net/iface/common.rs b/kernel/aster-nix/src/net/iface/common.rs index 7f77dd3a9..f7923024e 100644 --- a/kernel/aster-nix/src/net/iface/common.rs +++ b/kernel/aster-nix/src/net/iface/common.rs @@ -11,7 +11,7 @@ use smoltcp::{ }; use super::{ - any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket}, + any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily}, time::get_network_timestamp, util::BindPortConfig, Iface, Ipv4Address, @@ -107,20 +107,28 @@ impl IfaceCommon { } else { match self.alloc_ephemeral_port() { Ok(port) => port, - Err(e) => return Err((e, socket)), + Err(err) => return Err((err, socket)), } }; - if let Some(e) = self.bind_port(port, config.can_reuse()).err() { - return Err((e, socket)); + if let Some(err) = self.bind_port(port, config.can_reuse()).err() { + return Err((err, socket)); } - let socket_family = socket.socket_family(); - let mut sockets = self.sockets.lock_irq_disabled(); - let handle = match socket.raw_socket_family() { - AnyRawSocket::Tcp(tcp_socket) => sockets.add(tcp_socket), - AnyRawSocket::Udp(udp_socket) => sockets.add(udp_socket), + + let (handle, socket_family, observer) = match socket.into_raw() { + (AnyRawSocket::Tcp(tcp_socket), observer) => ( + self.sockets.lock_irq_disabled().add(tcp_socket), + SocketFamily::Tcp, + observer, + ), + (AnyRawSocket::Udp(udp_socket), observer) => ( + self.sockets.lock_irq_disabled().add(udp_socket), + SocketFamily::Udp, + observer, + ), }; - let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family); + let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer); self.insert_bound_socket(&bound_socket).unwrap(); + Ok(bound_socket) } diff --git a/kernel/aster-nix/src/net/iface/mod.rs b/kernel/aster-nix/src/net/iface/mod.rs index 09cedb84f..b40c21eb4 100644 --- a/kernel/aster-nix/src/net/iface/mod.rs +++ b/kernel/aster-nix/src/net/iface/mod.rs @@ -46,7 +46,6 @@ pub trait Iface: internal::IfaceInternal + Send + Sync { config: BindPortConfig, ) -> core::result::Result, (Error, Box)> { let common = self.common(); - let socket_type_inner = socket.socket_family(); common.bind_socket(self.arc_self(), socket, config) } diff --git a/kernel/aster-nix/src/net/socket/ip/common.rs b/kernel/aster-nix/src/net/socket/ip/common.rs index 306847966..5edfa2ebe 100644 --- a/kernel/aster-nix/src/net/socket/ip/common.rs +++ b/kernel/aster-nix/src/net/socket/ip/common.rs @@ -44,7 +44,7 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc { pub(super) fn bind_socket( unbound_socket: Box, - endpoint: IpEndpoint, + endpoint: &IpEndpoint, can_reuse: bool, ) -> core::result::Result, (Error, Box)> { let iface = match get_iface_to_bind(&endpoint.addr) { diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs b/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs index 2962d5ae5..a161405c6 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs @@ -1,61 +1,49 @@ // SPDX-License-Identifier: MPL-2.0 use crate::{ - events::{IoEvents, Observer}, + events::IoEvents, net::{ iface::{AnyBoundSocket, IpEndpoint, RawUdpSocket}, - poll_ifaces, socket::util::send_recv_flags::SendRecvFlags, }, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::Pollee, }; pub struct BoundDatagram { bound_socket: Arc, - remote_endpoint: RwLock>, - pollee: Pollee, + remote_endpoint: Option, } impl BoundDatagram { - pub fn new(bound_socket: Arc, pollee: Pollee) -> Arc { - let bound = Arc::new(Self { + pub fn new(bound_socket: Arc) -> Self { + Self { bound_socket, - remote_endpoint: RwLock::new(None), - pollee, - }); - bound.bound_socket.set_observer(Arc::downgrade(&bound) as _); - bound + remote_endpoint: None, + } + } + + pub fn local_endpoint(&self) -> IpEndpoint { + self.bound_socket.local_endpoint().unwrap() } pub fn remote_endpoint(&self) -> Result { self.remote_endpoint - .read() .ok_or_else(|| Error::with_message(Errno::EINVAL, "remote endpoint is not specified")) } - pub fn set_remote_endpoint(&self, endpoint: IpEndpoint) { - *self.remote_endpoint.write() = Some(endpoint); - } - - pub fn local_endpoint(&self) -> Result { - self.bound_socket.local_endpoint().ok_or_else(|| { - Error::with_message(Errno::EINVAL, "socket does not bind to local endpoint") - }) + pub fn set_remote_endpoint(&mut self, endpoint: &IpEndpoint) { + self.remote_endpoint = Some(*endpoint) } pub fn try_recvfrom( &self, buf: &mut [u8], - flags: &SendRecvFlags, + flags: SendRecvFlags, ) -> Result<(usize, IpEndpoint)> { - poll_ifaces(); - let recv_slice = |socket: &mut RawUdpSocket| { - socket - .recv_slice(buf) - .map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty")) - }; - self.bound_socket.raw_with(recv_slice) + self.bound_socket + .raw_with(|socket: &mut RawUdpSocket| socket.recv_slice(buf)) + .map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty")) } pub fn try_sendto( @@ -65,27 +53,21 @@ impl BoundDatagram { flags: SendRecvFlags, ) -> Result { let remote_endpoint = remote - .or_else(|| self.remote_endpoint().ok()) + .or(self.remote_endpoint) .ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?; - let send_slice = |socket: &mut RawUdpSocket| { - socket - .send_slice(buf, remote_endpoint) - .map(|_| buf.len()) - .map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails")) - }; - let len = self.bound_socket.raw_with(send_slice)?; - poll_ifaces(); - Ok(len) + self.bound_socket + .raw_with(|socket: &mut RawUdpSocket| socket.send_slice(buf, remote_endpoint)) + .map(|_| buf.len()) + .map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails")) } - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) + pub(super) fn init_pollee(&self, pollee: &Pollee) { + pollee.reset_events(); + self.update_io_events(pollee) } - fn update_io_events(&self) { + pub(super) fn update_io_events(&self, pollee: &Pollee) { self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { - let pollee = &self.pollee; - if socket.can_recv() { pollee.add_events(IoEvents::IN); } else { @@ -100,9 +82,3 @@ impl BoundDatagram { }); } } - -impl Observer<()> for BoundDatagram { - fn on_events(&self, _: &()) { - self.update_io_events(); - } -} diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs index 19ae2ca6f..f73547309 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs @@ -1,81 +1,91 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; +use core::{ + mem, + sync::atomic::{AtomicBool, Ordering}, +}; use self::{bound::BoundDatagram, unbound::UnboundDatagram}; -use super::{always_some::AlwaysSome, common::get_ephemeral_endpoint}; +use super::common::get_ephemeral_endpoint; use crate::{ - events::IoEvents, + events::{IoEvents, Observer}, fs::{file_handle::FileLike, utils::StatusFlags}, net::{ iface::IpEndpoint, + poll_ifaces, socket::{ util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr}, Socket, }, }, prelude::*, - process::signal::Poller, + process::signal::{Pollee, Poller}, }; mod bound; mod unbound; pub struct DatagramSocket { - nonblocking: AtomicBool, inner: RwLock, + nonblocking: AtomicBool, + pollee: Pollee, } enum Inner { - Unbound(AlwaysSome), - Bound(Arc), + Unbound(UnboundDatagram), + Bound(BoundDatagram), + Poisoned, } impl Inner { - fn is_bound(&self) -> bool { - matches!(self, Inner::Bound { .. }) - } - - fn bind(&mut self, endpoint: IpEndpoint) -> Result> { - let unbound = match self { - Inner::Unbound(unbound) => unbound, - Inner::Bound(..) => return_errno_with_message!( - Errno::EINVAL, - "the socket is already bound to an address" - ), + fn bind(self, endpoint: &IpEndpoint) -> core::result::Result { + let unbound_datagram = match self { + Inner::Unbound(unbound_datagram) => unbound_datagram, + Inner::Bound(bound_datagram) => { + return Err(( + Error::with_message(Errno::EINVAL, "the socket is already bound to an address"), + Inner::Bound(bound_datagram), + )); + } + Inner::Poisoned => { + return Err(( + Error::with_message(Errno::EINVAL, "the socket is poisoned"), + Inner::Poisoned, + )); + } }; - let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?; - *self = Inner::Bound(bound.clone()); - Ok(bound) + + let bound_datagram = match unbound_datagram.bind(endpoint) { + Ok(bound_datagram) => bound_datagram, + Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))), + }; + Ok(bound_datagram) } fn bind_to_ephemeral_endpoint( - &mut self, + self, remote_endpoint: &IpEndpoint, - ) -> Result> { - let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(endpoint) - } - - fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - match self { - Inner::Unbound(unbound) => unbound.poll(mask, poller), - Inner::Bound(bound) => bound.poll(mask, poller), + ) -> core::result::Result { + if let Inner::Bound(bound_datagram) = self { + return Ok(bound_datagram); } + + let endpoint = get_ephemeral_endpoint(remote_endpoint); + self.bind(&endpoint) } } impl DatagramSocket { - pub fn new(nonblocking: bool) -> Self { - let unbound = UnboundDatagram::new(); - Self { - inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))), - nonblocking: AtomicBool::new(nonblocking), - } - } - - pub fn is_bound(&self) -> bool { - self.inner.read().is_bound() + pub fn new(nonblocking: bool) -> Arc { + Arc::new_cyclic(|me| { + let unbound_datagram = UnboundDatagram::new(me.clone() as _); + let pollee = Pollee::new(IoEvents::empty()); + Self { + inner: RwLock::new(Inner::Unbound(unbound_datagram)), + nonblocking: AtomicBool::new(nonblocking), + pollee, + } + }) } pub fn is_nonblocking(&self) -> bool { @@ -86,26 +96,81 @@ impl DatagramSocket { self.nonblocking.store(nonblocking, Ordering::SeqCst); } - fn bound(&self) -> Result> { - if let Inner::Bound(bound) = &*self.inner.read() { - Ok(bound.clone()) - } else { - return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint") - } - } - - fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result> { + fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> { // Fast path - if let Inner::Bound(bound) = &*self.inner.read() { - return Ok(bound.clone()); + if let Inner::Bound(_) = &*self.inner.read() { + return Ok(()); } // Slow path let mut inner = self.inner.write(); - if let Inner::Bound(bound) = &*inner { - return Ok(bound.clone()); + let owned_inner = mem::replace(&mut *inner, Inner::Poisoned); + + let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { + Ok(bound_datagram) => bound_datagram, + Err((err, err_inner)) => { + *inner = err_inner; + return Err(err); + } + }; + bound_datagram.init_pollee(&self.pollee); + *inner = Inner::Bound(bound_datagram); + Ok(()) + } + + fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + let inner = self.inner.read(); + let Inner::Bound(bound_datagram) = &*inner else { + return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); + }; + let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?; + bound_datagram.update_io_events(&self.pollee); + Ok((recv_bytes, remote_endpoint.try_into()?)) + } + + fn try_sendto( + &self, + buf: &[u8], + remote: Option, + flags: SendRecvFlags, + ) -> Result { + let inner = self.inner.read(); + let Inner::Bound(bound_datagram) = &*inner else { + return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); + }; + let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?; + bound_datagram.update_io_events(&self.pollee); + Ok(sent_bytes) + } + + // TODO: Support timeout + fn wait_events(&self, mask: IoEvents, mut cond: F) -> Result + where + F: FnMut() -> Result, + { + let poller = Poller::new(); + + loop { + match cond() { + Err(err) if err.error() == Errno::EAGAIN => (), + result => return result, + }; + + let events = self.poll(mask, Some(&poller)); + if !events.is_empty() { + continue; + } + + poller.wait()?; } - inner.bind_to_ephemeral_endpoint(remote_endpoint) + } + + fn update_io_events(&self) { + let inner = self.inner.read(); + let Inner::Bound(bound_datagram) = &*inner else { + return; + }; + bound_datagram.update_io_events(&self.pollee); } } @@ -124,7 +189,7 @@ impl FileLike for DatagramSocket { } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.inner.read().poll(mask, poller) + self.pollee.poll(mask, poller) } fn as_socket(self: Arc) -> Option> { @@ -152,43 +217,61 @@ impl FileLike for DatagramSocket { impl Socket for DatagramSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { let endpoint = socket_addr.try_into()?; - self.inner.write().bind(endpoint)?; + + let mut inner = self.inner.write(); + let owned_inner = mem::replace(&mut *inner, Inner::Poisoned); + + let bound_datagram = match owned_inner.bind(&endpoint) { + Ok(bound_datagram) => bound_datagram, + Err((err, err_inner)) => { + *inner = err_inner; + return Err(err); + } + }; + bound_datagram.init_pollee(&self.pollee); + *inner = Inner::Bound(bound_datagram); Ok(()) } fn connect(&self, socket_addr: SocketAddr) -> Result<()> { let endpoint = socket_addr.try_into()?; - let bound = self.try_bind_empheral(&endpoint)?; - bound.set_remote_endpoint(endpoint); + + self.try_bind_empheral(&endpoint)?; + + let mut inner = self.inner.write(); + let Inner::Bound(bound_datagram) = &mut *inner else { + return_errno_with_message!(Errno::EINVAL, "the socket is not bound") + }; + bound_datagram.set_remote_endpoint(&endpoint); + Ok(()) } fn addr(&self) -> Result { - self.bound()?.local_endpoint()?.try_into() + let inner = self.inner.read(); + let Inner::Bound(bound_datagram) = &*inner else { + return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); + }; + bound_datagram.local_endpoint().try_into() } fn peer_addr(&self) -> Result { - self.bound()?.remote_endpoint()?.try_into() + let inner = self.inner.read(); + let Inner::Bound(bound_datagram) = &*inner else { + return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); + }; + bound_datagram.remote_endpoint()?.try_into() } // FIXME: respect RecvFromFlags fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { debug_assert!(flags.is_all_supported()); - let bound = self.bound()?; - let poller = Poller::new(); - loop { - if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) { - let remote_addr = remote_endpoint.try_into()?; - return Ok((recv_len, remote_addr)); - } - let events = bound.poll(IoEvents::IN, Some(&poller)); - if !events.contains(IoEvents::IN) { - if self.is_nonblocking() { - return_errno_with_message!(Errno::EAGAIN, "try to receive again"); - } - // FIXME: deal with recvfrom timeout - poller.wait()?; - } + + poll_ifaces(); + if self.is_nonblocking() { + self.try_recvfrom(buf, flags) + } else { + self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags)) } } @@ -199,13 +282,24 @@ impl Socket for DatagramSocket { flags: SendRecvFlags, ) -> Result { debug_assert!(flags.is_all_supported()); - let (bound, remote_endpoint) = if let Some(addr) = remote { - let endpoint = addr.try_into()?; - (self.try_bind_empheral(&endpoint)?, Some(endpoint)) - } else { - let bound = self.bound()?; - (bound, None) + + let remote_endpoint = match remote { + Some(remote_addr) => Some(remote_addr.try_into()?), + None => None, }; - bound.try_sendto(buf, remote_endpoint, flags) + if let Some(endpoint) = remote_endpoint { + self.try_bind_empheral(&endpoint)?; + } + + // TODO: Block if the send buffer is full + let sent_bytes = self.try_sendto(buf, remote_endpoint, flags)?; + poll_ifaces(); + Ok(sent_bytes) + } +} + +impl Observer<()> for DatagramSocket { + fn on_events(&self, events: &()) { + self.update_io_events(); } } diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/unbound.rs b/kernel/aster-nix/src/net/socket/ip/datagram/unbound.rs index 219b2406b..9e56c7cc7 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/unbound.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/unbound.rs @@ -1,53 +1,39 @@ // SPDX-License-Identifier: MPL-2.0 +use alloc::sync::Weak; + use super::bound::BoundDatagram; use crate::{ - events::IoEvents, + events::Observer, net::{ iface::{AnyUnboundSocket, IpEndpoint, RawUdpSocket}, socket::ip::common::bind_socket, }, prelude::*, - process::signal::{Pollee, Poller}, }; pub struct UnboundDatagram { unbound_socket: Box, - pollee: Pollee, } impl UnboundDatagram { - pub fn new() -> Self { + pub fn new(observer: Weak>) -> Self { Self { - unbound_socket: Box::new(AnyUnboundSocket::new_udp()), - pollee: Pollee::new(IoEvents::empty()), + unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)), } } - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) - } - - pub fn bind( - self, - endpoint: IpEndpoint, - ) -> core::result::Result, (Error, Self)> { + pub fn bind(self, endpoint: &IpEndpoint) -> core::result::Result { let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) { Ok(bound_socket) => bound_socket, - Err((err, unbound_socket)) => { - return Err(( - err, - Self { - unbound_socket, - pollee: self.pollee, - }, - )) - } + Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })), }; + let bound_endpoint = bound_socket.local_endpoint().unwrap(); bound_socket.raw_with(|socket: &mut RawUdpSocket| { socket.bind(bound_endpoint).unwrap(); }); - Ok(BoundDatagram::new(bound_socket, self.pollee)) + + Ok(BoundDatagram::new(bound_socket)) } } diff --git a/kernel/aster-nix/src/net/socket/ip/mod.rs b/kernel/aster-nix/src/net/socket/ip/mod.rs index b6cf51ac6..30a8c4ee1 100644 --- a/kernel/aster-nix/src/net/socket/ip/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/mod.rs @@ -1,6 +1,5 @@ // SPDX-License-Identifier: MPL-2.0 -mod always_some; mod common; mod datagram; pub mod stream; diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs index 62a73b0f5..e26e3d9de 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs @@ -1,42 +1,28 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; +use alloc::sync::Weak; use crate::{ events::{IoEvents, Observer}, net::{ iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket}, - poll_ifaces, socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, }, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::Pollee, }; pub struct ConnectedStream { - nonblocking: AtomicBool, bound_socket: Arc, remote_endpoint: IpEndpoint, - pollee: Pollee, } impl ConnectedStream { - pub fn new( - is_nonblocking: bool, - bound_socket: Arc, - remote_endpoint: IpEndpoint, - pollee: Pollee, - ) -> Arc { - let connected = Arc::new(Self { - nonblocking: AtomicBool::new(is_nonblocking), + pub fn new(bound_socket: Arc, remote_endpoint: IpEndpoint) -> Self { + Self { bound_socket, remote_endpoint, - pollee, - }); - connected - .bound_socket - .set_observer(Arc::downgrade(&connected) as _); - connected + } } pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -44,102 +30,46 @@ impl ConnectedStream { self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { socket.close(); }); - poll_ifaces(); Ok(()) } - pub fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> { - debug_assert!(flags.is_all_supported()); - - let poller = Poller::new(); - loop { - let recv_len = self.try_recvfrom(buf, flags)?; - if recv_len > 0 { - let remote_endpoint = self.remote_endpoint()?; - return Ok((recv_len, remote_endpoint)); - } - let events = self.poll(IoEvents::IN, Some(&poller)); - if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) { - return_errno_with_message!(Errno::ENOTCONN, "recv packet fails"); - } - if !events.contains(IoEvents::IN) { - if self.is_nonblocking() { - return_errno_with_message!(Errno::EAGAIN, "try to recv again"); - } - // FIXME: deal with receive timeout - poller.wait()?; - } + pub fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { + let recv_bytes = self + .bound_socket + .raw_with(|socket: &mut RawTcpSocket| socket.recv_slice(buf)) + .map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))?; + if recv_bytes == 0 { + return_errno_with_message!(Errno::EAGAIN, "try to recv again"); } + Ok(recv_bytes) } - fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { - poll_ifaces(); - let res = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { - socket - .recv_slice(buf) - .map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet")) - }); - self.update_io_events(); - res - } - - pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { - debug_assert!(flags.is_all_supported()); - - let poller = Poller::new(); - loop { - let sent_len = self.try_sendto(buf, flags)?; - if sent_len > 0 { - return Ok(sent_len); - } - let events = self.poll(IoEvents::OUT, Some(&poller)); - if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) { - return_errno_with_message!(Errno::ENOBUFS, "fail to send packets"); - } - if !events.contains(IoEvents::OUT) { - if self.is_nonblocking() { - return_errno_with_message!(Errno::EAGAIN, "try to send again"); - } - // FIXME: deal with send timeout - poller.wait()?; - } - } - } - - fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { - let res = self + pub fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + let sent_bytes = self .bound_socket .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf)) - .map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet")); - match res { - // We have to explicitly invoke `update_io_events` when the send buffer becomes - // full. Note that smoltcp does not think it is an interface event, so calling - // `poll_ifaces` alone is not enough. - Ok(0) => self.update_io_events(), - Ok(_) => poll_ifaces(), - _ => (), - }; - res + .map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?; + if sent_bytes == 0 { + return_errno_with_message!(Errno::EAGAIN, "try to send again"); + } + Ok(sent_bytes) } - pub fn local_endpoint(&self) -> Result { - self.bound_socket - .local_endpoint() - .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint")) + pub fn local_endpoint(&self) -> IpEndpoint { + self.bound_socket.local_endpoint().unwrap() } - pub fn remote_endpoint(&self) -> Result { - Ok(self.remote_endpoint) + pub fn remote_endpoint(&self) -> IpEndpoint { + self.remote_endpoint } - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) + pub(super) fn init_pollee(&self, pollee: &Pollee) { + pollee.reset_events(); + self.update_io_events(pollee); } - fn update_io_events(&self) { + pub(super) fn update_io_events(&self, pollee: &Pollee) { self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { - let pollee = &self.pollee; - if socket.can_recv() { pollee.add_events(IoEvents::IN); } else { @@ -154,17 +84,7 @@ impl ConnectedStream { }); } - pub fn is_nonblocking(&self) -> bool { - self.nonblocking.load(Ordering::Relaxed) - } - - pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblocking.store(nonblocking, Ordering::Relaxed); - } -} - -impl Observer<()> for ConnectedStream { - fn on_events(&self, _: &()) { - self.update_io_events(); + pub(super) fn set_observer(&self, observer: Weak>) { + self.bound_socket.set_observer(observer) } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs index 626fe01df..d1ac85383 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs @@ -1,116 +1,77 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; - use super::{connected::ConnectedStream, init::InitStream}; use crate::{ - events::{IoEvents, Observer}, - net::{ - iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket}, - poll_ifaces, - }, + events::IoEvents, + net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket}, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::Pollee, }; pub struct ConnectingStream { - nonblocking: AtomicBool, bound_socket: Arc, remote_endpoint: IpEndpoint, conn_result: RwLock>, - pollee: Pollee, } +#[derive(Clone, Copy)] enum ConnResult { Connected, Refused, } +pub enum NonConnectedStream { + Init(InitStream), + Connecting(ConnectingStream), +} + impl ConnectingStream { pub fn new( - nonblocking: bool, bound_socket: Arc, remote_endpoint: IpEndpoint, - pollee: Pollee, - ) -> Result> { - bound_socket.do_connect(remote_endpoint)?; - - let connecting = Arc::new(Self { - nonblocking: AtomicBool::new(nonblocking), + ) -> core::result::Result)> { + if let Err(err) = bound_socket.do_connect(remote_endpoint) { + return Err((err, bound_socket)); + } + Ok(Self { bound_socket, remote_endpoint, conn_result: RwLock::new(None), - pollee, - }); - connecting.pollee.reset_events(); - connecting - .bound_socket - .set_observer(Arc::downgrade(&connecting) as _); - Ok(connecting) + }) } - pub fn wait_conn( - &self, - ) -> core::result::Result, (Error, Arc)> { - debug_assert!(!self.is_nonblocking()); - - let poller = Poller::new(); - loop { - poll_ifaces(); - - match *self.conn_result.read() { - Some(ConnResult::Connected) => { - return Ok(ConnectedStream::new( - self.is_nonblocking(), - self.bound_socket.clone(), - self.remote_endpoint, - self.pollee.clone(), - )); - } - Some(ConnResult::Refused) => { - return Err(( - Error::with_message(Errno::ECONNREFUSED, "connection refused"), - InitStream::new_bound( - self.is_nonblocking(), - self.bound_socket.clone(), - self.pollee.clone(), - ), - )); - } - None => (), - }; - - let events = self.poll(IoEvents::OUT, Some(&poller)); - if !events.contains(IoEvents::OUT) { - // FIXME: deal with nonblocking mode & connecting timeout - poller.wait().expect("async connect() not implemented"); - } + pub fn into_result(self) -> core::result::Result { + let conn_result = *self.conn_result.read(); + match conn_result { + Some(ConnResult::Connected) => Ok(ConnectedStream::new( + self.bound_socket, + self.remote_endpoint, + )), + Some(ConnResult::Refused) => Err(( + Error::with_message(Errno::ECONNREFUSED, "the connection is refused"), + NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)), + )), + None => Err(( + Error::with_message(Errno::EAGAIN, "the connection is pending"), + NonConnectedStream::Connecting(self), + )), } } - pub fn local_endpoint(&self) -> Result { - self.bound_socket - .local_endpoint() - .ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint")) + pub fn local_endpoint(&self) -> IpEndpoint { + self.bound_socket.local_endpoint().unwrap() } - pub fn remote_endpoint(&self) -> Result { - Ok(self.remote_endpoint) + pub fn remote_endpoint(&self) -> IpEndpoint { + self.remote_endpoint } - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) + pub(super) fn init_pollee(&self, pollee: &Pollee) { + pollee.reset_events(); + self.update_io_events(pollee); } - pub fn is_nonblocking(&self) -> bool { - self.nonblocking.load(Ordering::Relaxed) - } - - pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblocking.store(nonblocking, Ordering::Relaxed); - } - - fn update_io_events(&self) { + pub(super) fn update_io_events(&self, pollee: &Pollee) { if self.conn_result.read().is_some() { return; } @@ -143,13 +104,7 @@ impl ConnectingStream { // be responsible to initialize all the I/O events including `IoEvents::OUT`, so the // following hard-coded event addition can be removed. if became_writable { - self.pollee.add_events(IoEvents::OUT); + pollee.add_events(IoEvents::OUT); } } } - -impl Observer<()> for ConnectingStream { - fn on_events(&self, _: &()) { - self.update_io_events(); - } -} diff --git a/kernel/aster-nix/src/net/socket/ip/stream/init.rs b/kernel/aster-nix/src/net/socket/ip/stream/init.rs index b48c3e833..352237a9c 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/init.rs @@ -1,156 +1,93 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; +use alloc::sync::Weak; use super::{connecting::ConnectingStream, listen::ListenStream}; use crate::{ - events::IoEvents, + events::Observer, net::{ - iface::{AnyBoundSocket, AnyUnboundSocket, Iface, IpEndpoint}, - socket::ip::{ - always_some::AlwaysSome, - common::{bind_socket, get_ephemeral_endpoint}, - }, + iface::{AnyBoundSocket, AnyUnboundSocket, IpEndpoint}, + socket::ip::common::{bind_socket, get_ephemeral_endpoint}, }, prelude::*, - process::signal::{Pollee, Poller}, }; -pub struct InitStream { - inner: RwLock, - is_nonblocking: AtomicBool, - pollee: Pollee, -} - -enum Inner { - Unbound(AlwaysSome>), - Bound(AlwaysSome>), -} - -impl Inner { - fn new() -> Inner { - let unbound_socket = Box::new(AnyUnboundSocket::new_tcp()); - Inner::Unbound(AlwaysSome::new(unbound_socket)) - } - - fn is_bound(&self) -> bool { - match self { - Self::Unbound(_) => false, - Self::Bound(_) => true, - } - } - - fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> { - let unbound_socket = if let Inner::Unbound(unbound_socket) = self { - unbound_socket - } else { - return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); - }; - let bound_socket = - unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?; - *self = Inner::Bound(AlwaysSome::new(bound_socket)); - Ok(()) - } - - fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> { - let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(endpoint) - } - - fn bound_socket(&self) -> Option<&Arc> { - match self { - Inner::Bound(bound_socket) => Some(bound_socket), - Inner::Unbound(_) => None, - } - } - - fn iface(&self) -> Option> { - match self { - Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()), - Inner::Unbound(_) => None, - } - } - - fn local_endpoint(&self) -> Option { - self.bound_socket() - .and_then(|socket| socket.local_endpoint()) - } +pub enum InitStream { + Unbound(Box), + Bound(Arc), } impl InitStream { // FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling // `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by // experimentation and Linux source code. - pub fn new(nonblocking: bool) -> Arc { - Arc::new(Self { - inner: RwLock::new(Inner::new()), - is_nonblocking: AtomicBool::new(nonblocking), - pollee: Pollee::new(IoEvents::empty()), - }) + pub fn new(observer: Weak>) -> Self { + InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer))) } - pub fn new_bound( - nonblocking: bool, - bound_socket: Arc, - pollee: Pollee, - ) -> Arc { - bound_socket.set_observer(Weak::<()>::new()); - let inner = Inner::Bound(AlwaysSome::new(bound_socket)); - Arc::new(Self { - is_nonblocking: AtomicBool::new(nonblocking), - inner: RwLock::new(inner), - pollee, - }) + pub fn new_bound(bound_socket: Arc) -> Self { + InitStream::Bound(bound_socket) } - pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> { - self.inner.write().bind(endpoint) - } - - pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result> { - if !self.inner.read().is_bound() { - self.inner - .write() - .bind_to_ephemeral_endpoint(remote_endpoint)? - } - ConnectingStream::new( - self.is_nonblocking(), - self.inner.read().bound_socket().unwrap().clone(), - *remote_endpoint, - self.pollee.clone(), - ) - } - - pub fn listen(&self, backlog: usize) -> Result> { - let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() { - bound_socket.clone() - } else { - return_errno_with_message!(Errno::EINVAL, "cannot listen without bound") + pub fn bind( + self, + endpoint: &IpEndpoint, + ) -> core::result::Result, (Error, Self)> { + let unbound_socket = match self { + InitStream::Unbound(unbound_socket) => unbound_socket, + InitStream::Bound(bound_socket) => { + return Err(( + Error::with_message(Errno::EINVAL, "the socket is already bound to an address"), + InitStream::Bound(bound_socket), + )); + } }; - ListenStream::new( - self.is_nonblocking(), - bound_socket, - backlog, - self.pollee.clone(), - ) + let bound_socket = match bind_socket(unbound_socket, endpoint, false) { + Ok(bound_socket) => bound_socket, + Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))), + }; + Ok(bound_socket) + } + + fn bind_to_ephemeral_endpoint( + self, + remote_endpoint: &IpEndpoint, + ) -> core::result::Result, (Error, Self)> { + let endpoint = get_ephemeral_endpoint(remote_endpoint); + self.bind(&endpoint) + } + + pub fn connect( + self, + remote_endpoint: &IpEndpoint, + ) -> core::result::Result { + let bound_socket = match self { + InitStream::Bound(bound_socket) => bound_socket, + InitStream::Unbound(_) => self.bind_to_ephemeral_endpoint(remote_endpoint)?, + }; + + ConnectingStream::new(bound_socket, *remote_endpoint) + .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) + } + + pub fn listen(self, backlog: usize) -> core::result::Result { + let InitStream::Bound(bound_socket) = self else { + return Err(( + Error::with_message(Errno::EINVAL, "cannot listen without bound"), + self, + )); + }; + + ListenStream::new(bound_socket, backlog) + .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) } pub fn local_endpoint(&self) -> Result { - self.inner - .read() - .local_endpoint() - .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint")) - } - - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) - } - - pub fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - pub fn set_nonblocking(&self, nonblocking: bool) { - self.is_nonblocking.store(nonblocking, Ordering::Relaxed); + match self { + InitStream::Unbound(_) => { + return_errno_with_message!(Errno::EINVAL, "does not has local endpoint") + } + InitStream::Bound(bound_socket) => Ok(bound_socket.local_endpoint().unwrap()), + } } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs index 624f8b20f..a970eaea1 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs @@ -1,148 +1,97 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; - use super::connected::ConnectedStream; use crate::{ - events::{IoEvents, Observer}, - net::{ - iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket}, - poll_ifaces, - }, + events::IoEvents, + net::iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket}, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::Pollee, }; pub struct ListenStream { - is_nonblocking: AtomicBool, backlog: usize, /// A bound socket held to ensure the TCP port cannot be released bound_socket: Arc, /// Backlog sockets listening at the local endpoint backlog_sockets: RwLock>, - pollee: Pollee, } impl ListenStream { pub fn new( - nonblocking: bool, bound_socket: Arc, backlog: usize, - pollee: Pollee, - ) -> Result> { - let listen_stream = Arc::new(Self { - is_nonblocking: AtomicBool::new(nonblocking), + ) -> core::result::Result)> { + let listen_stream = Self { backlog, bound_socket, backlog_sockets: RwLock::new(Vec::new()), - pollee, - }); - listen_stream.fill_backlog_sockets()?; - listen_stream.pollee.reset_events(); - listen_stream - .bound_socket - .set_observer(Arc::downgrade(&listen_stream) as _); - Ok(listen_stream) - } - - pub fn accept(&self) -> Result<(Arc, IpEndpoint)> { - // wait to accept - let poller = Poller::new(); - loop { - poll_ifaces(); - let accepted_socket = if let Some(accepted_socket) = self.try_accept() { - accepted_socket - } else { - let events = self.poll(IoEvents::IN, Some(&poller)); - if !events.contains(IoEvents::IN) { - if self.is_nonblocking() { - return_errno_with_message!(Errno::EAGAIN, "try accept again"); - } - // FIXME: deal with accept timeout - poller.wait()?; - } - continue; - }; - let remote_endpoint = accepted_socket.remote_endpoint().unwrap(); - let connected_stream = { - let BacklogSocket { - bound_socket: backlog_socket, - } = accepted_socket; - ConnectedStream::new( - false, - backlog_socket, - remote_endpoint, - Pollee::new(IoEvents::empty()), - ) - }; - return Ok((connected_stream, remote_endpoint)); + }; + if let Err(err) = listen_stream.fill_backlog_sockets() { + return Err((err, listen_stream.bound_socket)); } + Ok(listen_stream) } /// Append sockets listening at LocalEndPoint to support backlog fn fill_backlog_sockets(&self) -> Result<()> { - let backlog = self.backlog; let mut backlog_sockets = self.backlog_sockets.write(); + + let backlog = self.backlog; let current_backlog_len = backlog_sockets.len(); debug_assert!(backlog >= current_backlog_len); if backlog == current_backlog_len { return Ok(()); } + for _ in current_backlog_len..backlog { let backlog_socket = BacklogSocket::new(&self.bound_socket)?; backlog_sockets.push(backlog_socket); } + Ok(()) } - fn try_accept(&self) -> Option { - let backlog_socket = { - let mut backlog_sockets = self.backlog_sockets.write(); - let index = backlog_sockets - .iter() - .position(|backlog_socket| backlog_socket.is_active())?; - backlog_sockets.remove(index) - }; - self.fill_backlog_sockets().unwrap(); - self.update_io_events(); - Some(backlog_socket) + pub fn try_accept(&self) -> Result { + let mut backlog_sockets = self.backlog_sockets.write(); + + let index = backlog_sockets + .iter() + .position(|backlog_socket| backlog_socket.is_active()) + .ok_or_else(|| Error::with_message(Errno::EAGAIN, "try to accept again"))?; + let active_backlog_socket = backlog_sockets.remove(index); + + match BacklogSocket::new(&self.bound_socket) { + Ok(backlog_socket) => backlog_sockets.push(backlog_socket), + Err(err) => (), + } + + let remote_endpoint = active_backlog_socket.remote_endpoint().unwrap(); + Ok(ConnectedStream::new( + active_backlog_socket.into_bound_socket(), + remote_endpoint, + )) } - pub fn local_endpoint(&self) -> Result { - self.bound_socket - .local_endpoint() - .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint")) + pub fn local_endpoint(&self) -> IpEndpoint { + self.bound_socket.local_endpoint().unwrap() } - pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) + pub(super) fn init_pollee(&self, pollee: &Pollee) { + pollee.reset_events(); + self.update_io_events(pollee); } - fn update_io_events(&self) { + pub(super) fn update_io_events(&self, pollee: &Pollee) { // The lock should be held to avoid data races let backlog_sockets = self.backlog_sockets.read(); let can_accept = backlog_sockets.iter().any(|socket| socket.is_active()); if can_accept { - self.pollee.add_events(IoEvents::IN); + pollee.add_events(IoEvents::IN); } else { - self.pollee.del_events(IoEvents::IN); + pollee.del_events(IoEvents::IN); } } - - pub fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - pub fn set_nonblocking(&self, nonblocking: bool) { - self.is_nonblocking.store(nonblocking, Ordering::Relaxed); - } -} - -impl Observer<()> for ListenStream { - fn on_events(&self, _: &()) { - self.update_io_events(); - } } struct BacklogSocket { @@ -155,19 +104,21 @@ impl BacklogSocket { Errno::EINVAL, "the socket is not bound", ))?; - let unbound_socket = Box::new(AnyUnboundSocket::new_tcp()); + + let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new())); let bound_socket = { let iface = bound_socket.iface(); let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?; iface .bind_socket(unbound_socket, bind_port_config) - .map_err(|(e, _)| e)? + .map_err(|(err, _)| err)? }; bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| { raw_tcp_socket .listen(local_endpoint) .map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen")) })?; + Ok(Self { bound_socket }) } @@ -180,4 +131,8 @@ impl BacklogSocket { self.bound_socket .raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint()) } + + fn into_bound_socket(self) -> Arc { + self.bound_socket + } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs index 74185c3ec..df23c2899 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs @@ -1,5 +1,10 @@ // SPDX-License-Identifier: MPL-2.0 +use core::{ + mem, + sync::atomic::{AtomicBool, Ordering}, +}; + use connected::ConnectedStream; use connecting::ConnectingStream; use init::InitStream; @@ -9,21 +14,24 @@ use smoltcp::wire::IpEndpoint; use util::{TcpOptionSet, DEFAULT_MAXSEG}; use crate::{ - events::IoEvents, + events::{IoEvents, Observer}, fs::{file_handle::FileLike, utils::StatusFlags}, match_sock_option_mut, match_sock_option_ref, - net::socket::{ - options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption}, - util::{ - options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF}, - send_recv_flags::SendRecvFlags, - shutdown_cmd::SockShutdownCmd, - socket_addr::SocketAddr, + net::{ + poll_ifaces, + socket::{ + options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption}, + util::{ + options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF}, + send_recv_flags::SendRecvFlags, + shutdown_cmd::SockShutdownCmd, + socket_addr::SocketAddr, + }, + Socket, }, - Socket, }, prelude::*, - process::signal::Poller, + process::signal::{Pollee, Poller}, }; mod connected; @@ -33,22 +41,27 @@ mod listen; pub mod options; mod util; +use self::connecting::NonConnectedStream; pub use self::util::CongestionControl; pub struct StreamSocket { options: RwLock, state: RwLock, + is_nonblocking: AtomicBool, + pollee: Pollee, } enum State { // Start state - Init(Arc), + Init(InitStream), // Intermediate state - Connecting(Arc), + Connecting(ConnectingStream), // Final State 1 - Connected(Arc), + Connected(ConnectedStream), // Final State 2 - Listen(Arc), + Listen(ListenStream), + // Poisoned state + Poisoned, } #[derive(Debug, Clone)] @@ -66,45 +79,159 @@ impl OptionSet { } impl StreamSocket { - pub fn new(nonblocking: bool) -> Self { - let options = OptionSet::new(); - let state = State::Init(InitStream::new(nonblocking)); - Self { - options: RwLock::new(options), - state: RwLock::new(state), - } + pub fn new(nonblocking: bool) -> Arc { + Arc::new_cyclic(|me| { + let init_stream = InitStream::new(me.clone() as _); + let pollee = Pollee::new(IoEvents::empty()); + Self { + options: RwLock::new(OptionSet::new()), + state: RwLock::new(State::Init(init_stream)), + is_nonblocking: AtomicBool::new(nonblocking), + pollee, + } + }) + } + + fn new_connected(connected_stream: ConnectedStream) -> Arc { + Arc::new_cyclic(move |me| { + let pollee = Pollee::new(IoEvents::empty()); + connected_stream.set_observer(me.clone() as _); + connected_stream.init_pollee(&pollee); + Self { + options: RwLock::new(OptionSet::new()), + state: RwLock::new(State::Connected(connected_stream)), + is_nonblocking: AtomicBool::new(false), + pollee, + } + }) } fn is_nonblocking(&self) -> bool { - match &*self.state.read() { - State::Init(init) => init.is_nonblocking(), - State::Connecting(connecting) => connecting.is_nonblocking(), - State::Connected(connected) => connected.is_nonblocking(), - State::Listen(listen) => listen.is_nonblocking(), - } + self.is_nonblocking.load(Ordering::Relaxed) } fn set_nonblocking(&self, nonblocking: bool) { - match &*self.state.read() { - State::Init(init) => init.set_nonblocking(nonblocking), - State::Connecting(connecting) => connecting.set_nonblocking(nonblocking), - State::Connected(connected) => connected.set_nonblocking(nonblocking), - State::Listen(listen) => listen.set_nonblocking(nonblocking), + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); + } + + fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> { + let mut state = self.state.write(); + + let owned_state = mem::replace(&mut *state, State::Poisoned); + let State::Init(init_stream) = owned_state else { + *state = owned_state; + return_errno_with_message!(Errno::EINVAL, "cannot connect") + }; + + let connecting_stream = match init_stream.connect(remote_endpoint) { + Ok(connecting_stream) => connecting_stream, + Err((err, init_stream)) => { + *state = State::Init(init_stream); + return Err(err); + } + }; + connecting_stream.init_pollee(&self.pollee); + *state = State::Connecting(connecting_stream); + + Ok(()) + } + + fn finish_connect(&self) -> Result<()> { + let mut state = self.state.write(); + + let owned_state = mem::replace(&mut *state, State::Poisoned); + let State::Connecting(connecting_stream) = owned_state else { + *state = owned_state; + debug_assert!(false, "the socket unexpectedly left the connecting state"); + return_errno_with_message!(Errno::EINVAL, "the socket is not connecting"); + }; + + let connected_stream = match connecting_stream.into_result() { + Ok(connected_stream) => connected_stream, + Err((err, NonConnectedStream::Init(init_stream))) => { + *state = State::Init(init_stream); + return Err(err); + } + Err((err, NonConnectedStream::Connecting(connecting_stream))) => { + *state = State::Connecting(connecting_stream); + return Err(err); + } + }; + connected_stream.init_pollee(&self.pollee); + *state = State::Connected(connected_stream); + + Ok(()) + } + + fn try_accept(&self) -> Result<(Arc, SocketAddr)> { + let state = self.state.read(); + + let State::Listen(listen_stream) = &*state else { + return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); + }; + + let connected_stream = listen_stream.try_accept()?; + listen_stream.update_io_events(&self.pollee); + + let remote_endpoint = connected_stream.remote_endpoint(); + let accepted_socket = Self::new_connected(connected_stream); + Ok((accepted_socket, remote_endpoint.try_into()?)) + } + + fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + let state = self.state.read(); + + let State::Connected(connected_stream) = &*state else { + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); + }; + let recv_bytes = connected_stream.try_recvfrom(buf, flags)?; + connected_stream.update_io_events(&self.pollee); + Ok((recv_bytes, connected_stream.remote_endpoint().try_into()?)) + } + + fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + let state = self.state.read(); + + let State::Connected(connected_stream) = &*state else { + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); + }; + let sent_bytes = connected_stream.try_sendto(buf, flags)?; + connected_stream.update_io_events(&self.pollee); + Ok(sent_bytes) + } + + // TODO: Support timeout + fn wait_events(&self, mask: IoEvents, mut cond: F) -> Result + where + F: FnMut() -> Result, + { + let poller = Poller::new(); + + loop { + match cond() { + Err(err) if err.error() == Errno::EAGAIN => (), + result => return result, + }; + + let events = self.poll(mask, Some(&poller)); + if !events.is_empty() { + continue; + } + + poller.wait()?; } } - fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result> { - let mut state = self.state.write(); - let init_stream = match &*state { - State::Init(init_stream) => init_stream, - State::Listen(_) | State::Connecting(_) | State::Connected(_) => { - return_errno_with_message!(Errno::EINVAL, "cannot connect") + fn update_io_events(&self) { + let state = self.state.read(); + match &*state { + State::Init(_) | State::Poisoned => (), + State::Connecting(connecting_stream) => { + connecting_stream.update_io_events(&self.pollee) } - }; - - let connecting = init_stream.connect(remote_endpoint)?; - *state = State::Connecting(connecting.clone()); - Ok(connecting) + State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee), + State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee), + } } } @@ -123,13 +250,7 @@ impl FileLike for StreamSocket { } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - let state = self.state.read(); - match &*state { - State::Init(init) => init.poll(mask, poller), - State::Connecting(connecting) => connecting.poll(mask, poller), - State::Connected(connected) => connected.poll(mask, poller), - State::Listen(listen) => listen.poll(mask, poller), - } + self.pollee.poll(mask, poller) } fn status_flags(&self) -> StatusFlags { @@ -157,68 +278,65 @@ impl FileLike for StreamSocket { impl Socket for StreamSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { let endpoint = socket_addr.try_into()?; - let state = self.state.read(); - match &*state { - State::Init(init_stream) => init_stream.bind(endpoint), - _ => return_errno_with_message!(Errno::EINVAL, "cannot bind"), - } + + let mut state = self.state.write(); + + let owned_state = mem::replace(&mut *state, State::Poisoned); + let State::Init(init_stream) = owned_state else { + *state = owned_state; + return_errno_with_message!(Errno::EINVAL, "cannot bind"); + }; + + let bound_socket = match init_stream.bind(&endpoint) { + Ok(bound_socket) => bound_socket, + Err((err, init_stream)) => { + *state = State::Init(init_stream); + return Err(err); + } + }; + *state = State::Init(InitStream::new_bound(bound_socket)); + + Ok(()) } + // TODO: Support nonblocking mode fn connect(&self, socket_addr: SocketAddr) -> Result<()> { let remote_endpoint = socket_addr.try_into()?; + self.start_connect(&remote_endpoint)?; - let connecting_stream = self.do_connect(&remote_endpoint)?; - match connecting_stream.wait_conn() { - Ok(connected_stream) => { - *self.state.write() = State::Connected(connected_stream); - Ok(()) - } - Err((err, init_stream)) => { - *self.state.write() = State::Init(init_stream); - Err(err) - } - } + poll_ifaces(); + self.wait_events(IoEvents::OUT, || self.finish_connect()) } fn listen(&self, backlog: usize) -> Result<()> { let mut state = self.state.write(); - let init_stream = match &*state { - State::Init(init_stream) => init_stream, - State::Connecting(connecting_stream) => { - return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream") - } - State::Listen(listen_stream) => { - return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream") - } - State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"), + + let owned_state = mem::replace(&mut *state, State::Poisoned); + let State::Init(init_stream) = owned_state else { + *state = owned_state; + return_errno_with_message!(Errno::EINVAL, "cannot listen"); }; - let listener = init_stream.listen(backlog)?; - *state = State::Listen(listener); + let listen_stream = match init_stream.listen(backlog) { + Ok(listen_stream) => listen_stream, + Err((err, init_stream)) => { + *state = State::Init(init_stream); + return Err(err); + } + }; + listen_stream.init_pollee(&self.pollee); + *state = State::Listen(listen_stream); + Ok(()) } fn accept(&self) -> Result<(Arc, SocketAddr)> { - let listen_stream = match &*self.state.read() { - State::Listen(listen_stream) => listen_stream.clone(), - _ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"), - }; - - let (connected_stream, remote_endpoint) = { - let listen_stream = listen_stream.clone(); - listen_stream.accept()? - }; - - let accepted_socket = { - let state = RwLock::new(State::Connected(connected_stream)); - Arc::new(StreamSocket { - options: RwLock::new(OptionSet::new()), - state, - }) - }; - - let socket_addr = remote_endpoint.try_into()?; - Ok((accepted_socket, socket_addr)) + poll_ifaces(); + if self.is_nonblocking() { + self.try_accept() + } else { + self.wait_events(IoEvents::IN, || self.try_accept()) + } } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -233,11 +351,12 @@ impl Socket for StreamSocket { fn addr(&self) -> Result { let state = self.state.read(); let local_endpoint = match &*state { - State::Init(init_stream) => init_stream.local_endpoint(), + State::Init(init_stream) => init_stream.local_endpoint()?, State::Connecting(connecting_stream) => connecting_stream.local_endpoint(), State::Listen(listen_stream) => listen_stream.local_endpoint(), State::Connected(connected_stream) => connected_stream.local_endpoint(), - }?; + State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"), + }; local_endpoint.try_into() } @@ -252,19 +371,20 @@ impl Socket for StreamSocket { return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer") } State::Connected(connected_stream) => connected_stream.remote_endpoint(), - }?; + State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"), + }; remote_endpoint.try_into() } fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - let connected_stream = match &*self.state.read() { - State::Connected(connected_stream) => connected_stream.clone(), - _ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"), - }; + debug_assert!(flags.is_all_supported()); - let (recv_size, remote_endpoint) = connected_stream.recvfrom(buf, flags)?; - let socket_addr = remote_endpoint.try_into()?; - Ok((recv_size, socket_addr)) + poll_ifaces(); + if self.is_nonblocking() { + self.try_recvfrom(buf, flags) + } else { + self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags)) + } } fn sendto( @@ -273,16 +393,19 @@ impl Socket for StreamSocket { remote: Option, flags: SendRecvFlags, ) -> Result { - debug_assert!(remote.is_none()); + debug_assert!(flags.is_all_supported()); + if remote.is_some() { return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr"); } - let connected_stream = match &*self.state.read() { - State::Connected(connected_stream) => connected_stream.clone(), - _ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"), + let sent_bytes = if self.is_nonblocking() { + self.try_sendto(buf, flags)? + } else { + self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))? }; - connected_stream.sendto(buf, flags) + poll_ifaces(); + Ok(sent_bytes) } fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { @@ -324,7 +447,9 @@ impl Socket for StreamSocket { // FIXME: how to get the current MSS? let maxseg = match &*self.state.read() { - State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG, + State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => { + DEFAULT_MAXSEG + } State::Connected(_) => options.tcp.maxseg(), }; tcp_maxseg.set(maxseg); @@ -348,7 +473,7 @@ impl Socket for StreamSocket { let recv_buf = socket_recv_buf.get().unwrap(); if *recv_buf <= MIN_RECVBUF { options.socket.set_recv_buf(MIN_RECVBUF); - } else{ + } else { options.socket.set_recv_buf(*recv_buf); } }, @@ -406,3 +531,9 @@ impl Socket for StreamSocket { Ok(()) } } + +impl Observer<()> for StreamSocket { + fn on_events(&self, events: &()) { + self.update_io_events(); + } +} diff --git a/kernel/aster-nix/src/syscall/socket.rs b/kernel/aster-nix/src/syscall/socket.rs index 34d6837a9..924a952fe 100644 --- a/kernel/aster-nix/src/syscall/socket.rs +++ b/kernel/aster-nix/src/syscall/socket.rs @@ -31,12 +31,12 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result Arc::new(StreamSocket::new(nonblocking)) as Arc, + ) => StreamSocket::new(nonblocking) as Arc, ( CSocketAddrFamily::AF_INET, SockType::SOCK_DGRAM, Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP, - ) => Arc::new(DatagramSocket::new(nonblocking)) as Arc, + ) => DatagramSocket::new(nonblocking) as Arc, _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"), }; let fd = {