From aa29640ed7690563bd0cd8f6ef27e83874b9e601 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Fri, 21 Feb 2025 11:47:04 +0800 Subject: [PATCH] Send RST packets when appropriate --- kernel/libs/aster-bigtcp/src/iface/common.rs | 7 +- .../aster-bigtcp/src/socket/bound/tcp_conn.rs | 200 +++++++++++++----- .../src/socket/bound/tcp_listen.rs | 48 +++-- kernel/src/net/socket/ip/stream/connected.rs | 62 ++++-- kernel/src/net/socket/ip/stream/connecting.rs | 4 + kernel/src/net/socket/ip/stream/listen.rs | 4 + kernel/src/net/socket/ip/stream/mod.rs | 23 +- kernel/src/process/signal/poll.rs | 7 +- test/apps/network/tcp_err.c | 37 ++++ 9 files changed, 293 insertions(+), 99 deletions(-) diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index e0f52538c..afe94033d 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -23,7 +23,7 @@ use super::{ use crate::{ errors::BindError, ext::Ext, - socket::{TcpConnectionBg, TcpListenerBg, UdpSocketBg}, + socket::{TcpListenerBg, UdpSocketBg}, socket_table::SocketTable, }; @@ -156,11 +156,6 @@ impl IfaceCommon { debug_assert!(removed.is_some()); } - pub(crate) fn remove_dead_tcp_connection(&self, socket: &Arc>) { - let mut sockets = self.sockets.lock(); - sockets.remove_dead_tcp_connection(socket.connection_key()); - } - pub(crate) fn remove_udp_socket(&self, socket: &Arc>) { let mut sockets = self.sockets.lock(); let removed = sockets.remove_udp_socket(socket); diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs index c18f190a9..33db931a3 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_conn.rs @@ -40,6 +40,8 @@ pub struct RawTcpSocketExt { socket: Box, pub(super) listener: Option>>, has_connected: bool, + /// Indicates if the receiving side of this socket is shut down by the user. + is_recv_shut: bool, } impl Deref for RawTcpSocketExt { @@ -73,6 +75,23 @@ impl RawTcpSocketExt { _ => false, } } + + /// Checks if the socket is closing. + /// + /// More specifically, we say a socket is closing if and only if it has sent its FIN packet but + /// is still waiting for an ACK packet from the peer to acknowledge the FIN it sent. + pub fn is_closing(&self) -> bool { + let state = self.state(); + matches!(state, State::FinWait1 | State::Closing | State::LastAck) + } + + /// Returns whether the receiving half of the socket is shut down. + /// + /// This method will return true if and only if [`TcpConnection::shut_recv`] or + /// [`TcpConnection::close`] is called. + pub fn is_recv_shut(&self) -> bool { + self.is_recv_shut + } } define_boolean_value!( @@ -81,10 +100,40 @@ define_boolean_value!( ); impl RawTcpSocketExt { - fn on_new_state( + /// Checks the TCP state for additional events and whether the connection is dead. + fn check_state( &mut self, this: &Arc>, + old_state: State, + old_recv_queue: usize, ) -> (SocketEvents, TcpConnBecameDead) { + let became_dead = if self.state() != State::Established { + // After the connection is closed by the user, no new data can be read, and such unread + // data will immediately cause the connection to be reset. + // Note that "closed" here means that either (1) `close()` or (2) both `shut_send()` + // and `shut_recv()` are called. In the latter case, there may be some buffered data. + if self.is_recv_shut + // These are states where the sending half is closed but new data can come in. + && matches!(old_state, State::FinWait1 | State::FinWait2) + && self.recv_queue() > old_recv_queue + { + self.abort(); + } + self.check_dead(this) + } else { + TcpConnBecameDead::FALSE + }; + + let events = if self.state() != old_state { + self.on_new_state(this) + } else { + SocketEvents::empty() + }; + + (events, became_dead) + } + + fn on_new_state(&mut self, this: &Arc>) -> SocketEvents { let may_send = self.may_send(); if may_send && !self.has_connected { @@ -99,8 +148,6 @@ impl RawTcpSocketExt { } } - let became_dead = self.check_dead(this); - let mut events = SocketEvents::empty(); if !self.may_recv_new() { events |= SocketEvents::CLOSED_RECV; @@ -109,7 +156,7 @@ impl RawTcpSocketExt { events |= SocketEvents::CLOSED_SEND; } - (events, became_dead) + events } /// Checks whether the TCP connection becomes dead. @@ -123,15 +170,20 @@ impl RawTcpSocketExt { /// dead if it is not closed. fn check_dead(&self, this: &Arc>) -> TcpConnBecameDead { // FIXME: This is a temporary workaround to mark TimeWait socket as dead. - if self.state() == smoltcp::socket::tcp::State::Closed - || self.state() == smoltcp::socket::tcp::State::TimeWait - { + if self.state() == State::TimeWait { + return TcpConnBecameDead::TRUE; + } + + // According to the current smoltcp implementation, a socket in the CLOSED state with the + // remote endpoint set means that an outgoing RST packet is pending. We cannot simply mark + // such a socket as dead. + if self.state() == State::Closed && self.remote_endpoint().is_none() { return TcpConnBecameDead::TRUE; } // According to the current smoltcp implementation, a backlog socket will return back to // the `Listen` state if the connection is RSTed before its establishment. - if self.state() == smoltcp::socket::tcp::State::Listen { + if self.state() == State::Listen { if let Some(ref listener) = self.listener { let mut backlog = listener.inner.backlog.lock(); // This may fail due to race conditions, but it's fine. @@ -158,6 +210,7 @@ impl TcpConnectionInner { socket, listener, has_connected: false, + is_recv_shut: false, }; TcpConnectionInner { @@ -175,25 +228,26 @@ impl Inner for TcpConnectionInner { type Observer = E::TcpEventObserver; fn on_drop(this: &Arc>) { - let became_dead = { - let mut socket = this.inner.lock(); - - // FIXME: Send RSTs when there is unread data. - socket.close(); - - if *socket.check_dead(this) { - true - } else { - // A TCP connection may not be appropriate for immediate removal. We leave the removal - // decision to the polling logic. - this.update_next_poll_at_ms(PollAt::Now); - false - } - }; - - if became_dead { - this.bound.iface().common().remove_dead_tcp_connection(this); - } + debug_assert!( + { + let socket = this.inner.lock(); + if socket.state() == State::Closed { + // (1) The socket is fully closed. + true + } else { + // (2) The receiving half is closed and the sending half is closing. + socket.is_recv_shut + && !matches!( + socket.state(), + State::SynSent + | State::SynReceived + | State::Established + | State::CloseWait, + ) + } + }, + "a connection must be either closed or reset before dropping" + ); } } @@ -319,6 +373,10 @@ impl TcpConnection { let mut socket = self.0.inner.lock(); + if socket.is_recv_shut && socket.recv_queue() == 0 { + return Err(smoltcp::socket::tcp::RecvError::Finished); + } + let result = socket.recv(f)?; let need_poll = self .0 @@ -327,17 +385,15 @@ impl TcpConnection { Ok((result, need_poll)) } - /// Closes the connection. + /// Shuts down the sending half of the connection. /// - /// This method returns `false` if the socket is closed _before_ calling this method. + /// This method will return `false` if the socket is in the CLOSED or TIME_WAIT state. /// /// Polling the iface is _always_ required after this method succeeds. - pub fn close(&self) -> bool { + pub fn shut_send(&self) -> bool { let mut socket = self.0.inner.lock(); - socket.listener = None; - - if socket.state() == State::Closed { + if matches!(socket.state(), State::Closed | State::TimeWait) { return false; } @@ -347,6 +403,56 @@ impl TcpConnection { true } + /// Shuts down the receiving half of the connection. + /// + /// This method will return `false` if the socket is in the CLOSED or TIME_WAIT state. + /// + /// Polling the iface is _not_ required after this method succeeds. + pub fn shut_recv(&self) -> bool { + let mut socket = self.0.inner.lock(); + + if matches!(socket.state(), State::Closed | State::TimeWait) { + return false; + } + + socket.is_recv_shut = true; + + true + } + + /// Closes the connection. + /// + /// Polling the iface is _always_ required after this method succeeds. + /// + /// Note that either this method or [`Self::reset`] must be called before dropping the TCP + /// connection to avoid resource leakage. + pub fn close(&self) { + let mut socket = self.0.inner.lock(); + + socket.is_recv_shut = true; + + if socket.recv_queue() != 0 { + // If there is unread data, reset the connection immediately. + socket.abort(); + } else { + socket.close(); + } + self.0.update_next_poll_at_ms(PollAt::Now); + } + + /// Resets the connection. + /// + /// Polling the iface is _always_ required after this method succeeds. + /// + /// Note that either this method or [`Self::close`] must be called before dropping the TCP + /// connection to avoid resource leakage. + pub fn reset(&self) { + let mut socket = self.0.inner.lock(); + + socket.abort(); + self.0.update_next_poll_at_ms(PollAt::Now); + } + /// Calls `f` with an immutable reference to the associated [`RawTcpSocket`]. // // NOTE: If a mutable reference is required, add a method above that correctly updates the next @@ -427,6 +533,7 @@ impl TcpConnectionBg { } let old_state = socket.state(); + let old_recv_queue = socket.recv_queue(); // For TCP, receiving an ACK packet can free up space in the queue, allowing more packets // to be queued. let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; @@ -436,13 +543,8 @@ impl TcpConnectionBg { Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; - let became_dead = if socket.state() != old_state { - let (new_events, became_dead) = socket.on_new_state(self); - events |= new_events; - became_dead - } else { - TcpConnBecameDead::FALSE - }; + let (state_events, became_dead) = socket.check_state(self, old_state, old_recv_queue); + events |= state_events; self.add_events(events); self.update_next_poll_at_ms(socket.poll_at(cx)); @@ -452,16 +554,17 @@ impl TcpConnectionBg { /// Tries to generate an outgoing packet and dispatches the generated packet. pub(crate) fn dispatch( - this: &Arc, + self: &Arc, cx: &mut Context, dispatch: D, ) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead) where D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, { - let mut socket = this.inner.lock(); + let mut socket = self.inner.lock(); let old_state = socket.state(); + let old_recv_queue = socket.recv_queue(); let mut events = SocketEvents::empty(); let mut reply = None; @@ -482,16 +585,11 @@ impl TcpConnectionBg { events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; } - let became_dead = if socket.state() != old_state { - let (new_events, became_dead) = socket.on_new_state(this); - events |= new_events; - became_dead - } else { - TcpConnBecameDead::FALSE - }; + let (state_events, became_dead) = socket.check_state(self, old_state, old_recv_queue); + events |= state_events; - this.add_events(events); - this.update_next_poll_at_ms(socket.poll_at(cx)); + self.add_events(events); + self.update_next_poll_at_ms(socket.poll_at(cx)); (reply, became_dead) } diff --git a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs index 195052ee9..e766d24ec 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound/tcp_listen.rs @@ -53,23 +53,11 @@ impl Inner for TcpListenerInner { type Observer = E::TcpEventObserver; fn on_drop(this: &Arc>) { - // A TCP listener can be removed immediately. - this.bound.iface().common().remove_tcp_listener(this); - - let (connecting, connected) = { - let mut socket = this.inner.backlog.lock(); - ( - core::mem::take(&mut socket.connecting), - core::mem::take(&mut socket.connected), - ) - }; - - // The lock on `connecting`/`connected` cannot be locked after locking `self`, otherwise we - // might get a deadlock. due to inconsistent lock order problems. - // - // FIXME: Send RSTs instead of going through the normal socket close process. - drop(connecting); - drop(connected); + debug_assert_eq!( + Arc::strong_count(this), + 1, + "a listener must be closed before dropping" + ); } } @@ -140,7 +128,7 @@ impl TcpListener { let remote_endpoint = { // The lock on `accepted` cannot be locked after locking `self`, otherwise we might get - // a deadlock. due to inconsistent lock order problems. + // a deadlock due to inconsistent lock order problems. let mut socket = accepted.0.inner.lock(); socket.listener = None; @@ -156,6 +144,30 @@ impl TcpListener { pub fn can_accept(&self) -> bool { !self.0.inner.backlog.lock().connected.is_empty() } + + /// Closes the listener. + /// + /// Polling the iface is _always_ required after this method succeeds. + /// + /// Note that this method must be called before dropping the TCP listener to avoid resource + /// leakage. + pub fn close(&self) { + // A TCP listener can be removed immediately. + self.0.bound.iface().common().remove_tcp_listener(&self.0); + + let (connecting, connected) = { + let mut socket = self.0.inner.backlog.lock(); + ( + core::mem::take(&mut socket.connecting), + core::mem::take(&mut socket.connected), + ) + }; + + // The lock on `connecting`/`connected` cannot be locked after locking `self`, otherwise we + // might get a deadlock. due to inconsistent lock order problems. + connecting.values().for_each(|socket| socket.reset()); + connected.iter().for_each(|socket| socket.reset()); + } } impl RawTcpSetOption for TcpListener { diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index 9cf9feb4d..08ea8235d 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, Ordering}; - use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, socket::{NeedIfacePoll, RawTcpSetOption}, @@ -13,10 +11,13 @@ use crate::{ events::IoEvents, net::{ iface::{Iface, RawTcpSocketExt, TcpConnection}, - socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, + socket::{ + util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, + LingerOption, + }, }, prelude::*, - process::signal::Pollee, + process::signal::{Pollee, Poller}, util::{MultiRead, MultiWrite}, }; @@ -34,8 +35,6 @@ pub struct ConnectedStream { /// connection is established asynchronously will succeed and any subsequent `connect()` will /// fail. is_new_connection: bool, - /// Indicates if the receiving side of this socket is shut down by the user. - is_receiving_shut: AtomicBool, } impl ConnectedStream { @@ -48,7 +47,6 @@ impl ConnectedStream { tcp_conn, remote_endpoint, is_new_connection, - is_receiving_shut: AtomicBool::new(false), } } @@ -56,12 +54,14 @@ impl ConnectedStream { let mut events = IoEvents::empty(); if cmd.shut_read() { - self.is_receiving_shut.store(true, Ordering::Relaxed); + if !self.tcp_conn.shut_recv() { + return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"); + } events |= IoEvents::IN | IoEvents::RDHUP; } if cmd.shut_write() { - if !self.tcp_conn.close() { + if !self.tcp_conn.shut_send() { return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"); } events |= IoEvents::OUT | IoEvents::HUP; @@ -85,9 +85,6 @@ impl ConnectedStream { }); match result { - Ok((Ok(0), need_poll)) if self.is_receiving_shut.load(Ordering::Relaxed) => { - Ok((0, need_poll)) - } Ok((Ok(0), need_poll)) => { debug_assert!(!*need_poll); return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty") @@ -166,8 +163,7 @@ impl ConnectedStream { pub(super) fn check_io_events(&self) -> IoEvents { self.tcp_conn.raw_with(|socket| { - let is_receiving_closed = - self.is_receiving_shut.load(Ordering::Relaxed) || !socket.may_recv_new(); + let is_receiving_closed = socket.is_recv_shut() || !socket.may_recv_new(); let is_sending_closed = !socket.may_send(); let mut events = IoEvents::empty(); @@ -205,4 +201,42 @@ impl ConnectedStream { pub(super) fn raw_with(&self, f: impl FnOnce(&RawTcpSocketExt) -> R) -> R { self.tcp_conn.raw_with(f) } + + pub(super) fn into_connection(self) -> TcpConnection { + self.tcp_conn + } +} + +pub(super) fn close_and_linger(tcp_conn: TcpConnection, linger: LingerOption, pollee: &Pollee) { + let timeout = match (linger.is_on(), linger.timeout()) { + // No linger. Drain the send buffer in the background. + (false, _) => { + tcp_conn.close(); + tcp_conn.iface().poll(); + return; + } + // Linger with a zero timeout. Reset the connection immediately. + (true, duration) if duration.is_zero() => { + tcp_conn.reset(); + tcp_conn.iface().poll(); + return; + } + // Linger with a non-zero timeout. See below. + (true, duration) => { + tcp_conn.close(); + tcp_conn.iface().poll(); + duration + } + }; + + let mut poller = Poller::new(Some(&timeout)); + pollee.register_poller(poller.as_handle_mut(), IoEvents::HUP); + + // Now wait for the ACK packet to acknowledge the FIN packet we sent. If the timeout expires or + // we are interrupted by signals, the remaining task is done in the background. + while tcp_conn.raw_with(|socket| socket.is_closing()) { + if poller.wait().is_err() { + break; + } + } } diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index cf5477a38..ce7e0a373 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -110,4 +110,8 @@ impl ConnectingStream { ) -> R { set_option(&self.tcp_conn) } + + pub(super) fn into_connection(self) -> TcpConnection { + self.tcp_conn + } } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 686ee3027..c2a01a9c6 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -73,4 +73,8 @@ impl ListenStream { ) -> R { set_option(&self.tcp_listener) } + + pub(super) fn into_listener(self) -> TcpListener { + self.tcp_listener + } } diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 64180273e..27645fba9 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -6,7 +6,7 @@ use aster_bigtcp::{ socket::{NeedIfacePoll, RawTcpOption, RawTcpSetOption}, wire::IpEndpoint, }; -use connected::ConnectedStream; +use connected::{close_and_linger, ConnectedStream}; use connecting::{ConnResult, ConnectingStream}; use init::InitStream; use listen::ListenStream; @@ -825,14 +825,19 @@ impl Drop for StreamSocket { fn drop(&mut self) { let state = self.state.get_mut().take(); - let iface_to_poll = state.iface().cloned(); + let conn = match state { + State::Init(_) => return, + State::Connecting(connecting_stream) => connecting_stream.into_connection(), + State::Connected(connected_stream) => connected_stream.into_connection(), + State::Listen(listen_stream) => { + let listener = listen_stream.into_listener(); + listener.close(); + listener.iface().poll(); + return; + } + }; - // Dropping the state will drop the sockets. This will trigger the socket close process (if - // needed) and require immediate iface polling afterwards. - drop(state); - - if let Some(iface) = iface_to_poll { - iface.poll(); - } + let linger = self.options.get_mut().socket.linger(); + close_and_linger(conn, linger, &self.pollee); } } diff --git a/kernel/src/process/signal/poll.rs b/kernel/src/process/signal/poll.rs index 9d7c116f3..853fb2ab1 100644 --- a/kernel/src/process/signal/poll.rs +++ b/kernel/src/process/signal/poll.rs @@ -151,7 +151,12 @@ impl Pollee { new_events & mask } - fn register_poller(&self, poller: &mut PollHandle, mask: IoEvents) { + /// Registers a poller to listen notification for new events. + /// + /// The functionality of this method is a subset of calling [`Self::poll_with`] and providing + /// the same poller. Unlike [`Self::poll_with`], this method performs poller registration + /// without checking (and perhaps caching) the current events. + pub fn register_poller(&self, poller: &mut PollHandle, mask: IoEvents) { self.inner .subject .register_observer(poller.observer.clone(), mask); diff --git a/test/apps/network/tcp_err.c b/test/apps/network/tcp_err.c index f060b0c89..2e7276315 100644 --- a/test/apps/network/tcp_err.c +++ b/test/apps/network/tcp_err.c @@ -611,3 +611,40 @@ FN_TEST(bind_and_connect_same_address) TEST_SUCC(close(sk_connect2)); } END_TEST() + +#define SETUP_CONN \ + sk_addr.sin_port = S_PORT; \ + \ + sk_connect = TEST_SUCC(socket(PF_INET, SOCK_STREAM, 0)); \ + TEST_SUCC(connect(sk_connect, (struct sockaddr *)&sk_addr, \ + sizeof(sk_addr))); \ + \ + len = sizeof(sk_addr); \ + sk_accept = TEST_SUCC( \ + accept(sk_listen, (struct sockaddr *)&sk_addr, &len)); + +FN_TEST(shutdown_shutdown) +{ + int sk_accept; + int sk_connect; + socklen_t len; + + SETUP_CONN; + + // Test 1: Perform `shutdown` multiple times + TEST_SUCC(shutdown(sk_accept, SHUT_RDWR)); + TEST_SUCC(shutdown(sk_accept, SHUT_RDWR)); + + // Test 2: Perform `shutdown` after the connection is closed + TEST_SUCC(shutdown(sk_connect, SHUT_RDWR)); + TEST_ERRNO(shutdown(sk_connect, SHUT_RD), ENOTCONN); + TEST_ERRNO(shutdown(sk_connect, SHUT_WR), ENOTCONN); + TEST_ERRNO(shutdown(sk_accept, SHUT_RD), ENOTCONN); + TEST_ERRNO(shutdown(sk_accept, SHUT_WR), ENOTCONN); + + TEST_SUCC(close(sk_accept)); + TEST_SUCC(close(sk_connect)); +} +END_TEST() + +#undef SETUP_CONN