mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-17 12:47:16 +00:00
Fix IRQ-related socket locks
This commit is contained in:
parent
7e21e1e653
commit
0415225c19
@ -130,7 +130,7 @@ impl DatagramSocket {
|
||||
}
|
||||
|
||||
// Slow path
|
||||
let mut inner = self.inner.write();
|
||||
let mut inner = self.inner.write_irq_disabled();
|
||||
inner.borrow_result(|owned_inner| {
|
||||
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
@ -278,7 +278,7 @@ impl Socket for DatagramSocket {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
|
||||
let can_reuse = self.options.read().socket.reuse_addr();
|
||||
let mut inner = self.inner.write();
|
||||
let mut inner = self.inner.write_irq_disabled();
|
||||
inner.borrow_result(|owned_inner| {
|
||||
let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
@ -296,7 +296,7 @@ impl Socket for DatagramSocket {
|
||||
|
||||
self.try_bind_ephemeral(&endpoint)?;
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
let mut inner = self.inner.write_irq_disabled();
|
||||
let Inner::Bound(bound_datagram) = inner.as_mut() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
|
||||
};
|
||||
|
@ -1,14 +1,15 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use aster_bigtcp::wire::IpEndpoint;
|
||||
use ostd::sync::LocalIrqDisabled;
|
||||
|
||||
use super::{connected::ConnectedStream, init::InitStream};
|
||||
use crate::{net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
|
||||
use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
|
||||
|
||||
pub struct ConnectingStream {
|
||||
bound_socket: BoundTcpSocket,
|
||||
remote_endpoint: IpEndpoint,
|
||||
conn_result: RwLock<Option<ConnResult>>,
|
||||
conn_result: SpinLock<Option<ConnResult>, LocalIrqDisabled>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@ -17,11 +18,6 @@ enum ConnResult {
|
||||
Refused,
|
||||
}
|
||||
|
||||
pub enum NonConnectedStream {
|
||||
Init(InitStream),
|
||||
Connecting(ConnectingStream),
|
||||
}
|
||||
|
||||
impl ConnectingStream {
|
||||
pub fn new(
|
||||
bound_socket: BoundTcpSocket,
|
||||
@ -45,12 +41,16 @@ impl ConnectingStream {
|
||||
Ok(Self {
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
conn_result: RwLock::new(None),
|
||||
conn_result: SpinLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, NonConnectedStream)> {
|
||||
let conn_result = *self.conn_result.read();
|
||||
pub fn has_result(&self) -> bool {
|
||||
self.conn_result.lock().is_some()
|
||||
}
|
||||
|
||||
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, InitStream)> {
|
||||
let conn_result = *self.conn_result.lock();
|
||||
match conn_result {
|
||||
Some(ConnResult::Connected) => Ok(ConnectedStream::new(
|
||||
self.bound_socket,
|
||||
@ -59,12 +59,9 @@ impl ConnectingStream {
|
||||
)),
|
||||
Some(ConnResult::Refused) => Err((
|
||||
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
|
||||
NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)),
|
||||
)),
|
||||
None => Err((
|
||||
Error::with_message(Errno::EAGAIN, "the connection is pending"),
|
||||
NonConnectedStream::Connecting(self),
|
||||
InitStream::new_bound(self.bound_socket),
|
||||
)),
|
||||
None => unreachable!("`has_result` must be true before calling `into_result`"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,35 +77,39 @@ impl ConnectingStream {
|
||||
pollee.reset_events();
|
||||
}
|
||||
|
||||
/// Returns `true` when `conn_result` becomes ready, which indicates that the caller should
|
||||
/// invoke the `into_result()` method as soon as possible.
|
||||
///
|
||||
/// Since `into_result()` needs to be called only once, this method will return `true`
|
||||
/// _exactly_ once. The caller is responsible for not missing this event.
|
||||
#[must_use]
|
||||
pub(super) fn update_io_events(&self) -> bool {
|
||||
if self.conn_result.read().is_some() {
|
||||
return false;
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
if self.conn_result.lock().is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.bound_socket.raw_with(|socket| {
|
||||
let mut result = self.conn_result.write();
|
||||
let mut result = self.conn_result.lock();
|
||||
if result.is_some() {
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Connected
|
||||
if socket.can_send() {
|
||||
*result = Some(ConnResult::Connected);
|
||||
return true;
|
||||
pollee.add_events(IoEvents::OUT);
|
||||
return;
|
||||
}
|
||||
// Connecting
|
||||
if socket.is_open() {
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
// Refused
|
||||
*result = Some(ConnResult::Refused);
|
||||
true
|
||||
pollee.add_events(IoEvents::OUT);
|
||||
|
||||
// Add `IoEvents::OUT` because the man pages say "EINPROGRESS [..] It is possible to
|
||||
// select(2) or poll(2) for completion by selecting the socket for writing". For
|
||||
// details, see <https://man7.org/linux/man-pages/man2/connect.2.html>.
|
||||
//
|
||||
// TODO: It is better to do the state transition and let `ConnectedStream` or
|
||||
// `InitStream` set the correct I/O events. However, the state transition is delayed
|
||||
// because we're probably in IRQ handlers. Maybe mark the `pollee` as obsolete and
|
||||
// re-calculate the I/O events in `poll`.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ impl ListenStream {
|
||||
|
||||
/// Append sockets listening at LocalEndPoint to support backlog
|
||||
fn fill_backlog_sockets(&self) -> Result<()> {
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
|
||||
|
||||
let backlog = self.backlog;
|
||||
let current_backlog_len = backlog_sockets.len();
|
||||
@ -54,7 +54,7 @@ impl ListenStream {
|
||||
}
|
||||
|
||||
pub fn try_accept(&self) -> Result<ConnectedStream> {
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
|
||||
|
||||
let index = backlog_sockets
|
||||
.iter()
|
||||
|
@ -8,8 +8,9 @@ use connecting::ConnectingStream;
|
||||
use init::InitStream;
|
||||
use listen::ListenStream;
|
||||
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
|
||||
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
|
||||
use takeable::Takeable;
|
||||
use util::{TcpOptionSet, DEFAULT_MAXSEG};
|
||||
use util::TcpOptionSet;
|
||||
|
||||
use super::UNSPECIFIED_LOCAL_ENDPOINT;
|
||||
use crate::{
|
||||
@ -39,7 +40,6 @@ mod listen;
|
||||
pub mod options;
|
||||
mod util;
|
||||
|
||||
use self::connecting::NonConnectedStream;
|
||||
pub use self::util::CongestionControl;
|
||||
|
||||
pub struct StreamSocket {
|
||||
@ -111,11 +111,79 @@ impl StreamSocket {
|
||||
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Ensures that the socket state is up to date and obtains a read lock on it.
|
||||
///
|
||||
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
|
||||
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>> {
|
||||
loop {
|
||||
let state = self.state.read();
|
||||
match state.as_ref() {
|
||||
State::Connecting(connecting_stream) if connecting_stream.has_result() => (),
|
||||
_ => return state,
|
||||
};
|
||||
drop(state);
|
||||
|
||||
self.update_connecting();
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures that the socket state is up to date and obtains a write lock on it.
|
||||
///
|
||||
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
|
||||
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>> {
|
||||
self.update_connecting().1
|
||||
}
|
||||
|
||||
/// Updates the socket state if the socket is an obsolete connecting socket.
|
||||
///
|
||||
/// A connecting socket can become obsolete because some network events can set the socket to
|
||||
/// connected state (if the connection succeeds) or initial state (if the connection is
|
||||
/// refused) in [`Self::update_io_events`], but the state transition is delayed until the user
|
||||
/// operates on the socket to avoid too many locks in the interrupt handler.
|
||||
///
|
||||
/// This method performs the delayed state transition to ensure that the state is up to date
|
||||
/// and returns the guards of the write-locked options and state.
|
||||
fn update_connecting(
|
||||
&self,
|
||||
) -> (
|
||||
RwLockWriteGuard<OptionSet>,
|
||||
RwLockWriteGuard<Takeable<State>>,
|
||||
) {
|
||||
// Hold the lock in advance to avoid race conditions.
|
||||
let mut options = self.options.write();
|
||||
let mut state = self.state.write_irq_disabled();
|
||||
|
||||
match state.as_ref() {
|
||||
State::Connecting(connection_stream) if connection_stream.has_result() => (),
|
||||
_ => return (options, state),
|
||||
}
|
||||
|
||||
let result = state.borrow_result(|owned_state| {
|
||||
let State::Connecting(connecting_stream) = owned_state else {
|
||||
unreachable!("`State::Connecting` is checked before calling `borrow_result`");
|
||||
};
|
||||
|
||||
let connected_stream = match connecting_stream.into_result() {
|
||||
Ok(connected_stream) => connected_stream,
|
||||
Err((err, init_stream)) => {
|
||||
init_stream.init_pollee(&self.pollee);
|
||||
return (State::Init(init_stream), Err(err));
|
||||
}
|
||||
};
|
||||
connected_stream.init_pollee(&self.pollee);
|
||||
|
||||
(State::Connected(connected_stream), Ok(()))
|
||||
});
|
||||
options.socket.set_sock_errors(result.err());
|
||||
|
||||
(options, state)
|
||||
}
|
||||
|
||||
// Returns `None` to block the task and wait for the connection to be established, and returns
|
||||
// `Some(_)` if blocking is not necessary or not allowed.
|
||||
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> {
|
||||
let is_nonblocking = self.is_nonblocking();
|
||||
let mut state = self.state.write();
|
||||
let mut state = self.write_updated_state();
|
||||
|
||||
let result_or_block = state.borrow_result(|mut owned_state| {
|
||||
let init_stream = match owned_state {
|
||||
@ -174,41 +242,8 @@ impl StreamSocket {
|
||||
result_or_block
|
||||
}
|
||||
|
||||
fn finish_connect(&self) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Connecting(connecting_stream) = owned_state else {
|
||||
debug_assert!(false, "the socket unexpectedly left the connecting state");
|
||||
return (
|
||||
owned_state,
|
||||
Err(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket is not connecting",
|
||||
)),
|
||||
);
|
||||
};
|
||||
|
||||
let connected_stream = match connecting_stream.into_result() {
|
||||
Ok(connected_stream) => connected_stream,
|
||||
Err((err, NonConnectedStream::Init(init_stream))) => {
|
||||
init_stream.init_pollee(&self.pollee);
|
||||
return (State::Init(init_stream), Err(err));
|
||||
}
|
||||
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
|
||||
return (State::Connecting(connecting_stream), Err(err));
|
||||
}
|
||||
};
|
||||
connected_stream.init_pollee(&self.pollee);
|
||||
|
||||
(State::Connected(connected_stream), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn check_connect(&self) -> Result<()> {
|
||||
// Hold the lock in advance to avoid deadlocks.
|
||||
let mut options = self.options.write();
|
||||
let mut state = self.state.write();
|
||||
let (mut options, mut state) = self.update_connecting();
|
||||
|
||||
match state.as_mut() {
|
||||
State::Connecting(_) => {
|
||||
@ -224,7 +259,7 @@ impl StreamSocket {
|
||||
}
|
||||
|
||||
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
let state = self.read_updated_state();
|
||||
|
||||
let State::Listen(listen_stream) = state.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
|
||||
@ -244,7 +279,7 @@ impl StreamSocket {
|
||||
writer: &mut dyn MultiWrite,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<(usize, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
let state = self.read_updated_state();
|
||||
|
||||
let connected_stream = match state.as_ref() {
|
||||
State::Connected(connected_stream) => connected_stream,
|
||||
@ -280,7 +315,7 @@ impl StreamSocket {
|
||||
}
|
||||
|
||||
fn try_send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result<usize> {
|
||||
let state = self.state.read();
|
||||
let state = self.read_updated_state();
|
||||
|
||||
let connected_stream = match state.as_ref() {
|
||||
State::Connected(connected_stream) => connected_stream,
|
||||
@ -311,21 +346,24 @@ impl StreamSocket {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn update_io_events(&self) -> bool {
|
||||
fn update_io_events(&self) {
|
||||
let state = self.state.read();
|
||||
match state.as_ref() {
|
||||
State::Init(_) => false,
|
||||
State::Connecting(connecting_stream) => connecting_stream.update_io_events(),
|
||||
State::Init(_) => (),
|
||||
State::Connecting(connecting_stream) => {
|
||||
connecting_stream.update_io_events(&self.pollee)
|
||||
}
|
||||
State::Listen(listen_stream) => {
|
||||
listen_stream.update_io_events(&self.pollee);
|
||||
false
|
||||
}
|
||||
State::Connected(connected_stream) => {
|
||||
connected_stream.update_io_events(&self.pollee);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Network events can cause a state transition from `State::Connecting` to
|
||||
// `State::Connected`/`State::Init`. The state transition is delayed until
|
||||
// `update_connecting`is triggered by user events, see that method for details.
|
||||
}
|
||||
}
|
||||
|
||||
@ -392,7 +430,7 @@ impl Socket for StreamSocket {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
|
||||
let can_reuse = self.options.read().socket.reuse_addr();
|
||||
let mut state = self.state.write();
|
||||
let mut state = self.write_updated_state();
|
||||
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
@ -427,7 +465,7 @@ impl Socket for StreamSocket {
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
let mut state = self.write_updated_state();
|
||||
|
||||
state.borrow_result(|owned_state| {
|
||||
let init_stream = match owned_state {
|
||||
@ -467,7 +505,7 @@ impl Socket for StreamSocket {
|
||||
}
|
||||
|
||||
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
let state = self.state.read();
|
||||
let state = self.read_updated_state();
|
||||
match state.as_ref() {
|
||||
State::Connected(connected_stream) => connected_stream.shutdown(cmd),
|
||||
// TODO: shutdown listening stream
|
||||
@ -476,7 +514,7 @@ impl Socket for StreamSocket {
|
||||
}
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let state = self.state.read();
|
||||
let state = self.read_updated_state();
|
||||
let local_endpoint = match state.as_ref() {
|
||||
State::Init(init_stream) => init_stream
|
||||
.local_endpoint()
|
||||
@ -489,7 +527,7 @@ impl Socket for StreamSocket {
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let state = self.state.read();
|
||||
let state = self.read_updated_state();
|
||||
let remote_endpoint = match state.as_ref() {
|
||||
State::Init(_) | State::Listen(_) => {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected")
|
||||
@ -547,7 +585,8 @@ impl Socket for StreamSocket {
|
||||
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
|
||||
match_sock_option_mut!(option, {
|
||||
socket_errors: SocketError => {
|
||||
self.options.write().socket.get_and_clear_sock_errors(socket_errors);
|
||||
let mut options = self.update_connecting().0;
|
||||
options.socket.get_and_clear_sock_errors(socket_errors);
|
||||
return Ok(());
|
||||
},
|
||||
_ => ()
|
||||
@ -632,15 +671,7 @@ impl Socket for StreamSocket {
|
||||
|
||||
impl SocketEventObserver for StreamSocket {
|
||||
fn on_events(&self) {
|
||||
let conn_ready = self.update_io_events();
|
||||
|
||||
if conn_ready {
|
||||
// Hold the lock in advance to avoid race conditions.
|
||||
let mut options = self.options.write();
|
||||
|
||||
let result = self.finish_connect();
|
||||
options.socket.set_sock_errors(result.err());
|
||||
}
|
||||
self.update_io_events();
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user