diff --git a/kernel/src/fs/utils/channel.rs b/kernel/src/fs/utils/channel.rs index 5c43e7fea..bd4cddc68 100644 --- a/kernel/src/fs/utils/channel.rs +++ b/kernel/src/fs/utils/channel.rs @@ -59,15 +59,11 @@ pub struct Consumer(Fifo); macro_rules! impl_common_methods_for_channel { () => { pub fn shutdown(&self) { - self.this_end().shutdown() + self.0.common.shutdown() } pub fn is_shutdown(&self) -> bool { - self.this_end().is_shutdown() - } - - pub fn is_peer_shutdown(&self) -> bool { - self.peer_end().is_shutdown() + self.0.common.is_shutdown() } pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { @@ -109,7 +105,7 @@ impl Producer { let this_end = self.this_end(); let rb = this_end.rb(); - if self.is_shutdown() || self.is_peer_shutdown() { + if self.is_shutdown() { // The POLLOUT event is always set in this case. Don't try to remove it. } else if rb.is_full() { this_end.pollee.del_events(IoEvents::OUT); @@ -139,7 +135,7 @@ impl Producer { return Ok(0); } - if self.is_shutdown() || self.is_peer_shutdown() { + if self.is_shutdown() { return_errno_with_message!(Errno::EPIPE, "the channel is shut down"); } @@ -161,7 +157,7 @@ impl Producer { /// - Returns `Err(EPIPE)` if the channel is shut down. /// - Returns `Err(EAGAIN)` if the channel is full. pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> { - if self.is_shutdown() || self.is_peer_shutdown() { + if self.is_shutdown() { let err = Error::with_message(Errno::EPIPE, "the channel is shut down"); return Err((err, item)); } @@ -179,11 +175,6 @@ impl Producer { impl Drop for Producer { fn drop(&mut self) { self.shutdown(); - - // The POLLHUP event indicates that the write end is shut down. - // - // No need to take a lock. There is no race because no one is modifying this particular event. - self.peer_end().pollee.add_events(IoEvents::HUP); } } @@ -232,7 +223,7 @@ impl Consumer { } // This must be recorded before the actual operation to avoid race conditions. - let is_shutdown = self.is_shutdown() || self.is_peer_shutdown(); + let is_shutdown = self.is_shutdown(); let read_len = self.0.read(writer); self.update_pollee(); @@ -255,7 +246,7 @@ impl Consumer { /// - Returns `Err(EAGAIN)` if the channel is empty. pub fn try_pop(&self) -> Result> { // This must be recorded before the actual operation to avoid race conditions. - let is_shutdown = self.is_shutdown() || self.is_peer_shutdown(); + let is_shutdown = self.is_shutdown(); let item = self.0.pop(); self.update_pollee(); @@ -273,15 +264,6 @@ impl Consumer { impl Drop for Consumer { fn drop(&mut self) { self.shutdown(); - - // The POLLERR event indicates that the read end is closed (so any subsequent writes will - // fail with an `EPIPE` error). - // - // The lock is taken because we are also adding the POLLOUT event, which may have races - // with the event updates triggered by the writer. - let peer_end = self.peer_end(); - let _rb = peer_end.rb(); - peer_end.pollee.add_events(IoEvents::ERR | IoEvents::OUT); } } @@ -346,6 +328,7 @@ impl Fifo { struct Common { producer: FifoInner>, consumer: FifoInner>, + is_shutdown: AtomicBool, } impl Common { @@ -356,18 +339,46 @@ impl Common { let producer = FifoInner::new(rb_producer, IoEvents::OUT); let consumer = FifoInner::new(rb_consumer, IoEvents::empty()); - Self { producer, consumer } + Self { + producer, + consumer, + is_shutdown: AtomicBool::new(false), + } } pub fn capacity(&self) -> usize { self.producer.rb().capacity() } + + pub fn is_shutdown(&self) -> bool { + self.is_shutdown.load(Ordering::Relaxed) + } + + pub fn shutdown(&self) { + if self.is_shutdown.swap(true, Ordering::Relaxed) { + return; + } + + // The POLLHUP event indicates that the write end is shut down. + // + // No need to take a lock. There is no race because no one is modifying this particular event. + self.consumer.pollee.add_events(IoEvents::HUP); + + // The POLLERR event indicates that the read end is shut down (so any subsequent writes + // will fail with an `EPIPE` error). + // + // The lock is taken because we are also adding the POLLOUT event, which may have races + // with the event updates triggered by the writer. + let _rb = self.producer.rb(); + self.producer + .pollee + .add_events(IoEvents::ERR | IoEvents::OUT); + } } struct FifoInner { rb: Mutex, pollee: Pollee, - is_shutdown: AtomicBool, } impl FifoInner { @@ -375,21 +386,12 @@ impl FifoInner { Self { rb: Mutex::new(rb), pollee: Pollee::new(init_events), - is_shutdown: AtomicBool::new(false), } } pub fn rb(&self) -> MutexGuard { self.rb.lock() } - - pub fn is_shutdown(&self) -> bool { - self.is_shutdown.load(Ordering::Acquire) - } - - pub fn shutdown(&self) { - self.is_shutdown.store(true, Ordering::Release) - } } #[cfg(ktest)] diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index 77f0d3705..f016610b0 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -58,8 +58,6 @@ impl Connected { } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - // FIXME: If the socket has already been shut down, should we return an error code? - if cmd.shut_read() { self.reader.shutdown(); } @@ -72,23 +70,27 @@ impl Connected { } pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { - let mut events = IoEvents::empty(); - - // FIXME: should reader and writer use the same mask? + // Note that `mask | IoEvents::ALWAYS_POLL` contains all the events we care about. let reader_events = self.reader.poll(mask, poller.as_deref_mut()); let writer_events = self.writer.poll(mask, poller); - // FIXME: Check this logic later. - if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() { + let mut events = IoEvents::empty(); + + if reader_events.contains(IoEvents::HUP) { + // The socket is shut down in one direction: the remote socket has shut down for + // writing or the local socket has shut down for reading. events |= IoEvents::RDHUP | IoEvents::IN; - if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() { - events |= IoEvents::HUP | IoEvents::OUT; + + if writer_events.contains(IoEvents::ERR) { + // The socket is shut down in both directions. Neither reading nor writing is + // possible. + events |= IoEvents::HUP; } } events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); - events + events & (mask | IoEvents::ALWAYS_POLL) } pub(super) fn register_observer( @@ -96,14 +98,8 @@ impl Connected { observer: Weak>, mask: IoEvents, ) -> Result<()> { - if mask.contains(IoEvents::IN) { - self.reader.register_observer(observer.clone(), mask)? - } - - if mask.contains(IoEvents::OUT) { - self.writer.register_observer(observer, mask)? - } - + self.reader.register_observer(observer.clone(), mask)?; + self.writer.register_observer(observer, mask)?; Ok(()) } @@ -111,16 +107,9 @@ impl Connected { &self, observer: &Weak>, ) -> Option>> { - let observer0 = self.reader.unregister_observer(observer); - let observer1 = self.writer.unregister_observer(observer); - - if observer0.is_some() { - observer0 - } else if observer1.is_some() { - observer1 - } else { - None - } + let reader_observer = self.reader.unregister_observer(observer); + let writer_observer = self.writer.unregister_observer(observer); + reader_observer.or(writer_observer) } } diff --git a/test/apps/network/unix_err.c b/test/apps/network/unix_err.c index 5ea2df88a..49d5f468e 100644 --- a/test/apps/network/unix_err.c +++ b/test/apps/network/unix_err.c @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +#define _GNU_SOURCE + #include #include #include @@ -304,6 +306,75 @@ FN_TEST(ns_abs) } END_TEST() +FN_TEST(shutdown_connected) +{ + int fildes[2]; + + TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes)); + + TEST_SUCC(shutdown(fildes[0], SHUT_RD)); + TEST_SUCC(shutdown(fildes[0], SHUT_WR)); + TEST_SUCC(shutdown(fildes[0], SHUT_RDWR)); + + TEST_SUCC(shutdown(fildes[0], SHUT_RD)); + TEST_SUCC(shutdown(fildes[0], SHUT_WR)); + TEST_SUCC(shutdown(fildes[0], SHUT_RDWR)); + + TEST_SUCC(close(fildes[0])); + TEST_SUCC(close(fildes[1])); +} +END_TEST() + +FN_TEST(poll_connected_close) +{ + int fildes[2]; + struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP }; + + TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes)); + + pfd.fd = fildes[1]; + TEST_RES(poll(&pfd, 1, 0), pfd.revents == POLLOUT); + + TEST_SUCC(close(fildes[0])); + + pfd.fd = fildes[1]; + TEST_RES(poll(&pfd, 1, 0), + pfd.revents == (POLLIN | POLLOUT | POLLRDHUP | POLLHUP)); + + TEST_SUCC(close(fildes[1])); +} +END_TEST() + +FN_TEST(poll_connected_shutdown) +{ + int fildes[2]; + struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP }; + +#define MAKE_TEST(shut, ev1, ev2) \ + TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes)); \ + \ + TEST_SUCC(shutdown(fildes[0], shut)); \ + \ + pfd.fd = fildes[0]; \ + TEST_RES(poll(&pfd, 1, 0), pfd.revents == (ev1)); \ + \ + pfd.fd = fildes[1]; \ + TEST_RES(poll(&pfd, 1, 0), pfd.revents == (ev2)); \ + \ + TEST_SUCC(close(fildes[0])); \ + TEST_SUCC(close(fildes[1])); + + MAKE_TEST(SHUT_RD, POLLIN | POLLOUT | POLLRDHUP, POLLOUT); + + MAKE_TEST(SHUT_WR, POLLOUT, POLLIN | POLLOUT | POLLRDHUP); + + MAKE_TEST(SHUT_RDWR, POLLIN | POLLOUT | POLLRDHUP | POLLHUP, + POLLIN | POLLOUT | POLLRDHUP | POLLHUP); + +#undef MAKE_TEST +} +END_TEST() + FN_SETUP(cleanup) { CHECK(close(sk_unbound));