diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index cba0e9a1..ca5de834 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::sync::Weak; +use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, - socket::{SocketEventObserver, TcpState}, + socket::{RawTcpSocket, SocketEventObserver, TcpState}, wire::IpEndpoint, }; @@ -33,6 +34,15 @@ pub struct ConnectedStream { /// connection is established asynchronously will succeed and any subsequent `connect()` will /// fail. is_new_connection: bool, + /// Indicates if the receiving side of this socket is closed. + /// + /// The receiving side may be closed if this side disables reading + /// or if the peer side closes its sending half. + is_receiving_closed: AtomicBool, + /// Indicates if the sending side of this socket is closed. + /// + /// The sending side can only be closed if this side disables writing. + is_sending_closed: AtomicBool, } impl ConnectedStream { @@ -45,12 +55,22 @@ impl ConnectedStream { bound_socket, remote_endpoint, is_new_connection, + is_receiving_closed: AtomicBool::new(false), + is_sending_closed: AtomicBool::new(false), } } - pub fn shutdown(&self, _cmd: SockShutdownCmd) -> Result<()> { - // TODO: deal with cmd - self.bound_socket.close(); + pub fn shutdown(&self, cmd: SockShutdownCmd, pollee: &Pollee) -> Result<()> { + if cmd.shut_read() { + self.is_receiving_closed.store(true, Ordering::Relaxed); + self.update_io_events(pollee); + } + + if cmd.shut_write() { + self.is_sending_closed.store(true, Ordering::Relaxed); + self.bound_socket.close(); + } + Ok(()) } @@ -63,6 +83,7 @@ impl ConnectedStream { }); match result { + Ok(Ok(0)) if self.is_receiving_closed.load(Ordering::Relaxed) => Ok(0), Ok(Ok(0)) => return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"), Ok(Ok(recv_bytes)) => Ok(recv_bytes), Ok(Err(e)) => Err(e), @@ -124,17 +145,40 @@ impl ConnectedStream { pub(super) fn update_io_events(&self, pollee: &Pollee) { self.bound_socket.raw_with(|socket| { - if socket.can_recv() { + if is_peer_closed(socket) { + // Only the sending side of peer socket is closed + self.is_receiving_closed.store(true, Ordering::Relaxed); + } else if is_closed(socket) { + // The sending side of both peer socket and this socket are closed + self.is_receiving_closed.store(true, Ordering::Relaxed); + self.is_sending_closed.store(true, Ordering::Relaxed); + } + + let is_receiving_closed = self.is_receiving_closed.load(Ordering::Relaxed); + let is_sending_closed = self.is_sending_closed.load(Ordering::Relaxed); + + // If the receiving side is closed, always add events IN and RDHUP; + // otherwise, check if the socket can receive. + if is_receiving_closed { + pollee.add_events(IoEvents::IN | IoEvents::RDHUP); + } else if socket.can_recv() { pollee.add_events(IoEvents::IN); } else { pollee.del_events(IoEvents::IN); } - if socket.can_send() { + // If the sending side is closed, always add an OUT event; + // otherwise, check if the socket can send. + if is_sending_closed || socket.can_send() { pollee.add_events(IoEvents::OUT); } else { pollee.del_events(IoEvents::OUT); } + + // If both sending and receiving sides are closed, add a HUP event. + if is_receiving_closed && is_sending_closed { + pollee.add_events(IoEvents::HUP); + } }); } @@ -158,3 +202,21 @@ impl ConnectedStream { }) } } + +/// Checks if the peer socket has closed its sending side. +/// +/// If the sending side of this socket is also closed, this method will return `false`. +/// In such cases, you should verify using [`is_closed`]. +fn is_peer_closed(socket: &RawTcpSocket) -> bool { + socket.state() == TcpState::CloseWait +} + +/// Checks if the socket is fully closed. +/// +/// This function returns `true` if both this socket and the peer have closed their sending sides. +/// +/// This TCP state corresponds to the `Normal Close Sequence` and `Simultaneous Close Sequence` +/// as outlined in RFC793 (https://datatracker.ietf.org/doc/html/rfc793#page-39). +fn is_closed(socket: &RawTcpSocket) -> bool { + !socket.is_open() || socket.state() == TcpState::Closing || socket.state() == TcpState::LastAck +} diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index 1650318f..d188099c 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -103,6 +103,7 @@ impl InitStream { pub(super) fn init_pollee(&self, pollee: &Pollee) { pollee.reset_events(); - pollee.add_events(IoEvents::OUT); + // Linux adds OUT and HUP events for a newly created socket + pollee.add_events(IoEvents::OUT | IoEvents::HUP); } } diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 3077710f..b0dc1434 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -511,11 +511,16 @@ impl Socket for StreamSocket { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { let state = self.read_updated_state(); - match state.as_ref() { - State::Connected(connected_stream) => connected_stream.shutdown(cmd), + let res = match state.as_ref() { + State::Connected(connected_stream) => connected_stream.shutdown(cmd, &self.pollee), // TODO: shutdown listening stream _ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"), - } + }; + + drop(state); + poll_ifaces(); + + res } fn addr(&self) -> Result {