From d40d452e9d525f40f064379241abc6e69f74ce02 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Mon, 24 Feb 2025 23:07:58 +0800 Subject: [PATCH] Receive RST packets as `ECONNRESET` errors --- kernel/libs/aster-bigtcp/src/errors.rs | 40 ++++++- .../aster-bigtcp/src/socket/bound/tcp_conn.rs | 63 +++++++++-- kernel/src/net/socket/ip/stream/connected.rs | 34 ++++-- kernel/src/net/socket/ip/stream/mod.rs | 17 ++- test/apps/network/tcp_err.c | 106 ++++++++++++++++++ test/apps/network/tcp_poll.c | 27 +++-- 6 files changed, 256 insertions(+), 31 deletions(-) diff --git a/kernel/libs/aster-bigtcp/src/errors.rs b/kernel/libs/aster-bigtcp/src/errors.rs index a2089033..c3aeb74d 100644 --- a/kernel/libs/aster-bigtcp/src/errors.rs +++ b/kernel/libs/aster-bigtcp/src/errors.rs @@ -10,8 +10,6 @@ pub enum BindError { } pub mod tcp { - pub use smoltcp::socket::tcp::{RecvError, SendError}; - /// An error returned by [`TcpListener::new_listen`]. /// /// [`TcpListener::new_listen`]: crate::socket::TcpListener::new_listen @@ -51,6 +49,44 @@ pub mod tcp { } } } + + /// An error returned by [`TcpConnection::send`]. + /// + /// [`TcpConnection::send`]: crate::socket::TcpConnection::send + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum SendError { + InvalidState, + /// The connection is reset. + ConnReset, + } + + impl From for SendError { + fn from(value: smoltcp::socket::tcp::SendError) -> Self { + match value { + smoltcp::socket::tcp::SendError::InvalidState => Self::InvalidState, + } + } + } + + /// An error returned by [`TcpConnection::recv`]. + /// + /// [`TcpConnection::recv`]: crate::socket::TcpConnection::recv + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum RecvError { + InvalidState, + Finished, + /// The connection is reset. + ConnReset, + } + + impl From for RecvError { + fn from(value: smoltcp::socket::tcp::RecvError) -> Self { + match value { + smoltcp::socket::tcp::RecvError::InvalidState => Self::InvalidState, + smoltcp::socket::tcp::RecvError::Finished => Self::Finished, + } + } + } } pub mod udp { diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs index 33db931a..aedf2700 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs @@ -17,7 +17,7 @@ use super::{ }; use crate::{ define_boolean_value, - errors::tcp::ConnectError, + errors::tcp::{ConnectError, RecvError, SendError}, ext::Ext, iface::BoundPort, socket::{ @@ -42,6 +42,8 @@ pub struct RawTcpSocketExt { has_connected: bool, /// Indicates if the receiving side of this socket is shut down by the user. is_recv_shut: bool, + /// Indicates if the socket is closed by a RST packet. + is_rst_closed: bool, } impl Deref for RawTcpSocketExt { @@ -92,6 +94,14 @@ impl RawTcpSocketExt { pub fn is_recv_shut(&self) -> bool { self.is_recv_shut } + + /// Checks if the socket is closed by a RST packet. + /// + /// Note that the flag is automatically cleared when it is read by + /// [`TcpConnection::clear_rst_closed`], [`TcpConnection::send`], or [`TcpConnection::recv`]. + pub fn is_rst_closed(&self) -> bool { + self.is_rst_closed + } } define_boolean_value!( @@ -106,6 +116,7 @@ impl RawTcpSocketExt { this: &Arc>, old_state: State, old_recv_queue: usize, + is_rst: bool, ) -> (SocketEvents, TcpConnBecameDead) { let became_dead = if self.state() != State::Established { // After the connection is closed by the user, no new data can be read, and such unread @@ -117,6 +128,10 @@ impl RawTcpSocketExt { && matches!(old_state, State::FinWait1 | State::FinWait2) && self.recv_queue() > old_recv_queue { + // Strictly speaking, the socket isn't closed by an incoming RST packet in this + // situation. Instead, we reset the connection and _send_ an outgoing RST packet. + // However, Linux reports `ECONNRESET`, so we have to follow Linux. + self.is_rst_closed = true; self.abort(); } self.check_dead(this) @@ -125,6 +140,9 @@ impl RawTcpSocketExt { }; let events = if self.state() != old_state { + if self.state() == State::Closed && is_rst { + self.is_rst_closed = true; + } self.on_new_state(this) } else { SocketEvents::empty() @@ -211,6 +229,7 @@ impl TcpConnectionInner { listener, has_connected: false, is_recv_shut: false, + is_rst_closed: false, }; TcpConnectionInner { @@ -344,7 +363,7 @@ impl TcpConnection { /// Sends some data. /// /// Polling the iface _may_ be required after this method succeeds. - pub fn send(&self, f: F) -> Result<(R, NeedIfacePoll), smoltcp::socket::tcp::SendError> + pub fn send(&self, f: F) -> Result<(R, NeedIfacePoll), SendError> where F: FnOnce(&mut [u8]) -> (usize, R), { @@ -353,7 +372,12 @@ impl TcpConnection { let mut socket = self.0.inner.lock(); + if socket.is_rst_closed { + socket.is_rst_closed = false; + return Err(SendError::ConnReset); + } let result = socket.send(f)?; + let need_poll = self .0 .update_next_poll_at_ms(socket.poll_at(iface.context())); @@ -364,7 +388,7 @@ impl TcpConnection { /// Receives some data. /// /// Polling the iface _may_ be required after this method succeeds. - pub fn recv(&self, f: F) -> Result<(R, NeedIfacePoll), smoltcp::socket::tcp::RecvError> + pub fn recv(&self, f: F) -> Result<(R, NeedIfacePoll), RecvError> where F: FnOnce(&mut [u8]) -> (usize, R), { @@ -374,10 +398,16 @@ impl TcpConnection { let mut socket = self.0.inner.lock(); if socket.is_recv_shut && socket.recv_queue() == 0 { - return Err(smoltcp::socket::tcp::RecvError::Finished); + return Err(RecvError::Finished); } + let result = match socket.recv(f) { + Err(_) if socket.is_rst_closed => { + socket.is_rst_closed = false; + return Err(RecvError::ConnReset); + } + res => res, + }?; - let result = socket.recv(f)?; let need_poll = self .0 .update_next_poll_at_ms(socket.poll_at(iface.context())); @@ -385,6 +415,18 @@ impl TcpConnection { Ok((result, need_poll)) } + /// Checks if the socket is closed by a RST packet and clears the flag. + /// + /// This flag is set when the socket is closed by a RST packet, and cleared when the connection + /// reset error is reported via one of the [`Self::send`], [`Self::recv`], or this method. + pub fn clear_rst_closed(&self) -> bool { + let mut socket = self.0.inner.lock(); + + let is_rst = socket.is_rst_closed; + socket.is_rst_closed = false; + is_rst + } + /// Shuts down the sending half of the connection. /// /// This method will return `false` if the socket is in the CLOSED or TIME_WAIT state. @@ -534,6 +576,7 @@ impl TcpConnectionBg { let old_state = socket.state(); let old_recv_queue = socket.recv_queue(); + let is_rst = tcp_repr.control == TcpControl::Rst; // For TCP, receiving an ACK packet can free up space in the queue, allowing more packets // to be queued. let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; @@ -543,7 +586,8 @@ impl TcpConnectionBg { Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; - let (state_events, became_dead) = socket.check_state(self, old_state, old_recv_queue); + let (state_events, became_dead) = + socket.check_state(self, old_state, old_recv_queue, is_rst); events |= state_events; self.add_events(events); @@ -565,6 +609,7 @@ impl TcpConnectionBg { let old_state = socket.state(); let old_recv_queue = socket.recv_queue(); + let mut is_rst = false; let mut events = SocketEvents::empty(); let mut reply = None; @@ -581,11 +626,13 @@ impl TcpConnectionBg { if !socket.accepts(cx, ip_repr, tcp_repr) { break; } - reply = socket.process(cx, ip_repr, tcp_repr); + is_rst |= tcp_repr.control == TcpControl::Rst; events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; + reply = socket.process(cx, ip_repr, tcp_repr); } - let (state_events, became_dead) = socket.check_state(self, old_state, old_recv_queue); + let (state_events, became_dead) = + socket.check_state(self, old_state, old_recv_queue, is_rst); events |= state_events; self.add_events(events); diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index 08ea8235..781e13c6 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -94,8 +94,12 @@ impl ConnectedStream { debug_assert!(!*need_poll); Err(e) } - Err(RecvError::Finished) => Ok((0, NeedIfacePoll::FALSE)), - Err(RecvError::InvalidState) => { + Err(RecvError::Finished) | Err(RecvError::InvalidState) => { + // `InvalidState` occurs when the connection is reset but `ECONNRESET` was reported + // earlier. Linux returns EOF in this case, so we follow it. + Ok((0, NeedIfacePoll::FALSE)) + } + Err(RecvError::ConnReset) => { return_errno_with_message!(Errno::ECONNRESET, "the connection is reset") } } @@ -106,10 +110,6 @@ impl ConnectedStream { reader: &mut dyn MultiRead, _flags: SendRecvFlags, ) -> Result<(usize, NeedIfacePoll)> { - if reader.is_empty() { - return Ok((0, NeedIfacePoll::FALSE)); - } - let result = self.tcp_conn.send(|socket_buffer| { match reader.read(&mut VmWriter::from(socket_buffer)) { Ok(len) => (len, Ok(len)), @@ -128,9 +128,9 @@ impl ConnectedStream { Err(e) } Err(SendError::InvalidState) => { - // FIXME: `EPIPE` is another possibility, which means that the socket is shut down - // for writing. In that case, we should also trigger a `SIGPIPE` if `MSG_NOSIGNAL` - // is not specified. + return_errno_with_message!(Errno::EPIPE, "the connection is closed"); + } + Err(SendError::ConnReset) => { return_errno_with_message!(Errno::ECONNRESET, "the connection is reset"); } } @@ -187,10 +187,26 @@ impl ConnectedStream { events |= IoEvents::HUP; } + // If the connection is reset, add an ERR event. + if socket.is_rst_closed() { + events |= IoEvents::ERR; + } + events }) } + pub(super) fn test_and_clear_error(&self) -> Option { + if self.tcp_conn.clear_rst_closed() { + Some(Error::with_message( + Errno::ECONNRESET, + "the connection is reset", + )) + } else { + None + } + } + pub(super) fn set_raw_option( &self, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R, diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 27645fba..8a4c9842 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -331,7 +331,10 @@ impl StreamSocket { } }; - let (recv_bytes, need_poll) = connected_stream.try_recv(writer, flags)?; + let result = connected_stream.try_recv(writer, flags); + self.pollee.invalidate(); + + let (recv_bytes, need_poll) = result?; let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); let remote_endpoint = connected_stream.remote_endpoint(); @@ -355,7 +358,6 @@ impl StreamSocket { return result; } State::Listen(_) => { - // TODO: Trigger `SIGPIPE` if `MSG_NOSIGNAL` is not specified return_errno_with_message!(Errno::EPIPE, "the socket is not connected"); } State::Connecting(_) => { @@ -365,11 +367,13 @@ impl StreamSocket { } }; - let (sent_bytes, need_poll) = connected_stream.try_send(reader, flags)?; + let result = connected_stream.try_send(reader, flags); + self.pollee.invalidate(); + + let (sent_bytes, need_poll) = result?; let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); drop(state); - self.pollee.invalidate(); if let Some(iface) = iface_to_poll { iface.poll(); } @@ -393,7 +397,8 @@ impl StreamSocket { let error = match state.as_ref() { State::Init(init_stream) => init_stream.test_and_clear_error(), - State::Connecting(_) | State::Listen(_) | State::Connected(_) => None, + State::Connected(connected_stream) => connected_stream.test_and_clear_error(), + State::Connecting(_) | State::Listen(_) => None, }; self.pollee.invalidate(); error @@ -553,6 +558,8 @@ impl Socket for StreamSocket { } self.block_on(IoEvents::OUT, || self.try_send(reader, flags)) + + // TODO: Trigger `SIGPIPE` if the error code is `EPIPE` and `MSG_NOSIGNAL` is not specified } fn recvmsg( diff --git a/test/apps/network/tcp_err.c b/test/apps/network/tcp_err.c index 2e727631..772726c6 100644 --- a/test/apps/network/tcp_err.c +++ b/test/apps/network/tcp_err.c @@ -616,6 +616,7 @@ END_TEST() sk_addr.sin_port = S_PORT; \ \ sk_connect = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); \ + pfd.fd = sk_connect; \ TEST_SUCC(connect(sk_connect, (struct sockaddr *)&sk_addr, \ sizeof(sk_addr))); \ \ @@ -628,6 +629,7 @@ FN_TEST(shutdown_shutdown) int sk_accept; int sk_connect; socklen_t len; + struct pollfd pfd __attribute__((unused)); SETUP_CONN; @@ -647,4 +649,108 @@ FN_TEST(shutdown_shutdown) } END_TEST() +FN_TEST(connreset) +{ + int sk_accept; + int sk_connect; + struct linger lin = { .l_onoff = 1, .l_linger = 0 }; + struct pollfd pfd = { .events = POLLIN | POLLOUT }; + char buf[6] = "hello"; + int err; + socklen_t len; + +#define RESET_CONN \ + TEST_SUCC(setsockopt(sk_accept, SOL_SOCKET, SO_LINGER, &lin, \ + sizeof(lin))); \ + TEST_SUCC(close(sk_accept)); + +#define EV_ERR (POLLIN | POLLOUT | POLLHUP | POLLERR) +#define EV_NO_ERR (POLLIN | POLLOUT | POLLHUP) + + // Test 1: `recv` should fail with `ECONNRESET` + + SETUP_CONN; + RESET_CONN; + + TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_ERR); + TEST_ERRNO(recv(sk_connect, buf, 0, 0), ECONNRESET); + TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_NO_ERR); + + TEST_RES(recv(sk_connect, buf, 0, 0), _ret == 0); + TEST_SUCC(close(sk_connect)); + + // Test 2: `send` should fail with `ECONNRESET` + + SETUP_CONN; + RESET_CONN; + + TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_ERR); + TEST_ERRNO(send(sk_connect, buf, 0, 0), ECONNRESET); + TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_NO_ERR); + + TEST_ERRNO(send(sk_connect, buf, 0, 0), EPIPE); + TEST_SUCC(close(sk_connect)); + + // Test 3: `recv` should drain the buffer, then fail with `ECONNRESET` + + SETUP_CONN; + TEST_RES(send(sk_accept, buf, sizeof(buf), 0), _ret == sizeof(buf)); + RESET_CONN; + + TEST_RES(recv(sk_connect, buf, 4, 0), + _ret == 4 && memcmp(buf, "hell", 4) == 0); + TEST_RES(recv(sk_connect, buf, sizeof(buf), 0), + _ret == 2 && memcmp(buf, "o", 2) == 0); + TEST_ERRNO(recv(sk_connect, buf, sizeof(buf), 0), ECONNRESET); + + TEST_RES(recv(sk_connect, buf, 0, 0), _ret == 0); + TEST_SUCC(close(sk_connect)); + + // Test 3: `getsockopt(SO_ERROR)` should report `ECONNRESET` + + SETUP_CONN; + RESET_CONN; + + TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_ERR); + len = sizeof(err); + TEST_RES(getsockopt(sk_connect, SOL_SOCKET, SO_ERROR, &err, &len), + len == sizeof(err) && err == ECONNRESET); + TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_NO_ERR); + + TEST_RES(getsockopt(sk_connect, SOL_SOCKET, SO_ERROR, &err, &len), + len == sizeof(err) && err == 0); + TEST_SUCC(close(sk_connect)); + +#undef EV_ERR +#undef EV_NO_ERR + +#undef RESET_CONN +} +END_TEST() + #undef SETUP_CONN + +FN_TEST(listen_close) +{ + int sk_listen; + int sk_connect; + + sk_addr.sin_port = htons(0x4321); + + sk_listen = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + TEST_SUCC( + bind(sk_listen, (struct sockaddr *)&sk_addr, sizeof(sk_addr))); + TEST_SUCC(listen(sk_listen, 10)); + + sk_connect = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); + TEST_SUCC(connect(sk_connect, (struct sockaddr *)&sk_addr, + sizeof(sk_addr))); + + // Test: `close(sk_listen)` will reset all connections in the backlog + TEST_SUCC(close(sk_listen)); + TEST_ERRNO(send(sk_connect, &sk_connect, sizeof(sk_connect), 0), + ECONNRESET); + + TEST_SUCC(close(sk_connect)); +} +END_TEST() diff --git a/test/apps/network/tcp_poll.c b/test/apps/network/tcp_poll.c index 7359bdf8..b5272011 100644 --- a/test/apps/network/tcp_poll.c +++ b/test/apps/network/tcp_poll.c @@ -230,14 +230,27 @@ FN_TEST(poll_shutdown_readwrite) CHECK(write(sk_connect, buf, 4096)); - // TODO: The following test cannot be passed on Asterinas due to the following reasons: - // 1. On Linux, an RST packet is generated when attempting to write to a closed socket. - // However, Asterinas currently does not generate this packet. - // 2. RST packets cause a POLLERR on Linux, but Asterinas currently lack support for this. + // 1. An RST packet is generated when attempting to write to a closed socket. + // 2. The RST packet will cause a POLLERR. + pfd.fd = sk_connect; + TEST_RES(poll(&pfd, 1, 0), + pfd.revents == + (POLLIN | POLLOUT | POLLRDHUP | POLLHUP | POLLERR)); + pfd.fd = sk_accept; + TEST_RES(poll(&pfd, 1, 0), + pfd.revents == + (POLLIN | POLLOUT | POLLRDHUP | POLLHUP | POLLERR)); - // TEST_RES(poll(&pfd, 1, 0), - // pfd.revents == - // (POLLIN | POLLOUT | POLLRDHUP | POLLHUP | POLLERR)); + int err = 0; + socklen_t errlen = sizeof(err); + // FIXME: This socket error should be `EPIPE`, but in Asterinas it is + // `ECONNRESET`. See the Linux implementation for details: + // . + // + // TEST_RES(getsockopt(sk_connect, SOL_SOCKET, SO_ERROR, &err, &errlen), + // errlen == sizeof(err) && err == EPIPE); + TEST_RES(getsockopt(sk_accept, SOL_SOCKET, SO_ERROR, &err, &errlen), + errlen == sizeof(err) && err == ECONNRESET); } END_TEST()