diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs index 991f5668a..7e6af5a8a 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs @@ -210,6 +210,24 @@ impl FileLike for DatagramSocket { } Ok(()) } + + fn register_observer( + &self, + observer: Weak>, + mask: IoEvents, + ) -> Result<()> { + self.pollee.register_observer(observer, mask); + Ok(()) + } + + fn unregister_observer( + &self, + observer: &Weak>, + ) -> Result>> { + self.pollee + .unregister_observer(observer) + .ok_or_else(|| Error::with_message(Errno::ENOENT, "observer is not registered")) + } } impl Socket for DatagramSocket { diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs index 4ca60ae1c..20e204cc7 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs @@ -78,26 +78,6 @@ impl ConnectedStream { self.update_io_events(pollee); } - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } - pub(super) fn update_io_events(&self, pollee: &Pollee) { self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { if socket.can_recv() { diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs index 2af4eb24f..d1ac85383 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs @@ -2,7 +2,7 @@ use super::{connected::ConnectedStream, init::InitStream}; use crate::{ - events::{IoEvents, Observer}, + events::IoEvents, net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket}, prelude::*, process::signal::Pollee, @@ -71,26 +71,6 @@ impl ConnectingStream { self.update_io_events(pollee); } - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } - pub(super) fn update_io_events(&self, pollee: &Pollee) { if self.conn_result.read().is_some() { return; diff --git a/kernel/aster-nix/src/net/socket/ip/stream/init.rs b/kernel/aster-nix/src/net/socket/ip/stream/init.rs index 9c84655da..80ee0c815 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/init.rs @@ -4,13 +4,12 @@ use alloc::sync::Weak; use super::{connecting::ConnectingStream, listen::ListenStream}; use crate::{ - events::{IoEvents, Observer}, + events::Observer, net::{ iface::{AnyBoundSocket, AnyUnboundSocket, IpEndpoint}, socket::ip::common::{bind_socket, get_ephemeral_endpoint}, }, prelude::*, - process::signal::Pollee, }; pub enum InitStream { @@ -92,24 +91,4 @@ impl InitStream { InitStream::Bound(bound_socket) => Some(bound_socket.local_endpoint().unwrap()), } } - - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs index 3625f1b66..dceed257c 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs @@ -3,7 +3,7 @@ use smoltcp::socket::tcp::ListenError; use super::connected::ConnectedStream; use crate::{ - events::{IoEvents, Observer}, + events::IoEvents, net::iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket}, prelude::*, process::signal::Pollee, @@ -95,26 +95,6 @@ impl ListenStream { pollee.del_events(IoEvents::IN); } } - - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } } struct BacklogSocket { diff --git a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs index 3e425b5e7..7567a27ca 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs @@ -296,32 +296,20 @@ impl FileLike for StreamSocket { fn register_observer( &self, - observer: Weak>, + observer: Weak>, mask: IoEvents, ) -> Result<()> { - let state = self.state.read(); - match state.as_ref() { - State::Init(init) => init.register_observer(&self.pollee, observer, mask), - State::Connecting(connecting) => { - connecting.register_observer(&self.pollee, observer, mask) - } - State::Connected(connected) => { - connected.register_observer(&self.pollee, observer, mask) - } - State::Listen(listen) => listen.register_observer(&self.pollee, observer, mask), - } + self.pollee.register_observer(observer, mask); + Ok(()) } fn unregister_observer( &self, - observer: &Weak>, - ) -> Result>> { - match self.state.read().as_ref() { - State::Init(init) => init.unregister_observer(&self.pollee, observer), - State::Connecting(connecting) => connecting.unregister_observer(&self.pollee, observer), - State::Connected(connected) => connected.unregister_observer(&self.pollee, observer), - State::Listen(listen) => listen.unregister_observer(&self.pollee, observer), - } + observer: &Weak>, + ) -> Result>> { + self.pollee + .unregister_observer(observer) + .ok_or_else(|| Error::with_message(Errno::ENOENT, "observer is not registered")) } } diff --git a/regression/apps/network/tcp_err.c b/regression/apps/network/tcp_err.c index 034d93c97..839772410 100644 --- a/regression/apps/network/tcp_err.c +++ b/regression/apps/network/tcp_err.c @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -70,11 +71,12 @@ FN_SETUP(accpected) { struct sockaddr addr; socklen_t addrlen = sizeof(addr); + struct pollfd pfd = { .fd = sk_listen, .events = POLLIN }; - do { - sk_accepted = CHECK_WITH(accept(sk_listen, &addr, &addrlen), - _ret >= 0 || errno == EAGAIN); - } while (sk_accepted < 0); + CHECK_WITH(poll(&pfd, 1, 1000), + _ret >= 0 && ((pfd.revents & (POLLIN | POLLOUT)) & POLLIN)); + + sk_accepted = CHECK(accept(sk_listen, &addr, &addrlen)); } END_SETUP() @@ -226,3 +228,30 @@ FN_TEST(accept) TEST_ERRNO(accept(sk_accepted, psaddr, &addrlen), EINVAL); } END_TEST() + +FN_TEST(poll) +{ + struct pollfd pfd = { .events = POLLIN | POLLOUT }; + + pfd.fd = sk_unbound; + // FIXME: Uncomment this + // TEST_RES(poll(&pfd, 1, 0), + // (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); + + pfd.fd = sk_bound; + // FIXME: Uncomment this + // TEST_RES(poll(&pfd, 1, 0), + // (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); + + pfd.fd = sk_listen; + TEST_RES(poll(&pfd, 1, 0), (pfd.revents & (POLLIN | POLLOUT)) == 0); + + pfd.fd = sk_connected; + TEST_RES(poll(&pfd, 1, 0), + (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); + + pfd.fd = sk_accepted; + TEST_RES(poll(&pfd, 1, 0), + (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); +} +END_TEST() diff --git a/regression/apps/network/udp_err.c b/regression/apps/network/udp_err.c index 41abc38ce..240a6349b 100644 --- a/regression/apps/network/udp_err.c +++ b/regression/apps/network/udp_err.c @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -170,3 +171,22 @@ FN_TEST(accept) TEST_ERRNO(accept(sk_connected, psaddr, &addrlen), EOPNOTSUPP); } END_TEST() + +FN_TEST(poll) +{ + struct pollfd pfd = { .events = POLLIN | POLLOUT }; + + pfd.fd = sk_unbound; + // FIXME: Uncomment this + // TEST_RES(poll(&pfd, 1, 0), + // (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); + + pfd.fd = sk_bound; + TEST_RES(poll(&pfd, 1, 0), + (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); + + pfd.fd = sk_connected; + TEST_RES(poll(&pfd, 1, 0), + (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); +} +END_TEST()