From a260411a2a8bb1bcc12b9920b33ad19c8a82899e Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Tue, 3 Dec 2024 10:51:11 +0800 Subject: [PATCH] Implement `WriteIrqDisabled` --- kernel/libs/aster-bigtcp/src/socket/bound.rs | 6 +-- kernel/src/net/socket/ip/datagram/mod.rs | 4 +- kernel/src/net/socket/ip/stream/listen.rs | 4 +- kernel/src/net/socket/ip/stream/mod.rs | 10 ++-- ostd/src/sync/guard.rs | 48 ++++++++++++++++++-- ostd/src/sync/mod.rs | 2 +- ostd/src/sync/rwlock.rs | 6 +-- 7 files changed, 58 insertions(+), 22 deletions(-) diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 1eba0421..0501004a 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -9,7 +9,7 @@ use core::{ sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, }; -use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard}; +use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard, WriteIrqDisabled}; use smoltcp::{ iface::Context, socket::{tcp::State, udp::UdpMetadata, PollAt}, @@ -46,7 +46,7 @@ pub struct BoundSocketInner { iface: Arc>, port: u16, socket: T, - observer: RwLock, LocalIrqDisabled>, + observer: RwLock, WriteIrqDisabled>, events: AtomicU8, next_poll_at_ms: AtomicU64, } @@ -232,8 +232,6 @@ impl BoundSocket { /// /// See also [`Self::set_observer`]. pub fn observer(&self) -> Weak { - // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we - // get the read lock. self.0.observer.read().clone() } diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index e1c32fae..b41c005a 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -6,7 +6,7 @@ use aster_bigtcp::{ socket::{SocketEventObserver, SocketEvents}, wire::IpEndpoint, }; -use ostd::sync::LocalIrqDisabled; +use ostd::sync::WriteIrqDisabled; use takeable::Takeable; use self::{bound::BoundDatagram, unbound::UnboundDatagram}; @@ -52,7 +52,7 @@ impl OptionSet { pub struct DatagramSocket { options: RwLock, - inner: RwLock, LocalIrqDisabled>, + inner: RwLock, WriteIrqDisabled>, nonblocking: AtomicBool, pollee: Pollee, } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index bf2b7ec9..613f5338 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -3,7 +3,7 @@ use aster_bigtcp::{ errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint, }; -use ostd::sync::LocalIrqDisabled; +use ostd::sync::WriteIrqDisabled; use super::connected::ConnectedStream; use crate::{ @@ -17,7 +17,7 @@ pub struct ListenStream { /// A bound socket held to ensure the TCP port cannot be released bound_socket: BoundTcpSocket, /// Backlog sockets listening at the local endpoint - backlog_sockets: RwLock, LocalIrqDisabled>, + backlog_sockets: RwLock, WriteIrqDisabled>, } impl ListenStream { diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 3d186b8d..a2e0c710 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -11,7 +11,7 @@ use connecting::{ConnResult, ConnectingStream}; use init::InitStream; use listen::ListenStream; use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; -use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; +use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard, WriteIrqDisabled}; use takeable::Takeable; use util::TcpOptionSet; @@ -50,7 +50,7 @@ pub use self::util::CongestionControl; pub struct StreamSocket { options: RwLock, - state: RwLock, LocalIrqDisabled>, + state: RwLock, WriteIrqDisabled>, is_nonblocking: AtomicBool, pollee: Pollee, } @@ -116,7 +116,7 @@ impl StreamSocket { /// Ensures that the socket state is up to date and obtains a read lock on it. /// /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. - fn read_updated_state(&self) -> RwLockReadGuard, LocalIrqDisabled> { + fn read_updated_state(&self) -> RwLockReadGuard, WriteIrqDisabled> { loop { let state = self.state.read(); match state.as_ref() { @@ -132,7 +132,7 @@ impl StreamSocket { /// Ensures that the socket state is up to date and obtains a write lock on it. /// /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. - fn write_updated_state(&self) -> RwLockWriteGuard, LocalIrqDisabled> { + fn write_updated_state(&self) -> RwLockWriteGuard, WriteIrqDisabled> { self.update_connecting().1 } @@ -149,7 +149,7 @@ impl StreamSocket { &self, ) -> ( RwLockWriteGuard, - RwLockWriteGuard, LocalIrqDisabled>, + RwLockWriteGuard, WriteIrqDisabled>, ) { // Hold the lock in advance to avoid race conditions. let mut options = self.options.write(); diff --git a/ostd/src/sync/guard.rs b/ostd/src/sync/guard.rs index 2dfb9a75..ba58276a 100644 --- a/ostd/src/sync/guard.rs +++ b/ostd/src/sync/guard.rs @@ -5,13 +5,17 @@ use crate::{ trap::{disable_local, DisabledLocalIrqGuard}, }; -/// A guardian that denotes the guard behavior for holding the spin lock. +/// A guardian that denotes the guard behavior for holding a lock. pub trait Guardian { - /// The guard type. + /// The guard type for holding a spin lock or a write lock. type Guard: GuardTransfer; + /// The guard type for holding a read lock. + type ReadGuard: GuardTransfer; /// Creates a new guard. fn guard() -> Self::Guard; + /// Creates a new read guard. + fn read_guard() -> Self::ReadGuard; } /// The Guard can be transferred atomically. @@ -25,21 +29,25 @@ pub trait GuardTransfer { fn transfer_to(&mut self) -> Self; } -/// A guardian that disables preemption while holding the spin lock. +/// A guardian that disables preemption while holding a lock. pub struct PreemptDisabled; impl Guardian for PreemptDisabled { type Guard = DisabledPreemptGuard; + type ReadGuard = DisabledPreemptGuard; fn guard() -> Self::Guard { disable_preempt() } + fn read_guard() -> Self::Guard { + disable_preempt() + } } -/// A guardian that disables IRQs while holding the spin lock. +/// A guardian that disables IRQs while holding a lock. /// /// This guardian would incur a certain time overhead over -/// [`PreemptDisabled']. So prefer avoiding using this guardian when +/// [`PreemptDisabled`]. So prefer avoiding using this guardian when /// IRQ handlers are allowed to get executed while holding the /// lock. For example, if a lock is never used in the interrupt /// context, then it is ok not to use this guardian in the process context. @@ -47,8 +55,38 @@ pub struct LocalIrqDisabled; impl Guardian for LocalIrqDisabled { type Guard = DisabledLocalIrqGuard; + type ReadGuard = DisabledLocalIrqGuard; fn guard() -> Self::Guard { disable_local() } + fn read_guard() -> Self::Guard { + disable_local() + } +} + +/// A guardian that disables IRQs while holding a write lock. +/// +/// This guardian should only be used for a [`RwLock`]. Using it with a [`SpinLock`] will behave in +/// the same way as using [`LocalIrqDisabled`]. +/// +/// When using this guardian with a [`RwLock`], holding the read lock will only disable preemption, +/// but holding a write lock will disable local IRQs. The user must ensure that the IRQ handlers +/// never take the write lock, so we can take the read lock without disabling IRQs, but we are +/// still free of deadlock even if the IRQ handlers are triggered in the middle. +/// +/// [`RwLock`]: super::RwLock +/// [`SpinLock`]: super::SpinLock +pub struct WriteIrqDisabled; + +impl Guardian for WriteIrqDisabled { + type Guard = DisabledLocalIrqGuard; + type ReadGuard = DisabledPreemptGuard; + + fn guard() -> Self::Guard { + disable_local() + } + fn read_guard() -> Self::ReadGuard { + disable_preempt() + } } diff --git a/ostd/src/sync/mod.rs b/ostd/src/sync/mod.rs index ec75ba0e..63ba1be6 100644 --- a/ostd/src/sync/mod.rs +++ b/ostd/src/sync/mod.rs @@ -15,7 +15,7 @@ mod wait; // pub use self::rcu::{pass_quiescent_state, OwnerPtr, Rcu, RcuReadGuard, RcuReclaimer}; pub(crate) use self::guard::GuardTransfer; pub use self::{ - guard::{LocalIrqDisabled, PreemptDisabled}, + guard::{LocalIrqDisabled, PreemptDisabled, WriteIrqDisabled}, mutex::{ArcMutexGuard, Mutex, MutexGuard}, rwlock::{ ArcRwLockReadGuard, ArcRwLockUpgradeableGuard, ArcRwLockWriteGuard, RwLock, diff --git a/ostd/src/sync/rwlock.rs b/ostd/src/sync/rwlock.rs index db6b023f..af61cb2d 100644 --- a/ostd/src/sync/rwlock.rs +++ b/ostd/src/sync/rwlock.rs @@ -230,7 +230,7 @@ impl RwLock { /// /// This function will never spin-wait and will return immediately. pub fn try_read(&self) -> Option> { - let guard = G::guard(); + let guard = G::read_guard(); let lock = self.lock.fetch_add(READER, Acquire); if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { Some(RwLockReadGuard { inner: self, guard }) @@ -247,7 +247,7 @@ impl RwLock { /// /// [`try_read`]: Self::try_read pub fn try_read_arc(self: &Arc) -> Option> { - let guard = G::guard(); + let guard = G::read_guard(); let lock = self.lock.fetch_add(READER, Acquire); if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { Some(ArcRwLockReadGuard { @@ -375,7 +375,7 @@ unsafe impl> + Clone + Sync, G: #[clippy::has_significant_drop] #[must_use] pub struct RwLockReadGuard_> + Clone, G: Guardian> { - guard: G::Guard, + guard: G::ReadGuard, inner: R, }