From e83e1fc01ba38ad2a405d7d710ec7258fb664f60 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 30 Jun 2024 23:08:25 +0800 Subject: [PATCH] Unpack states from `Arc` in UNIX sockets --- .../src/net/socket/unix/stream/connected.rs | 16 +- .../src/net/socket/unix/stream/endpoint.rs | 41 +--- .../src/net/socket/unix/stream/init.rs | 17 +- .../src/net/socket/unix/stream/listener.rs | 60 +----- .../src/net/socket/unix/stream/socket.rs | 203 +++++++++--------- kernel/aster-nix/src/syscall/socket.rs | 2 +- 6 files changed, 123 insertions(+), 216 deletions(-) diff --git a/kernel/aster-nix/src/net/socket/unix/stream/connected.rs b/kernel/aster-nix/src/net/socket/unix/stream/connected.rs index 9ef1efcc0..e8deb0855 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/connected.rs @@ -25,26 +25,18 @@ impl Connected { self.local_endpoint.peer_addr() } - pub(super) fn write(&self, buf: &[u8]) -> Result { - self.local_endpoint.write(buf) + pub(super) fn try_write(&self, buf: &[u8]) -> Result { + self.local_endpoint.try_write(buf) } - pub(super) fn read(&self, buf: &mut [u8]) -> Result { - self.local_endpoint.read(buf) + pub(super) fn try_read(&self, buf: &mut [u8]) -> Result { + self.local_endpoint.try_read(buf) } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { self.local_endpoint.shutdown(cmd) } - pub(super) fn is_nonblocking(&self) -> bool { - self.local_endpoint.is_nonblocking() - } - - pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { - self.local_endpoint.set_nonblocking(is_nonblocking).unwrap(); - } - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { self.local_endpoint.poll(mask, poller) } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs index 769d0932a..c780c0610 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs @@ -1,15 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::AtomicBool; - -use atomic::Ordering; - use crate::{ events::IoEvents, fs::utils::{Channel, Consumer, Producer}, net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, prelude::*, - process::signal::{Pollable, Poller}, + process::signal::Poller, }; pub(super) struct Endpoint { @@ -17,14 +13,12 @@ pub(super) struct Endpoint { peer_addr: Option, reader: Consumer, writer: Producer, - is_nonblocking: AtomicBool, } impl Endpoint { pub(super) fn new_pair( addr: Option, peer_addr: Option, - is_nonblocking: bool, ) -> (Endpoint, Endpoint) { let (writer_this, reader_peer) = Channel::new(DAFAULT_BUF_SIZE).split(); let (writer_peer, reader_this) = Channel::new(DAFAULT_BUF_SIZE).split(); @@ -34,14 +28,12 @@ impl Endpoint { peer_addr: peer_addr.clone(), reader: reader_this, writer: writer_this, - is_nonblocking: AtomicBool::new(is_nonblocking), }; let peer = Endpoint { addr: peer_addr, peer_addr: addr, reader: reader_peer, writer: writer_peer, - is_nonblocking: AtomicBool::new(is_nonblocking), }; (this, peer) @@ -55,29 +47,12 @@ impl Endpoint { self.peer_addr.as_ref() } - pub(super) fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) + pub(super) fn try_read(&self, buf: &mut [u8]) -> Result { + self.reader.try_read(buf) } - pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> { - self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); - Ok(()) - } - - pub(super) fn read(&self, buf: &mut [u8]) -> Result { - if self.is_nonblocking() { - self.reader.try_read(buf) - } else { - self.wait_events(IoEvents::IN, || self.reader.try_read(buf)) - } - } - - pub(super) fn write(&self, buf: &[u8]) -> Result { - if self.is_nonblocking() { - self.writer.try_write(buf) - } else { - self.wait_events(IoEvents::OUT, || self.writer.try_write(buf)) - } + pub(super) fn try_write(&self, buf: &[u8]) -> Result { + self.writer.try_write(buf) } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -115,10 +90,4 @@ impl Endpoint { } } -impl Pollable for Endpoint { - fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.poll(mask, poller) - } -} - const DAFAULT_BUF_SIZE: usize = 4096; diff --git a/kernel/aster-nix/src/net/socket/unix/stream/init.rs b/kernel/aster-nix/src/net/socket/unix/stream/init.rs index d53331a28..2abaeacb9 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/init.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; - use super::{connected::Connected, endpoint::Endpoint, listener::push_incoming}; use crate::{ events::IoEvents, @@ -18,15 +16,13 @@ use crate::{ pub(super) struct Init { addr: Mutex>, pollee: Pollee, - is_nonblocking: AtomicBool, } impl Init { - pub(super) fn new(is_nonblocking: bool) -> Self { + pub(super) fn new() -> Self { Self { addr: Mutex::new(None), pollee: Pollee::new(IoEvents::empty()), - is_nonblocking: AtomicBool::new(is_nonblocking), } } @@ -57,8 +53,7 @@ impl Init { } } - let (this_end, remote_end) = - Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking()); + let (this_end, remote_end) = Endpoint::new_pair(addr, Some(remote_addr.clone())); push_incoming(remote_addr, remote_end)?; Ok(Connected::new(this_end)) @@ -68,14 +63,6 @@ impl Init { self.addr.lock().clone() } - pub(super) fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { - self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); - } - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { self.pollee.poll(mask, poller) } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs index 3d540312b..90b8e3515 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; - use keyable_arc::KeyableWeak; use super::{connected::Connected, endpoint::Endpoint, UnixStreamSocket}; @@ -18,40 +16,23 @@ use crate::{ pub(super) struct Listener { addr: UnixSocketAddrBound, - is_nonblocking: AtomicBool, } impl Listener { - pub(super) fn new( - addr: UnixSocketAddrBound, - backlog: usize, - nonblocking: bool, - ) -> Result { + pub(super) fn new(addr: UnixSocketAddrBound, backlog: usize) -> Result { BACKLOG_TABLE.add_backlog(&addr, backlog)?; - Ok(Self { - addr, - is_nonblocking: AtomicBool::new(nonblocking), - }) + Ok(Self { addr }) } pub(super) fn addr(&self) -> &UnixSocketAddrBound { &self.addr } - pub(super) fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { - self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); - } - - pub(super) fn accept(&self) -> Result<(Arc, SocketAddr)> { + pub(super) fn try_accept(&self) -> Result<(Arc, SocketAddr)> { let addr = self.addr().clone(); - let is_nonblocking = self.is_nonblocking(); let connected = { - let local_endpoint = BACKLOG_TABLE.pop_incoming(is_nonblocking, &addr)?; + let local_endpoint = BACKLOG_TABLE.pop_incoming(&addr)?; Connected::new(local_endpoint) }; @@ -60,7 +41,7 @@ impl Listener { Some(addr) => SocketAddr::from(addr.clone()), }; - let socket = Arc::new(UnixStreamSocket::new_connected(connected)); + let socket = UnixStreamSocket::new_connected(connected, false); Ok((socket, peer_addr)) } @@ -118,32 +99,13 @@ impl BacklogTable { .ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened")) } - fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result { - let mut poller = Poller::new(); - loop { - let backlog = self.get_backlog(addr)?; + fn pop_incoming(&self, addr: &UnixSocketAddrBound) -> Result { + let backlog = self.get_backlog(addr)?; - if let Some(endpoint) = backlog.pop_incoming() { - return Ok(endpoint); - } - - if nonblocking { - return_errno_with_message!(Errno::EAGAIN, "no connection comes"); - } - - let events = { - let mask = IoEvents::IN; - backlog.poll(mask, Some(&mut poller)) - }; - - if events.contains(IoEvents::ERR) | events.contains(IoEvents::HUP) { - return_errno_with_message!(Errno::ECONNABORTED, "connection is aborted"); - } - - // FIXME: deal with accept timeout - if events.is_empty() { - poller.wait()?; - } + if let Some(endpoint) = backlog.pop_incoming() { + Ok(endpoint) + } else { + return_errno_with_message!(Errno::EAGAIN, "no pending connection is available") } } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index 41e41529c..9bdea7595 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -1,5 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 +use core::sync::atomic::AtomicBool; + +use atomic::Ordering; + use super::{ connected::Connected, endpoint::Endpoint, @@ -27,87 +31,104 @@ use crate::{ util::IoVec, }; -pub struct UnixStreamSocket(RwLock); +pub struct UnixStreamSocket { + state: RwLock, + is_nonblocking: AtomicBool, +} impl UnixStreamSocket { - pub(super) fn new_init(init: Init) -> Self { - Self(RwLock::new(State::Init(Arc::new(init)))) + pub(super) fn new_init(init: Init, is_nonblocking: bool) -> Arc { + Arc::new(Self { + state: RwLock::new(State::Init(init)), + is_nonblocking: AtomicBool::new(is_nonblocking), + }) } - pub(super) fn new_connected(connected: Connected) -> Self { - Self(RwLock::new(State::Connected(Arc::new(connected)))) + pub(super) fn new_connected(connected: Connected, is_nonblocking: bool) -> Arc { + Arc::new(Self { + state: RwLock::new(State::Connected(connected)), + is_nonblocking: AtomicBool::new(is_nonblocking), + }) } } enum State { - Init(Arc), - Listen(Arc), - Connected(Arc), + Init(Init), + Listen(Listener), + Connected(Connected), } impl UnixStreamSocket { - pub fn new(nonblocking: bool) -> Self { - let init = Init::new(nonblocking); - Self::new_init(init) + pub fn new(is_nonblocking: bool) -> Arc { + Self::new_init(Init::new(), is_nonblocking) } - pub fn new_pair(nonblocking: bool) -> (Arc, Arc) { - let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking); - - let connected_a = { - let connected = Connected::new(end_a); - Self::new_connected(connected) - }; - let connected_b = { - let connected = Connected::new(end_b); - Self::new_connected(connected) - }; - - (Arc::new(connected_a), Arc::new(connected_b)) + pub fn new_pair(is_nonblocking: bool) -> (Arc, Arc) { + let (end_a, end_b) = Endpoint::new_pair(None, None); + ( + Self::new_connected(Connected::new(end_a), is_nonblocking), + Self::new_connected(Connected::new(end_b), is_nonblocking), + ) } fn bound_addr(&self) -> Option { - let status = self.0.read(); - match &*status { + let state = self.state.read(); + match &*state { State::Init(init) => init.addr(), State::Listen(listen) => Some(listen.addr().clone()), State::Connected(connected) => connected.addr().cloned(), } } - fn mask_flags(status_flags: &StatusFlags) -> StatusFlags { - const SUPPORTED_FLAGS: StatusFlags = StatusFlags::O_NONBLOCK; - const UNSUPPORTED_FLAGS: StatusFlags = SUPPORTED_FLAGS.complement(); - - if status_flags.intersects(UNSUPPORTED_FLAGS) { - warn!("ignore unsupported flags"); + fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + if self.is_nonblocking() { + self.try_send(buf, flags) + } else { + self.wait_events(IoEvents::OUT, || self.try_send(buf, flags)) } - - status_flags.intersection(SUPPORTED_FLAGS) } - fn send(&self, buf: &[u8], _flags: SendRecvFlags) -> Result { - let connected = match &*self.0.read() { - State::Connected(connected) => connected.clone(), + fn try_send(&self, buf: &[u8], _flags: SendRecvFlags) -> Result { + match &*self.state.read() { + State::Connected(connected) => connected.try_write(buf), _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), - }; - - connected.write(buf) + } } - fn recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result { - let connected = match &*self.0.read() { - State::Connected(connected) => connected.clone(), - _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), - }; + fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { + if self.is_nonblocking() { + self.try_recv(buf, flags) + } else { + self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + } + } - connected.read(buf) + fn try_recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result { + match &*self.state.read() { + State::Connected(connected) => connected.try_read(buf), + _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), + } + } + + fn try_accept(&self) -> Result<(Arc, SocketAddr)> { + match &*self.state.read() { + State::Listen(listen) => listen.try_accept() as _, + _ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"), + } + } + + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) + } + + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } } impl Pollable for UnixStreamSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - let inner = self.0.read(); + let inner = self.state.read(); match &*inner { State::Init(init) => init.poll(mask, poller), State::Listen(listen) => listen.poll(mask, poller), @@ -134,15 +155,7 @@ impl FileLike for UnixStreamSocket { } fn status_flags(&self) -> StatusFlags { - let inner = self.0.read(); - let is_nonblocking = match &*inner { - State::Init(init) => init.is_nonblocking(), - State::Listen(listen) => listen.is_nonblocking(), - State::Connected(connected) => connected.is_nonblocking(), - }; - - // TODO: when we fully support O_ASYNC, return the flag - if is_nonblocking { + if self.is_nonblocking() { StatusFlags::O_NONBLOCK } else { StatusFlags::empty() @@ -150,17 +163,7 @@ impl FileLike for UnixStreamSocket { } fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - let is_nonblocking = { - let supported_flags = Self::mask_flags(&new_flags); - supported_flags.contains(StatusFlags::O_NONBLOCK) - }; - - let mut inner = self.0.write(); - match &mut *inner { - State::Init(init) => init.set_nonblocking(is_nonblocking), - State::Listen(listen) => listen.set_nonblocking(is_nonblocking), - State::Connected(connected) => connected.set_nonblocking(is_nonblocking), - } + self.set_nonblocking(new_flags.contains(StatusFlags::O_NONBLOCK)); Ok(()) } } @@ -169,16 +172,14 @@ impl Socket for UnixStreamSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { let addr = UnixSocketAddr::try_from(socket_addr)?; - let init = match &*self.0.read() { - State::Init(init) => init.clone(), + match &*self.state.read() { + State::Init(init) => init.bind(&addr), _ => return_errno_with_message!( Errno::EINVAL, "cannot bind a listening or connected socket" ), // FIXME: Maybe binding a connected socket should also be allowed? - }; - - init.bind(&addr) + } } fn connect(&self, socket_addr: SocketAddr) -> Result<()> { @@ -195,23 +196,27 @@ impl Socket for UnixStreamSocket { } }; - let init = match &*self.0.read() { - State::Init(init) => init.clone(), + let connected = match &*self.state.read() { + State::Init(init) => init.connect(&remote_addr)?, State::Listen(_) => return_errno_with_message!(Errno::EINVAL, "the socket is listened"), State::Connected(_) => { return_errno_with_message!(Errno::EISCONN, "the socket is connected") } }; - let connected = init.connect(&remote_addr)?; - - *self.0.write() = State::Connected(Arc::new(connected)); + *self.state.write() = State::Connected(connected); Ok(()) } fn listen(&self, backlog: usize) -> Result<()> { - let init = match &*self.0.read() { - State::Init(init) => init.clone(), + let addr = match &*self.state.read() { + State::Init(init) => init + .addr() + .ok_or(Error::with_message( + Errno::EINVAL, + "the socket is not bound", + ))? + .clone(), State::Listen(_) => { return_errno_with_message!(Errno::EINVAL, "the socket is already listening") } @@ -220,36 +225,28 @@ impl Socket for UnixStreamSocket { } }; - let addr = init.addr().ok_or(Error::with_message( - Errno::EINVAL, - "the socket is not bound", - ))?; - - let listener = Listener::new(addr.clone(), backlog, init.is_nonblocking())?; - *self.0.write() = State::Listen(Arc::new(listener)); + let listener = Listener::new(addr, backlog)?; + *self.state.write() = State::Listen(listener); Ok(()) } fn accept(&self) -> Result<(Arc, SocketAddr)> { - let listen = match &*self.0.read() { - State::Listen(listen) => listen.clone(), - _ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"), - }; - - listen.accept() + if self.is_nonblocking() { + self.try_accept() + } else { + self.wait_events(IoEvents::IN, || self.try_accept()) + } } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - let connected = match &*self.0.read() { - State::Connected(connected) => connected.clone(), + match &*self.state.read() { + State::Connected(connected) => connected.shutdown(cmd), _ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"), - }; - - connected.shutdown(cmd) + } } fn addr(&self) -> Result { - let addr = match &*self.0.read() { + let addr = match &*self.state.read() { State::Init(init) => init.addr(), State::Listen(listen) => Some(listen.addr().clone()), State::Connected(connected) => connected.addr().cloned(), @@ -263,14 +260,14 @@ impl Socket for UnixStreamSocket { } fn peer_addr(&self) -> Result { - let connected = match &*self.0.read() { - State::Connected(connected) => connected.clone(), + let peer_addr = match &*self.state.read() { + State::Connected(connected) => connected.peer_addr().cloned(), _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), }; - match connected.peer_addr() { + match peer_addr { None => Ok(SocketAddr::Unix(UnixSocketAddr::Path(String::new()))), - Some(peer_addr) => Ok(SocketAddr::from(peer_addr.clone())), + Some(peer_addr) => Ok(SocketAddr::from(peer_addr)), } } @@ -323,7 +320,7 @@ impl Drop for UnixStreamSocket { return; }; - if let State::Listen(_) = &*self.0.read() { + if let State::Listen(_) = &*self.state.read() { unregister_backlog(&bound_addr); } } diff --git a/kernel/aster-nix/src/syscall/socket.rs b/kernel/aster-nix/src/syscall/socket.rs index f3d8fc98d..83262650b 100644 --- a/kernel/aster-nix/src/syscall/socket.rs +++ b/kernel/aster-nix/src/syscall/socket.rs @@ -24,7 +24,7 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result { - Arc::new(UnixStreamSocket::new(nonblocking)) as Arc + UnixStreamSocket::new(nonblocking) as Arc } ( CSocketAddrFamily::AF_INET,