From e3307a6945e129ac2f86db7a3b702f3324a76a16 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Thu, 15 Aug 2024 12:22:16 +0800 Subject: [PATCH] Use `Connected` to replace `Endpoint` --- .../src/net/socket/unix/stream/connected.rs | 99 +++++++++++--- .../src/net/socket/unix/stream/endpoint.rs | 125 ------------------ .../src/net/socket/unix/stream/init.rs | 3 +- .../src/net/socket/unix/stream/listener.rs | 43 +++--- .../src/net/socket/unix/stream/mod.rs | 1 - .../src/net/socket/unix/stream/socket.rs | 8 +- 6 files changed, 106 insertions(+), 173 deletions(-) delete mode 100644 kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs diff --git a/kernel/aster-nix/src/net/socket/unix/stream/connected.rs b/kernel/aster-nix/src/net/socket/unix/stream/connected.rs index f496d0b1f..ab740d6af 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/connected.rs @@ -1,44 +1,92 @@ // SPDX-License-Identifier: MPL-2.0 -use super::endpoint::Endpoint; use crate::{ events::{IoEvents, Observer}, + fs::utils::{Channel, Consumer, Producer}, net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, prelude::*, process::signal::Poller, }; pub(super) struct Connected { - local_endpoint: Endpoint, + addr: Option, + peer_addr: Option, + reader: Consumer, + writer: Producer, } impl Connected { - pub(super) fn new(local_endpoint: Endpoint) -> Self { - Connected { local_endpoint } + pub(super) fn new_pair( + addr: Option, + peer_addr: Option, + ) -> (Connected, Connected) { + let (writer_this, reader_peer) = Channel::new(DAFAULT_BUF_SIZE).split(); + let (writer_peer, reader_this) = Channel::new(DAFAULT_BUF_SIZE).split(); + + let this = Connected { + addr: addr.clone(), + peer_addr: peer_addr.clone(), + reader: reader_this, + writer: writer_this, + }; + let peer = Connected { + addr: peer_addr, + peer_addr: addr, + reader: reader_peer, + writer: writer_peer, + }; + + (this, peer) } pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { - self.local_endpoint.addr() + self.addr.as_ref() } pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> { - self.local_endpoint.peer_addr() - } - - pub(super) fn try_write(&self, buf: &[u8]) -> Result { - self.local_endpoint.try_write(buf) + self.peer_addr.as_ref() } pub(super) fn try_read(&self, buf: &mut [u8]) -> Result { - self.local_endpoint.try_read(buf) + self.reader.try_read(buf) + } + + pub(super) fn try_write(&self, buf: &[u8]) -> Result { + self.writer.try_write(buf) } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - self.local_endpoint.shutdown(cmd) + // FIXME: If the socket has already been shut down, should we return an error code? + + if cmd.shut_read() { + self.reader.shutdown(); + } + + if cmd.shut_write() { + self.writer.shutdown(); + } + + Ok(()) } - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.local_endpoint.poll(mask, poller) + pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { + let mut events = IoEvents::empty(); + + // FIXME: should reader and writer use the same mask? + let reader_events = self.reader.poll(mask, poller.as_deref_mut()); + let writer_events = self.writer.poll(mask, poller); + + // FIXME: Check this logic later. + if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() { + events |= IoEvents::RDHUP | IoEvents::IN; + if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() { + events |= IoEvents::HUP | IoEvents::OUT; + } + } + + events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); + + events } pub(super) fn register_observer( @@ -46,13 +94,32 @@ impl Connected { observer: Weak>, mask: IoEvents, ) -> Result<()> { - self.local_endpoint.register_observer(observer, mask) + if mask.contains(IoEvents::IN) { + self.reader.register_observer(observer.clone(), mask)? + } + + if mask.contains(IoEvents::OUT) { + self.writer.register_observer(observer, mask)? + } + + Ok(()) } pub(super) fn unregister_observer( &self, observer: &Weak>, ) -> Option>> { - self.local_endpoint.unregister_observer(observer) + let observer0 = self.reader.unregister_observer(observer); + let observer1 = self.writer.unregister_observer(observer); + + if observer0.is_some() { + observer0 + } else if observer1.is_some() { + observer1 + } else { + None + } } } + +const DAFAULT_BUF_SIZE: usize = 65536; diff --git a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs deleted file mode 100644 index e140708cb..000000000 --- a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -use crate::{ - events::{IoEvents, Observer}, - fs::utils::{Channel, Consumer, Producer}, - net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, - prelude::*, - process::signal::Poller, -}; - -pub(super) struct Endpoint { - addr: Option, - peer_addr: Option, - reader: Consumer, - writer: Producer, -} - -impl Endpoint { - pub(super) fn new_pair( - addr: Option, - peer_addr: Option, - ) -> (Endpoint, Endpoint) { - let (writer_this, reader_peer) = Channel::new(DAFAULT_BUF_SIZE).split(); - let (writer_peer, reader_this) = Channel::new(DAFAULT_BUF_SIZE).split(); - - let this = Endpoint { - addr: addr.clone(), - peer_addr: peer_addr.clone(), - reader: reader_this, - writer: writer_this, - }; - let peer = Endpoint { - addr: peer_addr, - peer_addr: addr, - reader: reader_peer, - writer: writer_peer, - }; - - (this, peer) - } - - pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { - self.addr.as_ref() - } - - pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> { - self.peer_addr.as_ref() - } - - pub(super) fn try_read(&self, buf: &mut [u8]) -> Result { - self.reader.try_read(buf) - } - - pub(super) fn try_write(&self, buf: &[u8]) -> Result { - self.writer.try_write(buf) - } - - pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - // FIXME: If the socket has already been shut down, should we return an error code? - - if cmd.shut_read() { - self.reader.shutdown(); - } - - if cmd.shut_write() { - self.writer.shutdown(); - } - - Ok(()) - } - - pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { - let mut events = IoEvents::empty(); - - // FIXME: should reader and writer use the same mask? - let reader_events = self.reader.poll(mask, poller.as_deref_mut()); - let writer_events = self.writer.poll(mask, poller); - - // FIXME: Check this logic later. - if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() { - events |= IoEvents::RDHUP | IoEvents::IN; - if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() { - events |= IoEvents::HUP | IoEvents::OUT; - } - } - - events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); - - events - } - - pub(super) fn register_observer( - &self, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - if mask.contains(IoEvents::IN) { - self.reader.register_observer(observer.clone(), mask)? - } - - if mask.contains(IoEvents::OUT) { - self.writer.register_observer(observer, mask)? - } - - Ok(()) - } - - pub(super) fn unregister_observer( - &self, - observer: &Weak>, - ) -> Option>> { - let observer0 = self.reader.unregister_observer(observer); - let observer1 = self.writer.unregister_observer(observer); - - if observer0.is_some() { - observer0 - } else if observer1.is_some() { - observer1 - } else { - None - } - } -} - -const DAFAULT_BUF_SIZE: usize = 65536; diff --git a/kernel/aster-nix/src/net/socket/unix/stream/init.rs b/kernel/aster-nix/src/net/socket/unix/stream/init.rs index 75c8b68a1..d03dc67a1 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/init.rs @@ -45,8 +45,7 @@ impl Init { } pub(super) fn connect(&self, remote_addr: &UnixSocketAddrBound) -> Result { - let endpoint = push_incoming(remote_addr, self.addr.clone())?; - Ok(Connected::new(endpoint)) + push_incoming(remote_addr, self.addr.clone()) } pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { diff --git a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs index a76522be8..86f288177 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs @@ -4,7 +4,7 @@ use core::sync::atomic::{AtomicUsize, Ordering}; use keyable_arc::KeyableWeak; -use super::{connected::Connected, endpoint::Endpoint, UnixStreamSocket}; +use super::{connected::Connected, UnixStreamSocket}; use crate::{ events::{IoEvents, Observer}, fs::{file_handle::FileLike, path::Dentry, utils::Inode}, @@ -28,15 +28,10 @@ impl Listener { } pub(super) fn try_accept(&self) -> Result<(Arc, SocketAddr)> { - let connected = { - let local_endpoint = self.backlog.pop_incoming()?; - Connected::new(local_endpoint) - }; - + let connected = self.backlog.pop_incoming()?; let peer_addr = connected.peer_addr().cloned().into(); let socket = UnixStreamSocket::new_connected(connected, false); - Ok((socket, peer_addr)) } @@ -119,7 +114,7 @@ impl BacklogTable { &self, server_addr: &UnixSocketAddrBound, client_addr: Option, - ) -> Result { + ) -> Result { let backlog = self.get_backlog(server_addr).ok_or_else(|| { Error::with_message( Errno::ECONNREFUSED, @@ -144,7 +139,7 @@ struct Backlog { addr: UnixSocketAddrBound, pollee: Pollee, backlog: AtomicUsize, - incoming_endpoints: Mutex>, + incoming_conns: Mutex>, } impl Backlog { @@ -153,7 +148,7 @@ impl Backlog { addr, pollee: Pollee::new(IoEvents::empty()), backlog: AtomicUsize::new(backlog), - incoming_endpoints: Mutex::new(VecDeque::with_capacity(backlog)), + incoming_conns: Mutex::new(VecDeque::with_capacity(backlog)), } } @@ -161,33 +156,31 @@ impl Backlog { &self.addr } - fn push_incoming(&self, client_addr: Option) -> Result { - let mut endpoints = self.incoming_endpoints.lock(); + fn push_incoming(&self, client_addr: Option) -> Result { + let mut incoming_conns = self.incoming_conns.lock(); - if endpoints.len() >= self.backlog.load(Ordering::Relaxed) { + if incoming_conns.len() >= self.backlog.load(Ordering::Relaxed) { return_errno_with_message!( Errno::ECONNREFUSED, "the pending connection queue on the listening socket is full" ); } - let (server_endpoint, client_endpoint) = - Endpoint::new_pair(Some(self.addr.clone()), client_addr); - endpoints.push_back(server_endpoint); + let (server_conn, client_conn) = Connected::new_pair(Some(self.addr.clone()), client_addr); + incoming_conns.push_back(server_conn); self.pollee.add_events(IoEvents::IN); - Ok(client_endpoint) + Ok(client_conn) } - fn pop_incoming(&self) -> Result { - let mut incoming_endpoints = self.incoming_endpoints.lock(); - let endpoint = incoming_endpoints.pop_front(); - if incoming_endpoints.is_empty() { + fn pop_incoming(&self) -> Result { + let mut incoming_conns = self.incoming_conns.lock(); + let conn = incoming_conns.pop_front(); + if incoming_conns.is_empty() { self.pollee.del_events(IoEvents::IN); } - endpoint - .ok_or_else(|| Error::with_message(Errno::EAGAIN, "no pending connection is available")) + conn.ok_or_else(|| Error::with_message(Errno::EAGAIN, "no pending connection is available")) } fn set_backlog(&self, backlog: usize) { @@ -196,7 +189,7 @@ impl Backlog { fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { // Lock to avoid any events may change pollee state when we poll - let _lock = self.incoming_endpoints.lock(); + let _lock = self.incoming_conns.lock(); self.pollee.poll(mask, poller) } @@ -229,6 +222,6 @@ fn unregister_backlog(addr: &UnixSocketAddrBound) { pub(super) fn push_incoming( server_addr: &UnixSocketAddrBound, client_addr: Option, -) -> Result { +) -> Result { BACKLOG_TABLE.push_incoming(server_addr, client_addr) } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/mod.rs b/kernel/aster-nix/src/net/socket/unix/stream/mod.rs index ca5d91ff0..0a09bbab6 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 mod connected; -mod endpoint; mod init; mod listener; mod socket; diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index 3ae2f7254..53919919d 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -4,7 +4,7 @@ use core::sync::atomic::AtomicBool; use atomic::Ordering; -use super::{connected::Connected, endpoint::Endpoint, init::Init, listener::Listener}; +use super::{connected::Connected, init::Init, listener::Listener}; use crate::{ events::{IoEvents, Observer}, fs::{ @@ -59,10 +59,10 @@ impl UnixStreamSocket { } pub fn new_pair(is_nonblocking: bool) -> (Arc, Arc) { - let (end_a, end_b) = Endpoint::new_pair(None, None); + let (conn_a, conn_b) = Connected::new_pair(None, None); ( - Self::new_connected(Connected::new(end_a), is_nonblocking), - Self::new_connected(Connected::new(end_b), is_nonblocking), + Self::new_connected(conn_a, is_nonblocking), + Self::new_connected(conn_b, is_nonblocking), ) }