Fix I/O events for UNIX connected sockets

This commit is contained in:
Ruihan Li 2024-07-28 14:18:02 +08:00 committed by Tate, Hongliang Tian
parent 7ddb69f4db
commit f831f5685f
3 changed files with 126 additions and 64 deletions

View File

@ -59,15 +59,11 @@ pub struct Consumer<T>(Fifo<T, ReadOp>);
macro_rules! impl_common_methods_for_channel {
() => {
pub fn shutdown(&self) {
self.this_end().shutdown()
self.0.common.shutdown()
}
pub fn is_shutdown(&self) -> bool {
self.this_end().is_shutdown()
}
pub fn is_peer_shutdown(&self) -> bool {
self.peer_end().is_shutdown()
self.0.common.is_shutdown()
}
pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
@ -109,7 +105,7 @@ impl<T> Producer<T> {
let this_end = self.this_end();
let rb = this_end.rb();
if self.is_shutdown() || self.is_peer_shutdown() {
if self.is_shutdown() {
// The POLLOUT event is always set in this case. Don't try to remove it.
} else if rb.is_full() {
this_end.pollee.del_events(IoEvents::OUT);
@ -139,7 +135,7 @@ impl Producer<u8> {
return Ok(0);
}
if self.is_shutdown() || self.is_peer_shutdown() {
if self.is_shutdown() {
return_errno_with_message!(Errno::EPIPE, "the channel is shut down");
}
@ -161,7 +157,7 @@ impl<T: Pod> Producer<T> {
/// - Returns `Err(EPIPE)` if the channel is shut down.
/// - Returns `Err(EAGAIN)` if the channel is full.
pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> {
if self.is_shutdown() || self.is_peer_shutdown() {
if self.is_shutdown() {
let err = Error::with_message(Errno::EPIPE, "the channel is shut down");
return Err((err, item));
}
@ -179,11 +175,6 @@ impl<T: Pod> Producer<T> {
impl<T> Drop for Producer<T> {
fn drop(&mut self) {
self.shutdown();
// The POLLHUP event indicates that the write end is shut down.
//
// No need to take a lock. There is no race because no one is modifying this particular event.
self.peer_end().pollee.add_events(IoEvents::HUP);
}
}
@ -232,7 +223,7 @@ impl Consumer<u8> {
}
// This must be recorded before the actual operation to avoid race conditions.
let is_shutdown = self.is_shutdown() || self.is_peer_shutdown();
let is_shutdown = self.is_shutdown();
let read_len = self.0.read(writer);
self.update_pollee();
@ -255,7 +246,7 @@ impl<T: Pod> Consumer<T> {
/// - Returns `Err(EAGAIN)` if the channel is empty.
pub fn try_pop(&self) -> Result<Option<T>> {
// This must be recorded before the actual operation to avoid race conditions.
let is_shutdown = self.is_shutdown() || self.is_peer_shutdown();
let is_shutdown = self.is_shutdown();
let item = self.0.pop();
self.update_pollee();
@ -273,15 +264,6 @@ impl<T: Pod> Consumer<T> {
impl<T> Drop for Consumer<T> {
fn drop(&mut self) {
self.shutdown();
// The POLLERR event indicates that the read end is closed (so any subsequent writes will
// fail with an `EPIPE` error).
//
// The lock is taken because we are also adding the POLLOUT event, which may have races
// with the event updates triggered by the writer.
let peer_end = self.peer_end();
let _rb = peer_end.rb();
peer_end.pollee.add_events(IoEvents::ERR | IoEvents::OUT);
}
}
@ -346,6 +328,7 @@ impl<T: Pod, R: TRights> Fifo<T, R> {
struct Common<T> {
producer: FifoInner<RbProducer<T>>,
consumer: FifoInner<RbConsumer<T>>,
is_shutdown: AtomicBool,
}
impl<T> Common<T> {
@ -356,18 +339,46 @@ impl<T> Common<T> {
let producer = FifoInner::new(rb_producer, IoEvents::OUT);
let consumer = FifoInner::new(rb_consumer, IoEvents::empty());
Self { producer, consumer }
Self {
producer,
consumer,
is_shutdown: AtomicBool::new(false),
}
}
pub fn capacity(&self) -> usize {
self.producer.rb().capacity()
}
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::Relaxed)
}
pub fn shutdown(&self) {
if self.is_shutdown.swap(true, Ordering::Relaxed) {
return;
}
// The POLLHUP event indicates that the write end is shut down.
//
// No need to take a lock. There is no race because no one is modifying this particular event.
self.consumer.pollee.add_events(IoEvents::HUP);
// The POLLERR event indicates that the read end is shut down (so any subsequent writes
// will fail with an `EPIPE` error).
//
// The lock is taken because we are also adding the POLLOUT event, which may have races
// with the event updates triggered by the writer.
let _rb = self.producer.rb();
self.producer
.pollee
.add_events(IoEvents::ERR | IoEvents::OUT);
}
}
struct FifoInner<T> {
rb: Mutex<T>,
pollee: Pollee,
is_shutdown: AtomicBool,
}
impl<T> FifoInner<T> {
@ -375,21 +386,12 @@ impl<T> FifoInner<T> {
Self {
rb: Mutex::new(rb),
pollee: Pollee::new(init_events),
is_shutdown: AtomicBool::new(false),
}
}
pub fn rb(&self) -> MutexGuard<T> {
self.rb.lock()
}
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::Acquire)
}
pub fn shutdown(&self) {
self.is_shutdown.store(true, Ordering::Release)
}
}
#[cfg(ktest)]

View File

@ -58,8 +58,6 @@ impl Connected {
}
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
// FIXME: If the socket has already been shut down, should we return an error code?
if cmd.shut_read() {
self.reader.shutdown();
}
@ -72,23 +70,27 @@ impl Connected {
}
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
let mut events = IoEvents::empty();
// FIXME: should reader and writer use the same mask?
// Note that `mask | IoEvents::ALWAYS_POLL` contains all the events we care about.
let reader_events = self.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.writer.poll(mask, poller);
// FIXME: Check this logic later.
if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() {
let mut events = IoEvents::empty();
if reader_events.contains(IoEvents::HUP) {
// The socket is shut down in one direction: the remote socket has shut down for
// writing or the local socket has shut down for reading.
events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() {
events |= IoEvents::HUP | IoEvents::OUT;
if writer_events.contains(IoEvents::ERR) {
// The socket is shut down in both directions. Neither reading nor writing is
// possible.
events |= IoEvents::HUP;
}
}
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events
events & (mask | IoEvents::ALWAYS_POLL)
}
pub(super) fn register_observer(
@ -96,14 +98,8 @@ impl Connected {
observer: Weak<dyn Observer<IoEvents>>,
mask: IoEvents,
) -> Result<()> {
if mask.contains(IoEvents::IN) {
self.reader.register_observer(observer.clone(), mask)?
}
if mask.contains(IoEvents::OUT) {
self.writer.register_observer(observer, mask)?
}
self.reader.register_observer(observer.clone(), mask)?;
self.writer.register_observer(observer, mask)?;
Ok(())
}
@ -111,16 +107,9 @@ impl Connected {
&self,
observer: &Weak<dyn Observer<IoEvents>>,
) -> Option<Weak<dyn Observer<IoEvents>>> {
let observer0 = self.reader.unregister_observer(observer);
let observer1 = self.writer.unregister_observer(observer);
if observer0.is_some() {
observer0
} else if observer1.is_some() {
observer1
} else {
None
}
let reader_observer = self.reader.unregister_observer(observer);
let writer_observer = self.writer.unregister_observer(observer);
reader_observer.or(writer_observer)
}
}

View File

@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0
#define _GNU_SOURCE
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/poll.h>
@ -304,6 +306,75 @@ FN_TEST(ns_abs)
}
END_TEST()
FN_TEST(shutdown_connected)
{
int fildes[2];
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes));
TEST_SUCC(shutdown(fildes[0], SHUT_RD));
TEST_SUCC(shutdown(fildes[0], SHUT_WR));
TEST_SUCC(shutdown(fildes[0], SHUT_RDWR));
TEST_SUCC(shutdown(fildes[0], SHUT_RD));
TEST_SUCC(shutdown(fildes[0], SHUT_WR));
TEST_SUCC(shutdown(fildes[0], SHUT_RDWR));
TEST_SUCC(close(fildes[0]));
TEST_SUCC(close(fildes[1]));
}
END_TEST()
FN_TEST(poll_connected_close)
{
int fildes[2];
struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP };
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes));
pfd.fd = fildes[1];
TEST_RES(poll(&pfd, 1, 0), pfd.revents == POLLOUT);
TEST_SUCC(close(fildes[0]));
pfd.fd = fildes[1];
TEST_RES(poll(&pfd, 1, 0),
pfd.revents == (POLLIN | POLLOUT | POLLRDHUP | POLLHUP));
TEST_SUCC(close(fildes[1]));
}
END_TEST()
FN_TEST(poll_connected_shutdown)
{
int fildes[2];
struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP };
#define MAKE_TEST(shut, ev1, ev2) \
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes)); \
\
TEST_SUCC(shutdown(fildes[0], shut)); \
\
pfd.fd = fildes[0]; \
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (ev1)); \
\
pfd.fd = fildes[1]; \
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (ev2)); \
\
TEST_SUCC(close(fildes[0])); \
TEST_SUCC(close(fildes[1]));
MAKE_TEST(SHUT_RD, POLLIN | POLLOUT | POLLRDHUP, POLLOUT);
MAKE_TEST(SHUT_WR, POLLOUT, POLLIN | POLLOUT | POLLRDHUP);
MAKE_TEST(SHUT_RDWR, POLLIN | POLLOUT | POLLRDHUP | POLLHUP,
POLLIN | POLLOUT | POLLRDHUP | POLLHUP);
#undef MAKE_TEST
}
END_TEST()
FN_SETUP(cleanup)
{
CHECK(close(sk_unbound));