diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index cb1018d2..d696b7da 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; +use core::{ + net::Ipv4Addr, + sync::atomic::{AtomicBool, Ordering}, +}; use aster_bigtcp::{socket::RawTcpOption, wire::IpEndpoint}; @@ -9,7 +12,10 @@ use crate::{ events::IoEvents, net::{ iface::BoundPort, - socket::ip::common::{bind_port, get_ephemeral_endpoint}, + socket::{ + ip::common::{bind_port, get_ephemeral_endpoint}, + SocketAddr, + }, }, prelude::*, }; @@ -33,7 +39,8 @@ pub struct InitStream { /// Indicates whether the socket error is `ECONNREFUSED`. /// /// This boolean value is set to true when the connection is refused and set to false when the - /// error code is reported via either `getsockopt(SOL_SOCKET, SO_ERROR)` or `connect()`. + /// error code is reported via `getsockopt(SOL_SOCKET, SO_ERROR)`, `send()`, `recv()`, or + /// `connect()`. is_conn_refused: AtomicBool, } @@ -157,6 +164,33 @@ impl InitStream { } } + pub fn try_recv(&self) -> Result<(usize, SocketAddr)> { + // FIXME: Linux does not return addresses for `recvfrom` on connection-oriented sockets. + // This is a placeholder that has no Linux equivalent. (Note also that in this case + // `getpeeraddr` will simply fail with `ENOTCONN`). + const UNSPECIFIED_SOCKET_ADDR: SocketAddr = SocketAddr::IPv4(Ipv4Addr::UNSPECIFIED, 0); + + // Below are some magic checks to make our behavior identical to Linux. + + if self.is_connect_done { + return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"); + } + + if let Some(err) = self.test_and_clear_error() { + return Err(err); + } + + Ok((0, UNSPECIFIED_SOCKET_ADDR)) + } + + pub fn try_send(&self) -> Result { + if let Some(err) = self.test_and_clear_error() { + return Err(err); + } + + return_errno_with_message!(Errno::EPIPE, "the socket is not connected"); + } + pub fn local_endpoint(&self) -> Option { self.bound_port .as_ref() @@ -165,7 +199,13 @@ impl InitStream { pub(super) fn check_io_events(&self) -> IoEvents { // Linux adds OUT and HUP events for a newly created socket - IoEvents::OUT | IoEvents::HUP + let mut events = IoEvents::OUT | IoEvents::HUP; + + if self.is_conn_refused.load(Ordering::Relaxed) { + events |= IoEvents::ERR; + } + + events } pub(super) fn test_and_clear_error(&self) -> Option { diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 0879a135..25ebfd27 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -313,7 +313,12 @@ impl StreamSocket { let connected_stream = match state.as_ref() { State::Connected(connected_stream) => connected_stream, - State::Init(_) | State::Listen(_) => { + State::Init(init_stream) => { + let result = init_stream.try_recv(); + self.pollee.invalidate(); + return result; + } + State::Listen(_) => { return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected") } State::Connecting(_) => { @@ -339,7 +344,12 @@ impl StreamSocket { let connected_stream = match state.as_ref() { State::Connected(connected_stream) => connected_stream, - State::Init(_) | State::Listen(_) => { + State::Init(init_stream) => { + let result = init_stream.try_send(); + self.pollee.invalidate(); + return result; + } + State::Listen(_) => { // TODO: Trigger `SIGPIPE` if `MSG_NOSIGNAL` is not specified return_errno_with_message!(Errno::EPIPE, "the socket is not connected"); } @@ -376,10 +386,12 @@ impl StreamSocket { fn test_and_clear_error(&self) -> Option { let state = self.read_updated_state(); - match state.as_ref() { + let error = match state.as_ref() { State::Init(init_stream) => init_stream.test_and_clear_error(), State::Connecting(_) | State::Listen(_) | State::Connected(_) => None, - } + }; + self.pollee.invalidate(); + error } } diff --git a/kernel/src/net/socket/ip/stream/observer.rs b/kernel/src/net/socket/ip/stream/observer.rs index 862ae229..cb27d918 100644 --- a/kernel/src/net/socket/ip/stream/observer.rs +++ b/kernel/src/net/socket/ip/stream/observer.rs @@ -30,7 +30,8 @@ impl SocketEventObserver for StreamObserver { } if events.contains(SocketEvents::CLOSED) { - io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP; + io_events |= IoEvents::IN | IoEvents::OUT; + io_events |= IoEvents::RDHUP | IoEvents::HUP | IoEvents::ERR; } self.0.notify(io_events); diff --git a/test/apps/network/tcp_err.c b/test/apps/network/tcp_err.c index 4ce9b3e5..f060b0c8 100644 --- a/test/apps/network/tcp_err.c +++ b/test/apps/network/tcp_err.c @@ -320,7 +320,7 @@ FN_TEST(async_connect) { struct pollfd pfd = { .fd = sk_bound, .events = POLLOUT }; int err; - socklen_t errlen = sizeof(err); + socklen_t errlen; sk_addr.sin_port = 0xbeef; @@ -328,10 +328,16 @@ FN_TEST(async_connect) TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, \ sizeof(sk_addr)), \ EINPROGRESS); \ - TEST_RES(poll(&pfd, 1, 60), pfd.revents &POLLOUT); + TEST_RES(poll(&pfd, 1, 60), \ + pfd.revents == (POLLOUT | POLLHUP | POLLERR)); ASYNC_CONNECT; + // `getpeername` will fail with `ENOTCONN` even before the second `connect`. + errlen = sizeof(sk_addr); + TEST_ERRNO(getpeername(sk_bound, (struct sockaddr *)&sk_addr, &errlen), + ENOTCONN); + // The second `connect` will fail with `ECONNREFUSED`. TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, sizeof(sk_addr)), @@ -340,10 +346,12 @@ FN_TEST(async_connect) ASYNC_CONNECT; // Reading the socket error will cause it to be cleared + errlen = sizeof(err); TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen), errlen == sizeof(err) && err == ECONNREFUSED); TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen), errlen == sizeof(err) && err == 0); + TEST_RES(poll(&pfd, 1, 0), pfd.revents == (POLLOUT | POLLHUP)); // `listen` won't succeed until the second `connect`. TEST_ERRNO(listen(sk_bound, 10), EINVAL); @@ -354,6 +362,26 @@ FN_TEST(async_connect) sizeof(sk_addr)), ECONNABORTED); + ASYNC_CONNECT; + + // Testing `send` behavior before and after the second `connect`. + TEST_ERRNO(send(sk_bound, &err, 0, 0), ECONNREFUSED); + TEST_ERRNO(send(sk_bound, &err, 0, 0), EPIPE); + TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, + sizeof(sk_addr)), + ECONNABORTED); + TEST_ERRNO(send(sk_bound, &err, 0, 0), EPIPE); + + ASYNC_CONNECT; + + // Testing `recv` behavior before and after the second `connect`. + TEST_ERRNO(recv(sk_bound, &err, 0, 0), ECONNREFUSED); + TEST_RES(recv(sk_bound, &err, 0, 0), _ret == 0); + TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, + sizeof(sk_addr)), + ECONNABORTED); + TEST_ERRNO(recv(sk_bound, &err, 0, 0), ENOTCONN); + #undef ASYNC_CONNECT } END_TEST()