Receive RST packets as ECONNRESET errors

This commit is contained in:
Ruihan Li
2025-02-24 23:07:58 +08:00
committed by Tate, Hongliang Tian
parent aa29640ed7
commit d40d452e9d
6 changed files with 256 additions and 31 deletions

View File

@ -10,8 +10,6 @@ pub enum BindError {
} }
pub mod tcp { pub mod tcp {
pub use smoltcp::socket::tcp::{RecvError, SendError};
/// An error returned by [`TcpListener::new_listen`]. /// An error returned by [`TcpListener::new_listen`].
/// ///
/// [`TcpListener::new_listen`]: crate::socket::TcpListener::new_listen /// [`TcpListener::new_listen`]: crate::socket::TcpListener::new_listen
@ -51,6 +49,44 @@ pub mod tcp {
} }
} }
} }
/// An error returned by [`TcpConnection::send`].
///
/// [`TcpConnection::send`]: crate::socket::TcpConnection::send
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum SendError {
InvalidState,
/// The connection is reset.
ConnReset,
}
impl From<smoltcp::socket::tcp::SendError> for SendError {
fn from(value: smoltcp::socket::tcp::SendError) -> Self {
match value {
smoltcp::socket::tcp::SendError::InvalidState => Self::InvalidState,
}
}
}
/// An error returned by [`TcpConnection::recv`].
///
/// [`TcpConnection::recv`]: crate::socket::TcpConnection::recv
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum RecvError {
InvalidState,
Finished,
/// The connection is reset.
ConnReset,
}
impl From<smoltcp::socket::tcp::RecvError> for RecvError {
fn from(value: smoltcp::socket::tcp::RecvError) -> Self {
match value {
smoltcp::socket::tcp::RecvError::InvalidState => Self::InvalidState,
smoltcp::socket::tcp::RecvError::Finished => Self::Finished,
}
}
}
} }
pub mod udp { pub mod udp {

View File

@ -17,7 +17,7 @@ use super::{
}; };
use crate::{ use crate::{
define_boolean_value, define_boolean_value,
errors::tcp::ConnectError, errors::tcp::{ConnectError, RecvError, SendError},
ext::Ext, ext::Ext,
iface::BoundPort, iface::BoundPort,
socket::{ socket::{
@ -42,6 +42,8 @@ pub struct RawTcpSocketExt<E: Ext> {
has_connected: bool, has_connected: bool,
/// Indicates if the receiving side of this socket is shut down by the user. /// Indicates if the receiving side of this socket is shut down by the user.
is_recv_shut: bool, is_recv_shut: bool,
/// Indicates if the socket is closed by a RST packet.
is_rst_closed: bool,
} }
impl<E: Ext> Deref for RawTcpSocketExt<E> { impl<E: Ext> Deref for RawTcpSocketExt<E> {
@ -92,6 +94,14 @@ impl<E: Ext> RawTcpSocketExt<E> {
pub fn is_recv_shut(&self) -> bool { pub fn is_recv_shut(&self) -> bool {
self.is_recv_shut self.is_recv_shut
} }
/// Checks if the socket is closed by a RST packet.
///
/// Note that the flag is automatically cleared when it is read by
/// [`TcpConnection::clear_rst_closed`], [`TcpConnection::send`], or [`TcpConnection::recv`].
pub fn is_rst_closed(&self) -> bool {
self.is_rst_closed
}
} }
define_boolean_value!( define_boolean_value!(
@ -106,6 +116,7 @@ impl<E: Ext> RawTcpSocketExt<E> {
this: &Arc<TcpConnectionBg<E>>, this: &Arc<TcpConnectionBg<E>>,
old_state: State, old_state: State,
old_recv_queue: usize, old_recv_queue: usize,
is_rst: bool,
) -> (SocketEvents, TcpConnBecameDead) { ) -> (SocketEvents, TcpConnBecameDead) {
let became_dead = if self.state() != State::Established { let became_dead = if self.state() != State::Established {
// After the connection is closed by the user, no new data can be read, and such unread // After the connection is closed by the user, no new data can be read, and such unread
@ -117,6 +128,10 @@ impl<E: Ext> RawTcpSocketExt<E> {
&& matches!(old_state, State::FinWait1 | State::FinWait2) && matches!(old_state, State::FinWait1 | State::FinWait2)
&& self.recv_queue() > old_recv_queue && self.recv_queue() > old_recv_queue
{ {
// Strictly speaking, the socket isn't closed by an incoming RST packet in this
// situation. Instead, we reset the connection and _send_ an outgoing RST packet.
// However, Linux reports `ECONNRESET`, so we have to follow Linux.
self.is_rst_closed = true;
self.abort(); self.abort();
} }
self.check_dead(this) self.check_dead(this)
@ -125,6 +140,9 @@ impl<E: Ext> RawTcpSocketExt<E> {
}; };
let events = if self.state() != old_state { let events = if self.state() != old_state {
if self.state() == State::Closed && is_rst {
self.is_rst_closed = true;
}
self.on_new_state(this) self.on_new_state(this)
} else { } else {
SocketEvents::empty() SocketEvents::empty()
@ -211,6 +229,7 @@ impl<E: Ext> TcpConnectionInner<E> {
listener, listener,
has_connected: false, has_connected: false,
is_recv_shut: false, is_recv_shut: false,
is_rst_closed: false,
}; };
TcpConnectionInner { TcpConnectionInner {
@ -344,7 +363,7 @@ impl<E: Ext> TcpConnection<E> {
/// Sends some data. /// Sends some data.
/// ///
/// Polling the iface _may_ be required after this method succeeds. /// Polling the iface _may_ be required after this method succeeds.
pub fn send<F, R>(&self, f: F) -> Result<(R, NeedIfacePoll), smoltcp::socket::tcp::SendError> pub fn send<F, R>(&self, f: F) -> Result<(R, NeedIfacePoll), SendError>
where where
F: FnOnce(&mut [u8]) -> (usize, R), F: FnOnce(&mut [u8]) -> (usize, R),
{ {
@ -353,7 +372,12 @@ impl<E: Ext> TcpConnection<E> {
let mut socket = self.0.inner.lock(); let mut socket = self.0.inner.lock();
if socket.is_rst_closed {
socket.is_rst_closed = false;
return Err(SendError::ConnReset);
}
let result = socket.send(f)?; let result = socket.send(f)?;
let need_poll = self let need_poll = self
.0 .0
.update_next_poll_at_ms(socket.poll_at(iface.context())); .update_next_poll_at_ms(socket.poll_at(iface.context()));
@ -364,7 +388,7 @@ impl<E: Ext> TcpConnection<E> {
/// Receives some data. /// Receives some data.
/// ///
/// Polling the iface _may_ be required after this method succeeds. /// Polling the iface _may_ be required after this method succeeds.
pub fn recv<F, R>(&self, f: F) -> Result<(R, NeedIfacePoll), smoltcp::socket::tcp::RecvError> pub fn recv<F, R>(&self, f: F) -> Result<(R, NeedIfacePoll), RecvError>
where where
F: FnOnce(&mut [u8]) -> (usize, R), F: FnOnce(&mut [u8]) -> (usize, R),
{ {
@ -374,10 +398,16 @@ impl<E: Ext> TcpConnection<E> {
let mut socket = self.0.inner.lock(); let mut socket = self.0.inner.lock();
if socket.is_recv_shut && socket.recv_queue() == 0 { if socket.is_recv_shut && socket.recv_queue() == 0 {
return Err(smoltcp::socket::tcp::RecvError::Finished); return Err(RecvError::Finished);
} }
let result = match socket.recv(f) {
Err(_) if socket.is_rst_closed => {
socket.is_rst_closed = false;
return Err(RecvError::ConnReset);
}
res => res,
}?;
let result = socket.recv(f)?;
let need_poll = self let need_poll = self
.0 .0
.update_next_poll_at_ms(socket.poll_at(iface.context())); .update_next_poll_at_ms(socket.poll_at(iface.context()));
@ -385,6 +415,18 @@ impl<E: Ext> TcpConnection<E> {
Ok((result, need_poll)) Ok((result, need_poll))
} }
/// Checks if the socket is closed by a RST packet and clears the flag.
///
/// This flag is set when the socket is closed by a RST packet, and cleared when the connection
/// reset error is reported via one of the [`Self::send`], [`Self::recv`], or this method.
pub fn clear_rst_closed(&self) -> bool {
let mut socket = self.0.inner.lock();
let is_rst = socket.is_rst_closed;
socket.is_rst_closed = false;
is_rst
}
/// Shuts down the sending half of the connection. /// Shuts down the sending half of the connection.
/// ///
/// This method will return `false` if the socket is in the CLOSED or TIME_WAIT state. /// This method will return `false` if the socket is in the CLOSED or TIME_WAIT state.
@ -534,6 +576,7 @@ impl<E: Ext> TcpConnectionBg<E> {
let old_state = socket.state(); let old_state = socket.state();
let old_recv_queue = socket.recv_queue(); let old_recv_queue = socket.recv_queue();
let is_rst = tcp_repr.control == TcpControl::Rst;
// For TCP, receiving an ACK packet can free up space in the queue, allowing more packets // For TCP, receiving an ACK packet can free up space in the queue, allowing more packets
// to be queued. // to be queued.
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
@ -543,7 +586,8 @@ impl<E: Ext> TcpConnectionBg<E> {
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
}; };
let (state_events, became_dead) = socket.check_state(self, old_state, old_recv_queue); let (state_events, became_dead) =
socket.check_state(self, old_state, old_recv_queue, is_rst);
events |= state_events; events |= state_events;
self.add_events(events); self.add_events(events);
@ -565,6 +609,7 @@ impl<E: Ext> TcpConnectionBg<E> {
let old_state = socket.state(); let old_state = socket.state();
let old_recv_queue = socket.recv_queue(); let old_recv_queue = socket.recv_queue();
let mut is_rst = false;
let mut events = SocketEvents::empty(); let mut events = SocketEvents::empty();
let mut reply = None; let mut reply = None;
@ -581,11 +626,13 @@ impl<E: Ext> TcpConnectionBg<E> {
if !socket.accepts(cx, ip_repr, tcp_repr) { if !socket.accepts(cx, ip_repr, tcp_repr) {
break; break;
} }
reply = socket.process(cx, ip_repr, tcp_repr); is_rst |= tcp_repr.control == TcpControl::Rst;
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
reply = socket.process(cx, ip_repr, tcp_repr);
} }
let (state_events, became_dead) = socket.check_state(self, old_state, old_recv_queue); let (state_events, became_dead) =
socket.check_state(self, old_state, old_recv_queue, is_rst);
events |= state_events; events |= state_events;
self.add_events(events); self.add_events(events);

View File

@ -94,8 +94,12 @@ impl ConnectedStream {
debug_assert!(!*need_poll); debug_assert!(!*need_poll);
Err(e) Err(e)
} }
Err(RecvError::Finished) => Ok((0, NeedIfacePoll::FALSE)), Err(RecvError::Finished) | Err(RecvError::InvalidState) => {
Err(RecvError::InvalidState) => { // `InvalidState` occurs when the connection is reset but `ECONNRESET` was reported
// earlier. Linux returns EOF in this case, so we follow it.
Ok((0, NeedIfacePoll::FALSE))
}
Err(RecvError::ConnReset) => {
return_errno_with_message!(Errno::ECONNRESET, "the connection is reset") return_errno_with_message!(Errno::ECONNRESET, "the connection is reset")
} }
} }
@ -106,10 +110,6 @@ impl ConnectedStream {
reader: &mut dyn MultiRead, reader: &mut dyn MultiRead,
_flags: SendRecvFlags, _flags: SendRecvFlags,
) -> Result<(usize, NeedIfacePoll)> { ) -> Result<(usize, NeedIfacePoll)> {
if reader.is_empty() {
return Ok((0, NeedIfacePoll::FALSE));
}
let result = self.tcp_conn.send(|socket_buffer| { let result = self.tcp_conn.send(|socket_buffer| {
match reader.read(&mut VmWriter::from(socket_buffer)) { match reader.read(&mut VmWriter::from(socket_buffer)) {
Ok(len) => (len, Ok(len)), Ok(len) => (len, Ok(len)),
@ -128,9 +128,9 @@ impl ConnectedStream {
Err(e) Err(e)
} }
Err(SendError::InvalidState) => { Err(SendError::InvalidState) => {
// FIXME: `EPIPE` is another possibility, which means that the socket is shut down return_errno_with_message!(Errno::EPIPE, "the connection is closed");
// for writing. In that case, we should also trigger a `SIGPIPE` if `MSG_NOSIGNAL` }
// is not specified. Err(SendError::ConnReset) => {
return_errno_with_message!(Errno::ECONNRESET, "the connection is reset"); return_errno_with_message!(Errno::ECONNRESET, "the connection is reset");
} }
} }
@ -187,10 +187,26 @@ impl ConnectedStream {
events |= IoEvents::HUP; events |= IoEvents::HUP;
} }
// If the connection is reset, add an ERR event.
if socket.is_rst_closed() {
events |= IoEvents::ERR;
}
events events
}) })
} }
pub(super) fn test_and_clear_error(&self) -> Option<Error> {
if self.tcp_conn.clear_rst_closed() {
Some(Error::with_message(
Errno::ECONNRESET,
"the connection is reset",
))
} else {
None
}
}
pub(super) fn set_raw_option<R>( pub(super) fn set_raw_option<R>(
&self, &self,
set_option: impl FnOnce(&dyn RawTcpSetOption) -> R, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R,

View File

@ -331,7 +331,10 @@ impl StreamSocket {
} }
}; };
let (recv_bytes, need_poll) = connected_stream.try_recv(writer, flags)?; let result = connected_stream.try_recv(writer, flags);
self.pollee.invalidate();
let (recv_bytes, need_poll) = result?;
let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); let iface_to_poll = need_poll.then(|| connected_stream.iface().clone());
let remote_endpoint = connected_stream.remote_endpoint(); let remote_endpoint = connected_stream.remote_endpoint();
@ -355,7 +358,6 @@ impl StreamSocket {
return result; return result;
} }
State::Listen(_) => { State::Listen(_) => {
// TODO: Trigger `SIGPIPE` if `MSG_NOSIGNAL` is not specified
return_errno_with_message!(Errno::EPIPE, "the socket is not connected"); return_errno_with_message!(Errno::EPIPE, "the socket is not connected");
} }
State::Connecting(_) => { State::Connecting(_) => {
@ -365,11 +367,13 @@ impl StreamSocket {
} }
}; };
let (sent_bytes, need_poll) = connected_stream.try_send(reader, flags)?; let result = connected_stream.try_send(reader, flags);
self.pollee.invalidate();
let (sent_bytes, need_poll) = result?;
let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); let iface_to_poll = need_poll.then(|| connected_stream.iface().clone());
drop(state); drop(state);
self.pollee.invalidate();
if let Some(iface) = iface_to_poll { if let Some(iface) = iface_to_poll {
iface.poll(); iface.poll();
} }
@ -393,7 +397,8 @@ impl StreamSocket {
let error = match state.as_ref() { let error = match state.as_ref() {
State::Init(init_stream) => init_stream.test_and_clear_error(), State::Init(init_stream) => init_stream.test_and_clear_error(),
State::Connecting(_) | State::Listen(_) | State::Connected(_) => None, State::Connected(connected_stream) => connected_stream.test_and_clear_error(),
State::Connecting(_) | State::Listen(_) => None,
}; };
self.pollee.invalidate(); self.pollee.invalidate();
error error
@ -553,6 +558,8 @@ impl Socket for StreamSocket {
} }
self.block_on(IoEvents::OUT, || self.try_send(reader, flags)) self.block_on(IoEvents::OUT, || self.try_send(reader, flags))
// TODO: Trigger `SIGPIPE` if the error code is `EPIPE` and `MSG_NOSIGNAL` is not specified
} }
fn recvmsg( fn recvmsg(

View File

@ -616,6 +616,7 @@ END_TEST()
sk_addr.sin_port = S_PORT; \ sk_addr.sin_port = S_PORT; \
\ \
sk_connect = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); \ sk_connect = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); \
pfd.fd = sk_connect; \
TEST_SUCC(connect(sk_connect, (struct sockaddr *)&sk_addr, \ TEST_SUCC(connect(sk_connect, (struct sockaddr *)&sk_addr, \
sizeof(sk_addr))); \ sizeof(sk_addr))); \
\ \
@ -628,6 +629,7 @@ FN_TEST(shutdown_shutdown)
int sk_accept; int sk_accept;
int sk_connect; int sk_connect;
socklen_t len; socklen_t len;
struct pollfd pfd __attribute__((unused));
SETUP_CONN; SETUP_CONN;
@ -647,4 +649,108 @@ FN_TEST(shutdown_shutdown)
} }
END_TEST() END_TEST()
FN_TEST(connreset)
{
int sk_accept;
int sk_connect;
struct linger lin = { .l_onoff = 1, .l_linger = 0 };
struct pollfd pfd = { .events = POLLIN | POLLOUT };
char buf[6] = "hello";
int err;
socklen_t len;
#define RESET_CONN \
TEST_SUCC(setsockopt(sk_accept, SOL_SOCKET, SO_LINGER, &lin, \
sizeof(lin))); \
TEST_SUCC(close(sk_accept));
#define EV_ERR (POLLIN | POLLOUT | POLLHUP | POLLERR)
#define EV_NO_ERR (POLLIN | POLLOUT | POLLHUP)
// Test 1: `recv` should fail with `ECONNRESET`
SETUP_CONN;
RESET_CONN;
TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_ERR);
TEST_ERRNO(recv(sk_connect, buf, 0, 0), ECONNRESET);
TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_NO_ERR);
TEST_RES(recv(sk_connect, buf, 0, 0), _ret == 0);
TEST_SUCC(close(sk_connect));
// Test 2: `send` should fail with `ECONNRESET`
SETUP_CONN;
RESET_CONN;
TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_ERR);
TEST_ERRNO(send(sk_connect, buf, 0, 0), ECONNRESET);
TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_NO_ERR);
TEST_ERRNO(send(sk_connect, buf, 0, 0), EPIPE);
TEST_SUCC(close(sk_connect));
// Test 3: `recv` should drain the buffer, then fail with `ECONNRESET`
SETUP_CONN;
TEST_RES(send(sk_accept, buf, sizeof(buf), 0), _ret == sizeof(buf));
RESET_CONN;
TEST_RES(recv(sk_connect, buf, 4, 0),
_ret == 4 && memcmp(buf, "hell", 4) == 0);
TEST_RES(recv(sk_connect, buf, sizeof(buf), 0),
_ret == 2 && memcmp(buf, "o", 2) == 0);
TEST_ERRNO(recv(sk_connect, buf, sizeof(buf), 0), ECONNRESET);
TEST_RES(recv(sk_connect, buf, 0, 0), _ret == 0);
TEST_SUCC(close(sk_connect));
// Test 3: `getsockopt(SO_ERROR)` should report `ECONNRESET`
SETUP_CONN;
RESET_CONN;
TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_ERR);
len = sizeof(err);
TEST_RES(getsockopt(sk_connect, SOL_SOCKET, SO_ERROR, &err, &len),
len == sizeof(err) && err == ECONNRESET);
TEST_RES(poll(&pfd, 1, 0), pfd.revents == EV_NO_ERR);
TEST_RES(getsockopt(sk_connect, SOL_SOCKET, SO_ERROR, &err, &len),
len == sizeof(err) && err == 0);
TEST_SUCC(close(sk_connect));
#undef EV_ERR
#undef EV_NO_ERR
#undef RESET_CONN
}
END_TEST()
#undef SETUP_CONN #undef SETUP_CONN
FN_TEST(listen_close)
{
int sk_listen;
int sk_connect;
sk_addr.sin_port = htons(0x4321);
sk_listen = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0));
TEST_SUCC(
bind(sk_listen, (struct sockaddr *)&sk_addr, sizeof(sk_addr)));
TEST_SUCC(listen(sk_listen, 10));
sk_connect = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0));
TEST_SUCC(connect(sk_connect, (struct sockaddr *)&sk_addr,
sizeof(sk_addr)));
// Test: `close(sk_listen)` will reset all connections in the backlog
TEST_SUCC(close(sk_listen));
TEST_ERRNO(send(sk_connect, &sk_connect, sizeof(sk_connect), 0),
ECONNRESET);
TEST_SUCC(close(sk_connect));
}
END_TEST()

