Fix I/O events related to shutdown

This commit is contained in:
Ruihan Li 2024-07-27 13:10:43 +08:00 committed by Tate, Hongliang Tian
parent 421f6b8e5b
commit a8592a16ea
5 changed files with 230 additions and 49 deletions

View File

@ -61,7 +61,7 @@ impl Connected {
self.writer.try_write(&mut reader)
}
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) {
if cmd.shut_read() {
self.reader.shutdown();
}
@ -69,8 +69,6 @@ impl Connected {
if cmd.shut_write() {
self.writer.shutdown();
}
Ok(())
}
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
@ -78,23 +76,7 @@ impl Connected {
let reader_events = self.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.writer.poll(mask, poller);
let mut events = IoEvents::empty();
if reader_events.contains(IoEvents::HUP) {
// The socket is shut down in one direction: the remote socket has shut down for
// writing or the local socket has shut down for reading.
events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) {
// The socket is shut down in both directions. Neither reading nor writing is
// possible.
events |= IoEvents::HUP;
}
}
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events & (mask | IoEvents::ALWAYS_POLL)
combine_io_events(mask, reader_events, writer_events)
}
pub(super) fn register_observer(
@ -117,4 +99,28 @@ impl Connected {
}
}
pub(super) fn combine_io_events(
mask: IoEvents,
reader_events: IoEvents,
writer_events: IoEvents,
) -> IoEvents {
let mut events = IoEvents::empty();
if reader_events.contains(IoEvents::HUP) {
// The socket is shut down in one direction: the remote socket has shut down for
// writing or the local socket has shut down for reading.
events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) {
// The socket is shut down in both directions. Neither reading nor writing is
// possible.
events |= IoEvents::HUP;
}
}
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events & (mask | IoEvents::ALWAYS_POLL)
}
const DEFAULT_BUF_SIZE: usize = 65536;

View File

