From 495c93c2ada285d8abe44e7a6abd9e7adf0917b3 Mon Sep 17 00:00:00 2001 From: jiangjianfeng Date: Tue, 19 Nov 2024 10:35:20 +0000 Subject: [PATCH] Refactor Rwlock to take type parameter --- kernel/comps/time/src/clocksource.rs | 10 +- .../comps/virtio/src/device/console/device.rs | 8 +- .../comps/virtio/src/device/input/device.rs | 8 +- kernel/libs/aster-bigtcp/src/socket/bound.rs | 4 +- kernel/src/fs/ramfs/fs.rs | 6 +- kernel/src/ipc/semaphore/system_v/sem_set.rs | 4 +- kernel/src/net/socket/ip/datagram/mod.rs | 9 +- kernel/src/net/socket/ip/stream/listen.rs | 7 +- kernel/src/net/socket/ip/stream/mod.rs | 14 +- kernel/src/net/socket/vsock/common.rs | 26 +- .../src/process/credentials/credentials_.rs | 6 +- kernel/src/process/credentials/static_cap.rs | 6 +- kernel/src/time/softirq.rs | 14 +- ostd/src/arch/x86/irq.rs | 6 +- ostd/src/mm/vm_space.rs | 4 +- ostd/src/sync/mod.rs | 4 +- ostd/src/sync/rwlock.rs | 393 +++++------------- ostd/src/sync/spin.rs | 24 +- ostd/src/task/preempt/guard.rs | 8 +- ostd/src/trap/irq.rs | 7 +- 20 files changed, 205 insertions(+), 363 deletions(-) diff --git a/kernel/comps/time/src/clocksource.rs b/kernel/comps/time/src/clocksource.rs index 658869069..61780bb0d 100644 --- a/kernel/comps/time/src/clocksource.rs +++ b/kernel/comps/time/src/clocksource.rs @@ -14,7 +14,7 @@ use alloc::sync::Arc; use core::{cmp::max, ops::Add, time::Duration}; use aster_util::coeff::Coeff; -use ostd::sync::RwLock; +use ostd::sync::{LocalIrqDisabled, RwLock}; use crate::NANOS_PER_SECOND; @@ -55,7 +55,7 @@ pub struct ClockSource { base: ClockSourceBase, coeff: Coeff, /// A record to an `Instant` and the corresponding cycles of this `ClockSource`. - last_record: RwLock<(Instant, u64)>, + last_record: RwLock<(Instant, u64), LocalIrqDisabled>, } impl ClockSource { @@ -91,7 +91,7 @@ impl ClockSource { /// Returns the calculated instant and instant cycles. fn calculate_instant(&self) -> (Instant, u64) { let (instant_cycles, last_instant, last_cycles) = { - let last_record = self.last_record.read_irq_disabled(); + let last_record = self.last_record.read(); let (last_instant, last_cycles) = *last_record; (self.read_cycles(), last_instant, last_cycles) }; @@ -121,7 +121,7 @@ impl ClockSource { /// Uses an input instant and cycles to update the `last_record` in the `ClockSource`. fn update_last_record(&self, record: (Instant, u64)) { - *self.last_record.write_irq_disabled() = record; + *self.last_record.write() = record; } /// Reads current cycles of the `ClockSource`. @@ -131,7 +131,7 @@ impl ClockSource { /// Returns the last instant and last cycles recorded in the `ClockSource`. pub fn last_record(&self) -> (Instant, u64) { - return *self.last_record.read_irq_disabled(); + return *self.last_record.read(); } /// Returns the maximum delay seconds for updating of the `ClockSource`. diff --git a/kernel/comps/virtio/src/device/console/device.rs b/kernel/comps/virtio/src/device/console/device.rs index bb4d1f89c..4e945cc71 100644 --- a/kernel/comps/virtio/src/device/console/device.rs +++ b/kernel/comps/virtio/src/device/console/device.rs @@ -9,7 +9,7 @@ use log::debug; use ostd::{ io_mem::IoMem, mm::{DmaDirection, DmaStream, DmaStreamSlice, FrameAllocOptions, VmReader}, - sync::{RwLock, SpinLock}, + sync::{LocalIrqDisabled, RwLock, SpinLock}, trap::TrapFrame, }; @@ -27,7 +27,7 @@ pub struct ConsoleDevice { transmit_queue: SpinLock, send_buffer: DmaStream, receive_buffer: DmaStream, - callbacks: RwLock>, + callbacks: RwLock, LocalIrqDisabled>, } impl AnyConsoleDevice for ConsoleDevice { @@ -54,7 +54,7 @@ impl AnyConsoleDevice for ConsoleDevice { } fn register_callback(&self, callback: &'static ConsoleCallback) { - self.callbacks.write_irq_disabled().push(callback); + self.callbacks.write().push(callback); } } @@ -136,7 +136,7 @@ impl ConsoleDevice { }; self.receive_buffer.sync(0..len as usize).unwrap(); - let callbacks = self.callbacks.read_irq_disabled(); + let callbacks = self.callbacks.read(); for callback in callbacks.iter() { let reader = self.receive_buffer.reader().unwrap().limit(len as usize); callback(reader); diff --git a/kernel/comps/virtio/src/device/input/device.rs b/kernel/comps/virtio/src/device/input/device.rs index 2378603f3..6b76cccee 100644 --- a/kernel/comps/virtio/src/device/input/device.rs +++ b/kernel/comps/virtio/src/device/input/device.rs @@ -19,7 +19,7 @@ use ostd::{ io_mem::IoMem, mm::{DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, VmIo, PAGE_SIZE}, offset_of, - sync::{RwLock, SpinLock}, + sync::{LocalIrqDisabled, RwLock, SpinLock}, trap::TrapFrame, }; @@ -76,7 +76,7 @@ pub struct InputDevice { status_queue: VirtQueue, event_table: EventTable, #[allow(clippy::type_complexity)] - callbacks: RwLock>>, + callbacks: RwLock>, LocalIrqDisabled>, transport: SpinLock>, } @@ -209,7 +209,7 @@ impl InputDevice { } fn handle_irq(&self) { - let callbacks = self.callbacks.read_irq_disabled(); + let callbacks = self.callbacks.read(); // Returns true if there may be more events to handle let handle_event = |event: &EventBuf| -> bool { event.sync().unwrap(); @@ -295,7 +295,7 @@ impl DmaBuf for SafePtr { impl aster_input::InputDevice for InputDevice { fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) { - self.callbacks.write_irq_disabled().push(Arc::new(function)) + self.callbacks.write().push(Arc::new(function)) } } diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 1fe45fd6c..1eba0421f 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -46,7 +46,7 @@ pub struct BoundSocketInner { iface: Arc>, port: u16, socket: T, - observer: RwLock>, + observer: RwLock, LocalIrqDisabled>, events: AtomicU8, next_poll_at_ms: AtomicU64, } @@ -223,7 +223,7 @@ impl BoundSocket { /// that the old observer will never be called after the setting. Users should be aware of this /// and proactively handle the race conditions if necessary. pub fn set_observer(&self, new_observer: Weak) { - *self.0.observer.write_irq_disabled() = new_observer; + *self.0.observer.write() = new_observer; self.0.on_events(); } diff --git a/kernel/src/fs/ramfs/fs.rs b/kernel/src/fs/ramfs/fs.rs index 6f36a4077..a996cf459 100644 --- a/kernel/src/fs/ramfs/fs.rs +++ b/kernel/src/fs/ramfs/fs.rs @@ -12,7 +12,7 @@ use aster_util::slot_vec::SlotVec; use hashbrown::HashMap; use ostd::{ mm::{Frame, VmIo}, - sync::RwLockWriteGuard, + sync::{PreemptDisabled, RwLockWriteGuard}, }; use super::*; @@ -1195,8 +1195,8 @@ fn write_lock_two_direntries_by_ino<'a>( this: (u64, &'a RwLock), other: (u64, &'a RwLock), ) -> ( - RwLockWriteGuard<'a, DirEntry>, - RwLockWriteGuard<'a, DirEntry>, + RwLockWriteGuard<'a, DirEntry, PreemptDisabled>, + RwLockWriteGuard<'a, DirEntry, PreemptDisabled>, ) { if this.0 < other.0 { let this = this.1.write(); diff --git a/kernel/src/ipc/semaphore/system_v/sem_set.rs b/kernel/src/ipc/semaphore/system_v/sem_set.rs index 497330d5b..fbd9b1cc0 100644 --- a/kernel/src/ipc/semaphore/system_v/sem_set.rs +++ b/kernel/src/ipc/semaphore/system_v/sem_set.rs @@ -290,11 +290,11 @@ pub fn create_sem_set(nsems: usize, mode: u16, credentials: Credentials) Ok(id) } -pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap> { +pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap, PreemptDisabled> { SEMAPHORE_SETS.read() } -pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap> { +pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap, PreemptDisabled> { SEMAPHORE_SETS.write() } diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 8f136eddd..e1c32fae3 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -6,6 +6,7 @@ use aster_bigtcp::{ socket::{SocketEventObserver, SocketEvents}, wire::IpEndpoint, }; +use ostd::sync::LocalIrqDisabled; use takeable::Takeable; use self::{bound::BoundDatagram, unbound::UnboundDatagram}; @@ -51,7 +52,7 @@ impl OptionSet { pub struct DatagramSocket { options: RwLock, - inner: RwLock>, + inner: RwLock, LocalIrqDisabled>, nonblocking: AtomicBool, pollee: Pollee, } @@ -134,7 +135,7 @@ impl DatagramSocket { } // Slow path - let mut inner = self.inner.write_irq_disabled(); + let mut inner = self.inner.write(); inner.borrow_result(|owned_inner| { let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { Ok(bound_datagram) => bound_datagram, @@ -277,7 +278,7 @@ impl Socket for DatagramSocket { let endpoint = socket_addr.try_into()?; let can_reuse = self.options.read().socket.reuse_addr(); - let mut inner = self.inner.write_irq_disabled(); + let mut inner = self.inner.write(); inner.borrow_result(|owned_inner| { let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) { Ok(bound_datagram) => bound_datagram, @@ -294,7 +295,7 @@ impl Socket for DatagramSocket { self.try_bind_ephemeral(&endpoint)?; - let mut inner = self.inner.write_irq_disabled(); + let mut inner = self.inner.write(); let Inner::Bound(bound_datagram) = inner.as_mut() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound") }; diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index fba1163eb..bf2b7ec97 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -3,6 +3,7 @@ use aster_bigtcp::{ errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint, }; +use ostd::sync::LocalIrqDisabled; use super::connected::ConnectedStream; use crate::{ @@ -16,7 +17,7 @@ pub struct ListenStream { /// A bound socket held to ensure the TCP port cannot be released bound_socket: BoundTcpSocket, /// Backlog sockets listening at the local endpoint - backlog_sockets: RwLock>, + backlog_sockets: RwLock, LocalIrqDisabled>, } impl ListenStream { @@ -40,7 +41,7 @@ impl ListenStream { /// Append sockets listening at LocalEndPoint to support backlog fn fill_backlog_sockets(&self) -> Result<()> { - let mut backlog_sockets = self.backlog_sockets.write_irq_disabled(); + let mut backlog_sockets = self.backlog_sockets.write(); let backlog = self.backlog; let current_backlog_len = backlog_sockets.len(); @@ -58,7 +59,7 @@ impl ListenStream { } pub fn try_accept(&self) -> Result { - let mut backlog_sockets = self.backlog_sockets.write_irq_disabled(); + let mut backlog_sockets = self.backlog_sockets.write(); let index = backlog_sockets .iter() diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 5eb89d9fd..3d186b8d4 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -11,7 +11,7 @@ use connecting::{ConnResult, ConnectingStream}; use init::InitStream; use listen::ListenStream; use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; -use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; +use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; use takeable::Takeable; use util::TcpOptionSet; @@ -50,7 +50,7 @@ pub use self::util::CongestionControl; pub struct StreamSocket { options: RwLock, - state: RwLock>, + state: RwLock, LocalIrqDisabled>, is_nonblocking: AtomicBool, pollee: Pollee, } @@ -116,7 +116,7 @@ impl StreamSocket { /// Ensures that the socket state is up to date and obtains a read lock on it. /// /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. - fn read_updated_state(&self) -> RwLockReadGuard> { + fn read_updated_state(&self) -> RwLockReadGuard, LocalIrqDisabled> { loop { let state = self.state.read(); match state.as_ref() { @@ -132,7 +132,7 @@ impl StreamSocket { /// Ensures that the socket state is up to date and obtains a write lock on it. /// /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. - fn write_updated_state(&self) -> RwLockWriteGuard> { + fn write_updated_state(&self) -> RwLockWriteGuard, LocalIrqDisabled> { self.update_connecting().1 } @@ -148,12 +148,12 @@ impl StreamSocket { fn update_connecting( &self, ) -> ( - RwLockWriteGuard, - RwLockWriteGuard>, + RwLockWriteGuard, + RwLockWriteGuard, LocalIrqDisabled>, ) { // Hold the lock in advance to avoid race conditions. let mut options = self.options.write(); - let mut state = self.state.write_irq_disabled(); + let mut state = self.state.write(); match state.as_ref() { State::Connecting(connection_stream) if connection_stream.has_result() => (), diff --git a/kernel/src/net/socket/vsock/common.rs b/kernel/src/net/socket/vsock/common.rs index eedf83036..da8c33953 100644 --- a/kernel/src/net/socket/vsock/common.rs +++ b/kernel/src/net/socket/vsock/common.rs @@ -7,6 +7,7 @@ use aster_virtio::device::socket::{ device::SocketDevice, error::SocketError, }; +use ostd::sync::LocalIrqDisabled; use super::{ addr::VsockSocketAddr, @@ -26,7 +27,7 @@ pub struct VsockSpace { // (key, value) = (local_addr, listen) listen_sockets: SpinLock>>, // (key, value) = (id(local_addr,peer_addr), connected) - connected_sockets: RwLock>>, + connected_sockets: RwLock>, LocalIrqDisabled>, // Used ports used_ports: SpinLock>, } @@ -54,10 +55,7 @@ impl VsockSpace { .disable_irq() .lock() .contains_key(&event.destination.into()) - || self - .connected_sockets - .read_irq_disabled() - .contains_key(&(*event).into()) + || self.connected_sockets.read().contains_key(&(*event).into()) } /// Alloc an unused port range @@ -91,13 +89,13 @@ impl VsockSpace { id: ConnectionID, connected: Arc, ) -> Option> { - let mut connected_sockets = self.connected_sockets.write_irq_disabled(); + let mut connected_sockets = self.connected_sockets.write(); connected_sockets.insert(id, connected) } /// Remove a connected socket pub fn remove_connected_socket(&self, id: &ConnectionID) -> Option> { - let mut connected_sockets = self.connected_sockets.write_irq_disabled(); + let mut connected_sockets = self.connected_sockets.write(); connected_sockets.remove(id) } @@ -214,11 +212,7 @@ impl VsockSpace { debug!("vsock receive event: {:?}", event); // The socket must be stored in the VsockSpace. - if let Some(connected) = self - .connected_sockets - .read_irq_disabled() - .get(&event.into()) - { + if let Some(connected) = self.connected_sockets.read().get(&event.into()) { connected.update_info(&event); } @@ -255,7 +249,7 @@ impl VsockSpace { connecting.set_connected(); } VsockEventType::Disconnected { .. } => { - let connected_sockets = self.connected_sockets.read_irq_disabled(); + let connected_sockets = self.connected_sockets.read(); let Some(connected) = connected_sockets.get(&event.into()) else { return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); }; @@ -263,7 +257,7 @@ impl VsockSpace { } VsockEventType::Received { .. } => {} VsockEventType::CreditRequest => { - let connected_sockets = self.connected_sockets.read_irq_disabled(); + let connected_sockets = self.connected_sockets.read(); let Some(connected) = connected_sockets.get(&event.into()) else { return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); }; @@ -272,7 +266,7 @@ impl VsockSpace { })?; } VsockEventType::CreditUpdate => { - let connected_sockets = self.connected_sockets.read_irq_disabled(); + let connected_sockets = self.connected_sockets.read(); let Some(connected) = connected_sockets.get(&event.into()) else { return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); }; @@ -289,7 +283,7 @@ impl VsockSpace { // Deal with Received before the buffer are recycled. if let VsockEventType::Received { .. } = event.event_type { // Only consider the connected socket and copy body to buffer - let connected_sockets = self.connected_sockets.read_irq_disabled(); + let connected_sockets = self.connected_sockets.read(); let connected = connected_sockets.get(&event.into()).unwrap(); debug!("Rw matches a connection with id {:?}", connected.id()); if !connected.add_connection_buffer(body) { diff --git a/kernel/src/process/credentials/credentials_.rs b/kernel/src/process/credentials/credentials_.rs index c54f4f6be..9a601fadb 100644 --- a/kernel/src/process/credentials/credentials_.rs +++ b/kernel/src/process/credentials/credentials_.rs @@ -2,7 +2,7 @@ use core::sync::atomic::Ordering; -use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; +use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; use super::{group::AtomicGid, user::AtomicUid, Gid, Uid}; use crate::{ @@ -387,11 +387,11 @@ impl Credentials_ { // ******* Supplementary groups methods ******* - pub(super) fn groups(&self) -> RwLockReadGuard> { + pub(super) fn groups(&self) -> RwLockReadGuard, PreemptDisabled> { self.supplementary_gids.read() } - pub(super) fn groups_mut(&self) -> RwLockWriteGuard> { + pub(super) fn groups_mut(&self) -> RwLockWriteGuard, PreemptDisabled> { self.supplementary_gids.write() } diff --git a/kernel/src/process/credentials/static_cap.rs b/kernel/src/process/credentials/static_cap.rs index 6323569df..5ddbc5441 100644 --- a/kernel/src/process/credentials/static_cap.rs +++ b/kernel/src/process/credentials/static_cap.rs @@ -4,7 +4,7 @@ use aster_rights::{Dup, Read, TRights, Write}; use aster_rights_proc::require; -use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; +use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; use super::{capabilities::CapSet, credentials_::Credentials_, Credentials, Gid, Uid}; use crate::prelude::*; @@ -239,7 +239,7 @@ impl Credentials { /// /// This method requires the `Read` right. #[require(R > Read)] - pub fn groups(&self) -> RwLockReadGuard> { + pub fn groups(&self) -> RwLockReadGuard, PreemptDisabled> { self.0.groups() } @@ -247,7 +247,7 @@ impl Credentials { /// /// This method requires the `Write` right. #[require(R > Write)] - pub fn groups_mut(&self) -> RwLockWriteGuard> { + pub fn groups_mut(&self) -> RwLockWriteGuard, PreemptDisabled> { self.0.groups_mut() } diff --git a/kernel/src/time/softirq.rs b/kernel/src/time/softirq.rs index 18ad97ccd..a2c4bdfe4 100644 --- a/kernel/src/time/softirq.rs +++ b/kernel/src/time/softirq.rs @@ -3,9 +3,13 @@ use alloc::{boxed::Box, vec::Vec}; use aster_softirq::{softirq_id::TIMER_SOFTIRQ_ID, SoftIrqLine}; -use ostd::{sync::RwLock, timer}; +use ostd::{ + sync::{LocalIrqDisabled, RwLock}, + timer, +}; -static TIMER_SOFTIRQ_CALLBACKS: RwLock>> = RwLock::new(Vec::new()); +static TIMER_SOFTIRQ_CALLBACKS: RwLock>, LocalIrqDisabled> = + RwLock::new(Vec::new()); pub(super) fn init() { SoftIrqLine::get(TIMER_SOFTIRQ_ID).enable(timer_softirq_handler); @@ -20,13 +24,11 @@ pub(super) fn register_callback(func: F) where F: Fn() + Sync + Send + 'static, { - TIMER_SOFTIRQ_CALLBACKS - .write_irq_disabled() - .push(Box::new(func)); + TIMER_SOFTIRQ_CALLBACKS.write().push(Box::new(func)); } fn timer_softirq_handler() { - let callbacks = TIMER_SOFTIRQ_CALLBACKS.read_irq_disabled(); + let callbacks = TIMER_SOFTIRQ_CALLBACKS.read(); for callback in callbacks.iter() { (callback)(); } diff --git a/ostd/src/arch/x86/irq.rs b/ostd/src/arch/x86/irq.rs index 3635cd75e..030ae14d7 100644 --- a/ostd/src/arch/x86/irq.rs +++ b/ostd/src/arch/x86/irq.rs @@ -13,7 +13,7 @@ use x86_64::registers::rflags::{self, RFlags}; use super::iommu::{alloc_irt_entry, has_interrupt_remapping, IrtEntryHandle}; use crate::{ cpu::CpuId, - sync::{LocalIrqDisabled, Mutex, RwLock, RwLockReadGuard, SpinLock}, + sync::{LocalIrqDisabled, Mutex, PreemptDisabled, RwLock, RwLockReadGuard, SpinLock}, trap::TrapFrame, }; @@ -119,7 +119,9 @@ impl IrqLine { self.irq_num } - pub fn callback_list(&self) -> RwLockReadGuard> { + pub fn callback_list( + &self, + ) -> RwLockReadGuard, PreemptDisabled> { self.callback_list.read() } diff --git a/ostd/src/mm/vm_space.rs b/ostd/src/mm/vm_space.rs index 60de8a73a..482ce0a52 100644 --- a/ostd/src/mm/vm_space.rs +++ b/ostd/src/mm/vm_space.rs @@ -25,7 +25,7 @@ use crate::{ Frame, PageProperty, VmReader, VmWriter, MAX_USERSPACE_VADDR, }, prelude::*, - sync::{RwLock, RwLockReadGuard}, + sync::{PreemptDisabled, RwLock, RwLockReadGuard}, task::{disable_preempt, DisabledPreemptGuard}, Error, }; @@ -283,7 +283,7 @@ impl Cursor<'_> { pub struct CursorMut<'a, 'b> { pt_cursor: page_table::CursorMut<'a, UserMode, PageTableEntry, PagingConsts>, #[allow(dead_code)] - activation_lock: RwLockReadGuard<'b, ()>, + activation_lock: RwLockReadGuard<'b, (), PreemptDisabled>, // We have a read lock so the CPU set in the flusher is always a superset // of actual activated CPUs. flusher: TlbFlusher, diff --git a/ostd/src/sync/mod.rs b/ostd/src/sync/mod.rs index 6dd9dda45..02053f2c7 100644 --- a/ostd/src/sync/mod.rs +++ b/ostd/src/sync/mod.rs @@ -22,6 +22,8 @@ pub use self::{ ArcRwMutexReadGuard, ArcRwMutexUpgradeableGuard, ArcRwMutexWriteGuard, RwMutex, RwMutexReadGuard, RwMutexUpgradeableGuard, RwMutexWriteGuard, }, - spin::{ArcSpinLockGuard, LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard}, + spin::{ + ArcSpinLockGuard, GuardTransfer, LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard, + }, wait::{WaitQueue, Waiter, Waker}, }; diff --git a/ostd/src/sync/rwlock.rs b/ostd/src/sync/rwlock.rs index 2f02bd955..dd73ff812 100644 --- a/ostd/src/sync/rwlock.rs +++ b/ostd/src/sync/rwlock.rs @@ -6,6 +6,7 @@ use alloc::sync::Arc; use core::{ cell::UnsafeCell, fmt, + marker::PhantomData, ops::{Deref, DerefMut}, sync::atomic::{ AtomicUsize, @@ -13,10 +14,7 @@ use core::{ }, }; -use crate::{ - task::{disable_preempt, DisabledPreemptGuard}, - trap::{disable_local, DisabledLocalIrqGuard}, -}; +use super::{spin::Guardian, GuardTransfer, PreemptDisabled}; /// Spin-based Read-write Lock /// @@ -33,10 +31,6 @@ use crate::{ /// periods of time, and the overhead of context switching is higher than /// the cost of spinning. /// -/// The lock provides methods to safely acquire locks with interrupts -/// disabled, preventing deadlocks in scenarios where locks are used within -/// interrupt handlers. -/// /// In addition to traditional read and write locks, this implementation /// provides the upgradeable read lock (`upread lock`). The `upread lock` /// can be upgraded to write locks atomically, useful in scenarios @@ -59,10 +53,9 @@ use crate::{ /// This lock should not be used in scenarios where lock-holding times are /// long as it can lead to CPU resource wastage due to spinning. /// -/// # Safety +/// # About Guard /// -/// Use interrupt-disabled version methods when dealing with interrupt-related read-write locks, -/// as nested interrupts may lead to a deadlock if not properly handled. +/// See the comments of [`SpinLock`]. /// /// # Examples /// @@ -99,7 +92,10 @@ use crate::{ /// assert_eq!(*w2, 7); /// } // write lock is dropped at this point /// ``` -pub struct RwLock { +/// +/// [`SpinLock`]: super::SpinLock +pub struct RwLock { + guard: PhantomData, /// The internal representation of the lock state is as follows: /// - **Bit 63:** Writer lock. /// - **Bit 62:** Upgradeable reader lock. @@ -115,152 +111,25 @@ const UPGRADEABLE_READER: usize = 1 << (usize::BITS - 2); const BEING_UPGRADED: usize = 1 << (usize::BITS - 3); const MAX_READER: usize = 1 << (usize::BITS - 4); -impl RwLock { +impl RwLock { /// Creates a new spin-based read-write lock with an initial value. pub const fn new(val: T) -> Self { Self { - val: UnsafeCell::new(val), + guard: PhantomData, lock: AtomicUsize::new(0), + val: UnsafeCell::new(val), } } } -impl RwLock { - /// Acquires a read lock while disabling the local IRQs and spin-wait - /// until it can be acquired. - /// - /// The calling thread will spin-wait until there are no writers or - /// upgrading upreaders present. There is no guarantee for the order - /// in which other readers or writers waiting simultaneously will - /// obtain the lock. Once this lock is acquired, the calling thread - /// will not be interrupted. - pub fn read_irq_disabled(&self) -> RwLockReadGuard { - loop { - if let Some(readguard) = self.try_read_irq_disabled() { - return readguard; - } else { - core::hint::spin_loop(); - } - } - } - - /// Acquires a write lock while disabling the local IRQs and spin-wait - /// until it can be acquired. - /// - /// The calling thread will spin-wait until there are no other writers, - /// upreaders or readers present. There is no guarantee for the order - /// in which other readers or writers waiting simultaneously will - /// obtain the lock. Once this lock is acquired, the calling thread - /// will not be interrupted. - pub fn write_irq_disabled(&self) -> RwLockWriteGuard { - loop { - if let Some(writeguard) = self.try_write_irq_disabled() { - return writeguard; - } else { - core::hint::spin_loop(); - } - } - } - - /// Acquires an upgradeable reader (upreader) while disabling local IRQs - /// and spin-wait until it can be acquired. - /// - /// The calling thread will spin-wait until there are no other writers, - /// or upreaders. There is no guarantee for the order in which other - /// readers or writers waiting simultaneously will obtain the lock. Once - /// this lock is acquired, the calling thread will not be interrupted. - /// - /// Upreader will not block new readers until it tries to upgrade. Upreader - /// and reader do not differ before invoking the upgread method. However, - /// only one upreader can exist at any time to avoid deadlock in the - /// upgread method. - pub fn upread_irq_disabled(&self) -> RwLockUpgradeableGuard { - loop { - if let Some(guard) = self.try_upread_irq_disabled() { - return guard; - } else { - core::hint::spin_loop(); - } - } - } - - /// Attempts to acquire a read lock while disabling local IRQs. - /// - /// This function will never spin-wait and will return immediately. When - /// multiple readers or writers attempt to acquire the lock, this method - /// does not guarantee any order. Interrupts will automatically be restored - /// when acquiring fails. - pub fn try_read_irq_disabled(&self) -> Option> { - let irq_guard = disable_local(); - let lock = self.lock.fetch_add(READER, Acquire); - if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { - Some(RwLockReadGuard { - inner: self, - inner_guard: InnerGuard::IrqGuard(irq_guard), - }) - } else { - self.lock.fetch_sub(READER, Release); - None - } - } - - /// Attempts to acquire a write lock while disabling local IRQs. - /// - /// This function will never spin-wait and will return immediately. When - /// multiple readers or writers attempt to acquire the lock, this method - /// does not guarantee any order. Interrupts will automatically be restored - /// when acquiring fails. - pub fn try_write_irq_disabled(&self) -> Option> { - let irq_guard = disable_local(); - if self - .lock - .compare_exchange(0, WRITER, Acquire, Relaxed) - .is_ok() - { - Some(RwLockWriteGuard { - inner: self, - inner_guard: InnerGuard::IrqGuard(irq_guard), - }) - } else { - None - } - } - - /// Attempts to acquire a upread lock while disabling local IRQs. - /// - /// This function will never spin-wait and will return immediately. When - /// multiple readers or writers attempt to acquire the lock, this method - /// does not guarantee any order. Interrupts will automatically be restored - /// when acquiring fails. - pub fn try_upread_irq_disabled(&self) -> Option> { - let irq_guard = disable_local(); - let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER); - if lock == 0 { - return Some(RwLockUpgradeableGuard { - inner: self, - inner_guard: InnerGuard::IrqGuard(irq_guard), - }); - } else if lock == WRITER { - self.lock.fetch_sub(UPGRADEABLE_READER, Release); - } - None - } - +impl RwLock { /// Acquires a read lock and spin-wait until it can be acquired. /// /// The calling thread will spin-wait until there are no writers or /// upgrading upreaders present. There is no guarantee for the order /// in which other readers or writers waiting simultaneously will /// obtain the lock. - /// - /// This method does not disable interrupts, so any locks related to - /// interrupt context should avoid using this method, and use [`read_irq_disabled`] - /// instead. When IRQ handlers are allowed to be executed while holding - /// this lock, it is preferable to use this method over the [`read_irq_disabled`] - /// method as it has a higher efficiency. - /// - /// [`read_irq_disabled`]: Self::read_irq_disabled - pub fn read(&self) -> RwLockReadGuard { + pub fn read(&self) -> RwLockReadGuard { loop { if let Some(readguard) = self.try_read() { return readguard; @@ -276,7 +145,7 @@ impl RwLock { /// for compile-time checked lifetimes of the read guard. /// /// [`read`]: Self::read - pub fn read_arc(self: &Arc) -> ArcRwLockReadGuard { + pub fn read_arc(self: &Arc) -> ArcRwLockReadGuard { loop { if let Some(readguard) = self.try_read_arc() { return readguard; @@ -292,15 +161,7 @@ impl RwLock { /// upreaders or readers present. There is no guarantee for the order /// in which other readers or writers waiting simultaneously will /// obtain the lock. - /// - /// This method does not disable interrupts, so any locks related to - /// interrupt context should avoid using this method, and use [`write_irq_disabled`] - /// instead. When IRQ handlers are allowed to be executed while holding - /// this lock, it is preferable to use this method over the [`write_irq_disabled`] - /// method as it has a higher efficiency. - /// - /// [`write_irq_disabled`]: Self::write_irq_disabled - pub fn write(&self) -> RwLockWriteGuard { + pub fn write(&self) -> RwLockWriteGuard { loop { if let Some(writeguard) = self.try_write() { return writeguard; @@ -316,7 +177,7 @@ impl RwLock { /// for compile-time checked lifetimes of the lock guard. /// /// [`write`]: Self::write - pub fn write_arc(self: &Arc) -> ArcRwLockWriteGuard { + pub fn write_arc(self: &Arc) -> ArcRwLockWriteGuard { loop { if let Some(writeguard) = self.try_write_arc() { return writeguard; @@ -336,15 +197,7 @@ impl RwLock { /// and reader do not differ before invoking the upgread method. However, /// only one upreader can exist at any time to avoid deadlock in the /// upgread method. - /// - /// This method does not disable interrupts, so any locks related to - /// interrupt context should avoid using this method, and use [`upread_irq_disabled`] - /// instead. When IRQ handlers are allowed to be executed while holding - /// this lock, it is preferable to use this method over the [`upread_irq_disabled`] - /// method as it has a higher efficiency. - /// - /// [`upread_irq_disabled`]: Self::upread_irq_disabled - pub fn upread(&self) -> RwLockUpgradeableGuard { + pub fn upread(&self) -> RwLockUpgradeableGuard { loop { if let Some(guard) = self.try_upread() { return guard; @@ -360,7 +213,7 @@ impl RwLock { /// for compile-time checked lifetimes of the lock guard. /// /// [`upread`]: Self::upread - pub fn upread_arc(self: &Arc) -> ArcRwLockUpgradeableGuard { + pub fn upread_arc(self: &Arc) -> ArcRwLockUpgradeableGuard { loop { if let Some(guard) = self.try_upread_arc() { return guard; @@ -373,23 +226,11 @@ impl RwLock { /// Attempts to acquire a read lock. /// /// This function will never spin-wait and will return immediately. - /// - /// This method does not disable interrupts, so any locks related to - /// interrupt context should avoid using this method, and use - /// [`try_read_irq_disabled`] instead. When IRQ handlers are allowed to - /// be executed while holding this lock, it is preferable to use this - /// method over the [`try_read_irq_disabled`] method as it has a higher - /// efficiency. - /// - /// [`try_read_irq_disabled`]: Self::try_read_irq_disabled - pub fn try_read(&self) -> Option> { - let guard = disable_preempt(); + pub fn try_read(&self) -> Option> { + let guard = G::guard(); let lock = self.lock.fetch_add(READER, Acquire); if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { - Some(RwLockReadGuard { - inner: self, - inner_guard: InnerGuard::PreemptGuard(guard), - }) + Some(RwLockReadGuard { inner: self, guard }) } else { self.lock.fetch_sub(READER, Release); None @@ -402,13 +243,13 @@ impl RwLock { /// for compile-time checked lifetimes of the lock guard. /// /// [`try_read`]: Self::try_read - pub fn try_read_arc(self: &Arc) -> Option> { - let guard = disable_preempt(); + pub fn try_read_arc(self: &Arc) -> Option> { + let guard = G::guard(); let lock = self.lock.fetch_add(READER, Acquire); if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { Some(ArcRwLockReadGuard { inner: self.clone(), - inner_guard: InnerGuard::PreemptGuard(guard), + guard, }) } else { self.lock.fetch_sub(READER, Release); @@ -419,26 +260,14 @@ impl RwLock { /// Attempts to acquire a write lock. /// /// This function will never spin-wait and will return immediately. - /// - /// This method does not disable interrupts, so any locks related to - /// interrupt context should avoid using this method, and use - /// [`try_write_irq_disabled`] instead. When IRQ handlers are allowed to - /// be executed while holding this lock, it is preferable to use this - /// method over the [`try_write_irq_disabled`] method as it has a higher - /// efficiency. - /// - /// [`try_write_irq_disabled`]: Self::try_write_irq_disabled - pub fn try_write(&self) -> Option> { - let guard = disable_preempt(); + pub fn try_write(&self) -> Option> { + let guard = G::guard(); if self .lock .compare_exchange(0, WRITER, Acquire, Relaxed) .is_ok() { - Some(RwLockWriteGuard { - inner: self, - inner_guard: InnerGuard::PreemptGuard(guard), - }) + Some(RwLockWriteGuard { inner: self, guard }) } else { None } @@ -450,8 +279,8 @@ impl RwLock { /// for compile-time checked lifetimes of the lock guard. /// /// [`try_write`]: Self::try_write - fn try_write_arc(self: &Arc) -> Option> { - let guard = disable_preempt(); + fn try_write_arc(self: &Arc) -> Option> { + let guard = G::guard(); if self .lock .compare_exchange(0, WRITER, Acquire, Relaxed) @@ -459,7 +288,7 @@ impl RwLock { { Some(ArcRwLockWriteGuard { inner: self.clone(), - inner_guard: InnerGuard::PreemptGuard(guard), + guard, }) } else { None @@ -469,23 +298,11 @@ impl RwLock { /// Attempts to acquire an upread lock. /// /// This function will never spin-wait and will return immediately. - /// - /// This method does not disable interrupts, so any locks related to - /// interrupt context should avoid using this method, and use - /// [`try_upread_irq_disabled`] instead. When IRQ handlers are allowed to - /// be executed while holding this lock, it is preferable to use this - /// method over the [`try_upread_irq_disabled`] method as it has a higher - /// efficiency. - /// - /// [`try_upread_irq_disabled`]: Self::try_upread_irq_disabled - pub fn try_upread(&self) -> Option> { - let guard = disable_preempt(); + pub fn try_upread(&self) -> Option> { + let guard = G::guard(); let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER); if lock == 0 { - return Some(RwLockUpgradeableGuard { - inner: self, - inner_guard: InnerGuard::PreemptGuard(guard), - }); + return Some(RwLockUpgradeableGuard { inner: self, guard }); } else if lock == WRITER { self.lock.fetch_sub(UPGRADEABLE_READER, Release); } @@ -498,13 +315,13 @@ impl RwLock { /// for compile-time checked lifetimes of the lock guard. /// /// [`try_upread`]: Self::try_upread - pub fn try_upread_arc(self: &Arc) -> Option> { - let guard = disable_preempt(); + pub fn try_upread_arc(self: &Arc) -> Option> { + let guard = G::guard(); let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER); if lock == 0 { return Some(ArcRwLockUpgradeableGuard { inner: self.clone(), - inner_guard: InnerGuard::PreemptGuard(guard), + guard, }); } else if lock == WRITER { self.lock.fetch_sub(UPGRADEABLE_READER, Release); @@ -513,7 +330,7 @@ impl RwLock { } } -impl fmt::Debug for RwLock { +impl fmt::Debug for RwLock { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&self.val, f) } @@ -521,62 +338,53 @@ impl fmt::Debug for RwLock { /// Because there can be more than one readers to get the T's immutable ref, /// so T must be Sync to guarantee the sharing safety. -unsafe impl Send for RwLock {} -unsafe impl Sync for RwLock {} +unsafe impl Send for RwLock {} +unsafe impl Sync for RwLock {} -impl> + Clone> !Send for RwLockWriteGuard_ {} -unsafe impl> + Clone + Sync> Sync - for RwLockWriteGuard_ +impl> + Clone, G: Guardian> !Send + for RwLockWriteGuard_ +{ +} +unsafe impl> + Clone + Sync, G: Guardian> Sync + for RwLockWriteGuard_ { } -impl> + Clone> !Send for RwLockReadGuard_ {} -unsafe impl> + Clone + Sync> Sync - for RwLockReadGuard_ +impl> + Clone, G: Guardian> !Send + for RwLockReadGuard_ +{ +} +unsafe impl> + Clone + Sync, G: Guardian> Sync + for RwLockReadGuard_ { } -impl> + Clone> !Send for RwLockUpgradeableGuard_ {} -unsafe impl> + Clone + Sync> Sync - for RwLockUpgradeableGuard_ +impl> + Clone, G: Guardian> !Send + for RwLockUpgradeableGuard_ { } - -enum InnerGuard { - IrqGuard(DisabledLocalIrqGuard), - PreemptGuard(DisabledPreemptGuard), -} - -impl InnerGuard { - /// Transfers the current guard to a new `InnerGuard` instance ensuring atomicity during lock upgrades or downgrades. - /// - /// This function guarantees that there will be no 'gaps' between the destruction of the old guard and - /// the creation of the new guard, maintaining the atomicity of lock transitions. - fn transfer_to(&mut self) -> Self { - match self { - InnerGuard::IrqGuard(irq_guard) => InnerGuard::IrqGuard(irq_guard.transfer_to()), - InnerGuard::PreemptGuard(preempt_guard) => { - InnerGuard::PreemptGuard(preempt_guard.transfer_to()) - } - } - } +unsafe impl> + Clone + Sync, G: Guardian> Sync + for RwLockUpgradeableGuard_ +{ } /// A guard that provides immutable data access. #[clippy::has_significant_drop] #[must_use] -pub struct RwLockReadGuard_> + Clone> { - inner_guard: InnerGuard, +pub struct RwLockReadGuard_> + Clone, G: Guardian> { + guard: G::Guard, inner: R, } /// A guard that provides shared read access to the data protected by a [`RwLock`]. -pub type RwLockReadGuard<'a, T> = RwLockReadGuard_>; +pub type RwLockReadGuard<'a, T, G> = RwLockReadGuard_, G>; /// A guard that provides shared read access to the data protected by a `Arc`. -pub type ArcRwLockReadGuard = RwLockReadGuard_>>; +pub type ArcRwLockReadGuard = RwLockReadGuard_>, G>; -impl> + Clone> Deref for RwLockReadGuard_ { +impl> + Clone, G: Guardian> Deref + for RwLockReadGuard_ +{ type Target = T; fn deref(&self) -> &T { @@ -584,14 +392,16 @@ impl> + Clone> Deref for RwLockReadGuard_ } } -impl> + Clone> Drop for RwLockReadGuard_ { +impl> + Clone, G: Guardian> Drop + for RwLockReadGuard_ +{ fn drop(&mut self) { self.inner.lock.fetch_sub(READER, Release); } } -impl> + Clone> fmt::Debug - for RwLockReadGuard_ +impl> + Clone, G: Guardian> fmt::Debug + for RwLockReadGuard_ { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&**self, f) @@ -599,17 +409,19 @@ impl> + Clone> fmt::Debug } /// A guard that provides mutable data access. -pub struct RwLockWriteGuard_> + Clone> { - inner_guard: InnerGuard, +pub struct RwLockWriteGuard_> + Clone, G: Guardian> { + guard: G::Guard, inner: R, } /// A guard that provides exclusive write access to the data protected by a [`RwLock`]. -pub type RwLockWriteGuard<'a, T> = RwLockWriteGuard_>; +pub type RwLockWriteGuard<'a, T, G> = RwLockWriteGuard_, G>; /// A guard that provides exclusive write access to the data protected by a `Arc`. -pub type ArcRwLockWriteGuard = RwLockWriteGuard_>>; +pub type ArcRwLockWriteGuard = RwLockWriteGuard_>, G>; -impl> + Clone> Deref for RwLockWriteGuard_ { +impl> + Clone, G: Guardian> Deref + for RwLockWriteGuard_ +{ type Target = T; fn deref(&self) -> &T { @@ -617,11 +429,11 @@ impl> + Clone> Deref for RwLockWriteGuard } } -impl> + Clone> RwLockWriteGuard_ { +impl> + Clone, G: Guardian> RwLockWriteGuard_ { /// Atomically downgrades a write guard to an upgradeable reader guard. /// /// This method always succeeds because the lock is exclusively held by the writer. - pub fn downgrade(mut self) -> RwLockUpgradeableGuard_ { + pub fn downgrade(mut self) -> RwLockUpgradeableGuard_ { loop { self = match self.try_downgrade() { Ok(guard) => return guard, @@ -632,36 +444,40 @@ impl> + Clone> RwLockWriteGuard_ { /// This is not exposed as a public method to prevent intermediate lock states from affecting the /// downgrade process. - fn try_downgrade(mut self) -> Result, Self> { + fn try_downgrade(mut self) -> Result, Self> { let inner = self.inner.clone(); let res = self .inner .lock .compare_exchange(WRITER, UPGRADEABLE_READER, AcqRel, Relaxed); if res.is_ok() { - let inner_guard = self.inner_guard.transfer_to(); + let guard = self.guard.transfer_to(); drop(self); - Ok(RwLockUpgradeableGuard_ { inner, inner_guard }) + Ok(RwLockUpgradeableGuard_ { inner, guard }) } else { Err(self) } } } -impl> + Clone> DerefMut for RwLockWriteGuard_ { +impl> + Clone, G: Guardian> DerefMut + for RwLockWriteGuard_ +{ fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.inner.val.get() } } } -impl> + Clone> Drop for RwLockWriteGuard_ { +impl> + Clone, G: Guardian> Drop + for RwLockWriteGuard_ +{ fn drop(&mut self) { self.inner.lock.fetch_and(!WRITER, Release); } } -impl> + Clone> fmt::Debug - for RwLockWriteGuard_ +impl> + Clone, G: Guardian> fmt::Debug + for RwLockWriteGuard_ { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&**self, f) @@ -670,23 +486,26 @@ impl> + Clone> fmt::Debug /// A guard that provides immutable data access but can be atomically /// upgraded to `RwLockWriteGuard`. -pub struct RwLockUpgradeableGuard_> + Clone> { - inner_guard: InnerGuard, +pub struct RwLockUpgradeableGuard_> + Clone, G: Guardian> +{ + guard: G::Guard, inner: R, } /// A upgradable guard that provides read access to the data protected by a [`RwLock`]. -pub type RwLockUpgradeableGuard<'a, T> = RwLockUpgradeableGuard_>; +pub type RwLockUpgradeableGuard<'a, T, G> = RwLockUpgradeableGuard_, G>; /// A upgradable guard that provides read access to the data protected by a `Arc`. -pub type ArcRwLockUpgradeableGuard = RwLockUpgradeableGuard_>>; +pub type ArcRwLockUpgradeableGuard = RwLockUpgradeableGuard_>, G>; -impl> + Clone> RwLockUpgradeableGuard_ { +impl> + Clone, G: Guardian> + RwLockUpgradeableGuard_ +{ /// Upgrades this upread guard to a write guard atomically. /// /// After calling this method, subsequent readers will be blocked /// while previous readers remain unaffected. The calling thread /// will spin-wait until previous readers finish. - pub fn upgrade(mut self) -> RwLockWriteGuard_ { + pub fn upgrade(mut self) -> RwLockWriteGuard_ { self.inner.lock.fetch_or(BEING_UPGRADED, Acquire); loop { self = match self.try_upgrade() { @@ -698,7 +517,7 @@ impl> + Clone> RwLockUpgradeableGuard_ Result, Self> { + pub fn try_upgrade(mut self) -> Result, Self> { let res = self.inner.lock.compare_exchange( UPGRADEABLE_READER | BEING_UPGRADED, WRITER | UPGRADEABLE_READER, @@ -707,16 +526,18 @@ impl> + Clone> RwLockUpgradeableGuard_> + Clone> Deref for RwLockUpgradeableGuard_ { +impl> + Clone, G: Guardian> Deref + for RwLockUpgradeableGuard_ +{ type Target = T; fn deref(&self) -> &T { @@ -724,14 +545,16 @@ impl> + Clone> Deref for RwLockUpgradeabl } } -impl> + Clone> Drop for RwLockUpgradeableGuard_ { +impl> + Clone, G: Guardian> Drop + for RwLockUpgradeableGuard_ +{ fn drop(&mut self) { self.inner.lock.fetch_sub(UPGRADEABLE_READER, Release); } } -impl> + Clone> fmt::Debug - for RwLockUpgradeableGuard_ +impl> + Clone, G: Guardian> fmt::Debug + for RwLockUpgradeableGuard_ { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&**self, f) diff --git a/ostd/src/sync/spin.rs b/ostd/src/sync/spin.rs index 87a6e7463..a8211f799 100644 --- a/ostd/src/sync/spin.rs +++ b/ostd/src/sync/spin.rs @@ -24,12 +24,15 @@ use crate::{ /// - if `G` is [`PreemptDisabled`], preemption is disabled; /// - if `G` is [`LocalIrqDisabled`], local IRQs are disabled. /// +/// The `G` can also be provided by other crates other than ostd, +/// if it behaves similar like [`PreemptDisabled`] or [`LocalIrqDisabled`]. +/// /// The guard behavior can be temporarily upgraded from [`PreemptDisabled`] to /// [`LocalIrqDisabled`] using the [`disable_irq`] method. /// /// [`disable_irq`]: Self::disable_irq #[repr(transparent)] -pub struct SpinLock { +pub struct SpinLock { phantom: PhantomData, /// Only the last field of a struct may have a dynamically sized type. /// That's why SpinLockInner is put in the last field. @@ -44,12 +47,23 @@ struct SpinLockInner { /// A guardian that denotes the guard behavior for holding the spin lock. pub trait Guardian { /// The guard type. - type Guard; + type Guard: GuardTransfer; /// Creates a new guard. fn guard() -> Self::Guard; } +/// The Guard can be transferred atomically. +pub trait GuardTransfer { + /// Atomically transfers the current guard to a new instance. + /// + /// This function ensures that there are no 'gaps' between the destruction of the old guard and + /// the creation of the new guard, thereby maintaining the atomicity of guard transitions. + /// + /// The original guard must be dropped immediately after calling this method. + fn transfer_to(&mut self) -> Self; +} + /// A guardian that disables preemption while holding the spin lock. pub struct PreemptDisabled; @@ -165,15 +179,15 @@ impl SpinLock { } } -impl fmt::Debug for SpinLock { +impl fmt::Debug for SpinLock { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&self.inner.val, f) } } // SAFETY: Only a single lock holder is permitted to access the inner data of Spinlock. -unsafe impl Send for SpinLock {} -unsafe impl Sync for SpinLock {} +unsafe impl Send for SpinLock {} +unsafe impl Sync for SpinLock {} /// A guard that provides exclusive access to the data protected by a [`SpinLock`]. pub type SpinLockGuard<'a, T, G> = SpinLockGuard_, G>; diff --git a/ostd/src/task/preempt/guard.rs b/ostd/src/task/preempt/guard.rs index 1aa88c6e1..6df1ba6ee 100644 --- a/ostd/src/task/preempt/guard.rs +++ b/ostd/src/task/preempt/guard.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +use crate::sync::GuardTransfer; + /// A guard for disable preempt. #[clippy::has_significant_drop] #[must_use] @@ -16,10 +18,10 @@ impl DisabledPreemptGuard { super::cpu_local::inc_guard_count(); Self { _private: () } } +} - /// Transfer this guard to a new guard. - /// This guard must be dropped after this function. - pub fn transfer_to(&self) -> Self { +impl GuardTransfer for DisabledPreemptGuard { + fn transfer_to(&mut self) -> Self { disable_preempt() } } diff --git a/ostd/src/trap/irq.rs b/ostd/src/trap/irq.rs index ca735a14d..152d933f8 100644 --- a/ostd/src/trap/irq.rs +++ b/ostd/src/trap/irq.rs @@ -7,6 +7,7 @@ use core::fmt::Debug; use crate::{ arch::irq::{self, IrqCallbackHandle, IRQ_ALLOCATOR}, prelude::*, + sync::GuardTransfer, trap::TrapFrame, Error, }; @@ -149,10 +150,10 @@ impl DisabledLocalIrqGuard { } Self { was_enabled } } +} - /// Transfers the saved IRQ status of this guard to a new guard. - /// The saved IRQ status of this guard is cleared. - pub fn transfer_to(&mut self) -> Self { +impl GuardTransfer for DisabledLocalIrqGuard { + fn transfer_to(&mut self) -> Self { let was_enabled = self.was_enabled; self.was_enabled = false; Self { was_enabled }