View File

@ -230,14 +230,27 @@ FN_TEST(poll_shutdown_readwrite)
CHECK(write(sk_connect, buf, 4096)); CHECK(write(sk_connect, buf, 4096));
// TODO: The following test cannot be passed on Asterinas due to the following reasons: // 1. An RST packet is generated when attempting to write to a closed socket.
// 1. On Linux, an RST packet is generated when attempting to write to a closed socket. // 2. The RST packet will cause a POLLERR.
// However, Asterinas currently does not generate this packet. pfd.fd = sk_connect;
// 2. RST packets cause a POLLERR on Linux, but Asterinas currently lack support for this. TEST_RES(poll(&pfd, 1, 0),
pfd.revents ==
(POLLIN | POLLOUT | POLLRDHUP | POLLHUP | POLLERR));
pfd.fd = sk_accept;
TEST_RES(poll(&pfd, 1, 0),
pfd.revents ==
(POLLIN | POLLOUT | POLLRDHUP | POLLHUP | POLLERR));
// TEST_RES(poll(&pfd, 1, 0), int err = 0;
// pfd.revents == socklen_t errlen = sizeof(err);
// (POLLIN | POLLOUT | POLLRDHUP | POLLHUP | POLLERR)); // FIXME: This socket error should be `EPIPE`, but in Asterinas it is
// `ECONNRESET`. See the Linux implementation for details:
// <https://github.com/torvalds/linux/blob/848e076317446f9c663771ddec142d7c2eb4cb43/net/ipv4/tcp_input.c#L4553-L4555>.
//
// TEST_RES(getsockopt(sk_connect, SOL_SOCKET, SO_ERROR, &err, &errlen),
// errlen == sizeof(err) && err == EPIPE);
TEST_RES(getsockopt(sk_accept, SOL_SOCKET, SO_ERROR, &err, &errlen),
errlen == sizeof(err) && err == ECONNRESET);
} }
END_TEST() END_TEST()