diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index 75fcdf95e..30a7e5319 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -61,7 +61,7 @@ impl Connected { self.writer.try_write(&mut reader) } - pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { + pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { if cmd.shut_read() { self.reader.shutdown(); } @@ -69,8 +69,6 @@ impl Connected { if cmd.shut_write() { self.writer.shutdown(); } - - Ok(()) } pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { @@ -78,23 +76,7 @@ impl Connected { let reader_events = self.reader.poll(mask, poller.as_deref_mut()); let writer_events = self.writer.poll(mask, poller); - let mut events = IoEvents::empty(); - - if reader_events.contains(IoEvents::HUP) { - // The socket is shut down in one direction: the remote socket has shut down for - // writing or the local socket has shut down for reading. - events |= IoEvents::RDHUP | IoEvents::IN; - - if writer_events.contains(IoEvents::ERR) { - // The socket is shut down in both directions. Neither reading nor writing is - // possible. - events |= IoEvents::HUP; - } - } - - events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); - - events & (mask | IoEvents::ALWAYS_POLL) + combine_io_events(mask, reader_events, writer_events) } pub(super) fn register_observer( @@ -117,4 +99,28 @@ impl Connected { } } +pub(super) fn combine_io_events( + mask: IoEvents, + reader_events: IoEvents, + writer_events: IoEvents, +) -> IoEvents { + let mut events = IoEvents::empty(); + + if reader_events.contains(IoEvents::HUP) { + // The socket is shut down in one direction: the remote socket has shut down for + // writing or the local socket has shut down for reading. + events |= IoEvents::RDHUP | IoEvents::IN; + + if writer_events.contains(IoEvents::ERR) { + // The socket is shut down in both directions. Neither reading nor writing is + // possible. + events |= IoEvents::HUP; + } + } + + events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); + + events & (mask | IoEvents::ALWAYS_POLL) +} + const DEFAULT_BUF_SIZE: usize = 65536; diff --git a/kernel/src/net/socket/unix/stream/init.rs b/kernel/src/net/socket/unix/stream/init.rs index 8daab079f..3493c3f85 100644 --- a/kernel/src/net/socket/unix/stream/init.rs +++ b/kernel/src/net/socket/unix/stream/init.rs @@ -1,9 +1,17 @@ // SPDX-License-Identifier: MPL-2.0 -use super::{connected::Connected, listener::Listener}; +use core::sync::atomic::{AtomicBool, Ordering}; + +use super::{ + connected::{combine_io_events, Connected}, + listener::Listener, +}; use crate::{ events::{IoEvents, Observer}, - net::socket::unix::addr::{UnixSocketAddr, UnixSocketAddrBound}, + net::socket::{ + unix::addr::{UnixSocketAddr, UnixSocketAddrBound}, + SockShutdownCmd, + }, prelude::*, process::signal::{Pollee, Poller}, }; @@ -12,6 +20,8 @@ pub(super) struct Init { addr: Option, reader_pollee: Pollee, writer_pollee: Pollee, + is_read_shutdown: AtomicBool, + is_write_shutdown: AtomicBool, } impl Init { @@ -19,7 +29,9 @@ impl Init { Self { addr: None, reader_pollee: Pollee::new(IoEvents::empty()), - writer_pollee: Pollee::new(IoEvents::empty()), + writer_pollee: Pollee::new(IoEvents::OUT), + is_read_shutdown: AtomicBool::new(false), + is_write_shutdown: AtomicBool::new(false), } } @@ -39,14 +51,26 @@ impl Init { addr, reader_pollee, writer_pollee, + is_read_shutdown, + is_write_shutdown, } = self; - Connected::new_pair( + let (this_conn, peer_conn) = Connected::new_pair( addr, Some(peer_addr), Some(reader_pollee), Some(writer_pollee), - ) + ); + + if is_read_shutdown.into_inner() { + this_conn.shutdown(SockShutdownCmd::SHUT_RD); + } + + if is_write_shutdown.into_inner() { + this_conn.shutdown(SockShutdownCmd::SHUT_WR) + } + + (this_conn, peer_conn) } pub(super) fn listen(self, backlog: usize) -> core::result::Result { @@ -57,8 +81,31 @@ impl Init { )); }; - // There is no `writer_pollee` in `Listener`. - Ok(Listener::new(addr, self.reader_pollee, backlog)) + Ok(Listener::new( + addr, + self.reader_pollee, + self.writer_pollee, + backlog, + self.is_read_shutdown.into_inner(), + )) + } + + pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { + match cmd { + SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { + self.is_write_shutdown.store(true, Ordering::Relaxed); + self.writer_pollee.add_events(IoEvents::ERR); + } + SockShutdownCmd::SHUT_RD => (), + } + + match cmd { + SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => { + self.is_read_shutdown.store(true, Ordering::Relaxed); + self.reader_pollee.add_events(IoEvents::HUP); + } + SockShutdownCmd::SHUT_WR => (), + } } pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { @@ -68,10 +115,12 @@ impl Init { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { // To avoid loss of events, this must be compatible with // `Connected::poll`/`Listener::poll`. - self.reader_pollee.poll(mask, poller.as_deref_mut()); - self.writer_pollee.poll(mask, poller); + let reader_events = self.reader_pollee.poll(mask, poller.as_deref_mut()); + let writer_events = self.writer_pollee.poll(mask, poller); - (IoEvents::OUT | IoEvents::HUP) & (mask | IoEvents::ALWAYS_POLL) + // According to the Linux implementation, we always have `IoEvents::HUP` in this state. + // Meanwhile, it is in `IoEvents::ALWAYS_POLL`, so we always return it. + combine_io_events(mask, reader_events, writer_events) | IoEvents::HUP } pub(super) fn register_observer( diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index c691e6024..3d8e3108b 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -2,13 +2,17 @@ use core::sync::atomic::{AtomicUsize, Ordering}; -use super::{connected::Connected, init::Init, UnixStreamSocket}; +use super::{ + connected::{combine_io_events, Connected}, + init::Init, + UnixStreamSocket, +}; use crate::{ events::{IoEvents, Observer}, fs::file_handle::FileLike, net::socket::{ unix::addr::{UnixSocketAddrBound, UnixSocketAddrKey}, - SocketAddr, + SockShutdownCmd, SocketAddr, }, prelude::*, process::signal::{Pollee, Poller}, @@ -16,12 +20,28 @@ use crate::{ pub(super) struct Listener { backlog: Arc, + writer_pollee: Pollee, } impl Listener { - pub(super) fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize) -> Self { - let backlog = BACKLOG_TABLE.add_backlog(addr, pollee, backlog).unwrap(); - Self { backlog } + pub(super) fn new( + addr: UnixSocketAddrBound, + reader_pollee: Pollee, + writer_pollee: Pollee, + backlog: usize, + is_shutdown: bool, + ) -> Self { + // Note that the I/O events can be correctly inherited from `Init`. There is no need to + // explicitly call `Pollee::reset_io_events`. + let backlog = BACKLOG_TABLE + .add_backlog(addr, reader_pollee, backlog, is_shutdown) + .unwrap(); + writer_pollee.del_events(IoEvents::OUT); + + Self { + backlog, + writer_pollee, + } } pub(super) fn addr(&self) -> &UnixSocketAddrBound { @@ -40,8 +60,27 @@ impl Listener { self.backlog.set_backlog(backlog); } - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.backlog.poll(mask, poller) + pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { + match cmd { + SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { + self.writer_pollee.add_events(IoEvents::ERR); + } + SockShutdownCmd::SHUT_RD => (), + } + + match cmd { + SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => { + self.backlog.shutdown(); + } + SockShutdownCmd::SHUT_WR => (), + } + } + + pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { + let reader_events = self.backlog.poll(mask, poller.as_deref_mut()); + let writer_events = self.writer_pollee.poll(mask, poller); + + combine_io_events(mask, reader_events, writer_events) } pub(super) fn register_observer( @@ -49,14 +88,18 @@ impl Listener { observer: Weak>, mask: IoEvents, ) -> Result<()> { - self.backlog.register_observer(observer, mask) + self.backlog.register_observer(observer.clone(), mask)?; + self.writer_pollee.register_observer(observer, mask); + Ok(()) } pub(super) fn unregister_observer( &self, observer: &Weak>, ) -> Option>> { - self.backlog.unregister_observer(observer) + let reader_observer = self.backlog.unregister_observer(observer); + let writer_observer = self.writer_pollee.unregister_observer(observer); + reader_observer.or(writer_observer) } } @@ -84,6 +127,7 @@ impl BacklogTable { addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize, + is_shutdown: bool, ) -> Option> { let addr_key = addr.to_key(); @@ -93,7 +137,7 @@ impl BacklogTable { return None; } - let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog)); + let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog, is_shutdown)); backlog_sockets.insert(addr_key, new_backlog.clone()); Some(new_backlog) @@ -133,18 +177,22 @@ struct Backlog { addr: UnixSocketAddrBound, pollee: Pollee, backlog: AtomicUsize, - incoming_conns: Mutex>, + incoming_conns: Mutex>>, } impl Backlog { - fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize) -> Self { - pollee.reset_events(); + fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize, is_shutdown: bool) -> Self { + let incoming_sockets = if is_shutdown { + None + } else { + Some(VecDeque::with_capacity(backlog)) + }; Self { addr, pollee, backlog: AtomicUsize::new(backlog), - incoming_conns: Mutex::new(VecDeque::with_capacity(backlog)), + incoming_conns: Mutex::new(incoming_sockets), } } @@ -153,7 +201,17 @@ impl Backlog { } fn push_incoming(&self, init: Init) -> core::result::Result { - let mut incoming_conns = self.incoming_conns.lock(); + let mut locked_incoming_conns = self.incoming_conns.lock(); + + let Some(incoming_conns) = &mut *locked_incoming_conns else { + return Err(( + Error::with_message( + Errno::ECONNREFUSED, + "the listening socket is shut down for reading", + ), + init, + )); + }; if incoming_conns.len() >= self.backlog.load(Ordering::Relaxed) { return Err(( @@ -174,11 +232,17 @@ impl Backlog { } fn pop_incoming(&self) -> Result { - let mut incoming_conns = self.incoming_conns.lock(); + let mut locked_incoming_conns = self.incoming_conns.lock(); + + let Some(incoming_conns) = &mut *locked_incoming_conns else { + return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading"); + }; + let conn = incoming_conns.pop_front(); if incoming_conns.is_empty() { self.pollee.del_events(IoEvents::IN); } + conn.ok_or_else(|| Error::with_message(Errno::EAGAIN, "no pending connection is available")) } @@ -186,9 +250,15 @@ impl Backlog { self.backlog.store(backlog, Ordering::Relaxed); } + fn shutdown(&self) { + let mut incoming_conns = self.incoming_conns.lock(); + + *incoming_conns = None; + self.pollee.add_events(IoEvents::HUP); + self.pollee.del_events(IoEvents::IN); + } + fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - // Lock to avoid any events may change pollee state when we poll - let _lock = self.incoming_conns.lock(); self.pollee.poll(mask, poller) } diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index f3f981905..0931fb8e7 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -297,9 +297,12 @@ impl Socket for UnixStreamSocket { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { match self.state.read().as_ref() { + State::Init(init) => init.shutdown(cmd), + State::Listen(listen) => listen.shutdown(cmd), State::Connected(connected) => connected.shutdown(cmd), - _ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"), } + + Ok(()) } fn addr(&self) -> Result { diff --git a/test/apps/network/unix_err.c b/test/apps/network/unix_err.c index 5cadc17e8..0bdbd68b9 100644 --- a/test/apps/network/unix_err.c +++ b/test/apps/network/unix_err.c @@ -326,6 +326,59 @@ FN_TEST(shutdown_connected) } END_TEST() +FN_TEST(poll_unbound) +{ + int sk; + struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP }; + + sk = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0)); + pfd.fd = sk; + + TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLOUT | POLLHUP)); + + TEST_SUCC(shutdown(sk, SHUT_WR)); + TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLOUT | POLLHUP)); + + TEST_SUCC(shutdown(sk, SHUT_RD)); + TEST_RES(poll(&pfd, 1, 0), + pfd.revents == (POLLIN | POLLOUT | POLLRDHUP | POLLHUP)); + + TEST_SUCC( + bind(sk, (struct sockaddr *)&UNIX_ADDR("\0"), PATH_OFFSET + 1)); + TEST_SUCC(listen(sk, 10)); + + TEST_RES(poll(&pfd, 1, 0), + pfd.revents == (POLLIN | POLLRDHUP | POLLHUP)); + + TEST_SUCC(close(sk)); +} +END_TEST() + +FN_TEST(poll_listen) +{ + int sk; + struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP }; + + sk = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0)); + pfd.fd = sk; + + TEST_SUCC( + bind(sk, (struct sockaddr *)&UNIX_ADDR("\0"), PATH_OFFSET + 1)); + TEST_SUCC(listen(sk, 10)); + + TEST_RES(poll(&pfd, 1, 0), pfd.revents == 0); + + TEST_SUCC(shutdown(sk, SHUT_RD)); + TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLIN | POLLRDHUP)); + + TEST_SUCC(shutdown(sk, SHUT_WR)); + TEST_RES(poll(&pfd, 1, 0), + pfd.revents == (POLLIN | POLLRDHUP | POLLHUP)); + + TEST_SUCC(close(sk)); +} +END_TEST() + FN_TEST(poll_connected_close) { int fildes[2];