diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index b1989bc08..b27634af2 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -130,7 +130,7 @@ impl DatagramSocket { } // Slow path - let mut inner = self.inner.write(); + let mut inner = self.inner.write_irq_disabled(); inner.borrow_result(|owned_inner| { let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { Ok(bound_datagram) => bound_datagram, @@ -278,7 +278,7 @@ impl Socket for DatagramSocket { let endpoint = socket_addr.try_into()?; let can_reuse = self.options.read().socket.reuse_addr(); - let mut inner = self.inner.write(); + let mut inner = self.inner.write_irq_disabled(); inner.borrow_result(|owned_inner| { let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) { Ok(bound_datagram) => bound_datagram, @@ -296,7 +296,7 @@ impl Socket for DatagramSocket { self.try_bind_ephemeral(&endpoint)?; - let mut inner = self.inner.write(); + let mut inner = self.inner.write_irq_disabled(); let Inner::Bound(bound_datagram) = inner.as_mut() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound") }; diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index 8a3141208..d2ca8d7ff 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -1,14 +1,15 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::wire::IpEndpoint; +use ostd::sync::LocalIrqDisabled; use super::{connected::ConnectedStream, init::InitStream}; -use crate::{net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; +use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; pub struct ConnectingStream { bound_socket: BoundTcpSocket, remote_endpoint: IpEndpoint, - conn_result: RwLock>, + conn_result: SpinLock, LocalIrqDisabled>, } #[derive(Clone, Copy)] @@ -17,11 +18,6 @@ enum ConnResult { Refused, } -pub enum NonConnectedStream { - Init(InitStream), - Connecting(ConnectingStream), -} - impl ConnectingStream { pub fn new( bound_socket: BoundTcpSocket, @@ -45,12 +41,16 @@ impl ConnectingStream { Ok(Self { bound_socket, remote_endpoint, - conn_result: RwLock::new(None), + conn_result: SpinLock::new(None), }) } - pub fn into_result(self) -> core::result::Result { - let conn_result = *self.conn_result.read(); + pub fn has_result(&self) -> bool { + self.conn_result.lock().is_some() + } + + pub fn into_result(self) -> core::result::Result { + let conn_result = *self.conn_result.lock(); match conn_result { Some(ConnResult::Connected) => Ok(ConnectedStream::new( self.bound_socket, @@ -59,12 +59,9 @@ impl ConnectingStream { )), Some(ConnResult::Refused) => Err(( Error::with_message(Errno::ECONNREFUSED, "the connection is refused"), - NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)), - )), - None => Err(( - Error::with_message(Errno::EAGAIN, "the connection is pending"), - NonConnectedStream::Connecting(self), + InitStream::new_bound(self.bound_socket), )), + None => unreachable!("`has_result` must be true before calling `into_result`"), } } @@ -80,35 +77,39 @@ impl ConnectingStream { pollee.reset_events(); } - /// Returns `true` when `conn_result` becomes ready, which indicates that the caller should - /// invoke the `into_result()` method as soon as possible. - /// - /// Since `into_result()` needs to be called only once, this method will return `true` - /// _exactly_ once. The caller is responsible for not missing this event. - #[must_use] - pub(super) fn update_io_events(&self) -> bool { - if self.conn_result.read().is_some() { - return false; + pub(super) fn update_io_events(&self, pollee: &Pollee) { + if self.conn_result.lock().is_some() { + return; } self.bound_socket.raw_with(|socket| { - let mut result = self.conn_result.write(); + let mut result = self.conn_result.lock(); if result.is_some() { - return false; + return; } // Connected if socket.can_send() { *result = Some(ConnResult::Connected); - return true; + pollee.add_events(IoEvents::OUT); + return; } // Connecting if socket.is_open() { - return false; + return; } // Refused *result = Some(ConnResult::Refused); - true + pollee.add_events(IoEvents::OUT); + + // Add `IoEvents::OUT` because the man pages say "EINPROGRESS [..] It is possible to + // select(2) or poll(2) for completion by selecting the socket for writing". For + // details, see . + // + // TODO: It is better to do the state transition and let `ConnectedStream` or + // `InitStream` set the correct I/O events. However, the state transition is delayed + // because we're probably in IRQ handlers. Maybe mark the `pollee` as obsolete and + // re-calculate the I/O events in `poll`. }) } } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 51489d7f5..f701ddff5 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -36,7 +36,7 @@ impl ListenStream { /// Append sockets listening at LocalEndPoint to support backlog fn fill_backlog_sockets(&self) -> Result<()> { - let mut backlog_sockets = self.backlog_sockets.write(); + let mut backlog_sockets = self.backlog_sockets.write_irq_disabled(); let backlog = self.backlog; let current_backlog_len = backlog_sockets.len(); @@ -54,7 +54,7 @@ impl ListenStream { } pub fn try_accept(&self) -> Result { - let mut backlog_sockets = self.backlog_sockets.write(); + let mut backlog_sockets = self.backlog_sockets.write_irq_disabled(); let index = backlog_sockets .iter() diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 2fcb0d81f..531319c80 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -8,8 +8,9 @@ use connecting::ConnectingStream; use init::InitStream; use listen::ListenStream; use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; +use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; use takeable::Takeable; -use util::{TcpOptionSet, DEFAULT_MAXSEG}; +use util::TcpOptionSet; use super::UNSPECIFIED_LOCAL_ENDPOINT; use crate::{ @@ -39,7 +40,6 @@ mod listen; pub mod options; mod util; -use self::connecting::NonConnectedStream; pub use self::util::CongestionControl; pub struct StreamSocket { @@ -111,11 +111,79 @@ impl StreamSocket { self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } + /// Ensures that the socket state is up to date and obtains a read lock on it. + /// + /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. + fn read_updated_state(&self) -> RwLockReadGuard> { + loop { + let state = self.state.read(); + match state.as_ref() { + State::Connecting(connecting_stream) if connecting_stream.has_result() => (), + _ => return state, + }; + drop(state); + + self.update_connecting(); + } + } + + /// Ensures that the socket state is up to date and obtains a write lock on it. + /// + /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. + fn write_updated_state(&self) -> RwLockWriteGuard> { + self.update_connecting().1 + } + + /// Updates the socket state if the socket is an obsolete connecting socket. + /// + /// A connecting socket can become obsolete because some network events can set the socket to + /// connected state (if the connection succeeds) or initial state (if the connection is + /// refused) in [`Self::update_io_events`], but the state transition is delayed until the user + /// operates on the socket to avoid too many locks in the interrupt handler. + /// + /// This method performs the delayed state transition to ensure that the state is up to date + /// and returns the guards of the write-locked options and state. + fn update_connecting( + &self, + ) -> ( + RwLockWriteGuard, + RwLockWriteGuard>, + ) { + // Hold the lock in advance to avoid race conditions. + let mut options = self.options.write(); + let mut state = self.state.write_irq_disabled(); + + match state.as_ref() { + State::Connecting(connection_stream) if connection_stream.has_result() => (), + _ => return (options, state), + } + + let result = state.borrow_result(|owned_state| { + let State::Connecting(connecting_stream) = owned_state else { + unreachable!("`State::Connecting` is checked before calling `borrow_result`"); + }; + + let connected_stream = match connecting_stream.into_result() { + Ok(connected_stream) => connected_stream, + Err((err, init_stream)) => { + init_stream.init_pollee(&self.pollee); + return (State::Init(init_stream), Err(err)); + } + }; + connected_stream.init_pollee(&self.pollee); + + (State::Connected(connected_stream), Ok(())) + }); + options.socket.set_sock_errors(result.err()); + + (options, state) + } + // Returns `None` to block the task and wait for the connection to be established, and returns // `Some(_)` if blocking is not necessary or not allowed. fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option> { let is_nonblocking = self.is_nonblocking(); - let mut state = self.state.write(); + let mut state = self.write_updated_state(); let result_or_block = state.borrow_result(|mut owned_state| { let init_stream = match owned_state { @@ -174,41 +242,8 @@ impl StreamSocket { result_or_block } - fn finish_connect(&self) -> Result<()> { - let mut state = self.state.write(); - - state.borrow_result(|owned_state| { - let State::Connecting(connecting_stream) = owned_state else { - debug_assert!(false, "the socket unexpectedly left the connecting state"); - return ( - owned_state, - Err(Error::with_message( - Errno::EINVAL, - "the socket is not connecting", - )), - ); - }; - - let connected_stream = match connecting_stream.into_result() { - Ok(connected_stream) => connected_stream, - Err((err, NonConnectedStream::Init(init_stream))) => { - init_stream.init_pollee(&self.pollee); - return (State::Init(init_stream), Err(err)); - } - Err((err, NonConnectedStream::Connecting(connecting_stream))) => { - return (State::Connecting(connecting_stream), Err(err)); - } - }; - connected_stream.init_pollee(&self.pollee); - - (State::Connected(connected_stream), Ok(())) - }) - } - fn check_connect(&self) -> Result<()> { - // Hold the lock in advance to avoid deadlocks. - let mut options = self.options.write(); - let mut state = self.state.write(); + let (mut options, mut state) = self.update_connecting(); match state.as_mut() { State::Connecting(_) => { @@ -224,7 +259,7 @@ impl StreamSocket { } fn try_accept(&self) -> Result<(Arc, SocketAddr)> { - let state = self.state.read(); + let state = self.read_updated_state(); let State::Listen(listen_stream) = state.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); @@ -244,7 +279,7 @@ impl StreamSocket { writer: &mut dyn MultiWrite, flags: SendRecvFlags, ) -> Result<(usize, SocketAddr)> { - let state = self.state.read(); + let state = self.read_updated_state(); let connected_stream = match state.as_ref() { State::Connected(connected_stream) => connected_stream, @@ -280,7 +315,7 @@ impl StreamSocket { } fn try_send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { - let state = self.state.read(); + let state = self.read_updated_state(); let connected_stream = match state.as_ref() { State::Connected(connected_stream) => connected_stream, @@ -311,21 +346,24 @@ impl StreamSocket { } } - #[must_use] - fn update_io_events(&self) -> bool { + fn update_io_events(&self) { let state = self.state.read(); match state.as_ref() { - State::Init(_) => false, - State::Connecting(connecting_stream) => connecting_stream.update_io_events(), + State::Init(_) => (), + State::Connecting(connecting_stream) => { + connecting_stream.update_io_events(&self.pollee) + } State::Listen(listen_stream) => { listen_stream.update_io_events(&self.pollee); - false } State::Connected(connected_stream) => { connected_stream.update_io_events(&self.pollee); - false } } + + // Note: Network events can cause a state transition from `State::Connecting` to + // `State::Connected`/`State::Init`. The state transition is delayed until + // `update_connecting`is triggered by user events, see that method for details. } } @@ -392,7 +430,7 @@ impl Socket for StreamSocket { let endpoint = socket_addr.try_into()?; let can_reuse = self.options.read().socket.reuse_addr(); - let mut state = self.state.write(); + let mut state = self.write_updated_state(); state.borrow_result(|owned_state| { let State::Init(init_stream) = owned_state else { @@ -427,7 +465,7 @@ impl Socket for StreamSocket { } fn listen(&self, backlog: usize) -> Result<()> { - let mut state = self.state.write(); + let mut state = self.write_updated_state(); state.borrow_result(|owned_state| { let init_stream = match owned_state { @@ -467,7 +505,7 @@ impl Socket for StreamSocket { } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { - let state = self.state.read(); + let state = self.read_updated_state(); match state.as_ref() { State::Connected(connected_stream) => connected_stream.shutdown(cmd), // TODO: shutdown listening stream @@ -476,7 +514,7 @@ impl Socket for StreamSocket { } fn addr(&self) -> Result { - let state = self.state.read(); + let state = self.read_updated_state(); let local_endpoint = match state.as_ref() { State::Init(init_stream) => init_stream .local_endpoint() @@ -489,7 +527,7 @@ impl Socket for StreamSocket { } fn peer_addr(&self) -> Result { - let state = self.state.read(); + let state = self.read_updated_state(); let remote_endpoint = match state.as_ref() { State::Init(_) | State::Listen(_) => { return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected") @@ -547,7 +585,8 @@ impl Socket for StreamSocket { fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { match_sock_option_mut!(option, { socket_errors: SocketError => { - self.options.write().socket.get_and_clear_sock_errors(socket_errors); + let mut options = self.update_connecting().0; + options.socket.get_and_clear_sock_errors(socket_errors); return Ok(()); }, _ => () @@ -632,15 +671,7 @@ impl Socket for StreamSocket { impl SocketEventObserver for StreamSocket { fn on_events(&self) { - let conn_ready = self.update_io_events(); - - if conn_ready { - // Hold the lock in advance to avoid race conditions. - let mut options = self.options.write(); - - let result = self.finish_connect(); - options.socket.set_sock_errors(result.err()); - } + self.update_io_events(); } }