Fix I/O events for UNIX connected sockets

This commit is contained in:
Ruihan Li 2024-07-28 14:18:02 +08:00 committed by Tate, Hongliang Tian
parent 7ddb69f4db
commit f831f5685f
3 changed files with 126 additions and 64 deletions

View File

@ -59,15 +59,11 @@ pub struct Consumer<T>(Fifo<T, ReadOp>);
macro_rules! impl_common_methods_for_channel { macro_rules! impl_common_methods_for_channel {
() => { () => {
pub fn shutdown(&self) { pub fn shutdown(&self) {
self.this_end().shutdown() self.0.common.shutdown()
} }
pub fn is_shutdown(&self) -> bool { pub fn is_shutdown(&self) -> bool {
self.this_end().is_shutdown() self.0.common.is_shutdown()
}
pub fn is_peer_shutdown(&self) -> bool {
self.peer_end().is_shutdown()
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
@ -109,7 +105,7 @@ impl<T> Producer<T> {
let this_end = self.this_end(); let this_end = self.this_end();
let rb = this_end.rb(); let rb = this_end.rb();
if self.is_shutdown() || self.is_peer_shutdown() { if self.is_shutdown() {
// The POLLOUT event is always set in this case. Don't try to remove it. // The POLLOUT event is always set in this case. Don't try to remove it.
} else if rb.is_full() { } else if rb.is_full() {
this_end.pollee.del_events(IoEvents::OUT); this_end.pollee.del_events(IoEvents::OUT);
@ -139,7 +135,7 @@ impl Producer<u8> {
return Ok(0); return Ok(0);
} }
if self.is_shutdown() || self.is_peer_shutdown() { if self.is_shutdown() {
return_errno_with_message!(Errno::EPIPE, "the channel is shut down"); return_errno_with_message!(Errno::EPIPE, "the channel is shut down");
} }
@ -161,7 +157,7 @@ impl<T: Pod> Producer<T> {
/// - Returns `Err(EPIPE)` if the channel is shut down. /// - Returns `Err(EPIPE)` if the channel is shut down.
/// - Returns `Err(EAGAIN)` if the channel is full. /// - Returns `Err(EAGAIN)` if the channel is full.
pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> { pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> {
if self.is_shutdown() || self.is_peer_shutdown() { if self.is_shutdown() {
let err = Error::with_message(Errno::EPIPE, "the channel is shut down"); let err = Error::with_message(Errno::EPIPE, "the channel is shut down");
return Err((err, item)); return Err((err, item));
} }
@ -179,11 +175,6 @@ impl<T: Pod> Producer<T> {
impl<T> Drop for Producer<T> { impl<T> Drop for Producer<T> {
fn drop(&mut self) { fn drop(&mut self) {
self.shutdown(); self.shutdown();
// The POLLHUP event indicates that the write end is shut down.
//
// No need to take a lock. There is no race because no one is modifying this particular event.
self.peer_end().pollee.add_events(IoEvents::HUP);
} }
} }
@ -232,7 +223,7 @@ impl Consumer<u8> {
} }
// This must be recorded before the actual operation to avoid race conditions. // This must be recorded before the actual operation to avoid race conditions.
let is_shutdown = self.is_shutdown() || self.is_peer_shutdown(); let is_shutdown = self.is_shutdown();
let read_len = self.0.read(writer); let read_len = self.0.read(writer);
self.update_pollee(); self.update_pollee();
@ -255,7 +246,7 @@ impl<T: Pod> Consumer<T> {
/// - Returns `Err(EAGAIN)` if the channel is empty. /// - Returns `Err(EAGAIN)` if the channel is empty.
pub fn try_pop(&self) -> Result<Option<T>> { pub fn try_pop(&self) -> Result<Option<T>> {
// This must be recorded before the actual operation to avoid race conditions. // This must be recorded before the actual operation to avoid race conditions.
let is_shutdown = self.is_shutdown() || self.is_peer_shutdown(); let is_shutdown = self.is_shutdown();
let item = self.0.pop(); let item = self.0.pop();
self.update_pollee(); self.update_pollee();
@ -273,15 +264,6 @@ impl<T: Pod> Consumer<T> {
impl<T> Drop for Consumer<T> { impl<T> Drop for Consumer<T> {
fn drop(&mut self) { fn drop(&mut self) {
self.shutdown(); self.shutdown();
// The POLLERR event indicates that the read end is closed (so any subsequent writes will
// fail with an `EPIPE` error).
//
// The lock is taken because we are also adding the POLLOUT event, which may have races
// with the event updates triggered by the writer.
let peer_end = self.peer_end();
let _rb = peer_end.rb();
peer_end.pollee.add_events(IoEvents::ERR | IoEvents::OUT);
} }
} }
@ -346,6 +328,7 @@ impl<T: Pod, R: TRights> Fifo<T, R> {
struct Common<T> { struct Common<T> {
producer: FifoInner<RbProducer<T>>, producer: FifoInner<RbProducer<T>>,
consumer: FifoInner<RbConsumer<T>>, consumer: FifoInner<RbConsumer<T>>,
is_shutdown: AtomicBool,
} }
impl<T> Common<T> { impl<T> Common<T> {
@ -356,18 +339,46 @@ impl<T> Common<T> {
let producer = FifoInner::new(rb_producer, IoEvents::OUT); let producer = FifoInner::new(rb_producer, IoEvents::OUT);
let consumer = FifoInner::new(rb_consumer, IoEvents::empty()); let consumer = FifoInner::new(rb_consumer, IoEvents::empty());
Self { producer, consumer } Self {
producer,
consumer,
is_shutdown: AtomicBool::new(false),
}
} }
pub fn capacity(&self) -> usize { pub fn capacity(&self) -> usize {
self.producer.rb().capacity() self.producer.rb().capacity()
} }
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::Relaxed)
}
pub fn shutdown(&self) {
if self.is_shutdown.swap(true, Ordering::Relaxed) {
return;
}
// The POLLHUP event indicates that the write end is shut down.
//
// No need to take a lock. There is no race because no one is modifying this particular event.
self.consumer.pollee.add_events(IoEvents::HUP);
// The POLLERR event indicates that the read end is shut down (so any subsequent writes
// will fail with an `EPIPE` error).
//
// The lock is taken because we are also adding the POLLOUT event, which may have races
// with the event updates triggered by the writer.
let _rb = self.producer.rb();
self.producer
.pollee
.add_events(IoEvents::ERR | IoEvents::OUT);
}
} }
struct FifoInner<T> { struct FifoInner<T> {
rb: Mutex<T>, rb: Mutex<T>,
pollee: Pollee, pollee: Pollee,
is_shutdown: AtomicBool,
} }
impl<T> FifoInner<T> { impl<T> FifoInner<T> {
@ -375,21 +386,12 @@ impl<T> FifoInner<T> {
Self { Self {
rb: Mutex::new(rb), rb: Mutex::new(rb),
pollee: Pollee::new(init_events), pollee: Pollee::new(init_events),
is_shutdown: AtomicBool::new(false),
} }
} }
pub fn rb(&self) -> MutexGuard<T> { pub fn rb(&self) -> MutexGuard<T> {
self.rb.lock() self.rb.lock()
} }
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::Acquire)
}
pub fn shutdown(&self) {
self.is_shutdown.store(true, Ordering::Release)
}
} }
#[cfg(ktest)] #[cfg(ktest)]

View File

@ -58,8 +58,6 @@ impl Connected {
} }
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
// FIXME: If the socket has already been shut down, should we return an error code?
if cmd.shut_read() { if cmd.shut_read() {
self.reader.shutdown(); self.reader.shutdown();
} }
@ -72,23 +70,27 @@ impl Connected {
} }
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
let mut events = IoEvents::empty(); // Note that `mask | IoEvents::ALWAYS_POLL` contains all the events we care about.
// FIXME: should reader and writer use the same mask?
let reader_events = self.reader.poll(mask, poller.as_deref_mut()); let reader_events = self.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.writer.poll(mask, poller); let writer_events = self.writer.poll(mask, poller);
// FIXME: Check this logic later. let mut events = IoEvents::empty();
if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() {
if reader_events.contains(IoEvents::HUP) {
// The socket is shut down in one direction: the remote socket has shut down for
// writing or the local socket has shut down for reading.
events |= IoEvents::RDHUP | IoEvents::IN; events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() {
events |= IoEvents::HUP | IoEvents::OUT; if writer_events.contains(IoEvents::ERR) {
// The socket is shut down in both directions. Neither reading nor writing is
// possible.
events |= IoEvents::HUP;
} }
} }
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events events & (mask | IoEvents::ALWAYS_POLL)
} }
pub(super) fn register_observer( pub(super) fn register_observer(
@ -96,14 +98,8 @@ impl Connected {
observer: Weak<dyn Observer<IoEvents>>, observer: Weak<dyn Observer<IoEvents>>,
mask: IoEvents, mask: IoEvents,
) -> Result<()> { ) -> Result<()> {
if mask.contains(IoEvents::IN) { self.reader.register_observer(observer.clone(), mask)?;
self.reader.register_observer(observer.clone(), mask)? self.writer.register_observer(observer, mask)?;
}
if mask.contains(IoEvents::OUT) {
self.writer.register_observer(observer, mask)?
}
Ok(()) Ok(())
} }
@ -111,16 +107,9 @@ impl Connected {
&self, &self,
observer: &Weak<dyn Observer<IoEvents>>, observer: &Weak<dyn Observer<IoEvents>>,
) -> Option<Weak<dyn Observer<IoEvents>>> { ) -> Option<Weak<dyn Observer<IoEvents>>> {
let observer0 = self.reader.unregister_observer(observer); let reader_observer = self.reader.unregister_observer(observer);
let observer1 = self.writer.unregister_observer(observer); let writer_observer = self.writer.unregister_observer(observer);
reader_observer.or(writer_observer)
if observer0.is_some() {
observer0
} else if observer1.is_some() {
observer1
} else {
None
}
} }
} }

View File

@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
#define _GNU_SOURCE
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/un.h> #include <sys/un.h>
#include <sys/poll.h> #include <sys/poll.h>
@ -304,6 +306,75 @@ FN_TEST(ns_abs)
} }
END_TEST() END_TEST()
FN_TEST(shutdown_connected)
{
int fildes[2];
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes));
TEST_SUCC(shutdown(fildes[0], SHUT_RD));
TEST_SUCC(shutdown(fildes[0], SHUT_WR));
TEST_SUCC(shutdown(fildes[0], SHUT_RDWR));
TEST_SUCC(shutdown(fildes[0], SHUT_RD));
TEST_SUCC(shutdown(fildes[0], SHUT_WR));
TEST_SUCC(shutdown(fildes[0], SHUT_RDWR));
TEST_SUCC(close(fildes[0]));
TEST_SUCC(close(fildes[1]));
}
END_TEST()
FN_TEST(poll_connected_close)
{
int fildes[2];
struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP };
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes));
pfd.fd = fildes[1];
TEST_RES(poll(&pfd, 1, 0), pfd.revents == POLLOUT);
TEST_SUCC(close(fildes[0]));
pfd.fd = fildes[1];
TEST_RES(poll(&pfd, 1, 0),
pfd.revents == (POLLIN | POLLOUT | POLLRDHUP | POLLHUP));
TEST_SUCC(close(fildes[1]));
}
END_TEST()
FN_TEST(poll_connected_shutdown)
{
int fildes[2];
struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP };
#define MAKE_TEST(shut, ev1, ev2) \
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes)); \
\
TEST_SUCC(shutdown(fildes[0], shut)); \
\
pfd.fd = fildes[0]; \
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (ev1)); \
\
pfd.fd = fildes[1]; \
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (ev2)); \
\
TEST_SUCC(close(fildes[0])); \
TEST_SUCC(close(fildes[1]));
MAKE_TEST(SHUT_RD, POLLIN | POLLOUT | POLLRDHUP, POLLOUT);
MAKE_TEST(SHUT_WR, POLLOUT, POLLIN | POLLOUT | POLLRDHUP);
MAKE_TEST(SHUT_RDWR, POLLIN | POLLOUT | POLLRDHUP | POLLHUP,
POLLIN | POLLOUT | POLLRDHUP | POLLHUP);
#undef MAKE_TEST
}
END_TEST()
FN_SETUP(cleanup) FN_SETUP(cleanup)
{ {
CHECK(close(sk_unbound)); CHECK(close(sk_unbound));