@ -1,9 +1,17 @@
// SPDX-License-Identifier: MPL-2.0
use super::{connected::Connected, listener::Listener};
use core::sync::atomic::{AtomicBool, Ordering};
use super::{
connected::{combine_io_events, Connected},
listener::Listener,
};
use crate::{
events::{IoEvents, Observer},
net::socket::unix::addr::{UnixSocketAddr, UnixSocketAddrBound},
net::socket::{
unix::addr::{UnixSocketAddr, UnixSocketAddrBound},
SockShutdownCmd,
},
prelude::*,
process::signal::{Pollee, Poller},
};
@ -12,6 +20,8 @@ pub(super) struct Init {
addr: Option<UnixSocketAddrBound>,
reader_pollee: Pollee,
writer_pollee: Pollee,
is_read_shutdown: AtomicBool,
is_write_shutdown: AtomicBool,
}
impl Init {
@ -19,7 +29,9 @@ impl Init {
Self {
addr: None,
reader_pollee: Pollee::new(IoEvents::empty()),
writer_pollee: Pollee::new(IoEvents::empty()),
writer_pollee: Pollee::new(IoEvents::OUT),
is_read_shutdown: AtomicBool::new(false),
is_write_shutdown: AtomicBool::new(false),
}
}
@ -39,14 +51,26 @@ impl Init {
addr,
reader_pollee,
writer_pollee,
is_read_shutdown,
is_write_shutdown,
} = self;
Connected::new_pair(
let (this_conn, peer_conn) = Connected::new_pair(
addr,
Some(peer_addr),
Some(reader_pollee),
Some(writer_pollee),
)
);
if is_read_shutdown.into_inner() {
this_conn.shutdown(SockShutdownCmd::SHUT_RD);
}
if is_write_shutdown.into_inner() {
this_conn.shutdown(SockShutdownCmd::SHUT_WR)
}
(this_conn, peer_conn)
}
pub(super) fn listen(self, backlog: usize) -> core::result::Result<Listener, (Error, Self)> {
@ -57,8 +81,31 @@ impl Init {
));
};
// There is no `writer_pollee` in `Listener`.
Ok(Listener::new(addr, self.reader_pollee, backlog))
Ok(Listener::new(
addr,
self.reader_pollee,
self.writer_pollee,
backlog,
self.is_read_shutdown.into_inner(),
))
}
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) {
match cmd {
SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => {
self.is_write_shutdown.store(true, Ordering::Relaxed);
self.writer_pollee.add_events(IoEvents::ERR);
}
SockShutdownCmd::SHUT_RD => (),
}
match cmd {
SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => {
self.is_read_shutdown.store(true, Ordering::Relaxed);
self.reader_pollee.add_events(IoEvents::HUP);
}
SockShutdownCmd::SHUT_WR => (),
}
}
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
@ -68,10 +115,12 @@ impl Init {
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
// To avoid loss of events, this must be compatible with
// `Connected::poll`/`Listener::poll`.
self.reader_pollee.poll(mask, poller.as_deref_mut());
self.writer_pollee.poll(mask, poller);
let reader_events = self.reader_pollee.poll(mask, poller.as_deref_mut());
let writer_events = self.writer_pollee.poll(mask, poller);
(IoEvents::OUT | IoEvents::HUP) & (mask | IoEvents::ALWAYS_POLL)
// According to the Linux implementation, we always have `IoEvents::HUP` in this state.
// Meanwhile, it is in `IoEvents::ALWAYS_POLL`, so we always return it.
combine_io_events(mask, reader_events, writer_events) | IoEvents::HUP
}
pub(super) fn register_observer(

View File

@ -2,13 +2,17 @@
use core::sync::atomic::{AtomicUsize, Ordering};
use super::{connected::Connected, init::Init, UnixStreamSocket};
use super::{
connected::{combine_io_events, Connected},
init::Init,
UnixStreamSocket,
};
use crate::{
events::{IoEvents, Observer},
fs::file_handle::FileLike,
net::socket::{
unix::addr::{UnixSocketAddrBound, UnixSocketAddrKey},
SocketAddr,
SockShutdownCmd, SocketAddr,
},
prelude::*,
process::signal::{Pollee, Poller},
@ -16,12 +20,28 @@ use crate::{
pub(super) struct Listener {
backlog: Arc<Backlog>,
writer_pollee: Pollee,
}
impl Listener {
pub(super) fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize) -> Self {
let backlog = BACKLOG_TABLE.add_backlog(addr, pollee, backlog).unwrap();
Self { backlog }
pub(super) fn new(
addr: UnixSocketAddrBound,
reader_pollee: Pollee,
writer_pollee: Pollee,
backlog: usize,
is_shutdown: bool,
) -> Self {
// Note that the I/O events can be correctly inherited from `Init`. There is no need to
// explicitly call `Pollee::reset_io_events`.
let backlog = BACKLOG_TABLE
.add_backlog(addr, reader_pollee, backlog, is_shutdown)
.unwrap();
writer_pollee.del_events(IoEvents::OUT);
Self {
backlog,
writer_pollee,
}
}
pub(super) fn addr(&self) -> &UnixSocketAddrBound {
@ -40,8 +60,27 @@ impl Listener {
self.backlog.set_backlog(backlog);
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
self.backlog.poll(mask, poller)
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) {
match cmd {
SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => {
self.writer_pollee.add_events(IoEvents::ERR);
}
SockShutdownCmd::SHUT_RD => (),
}
match cmd {
SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => {
self.backlog.shutdown();
}
SockShutdownCmd::SHUT_WR => (),
}
}
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
let reader_events = self.backlog.poll(mask, poller.as_deref_mut());
let writer_events = self.writer_pollee.poll(mask, poller);
combine_io_events(mask, reader_events, writer_events)
}
pub(super) fn register_observer(
@ -49,14 +88,18 @@ impl Listener {
observer: Weak<dyn Observer<IoEvents>>,
mask: IoEvents,
) -> Result<()> {
self.backlog.register_observer(observer, mask)
self.backlog.register_observer(observer.clone(), mask)?;
self.writer_pollee.register_observer(observer, mask);
Ok(())
}
pub(super) fn unregister_observer(
&self,
observer: &Weak<dyn Observer<IoEvents>>,
) -> Option<Weak<dyn Observer<IoEvents>>> {
self.backlog.unregister_observer(observer)
let reader_observer = self.backlog.unregister_observer(observer);
let writer_observer = self.writer_pollee.unregister_observer(observer);
reader_observer.or(writer_observer)
}
}
@ -84,6 +127,7 @@ impl BacklogTable {
addr: UnixSocketAddrBound,
pollee: Pollee,
backlog: usize,
is_shutdown: bool,
) -> Option<Arc<Backlog>> {
let addr_key = addr.to_key();
@ -93,7 +137,7 @@ impl BacklogTable {
return None;
}
let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog));
let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog, is_shutdown));
backlog_sockets.insert(addr_key, new_backlog.clone());
Some(new_backlog)
@ -133,18 +177,22 @@ struct Backlog {
addr: UnixSocketAddrBound,
pollee: Pollee,
backlog: AtomicUsize,
incoming_conns: Mutex<VecDeque<Connected>>,
incoming_conns: Mutex<Option<VecDeque<Connected>>>,
}
impl Backlog {
fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize) -> Self {
pollee.reset_events();
fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize, is_shutdown: bool) -> Self {
let incoming_sockets = if is_shutdown {
None
} else {
Some(VecDeque::with_capacity(backlog))
};
Self {
addr,
pollee,
backlog: AtomicUsize::new(backlog),
incoming_conns: Mutex::new(VecDeque::with_capacity(backlog)),
incoming_conns: Mutex::new(incoming_sockets),
}
}
@ -153,7 +201,17 @@ impl Backlog {
}
fn push_incoming(&self, init: Init) -> core::result::Result<Connected, (Error, Init)> {
let mut incoming_conns = self.incoming_conns.lock();
let mut locked_incoming_conns = self.incoming_conns.lock();
let Some(incoming_conns) = &mut *locked_incoming_conns else {
return Err((
Error::with_message(
Errno::ECONNREFUSED,
"the listening socket is shut down for reading",
),
init,
));
};
if incoming_conns.len() >= self.backlog.load(Ordering::Relaxed) {
return Err((
@ -174,11 +232,17 @@ impl Backlog {
}
fn pop_incoming(&self) -> Result<Connected> {
let mut incoming_conns = self.incoming_conns.lock();
let mut locked_incoming_conns = self.incoming_conns.lock();
let Some(incoming_conns) = &mut *locked_incoming_conns else {
return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading");
};
let conn = incoming_conns.pop_front();
if incoming_conns.is_empty() {
self.pollee.del_events(IoEvents::IN);
}
conn.ok_or_else(|| Error::with_message(Errno::EAGAIN, "no pending connection is available"))
}
@ -186,9 +250,15 @@ impl Backlog {
self.backlog.store(backlog, Ordering::Relaxed);
}
fn shutdown(&self) {
let mut incoming_conns = self.incoming_conns.lock();
*incoming_conns = None;
self.pollee.add_events(IoEvents::HUP);
self.pollee.del_events(IoEvents::IN);
}
fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
// Lock to avoid any events may change pollee state when we poll
let _lock = self.incoming_conns.lock();
self.pollee.poll(mask, poller)
}

View File

@ -297,9 +297,12 @@ impl Socket for UnixStreamSocket {
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
match self.state.read().as_ref() {
State::Init(init) => init.shutdown(cmd),
State::Listen(listen) => listen.shutdown(cmd),
State::Connected(connected) => connected.shutdown(cmd),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"),
}
Ok(())
}
fn addr(&self) -> Result<SocketAddr> {

View File

@ -326,6 +326,59 @@ FN_TEST(shutdown_connected)
}
END_TEST()
FN_TEST(poll_unbound)
{
int sk;
struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP };
sk = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0));
pfd.fd = sk;
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLOUT | POLLHUP));
TEST_SUCC(shutdown(sk, SHUT_WR));
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLOUT | POLLHUP));
TEST_SUCC(shutdown(sk, SHUT_RD));
TEST_RES(poll(&pfd, 1, 0),
pfd.revents == (POLLIN | POLLOUT | POLLRDHUP | POLLHUP));
TEST_SUCC(
bind(sk, (struct sockaddr *)&UNIX_ADDR("\0"), PATH_OFFSET + 1));
TEST_SUCC(listen(sk, 10));
TEST_RES(poll(&pfd, 1, 0),
pfd.revents == (POLLIN | POLLRDHUP | POLLHUP));
TEST_SUCC(close(sk));
}
END_TEST()
FN_TEST(poll_listen)
{
int sk;
struct pollfd pfd = { .events = POLLIN | POLLOUT | POLLRDHUP };
sk = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0));
pfd.fd = sk;
TEST_SUCC(
bind(sk, (struct sockaddr *)&UNIX_ADDR("\0"), PATH_OFFSET + 1));
TEST_SUCC(listen(sk, 10));
TEST_RES(poll(&pfd, 1, 0), pfd.revents == 0);
TEST_SUCC(shutdown(sk, SHUT_RD));
TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLIN | POLLRDHUP));
TEST_SUCC(shutdown(sk, SHUT_WR));
TEST_RES(poll(&pfd, 1, 0),
pfd.revents == (POLLIN | POLLRDHUP | POLLHUP));
TEST_SUCC(close(sk));
}
END_TEST()
FN_TEST(poll_connected_close)
{
int fildes[2];