From 31d99c66c2894b2bdda700b0935ed92b6db8a540 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 30 Jun 2024 21:58:38 +0800 Subject: [PATCH] Unpack `Arc` in UNIX sockets --- .../src/net/socket/unix/stream/connected.rs | 8 +- .../src/net/socket/unix/stream/endpoint.rs | 106 +++++++++--------- .../src/net/socket/unix/stream/init.rs | 7 +- .../src/net/socket/unix/stream/listener.rs | 15 +-- .../src/net/socket/unix/stream/socket.rs | 8 +- 5 files changed, 68 insertions(+), 76 deletions(-) diff --git a/kernel/aster-nix/src/net/socket/unix/stream/connected.rs b/kernel/aster-nix/src/net/socket/unix/stream/connected.rs index b04044d36..9ef1efcc0 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/connected.rs @@ -9,19 +9,19 @@ use crate::{ }; pub(super) struct Connected { - local_endpoint: Arc, + local_endpoint: Endpoint, } impl Connected { - pub(super) fn new(local_endpoint: Arc) -> Self { + pub(super) fn new(local_endpoint: Endpoint) -> Self { Connected { local_endpoint } } - pub(super) fn addr(&self) -> Option { + pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { self.local_endpoint.addr() } - pub(super) fn peer_addr(&self) -> Option { + pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> { self.local_endpoint.peer_addr() } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs index c048ae475..7486a3c65 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs @@ -8,118 +8,114 @@ use crate::{ process::signal::Poller, }; -pub(super) struct Endpoint(Inner); - -struct Inner { - addr: RwLock>, +pub(super) struct Endpoint { + addr: Option, + peer_addr: Option, reader: Consumer, writer: Producer, - peer: Weak, } impl Endpoint { - pub(super) fn new_pair(is_nonblocking: bool) -> Result<(Arc, Arc)> { + pub(super) fn new_pair( + addr: Option, + peer_addr: Option, + is_nonblocking: bool, + ) -> Result<(Endpoint, Endpoint)> { let flags = if is_nonblocking { StatusFlags::O_NONBLOCK } else { StatusFlags::empty() }; - let (writer_a, reader_b) = + + let (writer_this, reader_peer) = Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split(); - let (writer_b, reader_a) = + let (writer_peer, reader_this) = Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split(); - let mut endpoint_b = None; - let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| { - let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone())); - let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer)); - endpoint_b = Some(peer); - endpoint_a - }); - Ok((endpoint_a, endpoint_b.unwrap())) + + let this = Endpoint { + addr: addr.clone(), + peer_addr: peer_addr.clone(), + reader: reader_this, + writer: writer_this, + }; + let peer = Endpoint { + addr: peer_addr, + peer_addr: addr, + reader: reader_peer, + writer: writer_peer, + }; + + Ok((this, peer)) } - fn new(reader: Consumer, writer: Producer, peer: Weak) -> Self { - Self(Inner { - addr: RwLock::new(None), - reader, - writer, - peer, - }) + pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { + self.addr.as_ref() } - pub(super) fn addr(&self) -> Option { - self.0.addr.read().clone() - } - - pub(super) fn set_addr(&self, addr: UnixSocketAddrBound) { - *self.0.addr.write() = Some(addr); - } - - pub(super) fn peer_addr(&self) -> Option { - self.0.peer.upgrade().and_then(|peer| peer.addr()) + pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> { + self.peer_addr.as_ref() } pub(super) fn is_nonblocking(&self) -> bool { - let reader_status = self.0.reader.is_nonblocking(); - let writer_status = self.0.writer.is_nonblocking(); + let reader_status = self.reader.is_nonblocking(); + let writer_status = self.writer.is_nonblocking(); + debug_assert!(reader_status == writer_status); + reader_status } pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> { - let mut reader_flags = self.0.reader.status_flags(); + let mut reader_flags = self.reader.status_flags(); reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking); - self.0.reader.set_status_flags(reader_flags)?; + self.reader.set_status_flags(reader_flags)?; - let mut writer_flags = self.0.writer.status_flags(); + let mut writer_flags = self.writer.status_flags(); writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking); - self.0.writer.set_status_flags(writer_flags)?; + self.writer.set_status_flags(writer_flags)?; Ok(()) } pub(super) fn read(&self, buf: &mut [u8]) -> Result { - self.0.reader.read(buf) + self.reader.read(buf) } pub(super) fn write(&self, buf: &[u8]) -> Result { - self.0.writer.write(buf) + self.writer.write(buf) } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - if !self.is_connected() { - return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected."); - } + // FIXME: If the socket has already been shut down, should we return an error code? if cmd.shut_read() { - self.0.reader.shutdown(); + self.reader.shutdown(); } if cmd.shut_write() { - self.0.writer.shutdown(); + self.writer.shutdown(); } Ok(()) } - pub(super) fn is_connected(&self) -> bool { - self.0.peer.upgrade().is_some() - } - pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { let mut events = IoEvents::empty(); - // FIXME: should reader and writer use the same mask? - let reader_events = self.0.reader.poll(mask, poller.as_deref_mut()); - let writer_events = self.0.writer.poll(mask, poller); - if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() { + // FIXME: should reader and writer use the same mask? + let reader_events = self.reader.poll(mask, poller.as_deref_mut()); + let writer_events = self.writer.poll(mask, poller); + + // FIXME: Check this logic later. + if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() { events |= IoEvents::RDHUP | IoEvents::IN; - if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() { + if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() { events |= IoEvents::HUP | IoEvents::OUT; } } events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); + events } } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/init.rs b/kernel/aster-nix/src/net/socket/unix/stream/init.rs index d89b05e61..39245de11 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/init.rs @@ -57,11 +57,8 @@ impl Init { } } - let (this_end, remote_end) = Endpoint::new_pair(self.is_nonblocking())?; - remote_end.set_addr(remote_addr.clone()); - if let Some(addr) = addr { - this_end.set_addr(addr.clone()); - }; + let (this_end, remote_end) = + Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking())?; push_incoming(remote_addr, remote_end)?; Ok(Connected::new(this_end)) diff --git a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs index 4b69f22b0..9f75c9926 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs @@ -118,7 +118,7 @@ impl BacklogTable { .ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened")) } - fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result> { + fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result { let mut poller = Poller::new(); loop { let backlog = self.get_backlog(addr)?; @@ -147,7 +147,7 @@ impl BacklogTable { } } - fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Arc) -> Result<()> { + fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Endpoint) -> Result<()> { let backlog = self.get_backlog(addr).map_err(|_| { Error::with_message( Errno::ECONNREFUSED, @@ -171,7 +171,7 @@ impl BacklogTable { struct Backlog { pollee: Pollee, backlog: usize, - incoming_endpoints: Mutex>>, + incoming_endpoints: Mutex>, } impl Backlog { @@ -183,7 +183,7 @@ impl Backlog { } } - fn push_incoming(&self, endpoint: Arc) -> Result<()> { + fn push_incoming(&self, endpoint: Endpoint) -> Result<()> { let mut endpoints = self.incoming_endpoints.lock(); if endpoints.len() >= self.backlog { return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full"); @@ -193,7 +193,7 @@ impl Backlog { Ok(()) } - fn pop_incoming(&self) -> Option> { + fn pop_incoming(&self) -> Option { let mut incoming_endpoints = self.incoming_endpoints.lock(); let endpoint = incoming_endpoints.pop_front(); if incoming_endpoints.is_empty() { @@ -218,9 +218,6 @@ pub(super) fn unregister_backlog(addr: &UnixSocketAddrBound) { BACKLOG_TABLE.remove_backlog(addr); } -pub(super) fn push_incoming( - remote_addr: &UnixSocketAddrBound, - remote_end: Arc, -) -> Result<()> { +pub(super) fn push_incoming(remote_addr: &UnixSocketAddrBound, remote_end: Endpoint) -> Result<()> { BACKLOG_TABLE.push_incoming(remote_addr, remote_end) } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index f32b58287..ac82bf396 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -52,7 +52,8 @@ impl UnixStreamSocket { } pub fn new_pair(nonblocking: bool) -> Result<(Arc, Arc)> { - let (end_a, end_b) = Endpoint::new_pair(nonblocking)?; + let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking)?; + let connected_a = { let connected = Connected::new(end_a); Self::new_connected(connected) @@ -61,6 +62,7 @@ impl UnixStreamSocket { let connected = Connected::new(end_b); Self::new_connected(connected) }; + Ok((Arc::new(connected_a), Arc::new(connected_b))) } @@ -69,7 +71,7 @@ impl UnixStreamSocket { match &*status { State::Init(init) => init.addr(), State::Listen(listen) => Some(listen.addr().clone()), - State::Connected(connected) => connected.addr(), + State::Connected(connected) => connected.addr().cloned(), } } @@ -250,7 +252,7 @@ impl Socket for UnixStreamSocket { let addr = match &*self.0.read() { State::Init(init) => init.addr(), State::Listen(listen) => Some(listen.addr().clone()), - State::Connected(connected) => connected.addr(), + State::Connected(connected) => connected.addr().cloned(), }; addr.map(Into::::into)