diff --git a/kernel/src/fs/utils/channel.rs b/kernel/src/fs/utils/channel.rs index bd4cddc68..f2f537fec 100644 --- a/kernel/src/fs/utils/channel.rs +++ b/kernel/src/fs/utils/channel.rs @@ -26,7 +26,20 @@ impl Channel { /// /// This method will panic if the given capacity is zero. pub fn with_capacity(capacity: usize) -> Self { - let common = Arc::new(Common::new(capacity)); + Self::with_capacity_and_pollees(capacity, None, None) + } + + /// Creates a new channel with the given capacity and pollees. + /// + /// # Panics + /// + /// This method will panic if the given capacity is zero. + pub fn with_capacity_and_pollees( + capacity: usize, + producer_pollee: Option, + consumer_pollee: Option, + ) -> Self { + let common = Arc::new(Common::new(capacity, producer_pollee, consumer_pollee)); let producer = Producer(Fifo::new(common.clone())); let consumer = Consumer(Fifo::new(common)); @@ -332,12 +345,36 @@ struct Common { } impl Common { - fn new(capacity: usize) -> Self { + fn new( + capacity: usize, + producer_pollee: Option, + consumer_pollee: Option, + ) -> Self { let rb: RingBuffer = RingBuffer::new(capacity); let (rb_producer, rb_consumer) = rb.split(); - let producer = FifoInner::new(rb_producer, IoEvents::OUT); - let consumer = FifoInner::new(rb_consumer, IoEvents::empty()); + let producer = { + let polee = if let Some(pollee) = producer_pollee { + pollee.reset_events(); + pollee.add_events(IoEvents::OUT); + pollee + } else { + Pollee::new(IoEvents::OUT) + }; + + FifoInner::new(rb_producer, polee) + }; + + let consumer = { + let pollee = if let Some(pollee) = consumer_pollee { + pollee.reset_events(); + pollee + } else { + Pollee::new(IoEvents::empty()) + }; + + FifoInner::new(rb_consumer, pollee) + }; Self { producer, @@ -382,10 +419,10 @@ struct FifoInner { } impl FifoInner { - pub fn new(rb: T, init_events: IoEvents) -> Self { + pub fn new(rb: T, pollee: Pollee) -> Self { Self { rb: Mutex::new(rb), - pollee: Pollee::new(init_events), + pollee, } } diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index f016610b0..75fcdf95e 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -5,7 +5,7 @@ use crate::{ fs::utils::{Channel, Consumer, Producer}, net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, prelude::*, - process::signal::Poller, + process::signal::{Pollee, Poller}, }; pub(super) struct Connected { @@ -19,9 +19,13 @@ impl Connected { pub(super) fn new_pair( addr: Option, peer_addr: Option, + reader_pollee: Option, + writer_pollee: Option, ) -> (Connected, Connected) { - let (writer_this, reader_peer) = Channel::with_capacity(DEFAULT_BUF_SIZE).split(); - let (writer_peer, reader_this) = Channel::with_capacity(DEFAULT_BUF_SIZE).split(); + let (writer_peer, reader_this) = + Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, None, reader_pollee).split(); + let (writer_this, reader_peer) = + Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, writer_pollee, None).split(); let this = Connected { addr: addr.clone(), diff --git a/kernel/src/net/socket/unix/stream/init.rs b/kernel/src/net/socket/unix/stream/init.rs index 74a159cfd..8daab079f 100644 --- a/kernel/src/net/socket/unix/stream/init.rs +++ b/kernel/src/net/socket/unix/stream/init.rs @@ -1,5 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 +use super::{connected::Connected, listener::Listener}; use crate::{ events::{IoEvents, Observer}, net::socket::unix::addr::{UnixSocketAddr, UnixSocketAddrBound}, @@ -9,14 +10,16 @@ use crate::{ pub(super) struct Init { addr: Option, - pollee: Pollee, + reader_pollee: Pollee, + writer_pollee: Pollee, } impl Init { pub(super) fn new() -> Self { Self { addr: None, - pollee: Pollee::new(IoEvents::empty()), + reader_pollee: Pollee::new(IoEvents::empty()), + writer_pollee: Pollee::new(IoEvents::empty()), } } @@ -31,12 +34,44 @@ impl Init { Ok(()) } + pub(super) fn into_connected(self, peer_addr: UnixSocketAddrBound) -> (Connected, Connected) { + let Init { + addr, + reader_pollee, + writer_pollee, + } = self; + + Connected::new_pair( + addr, + Some(peer_addr), + Some(reader_pollee), + Some(writer_pollee), + ) + } + + pub(super) fn listen(self, backlog: usize) -> core::result::Result { + let Some(addr) = self.addr else { + return Err(( + Error::with_message(Errno::EINVAL, "the socket is not bound"), + self, + )); + }; + + // There is no `writer_pollee` in `Listener`. + Ok(Listener::new(addr, self.reader_pollee, backlog)) + } + pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { self.addr.as_ref() } - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.pollee.poll(mask, poller) + pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { + // To avoid loss of events, this must be compatible with + // `Connected::poll`/`Listener::poll`. + self.reader_pollee.poll(mask, poller.as_deref_mut()); + self.writer_pollee.poll(mask, poller); + + (IoEvents::OUT | IoEvents::HUP) & (mask | IoEvents::ALWAYS_POLL) } pub(super) fn register_observer( @@ -44,7 +79,10 @@ impl Init { observer: Weak>, mask: IoEvents, ) -> Result<()> { - self.pollee.register_observer(observer, mask); + // To avoid loss of events, this must be compatible with + // `Connected::poll`/`Listener::poll`. + self.reader_pollee.register_observer(observer.clone(), mask); + self.writer_pollee.register_observer(observer, mask); Ok(()) } @@ -52,6 +90,8 @@ impl Init { &self, observer: &Weak>, ) -> Option>> { - self.pollee.unregister_observer(observer) + let reader_observer = self.reader_pollee.unregister_observer(observer); + let writer_observer = self.writer_pollee.unregister_observer(observer); + reader_observer.or(writer_observer) } } diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index d3c6f6d5a..c691e6024 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -2,7 +2,7 @@ use core::sync::atomic::{AtomicUsize, Ordering}; -use super::{connected::Connected, UnixStreamSocket}; +use super::{connected::Connected, init::Init, UnixStreamSocket}; use crate::{ events::{IoEvents, Observer}, fs::file_handle::FileLike, @@ -19,8 +19,8 @@ pub(super) struct Listener { } impl Listener { - pub(super) fn new(addr: UnixSocketAddrBound, backlog: usize) -> Self { - let backlog = BACKLOG_TABLE.add_backlog(addr, backlog).unwrap(); + pub(super) fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize) -> Self { + let backlog = BACKLOG_TABLE.add_backlog(addr, pollee, backlog).unwrap(); Self { backlog } } @@ -36,9 +36,8 @@ impl Listener { Ok((socket, peer_addr)) } - pub(super) fn listen(&self, backlog: usize) -> Result<()> { + pub(super) fn listen(&self, backlog: usize) { self.backlog.set_backlog(backlog); - Ok(()) } pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { @@ -80,7 +79,12 @@ impl BacklogTable { } } - fn add_backlog(&self, addr: UnixSocketAddrBound, backlog: usize) -> Option> { + fn add_backlog( + &self, + addr: UnixSocketAddrBound, + pollee: Pollee, + backlog: usize, + ) -> Option> { let addr_key = addr.to_key(); let mut backlog_sockets = self.backlog_sockets.write(); @@ -89,7 +93,7 @@ impl BacklogTable { return None; } - let new_backlog = Arc::new(Backlog::new(addr, backlog)); + let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog)); backlog_sockets.insert(addr_key, new_backlog.clone()); Some(new_backlog) @@ -102,16 +106,22 @@ impl BacklogTable { fn push_incoming( &self, server_key: &UnixSocketAddrKey, - client_addr: Option, - ) -> Result { - let backlog = self.get_backlog(server_key).ok_or_else(|| { - Error::with_message( - Errno::ECONNREFUSED, - "no socket is listening at the remote address", - ) - })?; + init: Init, + ) -> core::result::Result { + let backlog = match self.get_backlog(server_key) { + Some(backlog) => backlog, + None => { + return Err(( + Error::with_message( + Errno::ECONNREFUSED, + "no socket is listening at the remote address", + ), + init, + )) + } + }; - backlog.push_incoming(client_addr) + backlog.push_incoming(init) } fn remove_backlog(&self, addr_key: &UnixSocketAddrKey) { @@ -127,10 +137,12 @@ struct Backlog { } impl Backlog { - fn new(addr: UnixSocketAddrBound, backlog: usize) -> Self { + fn new(addr: UnixSocketAddrBound, pollee: Pollee, backlog: usize) -> Self { + pollee.reset_events(); + Self { addr, - pollee: Pollee::new(IoEvents::empty()), + pollee, backlog: AtomicUsize::new(backlog), incoming_conns: Mutex::new(VecDeque::with_capacity(backlog)), } @@ -140,17 +152,20 @@ impl Backlog { &self.addr } - fn push_incoming(&self, client_addr: Option) -> Result { + fn push_incoming(&self, init: Init) -> core::result::Result { let mut incoming_conns = self.incoming_conns.lock(); if incoming_conns.len() >= self.backlog.load(Ordering::Relaxed) { - return_errno_with_message!( - Errno::EAGAIN, - "the pending connection queue on the listening socket is full" - ); + return Err(( + Error::with_message( + Errno::EAGAIN, + "the pending connection queue on the listening socket is full", + ), + init, + )); } - let (server_conn, client_conn) = Connected::new_pair(Some(self.addr.clone()), client_addr); + let (client_conn, server_conn) = init.into_connected(self.addr.clone()); incoming_conns.push_back(server_conn); self.pollee.add_events(IoEvents::IN); @@ -200,7 +215,7 @@ fn unregister_backlog(addr: &UnixSocketAddrKey) { pub(super) fn push_incoming( server_key: &UnixSocketAddrKey, - client_addr: Option, -) -> Result { - BACKLOG_TABLE.push_incoming(server_key, client_addr) + init: Init, +) -> core::result::Result { + BACKLOG_TABLE.push_incoming(server_key, init) } diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index e5eec582a..f3f981905 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -3,6 +3,7 @@ use core::sync::atomic::AtomicBool; use atomic::Ordering; +use takeable::Takeable; use super::{ connected::Connected, @@ -13,7 +14,7 @@ use crate::{ events::{IoEvents, Observer}, fs::{file_handle::FileLike, utils::StatusFlags}, net::socket::{ - unix::UnixSocketAddr, + unix::{addr::UnixSocketAddrKey, UnixSocketAddr}, util::{ copy_message_from_user, copy_message_to_user, create_message_buffer, send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader, @@ -27,21 +28,21 @@ use crate::{ }; pub struct UnixStreamSocket { - state: RwLock, + state: RwLock>, is_nonblocking: AtomicBool, } impl UnixStreamSocket { pub(super) fn new_init(init: Init, is_nonblocking: bool) -> Arc { Arc::new(Self { - state: RwLock::new(State::Init(init)), + state: RwLock::new(Takeable::new(State::Init(init))), is_nonblocking: AtomicBool::new(is_nonblocking), }) } pub(super) fn new_connected(connected: Connected, is_nonblocking: bool) -> Arc { Arc::new(Self { - state: RwLock::new(State::Connected(connected)), + state: RwLock::new(Takeable::new(State::Connected(connected))), is_nonblocking: AtomicBool::new(is_nonblocking), }) } @@ -59,7 +60,7 @@ impl UnixStreamSocket { } pub fn new_pair(is_nonblocking: bool) -> (Arc, Arc) { - let (conn_a, conn_b) = Connected::new_pair(None, None); + let (conn_a, conn_b) = Connected::new_pair(None, None, None, None); ( Self::new_connected(conn_a, is_nonblocking), Self::new_connected(conn_b, is_nonblocking), @@ -75,7 +76,7 @@ impl UnixStreamSocket { } fn try_send(&self, buf: &[u8], _flags: SendRecvFlags) -> Result { - match &*self.state.read() { + match self.state.read().as_ref() { State::Connected(connected) => connected.try_write(buf), _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), } @@ -90,14 +91,49 @@ impl UnixStreamSocket { } fn try_recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result { - match &*self.state.read() { + match self.state.read().as_ref() { State::Connected(connected) => connected.try_read(buf), _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), } } + fn try_connect(&self, remote_addr: &UnixSocketAddrKey) -> Result<()> { + let mut state = self.state.write(); + + state.borrow_result(|owned_state| { + let init = match owned_state { + State::Init(init) => init, + State::Listen(listener) => { + return ( + State::Listen(listener), + Err(Error::with_message( + Errno::EINVAL, + "the socket is listening", + )), + ); + } + State::Connected(connected) => { + return ( + State::Connected(connected), + Err(Error::with_message( + Errno::EISCONN, + "the socket is connected", + )), + ); + } + }; + + let connected = match push_incoming(remote_addr, init) { + Ok(connected) => connected, + Err((err, init)) => return (State::Init(init), Err(err)), + }; + + (State::Connected(connected), Ok(())) + }) + } + fn try_accept(&self) -> Result<(Arc, SocketAddr)> { - match &*self.state.read() { + match self.state.read().as_ref() { State::Listen(listen) => listen.try_accept() as _, _ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"), } @@ -115,7 +151,7 @@ impl UnixStreamSocket { impl Pollable for UnixStreamSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { let inner = self.state.read(); - match &*inner { + match inner.as_ref() { State::Init(init) => init.poll(mask, poller), State::Listen(listen) => listen.poll(mask, poller), State::Connected(connected) => connected.poll(mask, poller), @@ -162,8 +198,7 @@ impl FileLike for UnixStreamSocket { observer: Weak>, mask: IoEvents, ) -> Result<()> { - let inner = self.state.write(); - match &*inner { + match self.state.read().as_ref() { State::Init(init) => init.register_observer(observer, mask), State::Listen(listen) => listen.register_observer(observer, mask), State::Connected(connected) => connected.register_observer(observer, mask), @@ -174,8 +209,7 @@ impl FileLike for UnixStreamSocket { &self, observer: &Weak>, ) -> Option>> { - let inner = self.state.write(); - match &*inner { + match self.state.read().as_ref() { State::Init(init) => init.unregister_observer(observer), State::Listen(listen) => listen.unregister_observer(observer), State::Connected(connected) => connected.unregister_observer(observer), @@ -187,7 +221,7 @@ impl Socket for UnixStreamSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { let addr = UnixSocketAddr::try_from(socket_addr)?; - match &mut *self.state.write() { + match self.state.write().as_mut() { State::Init(init) => init.bind(addr), _ => return_errno_with_message!( Errno::EINVAL, @@ -209,58 +243,48 @@ impl Socket for UnixStreamSocket { // // See also . - let client_addr = match &*self.state.read() { - State::Init(init) => init.addr().cloned(), - State::Listen(_) => { - return_errno_with_message!(Errno::EINVAL, "the socket is listening") - } - State::Connected(_) => { - return_errno_with_message!(Errno::EISCONN, "the socket is connected") - } - }; - - // We use the `push_incoming` directly to avoid holding the read lock of `self.state` - // because it might call `Thread::yield_now` to wait for connection. loop { - let res = push_incoming(&remote_addr, client_addr.clone()); - match res { - Ok(connected) => { - *self.state.write() = State::Connected(connected); - return Ok(()); - } - Err(err) if err.error() == Errno::EAGAIN => { - // FIXME: Calling `Thread::yield_now` can cause the thread to run when the backlog is full, - // which wastes a lot of CPU time. Using `WaitQueue` maybe a better solution. - Thread::yield_now() - } - Err(err) => return Err(err), + let res = self.try_connect(&remote_addr); + + if !res.is_err_and(|err| err.error() == Errno::EAGAIN) { + return res; } + + // FIXME: Add `Pauser` in `Backlog` and use it to avoid this `Thread::yield_now`. + Thread::yield_now(); } } fn listen(&self, backlog: usize) -> Result<()> { let mut state = self.state.write(); - let addr = match &*state { - State::Init(init) => init - .addr() - .ok_or(Error::with_message( - Errno::EINVAL, - "the socket is not bound", - ))? - .clone(), - State::Listen(listen) => { - return listen.listen(backlog); - } - State::Connected(_) => { - return_errno_with_message!(Errno::EINVAL, "the socket is connected") - } - }; + state.borrow_result(|owned_state| { + let init = match owned_state { + State::Init(init) => init, + State::Listen(listener) => { + listener.listen(backlog); + return (State::Listen(listener), Ok(())); + } + State::Connected(connected) => { + return ( + State::Connected(connected), + Err(Error::with_message( + Errno::EINVAL, + "the socket is connected", + )), + ); + } + }; - let listener = Listener::new(addr, backlog); - *state = State::Listen(listener); + let listener = match init.listen(backlog) { + Ok(listener) => listener, + Err((err, init)) => { + return (State::Init(init), Err(err)); + } + }; - Ok(()) + (State::Listen(listener), Ok(())) + }) } fn accept(&self) -> Result<(Arc, SocketAddr)> { @@ -272,14 +296,14 @@ impl Socket for UnixStreamSocket { } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - match &*self.state.read() { + match self.state.read().as_ref() { State::Connected(connected) => connected.shutdown(cmd), _ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"), } } fn addr(&self) -> Result { - let addr = match &*self.state.read() { + let addr = match self.state.read().as_ref() { State::Init(init) => init.addr().cloned(), State::Listen(listen) => Some(listen.addr().clone()), State::Connected(connected) => connected.addr().cloned(), @@ -289,7 +313,7 @@ impl Socket for UnixStreamSocket { } fn peer_addr(&self) -> Result { - let peer_addr = match &*self.state.read() { + let peer_addr = match self.state.read().as_ref() { State::Connected(connected) => connected.peer_addr().cloned(), _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), }; diff --git a/kernel/src/process/signal/poll.rs b/kernel/src/process/signal/poll.rs index 75aaa362d..9a3f5b46c 100644 --- a/kernel/src/process/signal/poll.rs +++ b/kernel/src/process/signal/poll.rs @@ -13,7 +13,6 @@ use crate::{ /// A pollee maintains a set of active events, which can be polled with /// pollers or be monitored with observers. -#[derive(Clone)] pub struct Pollee { inner: Arc, } diff --git a/test/apps/network/unix_err.c b/test/apps/network/unix_err.c index 49d5f468e..5cadc17e8 100644 --- a/test/apps/network/unix_err.c +++ b/test/apps/network/unix_err.c @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -375,6 +376,61 @@ FN_TEST(poll_connected_shutdown) } END_TEST() +FN_TEST(epoll) +{ + int sk2_listen, sk2_connected, sk2_accepted; + int epfd_listen, epfd_connected; + struct epoll_event ev; + + // Setup + + sk2_listen = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0)); + sk2_connected = TEST_SUCC(socket(PF_UNIX, SOCK_STREAM, 0)); + + epfd_listen = TEST_SUCC(epoll_create1(0)); + ev.events = EPOLLIN; + ev.data.fd = sk2_listen; + TEST_SUCC(epoll_ctl(epfd_listen, EPOLL_CTL_ADD, sk2_listen, &ev)); + + epfd_connected = TEST_SUCC(epoll_create1(0)); + ev.events = EPOLLIN; + ev.data.fd = sk2_connected; + TEST_SUCC(epoll_ctl(epfd_connected, EPOLL_CTL_ADD, sk2_connected, &ev)); + + // Test 1: Switch from the unbound state to the listening state + + TEST_SUCC(bind(sk2_listen, (struct sockaddr *)&UNIX_ADDR("\0"), + PATH_OFFSET + 1)); + TEST_SUCC(listen(sk2_listen, 10)); + TEST_RES(epoll_wait(epfd_listen, &ev, 1, 0), _ret == 0); + + TEST_SUCC(connect(sk2_connected, (struct sockaddr *)&UNIX_ADDR("\0"), + PATH_OFFSET + 1)); + ev.data.fd = -1; + TEST_RES(epoll_wait(epfd_listen, &ev, 1, 0), + _ret == 1 && ev.data.fd == sk2_listen); + + // Test 2: Switch from the unbound state to the connected state + + TEST_RES(epoll_wait(epfd_connected, &ev, 1, 0), _ret == 0); + + sk2_accepted = TEST_SUCC(accept(sk2_listen, NULL, 0)); + TEST_SUCC(write(sk2_accepted, "", 1)); + + ev.data.fd = -1; + TEST_RES(epoll_wait(epfd_connected, &ev, 1, 0), + _ret == 1 && ev.data.fd == sk2_connected); + + // Clean up + + TEST_SUCC(close(epfd_listen)); + TEST_SUCC(close(epfd_connected)); + TEST_SUCC(close(sk2_connected)); + TEST_SUCC(close(sk2_accepted)); + TEST_SUCC(close(sk2_listen)); +} +END_TEST() + FN_SETUP(cleanup) { CHECK(close(sk_unbound));