Fix IRQ-related socket locks

This commit is contained in:
Ruihan Li 2024-10-03 12:24:55 +08:00 committed by Tate, Hongliang Tian
parent 7e21e1e653
commit 0415225c19
4 changed files with 127 additions and 95 deletions

View File

@ -130,7 +130,7 @@ impl DatagramSocket {
}
// Slow path
let mut inner = self.inner.write();
let mut inner = self.inner.write_irq_disabled();
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram,
@ -278,7 +278,7 @@ impl Socket for DatagramSocket {
let endpoint = socket_addr.try_into()?;
let can_reuse = self.options.read().socket.reuse_addr();
let mut inner = self.inner.write();
let mut inner = self.inner.write_irq_disabled();
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) {
Ok(bound_datagram) => bound_datagram,
@ -296,7 +296,7 @@ impl Socket for DatagramSocket {
self.try_bind_ephemeral(&endpoint)?;
let mut inner = self.inner.write();
let mut inner = self.inner.write_irq_disabled();
let Inner::Bound(bound_datagram) = inner.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
};

View File

@ -1,14 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::wire::IpEndpoint;
use ostd::sync::LocalIrqDisabled;
use super::{connected::ConnectedStream, init::InitStream};
use crate::{net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
pub struct ConnectingStream {
bound_socket: BoundTcpSocket,
remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>,
conn_result: SpinLock<Option<ConnResult>, LocalIrqDisabled>,
}
#[derive(Clone, Copy)]
@ -17,11 +18,6 @@ enum ConnResult {
Refused,
}
pub enum NonConnectedStream {
Init(InitStream),
Connecting(ConnectingStream),
}
impl ConnectingStream {
pub fn new(
bound_socket: BoundTcpSocket,
@ -45,12 +41,16 @@ impl ConnectingStream {
Ok(Self {
bound_socket,
remote_endpoint,
conn_result: RwLock::new(None),
conn_result: SpinLock::new(None),
})
}
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, NonConnectedStream)> {
let conn_result = *self.conn_result.read();
pub fn has_result(&self) -> bool {
self.conn_result.lock().is_some()
}
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, InitStream)> {
let conn_result = *self.conn_result.lock();
match conn_result {
Some(ConnResult::Connected) => Ok(ConnectedStream::new(
self.bound_socket,
@ -59,12 +59,9 @@ impl ConnectingStream {
)),
Some(ConnResult::Refused) => Err((
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)),
)),
None => Err((
Error::with_message(Errno::EAGAIN, "the connection is pending"),
NonConnectedStream::Connecting(self),
InitStream::new_bound(self.bound_socket),
)),
None => unreachable!("`has_result` must be true before calling `into_result`"),
}
}
@ -80,35 +77,39 @@ impl ConnectingStream {
pollee.reset_events();
}
/// Returns `true` when `conn_result` becomes ready, which indicates that the caller should
/// invoke the `into_result()` method as soon as possible.
///
/// Since `into_result()` needs to be called only once, this method will return `true`
/// _exactly_ once. The caller is responsible for not missing this event.
#[must_use]
pub(super) fn update_io_events(&self) -> bool {
if self.conn_result.read().is_some() {
return false;
pub(super) fn update_io_events(&self, pollee: &Pollee) {
if self.conn_result.lock().is_some() {
return;
}
self.bound_socket.raw_with(|socket| {
let mut result = self.conn_result.write();
let mut result = self.conn_result.lock();
if result.is_some() {
return false;
return;
}
// Connected
if socket.can_send() {
*result = Some(ConnResult::Connected);
return true;
pollee.add_events(IoEvents::OUT);
return;
}
// Connecting
if socket.is_open() {
return false;
return;
}
// Refused
*result = Some(ConnResult::Refused);
true
pollee.add_events(IoEvents::OUT);
// Add `IoEvents::OUT` because the man pages say "EINPROGRESS [..] It is possible to
// select(2) or poll(2) for completion by selecting the socket for writing". For
// details, see <https://man7.org/linux/man-pages/man2/connect.2.html>.
//
// TODO: It is better to do the state transition and let `ConnectedStream` or
// `InitStream` set the correct I/O events. However, the state transition is delayed
// because we're probably in IRQ handlers. Maybe mark the `pollee` as obsolete and
// re-calculate the I/O events in `poll`.
})
}
}

View File

@ -36,7 +36,7 @@ impl ListenStream {
/// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> {
let mut backlog_sockets = self.backlog_sockets.write();
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
let backlog = self.backlog;
let current_backlog_len = backlog_sockets.len();
@ -54,7 +54,7 @@ impl ListenStream {
}
pub fn try_accept(&self) -> Result<ConnectedStream> {
let mut backlog_sockets = self.backlog_sockets.write();
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
let index = backlog_sockets
.iter()

View File

@ -8,8 +8,9 @@ use connecting::ConnectingStream;
use init::InitStream;
use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
use takeable::Takeable;
use util::{TcpOptionSet, DEFAULT_MAXSEG};
use util::TcpOptionSet;
use super::UNSPECIFIED_LOCAL_ENDPOINT;
use crate::{
@ -39,7 +40,6 @@ mod listen;
pub mod options;
mod util;
use self::connecting::NonConnectedStream;
pub use self::util::CongestionControl;
pub struct StreamSocket {
@ -111,11 +111,79 @@ impl StreamSocket {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
/// Ensures that the socket state is up to date and obtains a read lock on it.
///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>> {
loop {
let state = self.state.read();
match state.as_ref() {
State::Connecting(connecting_stream) if connecting_stream.has_result() => (),
_ => return state,
};
drop(state);
self.update_connecting();
}
}
/// Ensures that the socket state is up to date and obtains a write lock on it.
///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>> {
self.update_connecting().1
}
/// Updates the socket state if the socket is an obsolete connecting socket.
///
/// A connecting socket can become obsolete because some network events can set the socket to
/// connected state (if the connection succeeds) or initial state (if the connection is
/// refused) in [`Self::update_io_events`], but the state transition is delayed until the user
/// operates on the socket to avoid too many locks in the interrupt handler.
///
/// This method performs the delayed state transition to ensure that the state is up to date
/// and returns the guards of the write-locked options and state.
fn update_connecting(
&self,
) -> (
RwLockWriteGuard<OptionSet>,
RwLockWriteGuard<Takeable<State>>,
) {
// Hold the lock in advance to avoid race conditions.
let mut options = self.options.write();
let mut state = self.state.write_irq_disabled();
match state.as_ref() {
State::Connecting(connection_stream) if connection_stream.has_result() => (),
_ => return (options, state),
}
let result = state.borrow_result(|owned_state| {
let State::Connecting(connecting_stream) = owned_state else {
unreachable!("`State::Connecting` is checked before calling `borrow_result`");
};
let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream,
Err((err, init_stream)) => {
init_stream.init_pollee(&self.pollee);
return (State::Init(init_stream), Err(err));
}
};
connected_stream.init_pollee(&self.pollee);
(State::Connected(connected_stream), Ok(()))
});
options.socket.set_sock_errors(result.err());
(options, state)
}
// Returns `None` to block the task and wait for the connection to be established, and returns
// `Some(_)` if blocking is not necessary or not allowed.
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> {
let is_nonblocking = self.is_nonblocking();
let mut state = self.state.write();
let mut state = self.write_updated_state();
let result_or_block = state.borrow_result(|mut owned_state| {
let init_stream = match owned_state {
@ -174,41 +242,8 @@ impl StreamSocket {
result_or_block
}
fn finish_connect(&self) -> Result<()> {
let mut state = self.state.write();
state.borrow_result(|owned_state| {
let State::Connecting(connecting_stream) = owned_state else {
debug_assert!(false, "the socket unexpectedly left the connecting state");
return (
owned_state,
Err(Error::with_message(
Errno::EINVAL,
"the socket is not connecting",
)),
);
};
let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream,
Err((err, NonConnectedStream::Init(init_stream))) => {
init_stream.init_pollee(&self.pollee);
return (State::Init(init_stream), Err(err));
}
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
return (State::Connecting(connecting_stream), Err(err));
}
};
connected_stream.init_pollee(&self.pollee);
(State::Connected(connected_stream), Ok(()))
})
}
fn check_connect(&self) -> Result<()> {
// Hold the lock in advance to avoid deadlocks.
let mut options = self.options.write();
let mut state = self.state.write();
let (mut options, mut state) = self.update_connecting();
match state.as_mut() {
State::Connecting(_) => {
@ -224,7 +259,7 @@ impl StreamSocket {
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read();
let state = self.read_updated_state();
let State::Listen(listen_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
@ -244,7 +279,7 @@ impl StreamSocket {
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, SocketAddr)> {
let state = self.state.read();
let state = self.read_updated_state();
let connected_stream = match state.as_ref() {
State::Connected(connected_stream) => connected_stream,
@ -280,7 +315,7 @@ impl StreamSocket {
}
fn try_send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result<usize> {
let state = self.state.read();
let state = self.read_updated_state();
let connected_stream = match state.as_ref() {
State::Connected(connected_stream) => connected_stream,
@ -311,21 +346,24 @@ impl StreamSocket {
}
}
#[must_use]
fn update_io_events(&self) -> bool {
fn update_io_events(&self) {
let state = self.state.read();
match state.as_ref() {
State::Init(_) => false,
State::Connecting(connecting_stream) => connecting_stream.update_io_events(),
State::Init(_) => (),
State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee)
}
State::Listen(listen_stream) => {
listen_stream.update_io_events(&self.pollee);
false
}
State::Connected(connected_stream) => {
connected_stream.update_io_events(&self.pollee);
false
}
}
// Note: Network events can cause a state transition from `State::Connecting` to
// `State::Connected`/`State::Init`. The state transition is delayed until
// `update_connecting`is triggered by user events, see that method for details.
}
}
@ -392,7 +430,7 @@ impl Socket for StreamSocket {
let endpoint = socket_addr.try_into()?;
let can_reuse = self.options.read().socket.reuse_addr();
let mut state = self.state.write();
let mut state = self.write_updated_state();
state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else {
@ -427,7 +465,7 @@ impl Socket for StreamSocket {
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write();
let mut state = self.write_updated_state();
state.borrow_result(|owned_state| {
let init_stream = match owned_state {
@ -467,7 +505,7 @@ impl Socket for StreamSocket {
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let state = self.state.read();
let state = self.read_updated_state();
match state.as_ref() {
State::Connected(connected_stream) => connected_stream.shutdown(cmd),
// TODO: shutdown listening stream
@ -476,7 +514,7 @@ impl Socket for StreamSocket {
}
fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let state = self.read_updated_state();
let local_endpoint = match state.as_ref() {
State::Init(init_stream) => init_stream
.local_endpoint()
@ -489,7 +527,7 @@ impl Socket for StreamSocket {
}
fn peer_addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let state = self.read_updated_state();
let remote_endpoint = match state.as_ref() {
State::Init(_) | State::Listen(_) => {
return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected")
@ -547,7 +585,8 @@ impl Socket for StreamSocket {
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
match_sock_option_mut!(option, {
socket_errors: SocketError => {
self.options.write().socket.get_and_clear_sock_errors(socket_errors);
let mut options = self.update_connecting().0;
options.socket.get_and_clear_sock_errors(socket_errors);
return Ok(());
},
_ => ()
@ -632,15 +671,7 @@ impl Socket for StreamSocket {
impl SocketEventObserver for StreamSocket {
fn on_events(&self) {
let conn_ready = self.update_io_events();
if conn_ready {
// Hold the lock in advance to avoid race conditions.
let mut options = self.options.write();
let result = self.finish_connect();
options.socket.set_sock_errors(result.err());
}
self.update_io_events();
}
}