Refactor Rwlock to take type parameter

This commit is contained in:
jiangjianfeng
2024-11-19 10:35:20 +00:00
committed by Tate, Hongliang Tian
parent ac1a6be05d
commit 495c93c2ad
20 changed files with 205 additions and 363 deletions

View File

@ -14,7 +14,7 @@ use alloc::sync::Arc;
use core::{cmp::max, ops::Add, time::Duration}; use core::{cmp::max, ops::Add, time::Duration};
use aster_util::coeff::Coeff; use aster_util::coeff::Coeff;
use ostd::sync::RwLock; use ostd::sync::{LocalIrqDisabled, RwLock};
use crate::NANOS_PER_SECOND; use crate::NANOS_PER_SECOND;
@ -55,7 +55,7 @@ pub struct ClockSource {
base: ClockSourceBase, base: ClockSourceBase,
coeff: Coeff, coeff: Coeff,
/// A record to an `Instant` and the corresponding cycles of this `ClockSource`. /// A record to an `Instant` and the corresponding cycles of this `ClockSource`.
last_record: RwLock<(Instant, u64)>, last_record: RwLock<(Instant, u64), LocalIrqDisabled>,
} }
impl ClockSource { impl ClockSource {
@ -91,7 +91,7 @@ impl ClockSource {
/// Returns the calculated instant and instant cycles. /// Returns the calculated instant and instant cycles.
fn calculate_instant(&self) -> (Instant, u64) { fn calculate_instant(&self) -> (Instant, u64) {
let (instant_cycles, last_instant, last_cycles) = { let (instant_cycles, last_instant, last_cycles) = {
let last_record = self.last_record.read_irq_disabled(); let last_record = self.last_record.read();
let (last_instant, last_cycles) = *last_record; let (last_instant, last_cycles) = *last_record;
(self.read_cycles(), last_instant, last_cycles) (self.read_cycles(), last_instant, last_cycles)
}; };
@ -121,7 +121,7 @@ impl ClockSource {
/// Uses an input instant and cycles to update the `last_record` in the `ClockSource`. /// Uses an input instant and cycles to update the `last_record` in the `ClockSource`.
fn update_last_record(&self, record: (Instant, u64)) { fn update_last_record(&self, record: (Instant, u64)) {
*self.last_record.write_irq_disabled() = record; *self.last_record.write() = record;
} }
/// Reads current cycles of the `ClockSource`. /// Reads current cycles of the `ClockSource`.
@ -131,7 +131,7 @@ impl ClockSource {
/// Returns the last instant and last cycles recorded in the `ClockSource`. /// Returns the last instant and last cycles recorded in the `ClockSource`.
pub fn last_record(&self) -> (Instant, u64) { pub fn last_record(&self) -> (Instant, u64) {
return *self.last_record.read_irq_disabled(); return *self.last_record.read();
} }
/// Returns the maximum delay seconds for updating of the `ClockSource`. /// Returns the maximum delay seconds for updating of the `ClockSource`.

View File

@ -9,7 +9,7 @@ use log::debug;
use ostd::{ use ostd::{
io_mem::IoMem, io_mem::IoMem,
mm::{DmaDirection, DmaStream, DmaStreamSlice, FrameAllocOptions, VmReader}, mm::{DmaDirection, DmaStream, DmaStreamSlice, FrameAllocOptions, VmReader},
sync::{RwLock, SpinLock}, sync::{LocalIrqDisabled, RwLock, SpinLock},
trap::TrapFrame, trap::TrapFrame,
}; };
@ -27,7 +27,7 @@ pub struct ConsoleDevice {
transmit_queue: SpinLock<VirtQueue>, transmit_queue: SpinLock<VirtQueue>,
send_buffer: DmaStream, send_buffer: DmaStream,
receive_buffer: DmaStream, receive_buffer: DmaStream,
callbacks: RwLock<Vec<&'static ConsoleCallback>>, callbacks: RwLock<Vec<&'static ConsoleCallback>, LocalIrqDisabled>,
} }
impl AnyConsoleDevice for ConsoleDevice { impl AnyConsoleDevice for ConsoleDevice {
@ -54,7 +54,7 @@ impl AnyConsoleDevice for ConsoleDevice {
} }
fn register_callback(&self, callback: &'static ConsoleCallback) { fn register_callback(&self, callback: &'static ConsoleCallback) {
self.callbacks.write_irq_disabled().push(callback); self.callbacks.write().push(callback);
} }
} }
@ -136,7 +136,7 @@ impl ConsoleDevice {
}; };
self.receive_buffer.sync(0..len as usize).unwrap(); self.receive_buffer.sync(0..len as usize).unwrap();
let callbacks = self.callbacks.read_irq_disabled(); let callbacks = self.callbacks.read();
for callback in callbacks.iter() { for callback in callbacks.iter() {
let reader = self.receive_buffer.reader().unwrap().limit(len as usize); let reader = self.receive_buffer.reader().unwrap().limit(len as usize);
callback(reader); callback(reader);

View File

@ -19,7 +19,7 @@ use ostd::{
io_mem::IoMem, io_mem::IoMem,
mm::{DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, VmIo, PAGE_SIZE}, mm::{DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, VmIo, PAGE_SIZE},
offset_of, offset_of,
sync::{RwLock, SpinLock}, sync::{LocalIrqDisabled, RwLock, SpinLock},
trap::TrapFrame, trap::TrapFrame,
}; };
@ -76,7 +76,7 @@ pub struct InputDevice {
status_queue: VirtQueue, status_queue: VirtQueue,
event_table: EventTable, event_table: EventTable,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
callbacks: RwLock<Vec<Arc<dyn Fn(InputEvent) + Send + Sync + 'static>>>, callbacks: RwLock<Vec<Arc<dyn Fn(InputEvent) + Send + Sync + 'static>>, LocalIrqDisabled>,
transport: SpinLock<Box<dyn VirtioTransport>>, transport: SpinLock<Box<dyn VirtioTransport>>,
} }
@ -209,7 +209,7 @@ impl InputDevice {
} }
fn handle_irq(&self) { fn handle_irq(&self) {
let callbacks = self.callbacks.read_irq_disabled(); let callbacks = self.callbacks.read();
// Returns true if there may be more events to handle // Returns true if there may be more events to handle
let handle_event = |event: &EventBuf| -> bool { let handle_event = |event: &EventBuf| -> bool {
event.sync().unwrap(); event.sync().unwrap();
@ -295,7 +295,7 @@ impl<T, M: HasDaddr> DmaBuf for SafePtr<T, M> {
impl aster_input::InputDevice for InputDevice { impl aster_input::InputDevice for InputDevice {
fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) { fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) {
self.callbacks.write_irq_disabled().push(Arc::new(function)) self.callbacks.write().push(Arc::new(function))
} }
} }

View File

@ -46,7 +46,7 @@ pub struct BoundSocketInner<T, E> {
iface: Arc<dyn Iface<E>>, iface: Arc<dyn Iface<E>>,
port: u16, port: u16,
socket: T, socket: T,
observer: RwLock<Weak<dyn SocketEventObserver>>, observer: RwLock<Weak<dyn SocketEventObserver>, LocalIrqDisabled>,
events: AtomicU8, events: AtomicU8,
next_poll_at_ms: AtomicU64, next_poll_at_ms: AtomicU64,
} }
@ -223,7 +223,7 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
/// that the old observer will never be called after the setting. Users should be aware of this /// that the old observer will never be called after the setting. Users should be aware of this
/// and proactively handle the race conditions if necessary. /// and proactively handle the race conditions if necessary.
pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) { pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) {
*self.0.observer.write_irq_disabled() = new_observer; *self.0.observer.write() = new_observer;
self.0.on_events(); self.0.on_events();
} }

View File

@ -12,7 +12,7 @@ use aster_util::slot_vec::SlotVec;
use hashbrown::HashMap; use hashbrown::HashMap;
use ostd::{ use ostd::{
mm::{Frame, VmIo}, mm::{Frame, VmIo},
sync::RwLockWriteGuard, sync::{PreemptDisabled, RwLockWriteGuard},
}; };
use super::*; use super::*;
@ -1195,8 +1195,8 @@ fn write_lock_two_direntries_by_ino<'a>(
this: (u64, &'a RwLock<DirEntry>), this: (u64, &'a RwLock<DirEntry>),
other: (u64, &'a RwLock<DirEntry>), other: (u64, &'a RwLock<DirEntry>),
) -> ( ) -> (
RwLockWriteGuard<'a, DirEntry>, RwLockWriteGuard<'a, DirEntry, PreemptDisabled>,
RwLockWriteGuard<'a, DirEntry>, RwLockWriteGuard<'a, DirEntry, PreemptDisabled>,
) { ) {
if this.0 < other.0 { if this.0 < other.0 {
let this = this.1.write(); let this = this.1.write();

View File

@ -290,11 +290,11 @@ pub fn create_sem_set(nsems: usize, mode: u16, credentials: Credentials<ReadOp>)
Ok(id) Ok(id)
} }
pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap<key_t, SemaphoreSet>> { pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap<key_t, SemaphoreSet>, PreemptDisabled> {
SEMAPHORE_SETS.read() SEMAPHORE_SETS.read()
} }
pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap<key_t, SemaphoreSet>> { pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap<key_t, SemaphoreSet>, PreemptDisabled> {
SEMAPHORE_SETS.write() SEMAPHORE_SETS.write()
} }

View File

@ -6,6 +6,7 @@ use aster_bigtcp::{
socket::{SocketEventObserver, SocketEvents}, socket::{SocketEventObserver, SocketEvents},
wire::IpEndpoint, wire::IpEndpoint,
}; };
use ostd::sync::LocalIrqDisabled;
use takeable::Takeable; use takeable::Takeable;
use self::{bound::BoundDatagram, unbound::UnboundDatagram}; use self::{bound::BoundDatagram, unbound::UnboundDatagram};
@ -51,7 +52,7 @@ impl OptionSet {
pub struct DatagramSocket { pub struct DatagramSocket {
options: RwLock<OptionSet>, options: RwLock<OptionSet>,
inner: RwLock<Takeable<Inner>>, inner: RwLock<Takeable<Inner>, LocalIrqDisabled>,
nonblocking: AtomicBool, nonblocking: AtomicBool,
pollee: Pollee, pollee: Pollee,
} }
@ -134,7 +135,7 @@ impl DatagramSocket {
} }
// Slow path // Slow path
let mut inner = self.inner.write_irq_disabled(); let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| { inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram, Ok(bound_datagram) => bound_datagram,
@ -277,7 +278,7 @@ impl Socket for DatagramSocket {
let endpoint = socket_addr.try_into()?; let endpoint = socket_addr.try_into()?;
let can_reuse = self.options.read().socket.reuse_addr(); let can_reuse = self.options.read().socket.reuse_addr();
let mut inner = self.inner.write_irq_disabled(); let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| { inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) { let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) {
Ok(bound_datagram) => bound_datagram, Ok(bound_datagram) => bound_datagram,
@ -294,7 +295,7 @@ impl Socket for DatagramSocket {
self.try_bind_ephemeral(&endpoint)?; self.try_bind_ephemeral(&endpoint)?;
let mut inner = self.inner.write_irq_disabled(); let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = inner.as_mut() else { let Inner::Bound(bound_datagram) = inner.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound") return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
}; };

View File

@ -3,6 +3,7 @@
use aster_bigtcp::{ use aster_bigtcp::{
errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint, errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint,
}; };
use ostd::sync::LocalIrqDisabled;
use super::connected::ConnectedStream; use super::connected::ConnectedStream;
use crate::{ use crate::{
@ -16,7 +17,7 @@ pub struct ListenStream {
/// A bound socket held to ensure the TCP port cannot be released /// A bound socket held to ensure the TCP port cannot be released
bound_socket: BoundTcpSocket, bound_socket: BoundTcpSocket,
/// Backlog sockets listening at the local endpoint /// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>, backlog_sockets: RwLock<Vec<BacklogSocket>, LocalIrqDisabled>,
} }
impl ListenStream { impl ListenStream {
@ -40,7 +41,7 @@ impl ListenStream {
/// Append sockets listening at LocalEndPoint to support backlog /// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> { fn fill_backlog_sockets(&self) -> Result<()> {
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled(); let mut backlog_sockets = self.backlog_sockets.write();
let backlog = self.backlog; let backlog = self.backlog;
let current_backlog_len = backlog_sockets.len(); let current_backlog_len = backlog_sockets.len();
@ -58,7 +59,7 @@ impl ListenStream {
} }
pub fn try_accept(&self) -> Result<ConnectedStream> { pub fn try_accept(&self) -> Result<ConnectedStream> {
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled(); let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets let index = backlog_sockets
.iter() .iter()

View File

@ -11,7 +11,7 @@ use connecting::{ConnResult, ConnectingStream};
use init::InitStream; use init::InitStream;
use listen::ListenStream; use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use takeable::Takeable; use takeable::Takeable;
use util::TcpOptionSet; use util::TcpOptionSet;
@ -50,7 +50,7 @@ pub use self::util::CongestionControl;
pub struct StreamSocket { pub struct StreamSocket {
options: RwLock<OptionSet>, options: RwLock<OptionSet>,
state: RwLock<Takeable<State>>, state: RwLock<Takeable<State>, LocalIrqDisabled>,
is_nonblocking: AtomicBool, is_nonblocking: AtomicBool,
pollee: Pollee, pollee: Pollee,
} }
@ -116,7 +116,7 @@ impl StreamSocket {
/// Ensures that the socket state is up to date and obtains a read lock on it. /// Ensures that the socket state is up to date and obtains a read lock on it.
/// ///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`]. /// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>> { fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>, LocalIrqDisabled> {
loop { loop {
let state = self.state.read(); let state = self.state.read();
match state.as_ref() { match state.as_ref() {
@ -132,7 +132,7 @@ impl StreamSocket {
/// Ensures that the socket state is up to date and obtains a write lock on it. /// Ensures that the socket state is up to date and obtains a write lock on it.
/// ///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`]. /// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>> { fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>, LocalIrqDisabled> {
self.update_connecting().1 self.update_connecting().1
} }
@ -148,12 +148,12 @@ impl StreamSocket {
fn update_connecting( fn update_connecting(
&self, &self,
) -> ( ) -> (
RwLockWriteGuard<OptionSet>, RwLockWriteGuard<OptionSet, PreemptDisabled>,
RwLockWriteGuard<Takeable<State>>, RwLockWriteGuard<Takeable<State>, LocalIrqDisabled>,
) { ) {
// Hold the lock in advance to avoid race conditions. // Hold the lock in advance to avoid race conditions.
let mut options = self.options.write(); let mut options = self.options.write();
let mut state = self.state.write_irq_disabled(); let mut state = self.state.write();
match state.as_ref() { match state.as_ref() {
State::Connecting(connection_stream) if connection_stream.has_result() => (), State::Connecting(connection_stream) if connection_stream.has_result() => (),

View File

@ -7,6 +7,7 @@ use aster_virtio::device::socket::{
device::SocketDevice, device::SocketDevice,
error::SocketError, error::SocketError,
}; };
use ostd::sync::LocalIrqDisabled;
use super::{ use super::{
addr::VsockSocketAddr, addr::VsockSocketAddr,
@ -26,7 +27,7 @@ pub struct VsockSpace {
// (key, value) = (local_addr, listen) // (key, value) = (local_addr, listen)
listen_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Listen>>>, listen_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Listen>>>,
// (key, value) = (id(local_addr,peer_addr), connected) // (key, value) = (id(local_addr,peer_addr), connected)
connected_sockets: RwLock<BTreeMap<ConnectionID, Arc<Connected>>>, connected_sockets: RwLock<BTreeMap<ConnectionID, Arc<Connected>>, LocalIrqDisabled>,
// Used ports // Used ports
used_ports: SpinLock<BTreeSet<u32>>, used_ports: SpinLock<BTreeSet<u32>>,
} }
@ -54,10 +55,7 @@ impl VsockSpace {
.disable_irq() .disable_irq()
.lock() .lock()
.contains_key(&event.destination.into()) .contains_key(&event.destination.into())
|| self || self.connected_sockets.read().contains_key(&(*event).into())
.connected_sockets
.read_irq_disabled()
.contains_key(&(*event).into())
} }
/// Alloc an unused port range /// Alloc an unused port range
@ -91,13 +89,13 @@ impl VsockSpace {
id: ConnectionID, id: ConnectionID,
connected: Arc<Connected>, connected: Arc<Connected>,
) -> Option<Arc<Connected>> { ) -> Option<Arc<Connected>> {
let mut connected_sockets = self.connected_sockets.write_irq_disabled(); let mut connected_sockets = self.connected_sockets.write();
connected_sockets.insert(id, connected) connected_sockets.insert(id, connected)
} }
/// Remove a connected socket /// Remove a connected socket
pub fn remove_connected_socket(&self, id: &ConnectionID) -> Option<Arc<Connected>> { pub fn remove_connected_socket(&self, id: &ConnectionID) -> Option<Arc<Connected>> {
let mut connected_sockets = self.connected_sockets.write_irq_disabled(); let mut connected_sockets = self.connected_sockets.write();
connected_sockets.remove(id) connected_sockets.remove(id)
} }
@ -214,11 +212,7 @@ impl VsockSpace {
debug!("vsock receive event: {:?}", event); debug!("vsock receive event: {:?}", event);
// The socket must be stored in the VsockSpace. // The socket must be stored in the VsockSpace.
if let Some(connected) = self if let Some(connected) = self.connected_sockets.read().get(&event.into()) {
.connected_sockets
.read_irq_disabled()
.get(&event.into())
{
connected.update_info(&event); connected.update_info(&event);
} }
@ -255,7 +249,7 @@ impl VsockSpace {
connecting.set_connected(); connecting.set_connected();
} }
VsockEventType::Disconnected { .. } => { VsockEventType::Disconnected { .. } => {
let connected_sockets = self.connected_sockets.read_irq_disabled(); let connected_sockets = self.connected_sockets.read();
let Some(connected) = connected_sockets.get(&event.into()) else { let Some(connected) = connected_sockets.get(&event.into()) else {
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
}; };
@ -263,7 +257,7 @@ impl VsockSpace {
} }
VsockEventType::Received { .. } => {} VsockEventType::Received { .. } => {}
VsockEventType::CreditRequest => { VsockEventType::CreditRequest => {
let connected_sockets = self.connected_sockets.read_irq_disabled(); let connected_sockets = self.connected_sockets.read();
let Some(connected) = connected_sockets.get(&event.into()) else { let Some(connected) = connected_sockets.get(&event.into()) else {
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
}; };
@ -272,7 +266,7 @@ impl VsockSpace {
})?; })?;
} }
VsockEventType::CreditUpdate => { VsockEventType::CreditUpdate => {
let connected_sockets = self.connected_sockets.read_irq_disabled(); let connected_sockets = self.connected_sockets.read();
let Some(connected) = connected_sockets.get(&event.into()) else { let Some(connected) = connected_sockets.get(&event.into()) else {
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
}; };
@ -289,7 +283,7 @@ impl VsockSpace {
// Deal with Received before the buffer are recycled. // Deal with Received before the buffer are recycled.
if let VsockEventType::Received { .. } = event.event_type { if let VsockEventType::Received { .. } = event.event_type {
// Only consider the connected socket and copy body to buffer // Only consider the connected socket and copy body to buffer
let connected_sockets = self.connected_sockets.read_irq_disabled(); let connected_sockets = self.connected_sockets.read();
let connected = connected_sockets.get(&event.into()).unwrap(); let connected = connected_sockets.get(&event.into()).unwrap();
debug!("Rw matches a connection with id {:?}", connected.id()); debug!("Rw matches a connection with id {:?}", connected.id());
if !connected.add_connection_buffer(body) { if !connected.add_connection_buffer(body) {

View File

@ -2,7 +2,7 @@
use core::sync::atomic::Ordering; use core::sync::atomic::Ordering;
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use super::{group::AtomicGid, user::AtomicUid, Gid, Uid}; use super::{group::AtomicGid, user::AtomicUid, Gid, Uid};
use crate::{ use crate::{
@ -387,11 +387,11 @@ impl Credentials_ {
// ******* Supplementary groups methods ******* // ******* Supplementary groups methods *******
pub(super) fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>> { pub(super) fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>, PreemptDisabled> {
self.supplementary_gids.read() self.supplementary_gids.read()
} }
pub(super) fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>> { pub(super) fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>, PreemptDisabled> {
self.supplementary_gids.write() self.supplementary_gids.write()
} }

View File

@ -4,7 +4,7 @@
use aster_rights::{Dup, Read, TRights, Write}; use aster_rights::{Dup, Read, TRights, Write};
use aster_rights_proc::require; use aster_rights_proc::require;
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard}; use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use super::{capabilities::CapSet, credentials_::Credentials_, Credentials, Gid, Uid}; use super::{capabilities::CapSet, credentials_::Credentials_, Credentials, Gid, Uid};
use crate::prelude::*; use crate::prelude::*;
@ -239,7 +239,7 @@ impl<R: TRights> Credentials<R> {
/// ///
/// This method requires the `Read` right. /// This method requires the `Read` right.
#[require(R > Read)] #[require(R > Read)]
pub fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>> { pub fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>, PreemptDisabled> {
self.0.groups() self.0.groups()
} }
@ -247,7 +247,7 @@ impl<R: TRights> Credentials<R> {
/// ///
/// This method requires the `Write` right. /// This method requires the `Write` right.
#[require(R > Write)] #[require(R > Write)]
pub fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>> { pub fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>, PreemptDisabled> {
self.0.groups_mut() self.0.groups_mut()
} }

View File

@ -3,9 +3,13 @@
use alloc::{boxed::Box, vec::Vec}; use alloc::{boxed::Box, vec::Vec};
use aster_softirq::{softirq_id::TIMER_SOFTIRQ_ID, SoftIrqLine}; use aster_softirq::{softirq_id::TIMER_SOFTIRQ_ID, SoftIrqLine};
use ostd::{sync::RwLock, timer}; use ostd::{
sync::{LocalIrqDisabled, RwLock},
timer,
};
static TIMER_SOFTIRQ_CALLBACKS: RwLock<Vec<Box<dyn Fn() + Sync + Send>>> = RwLock::new(Vec::new()); static TIMER_SOFTIRQ_CALLBACKS: RwLock<Vec<Box<dyn Fn() + Sync + Send>>, LocalIrqDisabled> =
RwLock::new(Vec::new());
pub(super) fn init() { pub(super) fn init() {
SoftIrqLine::get(TIMER_SOFTIRQ_ID).enable(timer_softirq_handler); SoftIrqLine::get(TIMER_SOFTIRQ_ID).enable(timer_softirq_handler);
@ -20,13 +24,11 @@ pub(super) fn register_callback<F>(func: F)
where where
F: Fn() + Sync + Send + 'static, F: Fn() + Sync + Send + 'static,
{ {
TIMER_SOFTIRQ_CALLBACKS TIMER_SOFTIRQ_CALLBACKS.write().push(Box::new(func));
.write_irq_disabled()
.push(Box::new(func));
} }
fn timer_softirq_handler() { fn timer_softirq_handler() {
let callbacks = TIMER_SOFTIRQ_CALLBACKS.read_irq_disabled(); let callbacks = TIMER_SOFTIRQ_CALLBACKS.read();
for callback in callbacks.iter() { for callback in callbacks.iter() {
(callback)(); (callback)();
} }

View File

@ -13,7 +13,7 @@ use x86_64::registers::rflags::{self, RFlags};
use super::iommu::{alloc_irt_entry, has_interrupt_remapping, IrtEntryHandle}; use super::iommu::{alloc_irt_entry, has_interrupt_remapping, IrtEntryHandle};
use crate::{ use crate::{
cpu::CpuId, cpu::CpuId,
sync::{LocalIrqDisabled, Mutex, RwLock, RwLockReadGuard, SpinLock}, sync::{LocalIrqDisabled, Mutex, PreemptDisabled, RwLock, RwLockReadGuard, SpinLock},
trap::TrapFrame, trap::TrapFrame,
}; };
@ -119,7 +119,9 @@ impl IrqLine {
self.irq_num self.irq_num
} }
pub fn callback_list(&self) -> RwLockReadGuard<alloc::vec::Vec<CallbackElement>> { pub fn callback_list(
&self,
) -> RwLockReadGuard<alloc::vec::Vec<CallbackElement>, PreemptDisabled> {
self.callback_list.read() self.callback_list.read()
} }

View File

@ -25,7 +25,7 @@ use crate::{
Frame, PageProperty, VmReader, VmWriter, MAX_USERSPACE_VADDR, Frame, PageProperty, VmReader, VmWriter, MAX_USERSPACE_VADDR,
}, },
prelude::*, prelude::*,
sync::{RwLock, RwLockReadGuard}, sync::{PreemptDisabled, RwLock, RwLockReadGuard},
task::{disable_preempt, DisabledPreemptGuard}, task::{disable_preempt, DisabledPreemptGuard},
Error, Error,
}; };
@ -283,7 +283,7 @@ impl Cursor<'_> {
pub struct CursorMut<'a, 'b> { pub struct CursorMut<'a, 'b> {
pt_cursor: page_table::CursorMut<'a, UserMode, PageTableEntry, PagingConsts>, pt_cursor: page_table::CursorMut<'a, UserMode, PageTableEntry, PagingConsts>,
#[allow(dead_code)] #[allow(dead_code)]
activation_lock: RwLockReadGuard<'b, ()>, activation_lock: RwLockReadGuard<'b, (), PreemptDisabled>,
// We have a read lock so the CPU set in the flusher is always a superset // We have a read lock so the CPU set in the flusher is always a superset
// of actual activated CPUs. // of actual activated CPUs.
flusher: TlbFlusher<DisabledPreemptGuard>, flusher: TlbFlusher<DisabledPreemptGuard>,

View File

@ -22,6 +22,8 @@ pub use self::{
ArcRwMutexReadGuard, ArcRwMutexUpgradeableGuard, ArcRwMutexWriteGuard, RwMutex, ArcRwMutexReadGuard, ArcRwMutexUpgradeableGuard, ArcRwMutexWriteGuard, RwMutex,
RwMutexReadGuard, RwMutexUpgradeableGuard, RwMutexWriteGuard, RwMutexReadGuard, RwMutexUpgradeableGuard, RwMutexWriteGuard,
}, },
spin::{ArcSpinLockGuard, LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard}, spin::{
ArcSpinLockGuard, GuardTransfer, LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard,
},
wait::{WaitQueue, Waiter, Waker}, wait::{WaitQueue, Waiter, Waker},
}; };

View File

@ -6,6 +6,7 @@ use alloc::sync::Arc;
use core::{ use core::{
cell::UnsafeCell, cell::UnsafeCell,
fmt, fmt,
marker::PhantomData,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::atomic::{ sync::atomic::{
AtomicUsize, AtomicUsize,
@ -13,10 +14,7 @@ use core::{
}, },
}; };
use crate::{ use super::{spin::Guardian, GuardTransfer, PreemptDisabled};
task::{disable_preempt, DisabledPreemptGuard},
trap::{disable_local, DisabledLocalIrqGuard},
};
/// Spin-based Read-write Lock /// Spin-based Read-write Lock
/// ///
@ -33,10 +31,6 @@ use crate::{
/// periods of time, and the overhead of context switching is higher than /// periods of time, and the overhead of context switching is higher than
/// the cost of spinning. /// the cost of spinning.
/// ///
/// The lock provides methods to safely acquire locks with interrupts
/// disabled, preventing deadlocks in scenarios where locks are used within
/// interrupt handlers.
///
/// In addition to traditional read and write locks, this implementation /// In addition to traditional read and write locks, this implementation
/// provides the upgradeable read lock (`upread lock`). The `upread lock` /// provides the upgradeable read lock (`upread lock`). The `upread lock`
/// can be upgraded to write locks atomically, useful in scenarios /// can be upgraded to write locks atomically, useful in scenarios
@ -59,10 +53,9 @@ use crate::{
/// This lock should not be used in scenarios where lock-holding times are /// This lock should not be used in scenarios where lock-holding times are
/// long as it can lead to CPU resource wastage due to spinning. /// long as it can lead to CPU resource wastage due to spinning.
/// ///
/// # Safety /// # About Guard
/// ///
/// Use interrupt-disabled version methods when dealing with interrupt-related read-write locks, /// See the comments of [`SpinLock`].
/// as nested interrupts may lead to a deadlock if not properly handled.
/// ///
/// # Examples /// # Examples
/// ///
@ -99,7 +92,10 @@ use crate::{
/// assert_eq!(*w2, 7); /// assert_eq!(*w2, 7);
/// } // write lock is dropped at this point /// } // write lock is dropped at this point
/// ``` /// ```
pub struct RwLock<T: ?Sized> { ///
/// [`SpinLock`]: super::SpinLock
pub struct RwLock<T: ?Sized, Guard = PreemptDisabled> {
guard: PhantomData<Guard>,
/// The internal representation of the lock state is as follows: /// The internal representation of the lock state is as follows:
/// - **Bit 63:** Writer lock. /// - **Bit 63:** Writer lock.
/// - **Bit 62:** Upgradeable reader lock. /// - **Bit 62:** Upgradeable reader lock.
@ -115,152 +111,25 @@ const UPGRADEABLE_READER: usize = 1 << (usize::BITS - 2);
const BEING_UPGRADED: usize = 1 << (usize::BITS - 3); const BEING_UPGRADED: usize = 1 << (usize::BITS - 3);
const MAX_READER: usize = 1 << (usize::BITS - 4); const MAX_READER: usize = 1 << (usize::BITS - 4);
impl<T> RwLock<T> { impl<T, G> RwLock<T, G> {
/// Creates a new spin-based read-write lock with an initial value. /// Creates a new spin-based read-write lock with an initial value.
pub const fn new(val: T) -> Self { pub const fn new(val: T) -> Self {
Self { Self {
val: UnsafeCell::new(val), guard: PhantomData,
lock: AtomicUsize::new(0), lock: AtomicUsize::new(0),
val: UnsafeCell::new(val),
} }
} }
} }
impl<T: ?Sized> RwLock<T> { impl<T: ?Sized, G: Guardian> RwLock<T, G> {
/// Acquires a read lock while disabling the local IRQs and spin-wait
/// until it can be acquired.
///
/// The calling thread will spin-wait until there are no writers or
/// upgrading upreaders present. There is no guarantee for the order
/// in which other readers or writers waiting simultaneously will
/// obtain the lock. Once this lock is acquired, the calling thread
/// will not be interrupted.
pub fn read_irq_disabled(&self) -> RwLockReadGuard<T> {
loop {
if let Some(readguard) = self.try_read_irq_disabled() {
return readguard;
} else {
core::hint::spin_loop();
}
}
}
/// Acquires a write lock while disabling the local IRQs and spin-wait
/// until it can be acquired.
///
/// The calling thread will spin-wait until there are no other writers,
/// upreaders or readers present. There is no guarantee for the order
/// in which other readers or writers waiting simultaneously will
/// obtain the lock. Once this lock is acquired, the calling thread
/// will not be interrupted.
pub fn write_irq_disabled(&self) -> RwLockWriteGuard<T> {
loop {
if let Some(writeguard) = self.try_write_irq_disabled() {
return writeguard;
} else {
core::hint::spin_loop();
}
}
}
/// Acquires an upgradeable reader (upreader) while disabling local IRQs
/// and spin-wait until it can be acquired.
///
/// The calling thread will spin-wait until there are no other writers,
/// or upreaders. There is no guarantee for the order in which other
/// readers or writers waiting simultaneously will obtain the lock. Once
/// this lock is acquired, the calling thread will not be interrupted.
///
/// Upreader will not block new readers until it tries to upgrade. Upreader
/// and reader do not differ before invoking the upgread method. However,
/// only one upreader can exist at any time to avoid deadlock in the
/// upgread method.
pub fn upread_irq_disabled(&self) -> RwLockUpgradeableGuard<T> {
loop {
if let Some(guard) = self.try_upread_irq_disabled() {
return guard;
} else {
core::hint::spin_loop();
}
}
}
/// Attempts to acquire a read lock while disabling local IRQs.
///
/// This function will never spin-wait and will return immediately. When
/// multiple readers or writers attempt to acquire the lock, this method
/// does not guarantee any order. Interrupts will automatically be restored
/// when acquiring fails.
pub fn try_read_irq_disabled(&self) -> Option<RwLockReadGuard<T>> {
let irq_guard = disable_local();
let lock = self.lock.fetch_add(READER, Acquire);
if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 {
Some(RwLockReadGuard {
inner: self,
inner_guard: InnerGuard::IrqGuard(irq_guard),
})
} else {
self.lock.fetch_sub(READER, Release);
None
}
}
/// Attempts to acquire a write lock while disabling local IRQs.
///
/// This function will never spin-wait and will return immediately. When
/// multiple readers or writers attempt to acquire the lock, this method
/// does not guarantee any order. Interrupts will automatically be restored
/// when acquiring fails.
pub fn try_write_irq_disabled(&self) -> Option<RwLockWriteGuard<T>> {
let irq_guard = disable_local();
if self
.lock
.compare_exchange(0, WRITER, Acquire, Relaxed)
.is_ok()
{
Some(RwLockWriteGuard {
inner: self,
inner_guard: InnerGuard::IrqGuard(irq_guard),
})
} else {
None
}
}
/// Attempts to acquire a upread lock while disabling local IRQs.
///
/// This function will never spin-wait and will return immediately. When
/// multiple readers or writers attempt to acquire the lock, this method
/// does not guarantee any order. Interrupts will automatically be restored
/// when acquiring fails.
pub fn try_upread_irq_disabled(&self) -> Option<RwLockUpgradeableGuard<T>> {
let irq_guard = disable_local();
let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER);
if lock == 0 {
return Some(RwLockUpgradeableGuard {
inner: self,
inner_guard: InnerGuard::IrqGuard(irq_guard),
});
} else if lock == WRITER {
self.lock.fetch_sub(UPGRADEABLE_READER, Release);
}
None
}
/// Acquires a read lock and spin-wait until it can be acquired. /// Acquires a read lock and spin-wait until it can be acquired.
/// ///
/// The calling thread will spin-wait until there are no writers or /// The calling thread will spin-wait until there are no writers or
/// upgrading upreaders present. There is no guarantee for the order /// upgrading upreaders present. There is no guarantee for the order
/// in which other readers or writers waiting simultaneously will /// in which other readers or writers waiting simultaneously will
/// obtain the lock. /// obtain the lock.
/// pub fn read(&self) -> RwLockReadGuard<T, G> {
/// This method does not disable interrupts, so any locks related to
/// interrupt context should avoid using this method, and use [`read_irq_disabled`]
/// instead. When IRQ handlers are allowed to be executed while holding
/// this lock, it is preferable to use this method over the [`read_irq_disabled`]
/// method as it has a higher efficiency.
///
/// [`read_irq_disabled`]: Self::read_irq_disabled
pub fn read(&self) -> RwLockReadGuard<T> {
loop { loop {
if let Some(readguard) = self.try_read() { if let Some(readguard) = self.try_read() {
return readguard; return readguard;
@ -276,7 +145,7 @@ impl<T: ?Sized> RwLock<T> {
/// for compile-time checked lifetimes of the read guard. /// for compile-time checked lifetimes of the read guard.
/// ///
/// [`read`]: Self::read /// [`read`]: Self::read
pub fn read_arc(self: &Arc<Self>) -> ArcRwLockReadGuard<T> { pub fn read_arc(self: &Arc<Self>) -> ArcRwLockReadGuard<T, G> {
loop { loop {
if let Some(readguard) = self.try_read_arc() { if let Some(readguard) = self.try_read_arc() {
return readguard; return readguard;
@ -292,15 +161,7 @@ impl<T: ?Sized> RwLock<T> {
/// upreaders or readers present. There is no guarantee for the order /// upreaders or readers present. There is no guarantee for the order
/// in which other readers or writers waiting simultaneously will /// in which other readers or writers waiting simultaneously will
/// obtain the lock. /// obtain the lock.
/// pub fn write(&self) -> RwLockWriteGuard<T, G> {
/// This method does not disable interrupts, so any locks related to
/// interrupt context should avoid using this method, and use [`write_irq_disabled`]
/// instead. When IRQ handlers are allowed to be executed while holding
/// this lock, it is preferable to use this method over the [`write_irq_disabled`]
/// method as it has a higher efficiency.
///
/// [`write_irq_disabled`]: Self::write_irq_disabled
pub fn write(&self) -> RwLockWriteGuard<T> {
loop { loop {
if let Some(writeguard) = self.try_write() { if let Some(writeguard) = self.try_write() {
return writeguard; return writeguard;
@ -316,7 +177,7 @@ impl<T: ?Sized> RwLock<T> {
/// for compile-time checked lifetimes of the lock guard. /// for compile-time checked lifetimes of the lock guard.
/// ///
/// [`write`]: Self::write /// [`write`]: Self::write
pub fn write_arc(self: &Arc<Self>) -> ArcRwLockWriteGuard<T> { pub fn write_arc(self: &Arc<Self>) -> ArcRwLockWriteGuard<T, G> {
loop { loop {
if let Some(writeguard) = self.try_write_arc() { if let Some(writeguard) = self.try_write_arc() {
return writeguard; return writeguard;
@ -336,15 +197,7 @@ impl<T: ?Sized> RwLock<T> {
/// and reader do not differ before invoking the upgread method. However, /// and reader do not differ before invoking the upgread method. However,
/// only one upreader can exist at any time to avoid deadlock in the /// only one upreader can exist at any time to avoid deadlock in the
/// upgread method. /// upgread method.
/// pub fn upread(&self) -> RwLockUpgradeableGuard<T, G> {
/// This method does not disable interrupts, so any locks related to
/// interrupt context should avoid using this method, and use [`upread_irq_disabled`]
/// instead. When IRQ handlers are allowed to be executed while holding
/// this lock, it is preferable to use this method over the [`upread_irq_disabled`]
/// method as it has a higher efficiency.
///
/// [`upread_irq_disabled`]: Self::upread_irq_disabled
pub fn upread(&self) -> RwLockUpgradeableGuard<T> {
loop { loop {
if let Some(guard) = self.try_upread() { if let Some(guard) = self.try_upread() {
return guard; return guard;
@ -360,7 +213,7 @@ impl<T: ?Sized> RwLock<T> {
/// for compile-time checked lifetimes of the lock guard. /// for compile-time checked lifetimes of the lock guard.
/// ///
/// [`upread`]: Self::upread /// [`upread`]: Self::upread
pub fn upread_arc(self: &Arc<Self>) -> ArcRwLockUpgradeableGuard<T> { pub fn upread_arc(self: &Arc<Self>) -> ArcRwLockUpgradeableGuard<T, G> {
loop { loop {
if let Some(guard) = self.try_upread_arc() { if let Some(guard) = self.try_upread_arc() {
return guard; return guard;
@ -373,23 +226,11 @@ impl<T: ?Sized> RwLock<T> {
/// Attempts to acquire a read lock. /// Attempts to acquire a read lock.
/// ///
/// This function will never spin-wait and will return immediately. /// This function will never spin-wait and will return immediately.
/// pub fn try_read(&self) -> Option<RwLockReadGuard<T, G>> {
/// This method does not disable interrupts, so any locks related to let guard = G::guard();
/// interrupt context should avoid using this method, and use
/// [`try_read_irq_disabled`] instead. When IRQ handlers are allowed to
/// be executed while holding this lock, it is preferable to use this
/// method over the [`try_read_irq_disabled`] method as it has a higher
/// efficiency.
///
/// [`try_read_irq_disabled`]: Self::try_read_irq_disabled
pub fn try_read(&self) -> Option<RwLockReadGuard<T>> {
let guard = disable_preempt();
let lock = self.lock.fetch_add(READER, Acquire); let lock = self.lock.fetch_add(READER, Acquire);
if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 {
Some(RwLockReadGuard { Some(RwLockReadGuard { inner: self, guard })
inner: self,
inner_guard: InnerGuard::PreemptGuard(guard),
})
} else { } else {
self.lock.fetch_sub(READER, Release); self.lock.fetch_sub(READER, Release);
None None
@ -402,13 +243,13 @@ impl<T: ?Sized> RwLock<T> {
/// for compile-time checked lifetimes of the lock guard. /// for compile-time checked lifetimes of the lock guard.
/// ///
/// [`try_read`]: Self::try_read /// [`try_read`]: Self::try_read
pub fn try_read_arc(self: &Arc<Self>) -> Option<ArcRwLockReadGuard<T>> { pub fn try_read_arc(self: &Arc<Self>) -> Option<ArcRwLockReadGuard<T, G>> {
let guard = disable_preempt(); let guard = G::guard();
let lock = self.lock.fetch_add(READER, Acquire); let lock = self.lock.fetch_add(READER, Acquire);
if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 { if lock & (WRITER | MAX_READER | BEING_UPGRADED) == 0 {
Some(ArcRwLockReadGuard { Some(ArcRwLockReadGuard {
inner: self.clone(), inner: self.clone(),
inner_guard: InnerGuard::PreemptGuard(guard), guard,
}) })
} else { } else {
self.lock.fetch_sub(READER, Release); self.lock.fetch_sub(READER, Release);
@ -419,26 +260,14 @@ impl<T: ?Sized> RwLock<T> {
/// Attempts to acquire a write lock. /// Attempts to acquire a write lock.
/// ///
/// This function will never spin-wait and will return immediately. /// This function will never spin-wait and will return immediately.
/// pub fn try_write(&self) -> Option<RwLockWriteGuard<T, G>> {
/// This method does not disable interrupts, so any locks related to let guard = G::guard();
/// interrupt context should avoid using this method, and use
/// [`try_write_irq_disabled`] instead. When IRQ handlers are allowed to
/// be executed while holding this lock, it is preferable to use this
/// method over the [`try_write_irq_disabled`] method as it has a higher
/// efficiency.
///
/// [`try_write_irq_disabled`]: Self::try_write_irq_disabled
pub fn try_write(&self) -> Option<RwLockWriteGuard<T>> {
let guard = disable_preempt();
if self if self
.lock .lock
.compare_exchange(0, WRITER, Acquire, Relaxed) .compare_exchange(0, WRITER, Acquire, Relaxed)
.is_ok() .is_ok()
{ {
Some(RwLockWriteGuard { Some(RwLockWriteGuard { inner: self, guard })
inner: self,
inner_guard: InnerGuard::PreemptGuard(guard),
})
} else { } else {
None None
} }
@ -450,8 +279,8 @@ impl<T: ?Sized> RwLock<T> {
/// for compile-time checked lifetimes of the lock guard. /// for compile-time checked lifetimes of the lock guard.
/// ///
/// [`try_write`]: Self::try_write /// [`try_write`]: Self::try_write
fn try_write_arc(self: &Arc<Self>) -> Option<ArcRwLockWriteGuard<T>> { fn try_write_arc(self: &Arc<Self>) -> Option<ArcRwLockWriteGuard<T, G>> {
let guard = disable_preempt(); let guard = G::guard();
if self if self
.lock .lock
.compare_exchange(0, WRITER, Acquire, Relaxed) .compare_exchange(0, WRITER, Acquire, Relaxed)
@ -459,7 +288,7 @@ impl<T: ?Sized> RwLock<T> {
{ {
Some(ArcRwLockWriteGuard { Some(ArcRwLockWriteGuard {
inner: self.clone(), inner: self.clone(),
inner_guard: InnerGuard::PreemptGuard(guard), guard,
}) })
} else { } else {
None None
@ -469,23 +298,11 @@ impl<T: ?Sized> RwLock<T> {
/// Attempts to acquire an upread lock. /// Attempts to acquire an upread lock.
/// ///
/// This function will never spin-wait and will return immediately. /// This function will never spin-wait and will return immediately.
/// pub fn try_upread(&self) -> Option<RwLockUpgradeableGuard<T, G>> {
/// This method does not disable interrupts, so any locks related to let guard = G::guard();
/// interrupt context should avoid using this method, and use
/// [`try_upread_irq_disabled`] instead. When IRQ handlers are allowed to
/// be executed while holding this lock, it is preferable to use this
/// method over the [`try_upread_irq_disabled`] method as it has a higher
/// efficiency.
///
/// [`try_upread_irq_disabled`]: Self::try_upread_irq_disabled
pub fn try_upread(&self) -> Option<RwLockUpgradeableGuard<T>> {
let guard = disable_preempt();
let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER); let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER);
if lock == 0 { if lock == 0 {
return Some(RwLockUpgradeableGuard { return Some(RwLockUpgradeableGuard { inner: self, guard });
inner: self,
inner_guard: InnerGuard::PreemptGuard(guard),
});
} else if lock == WRITER { } else if lock == WRITER {
self.lock.fetch_sub(UPGRADEABLE_READER, Release); self.lock.fetch_sub(UPGRADEABLE_READER, Release);
} }
@ -498,13 +315,13 @@ impl<T: ?Sized> RwLock<T> {
/// for compile-time checked lifetimes of the lock guard. /// for compile-time checked lifetimes of the lock guard.
/// ///
/// [`try_upread`]: Self::try_upread /// [`try_upread`]: Self::try_upread
pub fn try_upread_arc(self: &Arc<Self>) -> Option<ArcRwLockUpgradeableGuard<T>> { pub fn try_upread_arc(self: &Arc<Self>) -> Option<ArcRwLockUpgradeableGuard<T, G>> {
let guard = disable_preempt(); let guard = G::guard();
let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER); let lock = self.lock.fetch_or(UPGRADEABLE_READER, Acquire) & (WRITER | UPGRADEABLE_READER);
if lock == 0 { if lock == 0 {
return Some(ArcRwLockUpgradeableGuard { return Some(ArcRwLockUpgradeableGuard {
inner: self.clone(), inner: self.clone(),
inner_guard: InnerGuard::PreemptGuard(guard), guard,
}); });
} else if lock == WRITER { } else if lock == WRITER {
self.lock.fetch_sub(UPGRADEABLE_READER, Release); self.lock.fetch_sub(UPGRADEABLE_READER, Release);
@ -513,7 +330,7 @@ impl<T: ?Sized> RwLock<T> {
} }
} }
impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLock<T> { impl<T: ?Sized + fmt::Debug, G> fmt::Debug for RwLock<T, G> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.val, f) fmt::Debug::fmt(&self.val, f)
} }
@ -521,62 +338,53 @@ impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLock<T> {
/// Because there can be more than one readers to get the T's immutable ref, /// Because there can be more than one readers to get the T's immutable ref,
/// so T must be Sync to guarantee the sharing safety. /// so T must be Sync to guarantee the sharing safety.
unsafe impl<T: ?Sized + Send> Send for RwLock<T> {} unsafe impl<T: ?Sized + Send, G> Send for RwLock<T, G> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {} unsafe impl<T: ?Sized + Send + Sync, G> Sync for RwLock<T, G> {}
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> !Send for RwLockWriteGuard_<T, R> {} impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> !Send
unsafe impl<T: ?Sized + Sync, R: Deref<Target = RwLock<T>> + Clone + Sync> Sync for RwLockWriteGuard_<T, R, G>
for RwLockWriteGuard_<T, R> {
}
unsafe impl<T: ?Sized + Sync, R: Deref<Target = RwLock<T, G>> + Clone + Sync, G: Guardian> Sync
for RwLockWriteGuard_<T, R, G>
{ {
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> !Send for RwLockReadGuard_<T, R> {} impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> !Send
unsafe impl<T: ?Sized + Sync, R: Deref<Target = RwLock<T>> + Clone + Sync> Sync for RwLockReadGuard_<T, R, G>
for RwLockReadGuard_<T, R> {
}
unsafe impl<T: ?Sized + Sync, R: Deref<Target = RwLock<T, G>> + Clone + Sync, G: Guardian> Sync
for RwLockReadGuard_<T, R, G>
{ {
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> !Send for RwLockUpgradeableGuard_<T, R> {} impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> !Send
unsafe impl<T: ?Sized + Sync, R: Deref<Target = RwLock<T>> + Clone + Sync> Sync for RwLockUpgradeableGuard_<T, R, G>
for RwLockUpgradeableGuard_<T, R>
{ {
} }
unsafe impl<T: ?Sized + Sync, R: Deref<Target = RwLock<T, G>> + Clone + Sync, G: Guardian> Sync
enum InnerGuard { for RwLockUpgradeableGuard_<T, R, G>
IrqGuard(DisabledLocalIrqGuard), {
PreemptGuard(DisabledPreemptGuard),
}
impl InnerGuard {
/// Transfers the current guard to a new `InnerGuard` instance ensuring atomicity during lock upgrades or downgrades.
///
/// This function guarantees that there will be no 'gaps' between the destruction of the old guard and
/// the creation of the new guard, maintaining the atomicity of lock transitions.
fn transfer_to(&mut self) -> Self {
match self {
InnerGuard::IrqGuard(irq_guard) => InnerGuard::IrqGuard(irq_guard.transfer_to()),
InnerGuard::PreemptGuard(preempt_guard) => {
InnerGuard::PreemptGuard(preempt_guard.transfer_to())
}
}
}
} }
/// A guard that provides immutable data access. /// A guard that provides immutable data access.
#[clippy::has_significant_drop] #[clippy::has_significant_drop]
#[must_use] #[must_use]
pub struct RwLockReadGuard_<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> { pub struct RwLockReadGuard_<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> {
inner_guard: InnerGuard, guard: G::Guard,
inner: R, inner: R,
} }
/// A guard that provides shared read access to the data protected by a [`RwLock`]. /// A guard that provides shared read access to the data protected by a [`RwLock`].
pub type RwLockReadGuard<'a, T> = RwLockReadGuard_<T, &'a RwLock<T>>; pub type RwLockReadGuard<'a, T, G> = RwLockReadGuard_<T, &'a RwLock<T, G>, G>;
/// A guard that provides shared read access to the data protected by a `Arc<RwLock>`. /// A guard that provides shared read access to the data protected by a `Arc<RwLock>`.
pub type ArcRwLockReadGuard<T> = RwLockReadGuard_<T, Arc<RwLock<T>>>; pub type ArcRwLockReadGuard<T, G> = RwLockReadGuard_<T, Arc<RwLock<T, G>>, G>;
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Deref for RwLockReadGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> Deref
for RwLockReadGuard_<T, R, G>
{
type Target = T; type Target = T;
fn deref(&self) -> &T { fn deref(&self) -> &T {
@ -584,14 +392,16 @@ impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Deref for RwLockReadGuard_
} }
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Drop for RwLockReadGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> Drop
for RwLockReadGuard_<T, R, G>
{
fn drop(&mut self) { fn drop(&mut self) {
self.inner.lock.fetch_sub(READER, Release); self.inner.lock.fetch_sub(READER, Release);
} }
} }
impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T>> + Clone> fmt::Debug impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> fmt::Debug
for RwLockReadGuard_<T, R> for RwLockReadGuard_<T, R, G>
{ {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f) fmt::Debug::fmt(&**self, f)
@ -599,17 +409,19 @@ impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T>> + Clone> fmt::Debug
} }
/// A guard that provides mutable data access. /// A guard that provides mutable data access.
pub struct RwLockWriteGuard_<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> { pub struct RwLockWriteGuard_<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> {
inner_guard: InnerGuard, guard: G::Guard,
inner: R, inner: R,
} }
/// A guard that provides exclusive write access to the data protected by a [`RwLock`]. /// A guard that provides exclusive write access to the data protected by a [`RwLock`].
pub type RwLockWriteGuard<'a, T> = RwLockWriteGuard_<T, &'a RwLock<T>>; pub type RwLockWriteGuard<'a, T, G> = RwLockWriteGuard_<T, &'a RwLock<T, G>, G>;
/// A guard that provides exclusive write access to the data protected by a `Arc<RwLock>`. /// A guard that provides exclusive write access to the data protected by a `Arc<RwLock>`.
pub type ArcRwLockWriteGuard<T> = RwLockWriteGuard_<T, Arc<RwLock<T>>>; pub type ArcRwLockWriteGuard<T, G> = RwLockWriteGuard_<T, Arc<RwLock<T, G>>, G>;
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Deref for RwLockWriteGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> Deref
for RwLockWriteGuard_<T, R, G>
{
type Target = T; type Target = T;
fn deref(&self) -> &T { fn deref(&self) -> &T {
@ -617,11 +429,11 @@ impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Deref for RwLockWriteGuard
} }
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> RwLockWriteGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> RwLockWriteGuard_<T, R, G> {
/// Atomically downgrades a write guard to an upgradeable reader guard. /// Atomically downgrades a write guard to an upgradeable reader guard.
/// ///
/// This method always succeeds because the lock is exclusively held by the writer. /// This method always succeeds because the lock is exclusively held by the writer.
pub fn downgrade(mut self) -> RwLockUpgradeableGuard_<T, R> { pub fn downgrade(mut self) -> RwLockUpgradeableGuard_<T, R, G> {
loop { loop {
self = match self.try_downgrade() { self = match self.try_downgrade() {
Ok(guard) => return guard, Ok(guard) => return guard,
@ -632,36 +444,40 @@ impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> RwLockWriteGuard_<T, R> {
/// This is not exposed as a public method to prevent intermediate lock states from affecting the /// This is not exposed as a public method to prevent intermediate lock states from affecting the
/// downgrade process. /// downgrade process.
fn try_downgrade(mut self) -> Result<RwLockUpgradeableGuard_<T, R>, Self> { fn try_downgrade(mut self) -> Result<RwLockUpgradeableGuard_<T, R, G>, Self> {
let inner = self.inner.clone(); let inner = self.inner.clone();
let res = self let res = self
.inner .inner
.lock .lock
.compare_exchange(WRITER, UPGRADEABLE_READER, AcqRel, Relaxed); .compare_exchange(WRITER, UPGRADEABLE_READER, AcqRel, Relaxed);
if res.is_ok() { if res.is_ok() {
let inner_guard = self.inner_guard.transfer_to(); let guard = self.guard.transfer_to();
drop(self); drop(self);
Ok(RwLockUpgradeableGuard_ { inner, inner_guard }) Ok(RwLockUpgradeableGuard_ { inner, guard })
} else { } else {
Err(self) Err(self)
} }
} }
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> DerefMut for RwLockWriteGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> DerefMut
for RwLockWriteGuard_<T, R, G>
{
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.inner.val.get() } unsafe { &mut *self.inner.val.get() }
} }
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Drop for RwLockWriteGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> Drop
for RwLockWriteGuard_<T, R, G>
{
fn drop(&mut self) { fn drop(&mut self) {
self.inner.lock.fetch_and(!WRITER, Release); self.inner.lock.fetch_and(!WRITER, Release);
} }
} }
impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T>> + Clone> fmt::Debug impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> fmt::Debug
for RwLockWriteGuard_<T, R> for RwLockWriteGuard_<T, R, G>
{ {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f) fmt::Debug::fmt(&**self, f)
@ -670,23 +486,26 @@ impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T>> + Clone> fmt::Debug
/// A guard that provides immutable data access but can be atomically /// A guard that provides immutable data access but can be atomically
/// upgraded to `RwLockWriteGuard`. /// upgraded to `RwLockWriteGuard`.
pub struct RwLockUpgradeableGuard_<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> { pub struct RwLockUpgradeableGuard_<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian>
inner_guard: InnerGuard, {
guard: G::Guard,
inner: R, inner: R,
} }
/// A upgradable guard that provides read access to the data protected by a [`RwLock`]. /// A upgradable guard that provides read access to the data protected by a [`RwLock`].
pub type RwLockUpgradeableGuard<'a, T> = RwLockUpgradeableGuard_<T, &'a RwLock<T>>; pub type RwLockUpgradeableGuard<'a, T, G> = RwLockUpgradeableGuard_<T, &'a RwLock<T, G>, G>;
/// A upgradable guard that provides read access to the data protected by a `Arc<RwLock>`. /// A upgradable guard that provides read access to the data protected by a `Arc<RwLock>`.
pub type ArcRwLockUpgradeableGuard<T> = RwLockUpgradeableGuard_<T, Arc<RwLock<T>>>; pub type ArcRwLockUpgradeableGuard<T, G> = RwLockUpgradeableGuard_<T, Arc<RwLock<T, G>>, G>;
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> RwLockUpgradeableGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian>
RwLockUpgradeableGuard_<T, R, G>
{
/// Upgrades this upread guard to a write guard atomically. /// Upgrades this upread guard to a write guard atomically.
/// ///
/// After calling this method, subsequent readers will be blocked /// After calling this method, subsequent readers will be blocked
/// while previous readers remain unaffected. The calling thread /// while previous readers remain unaffected. The calling thread
/// will spin-wait until previous readers finish. /// will spin-wait until previous readers finish.
pub fn upgrade(mut self) -> RwLockWriteGuard_<T, R> { pub fn upgrade(mut self) -> RwLockWriteGuard_<T, R, G> {
self.inner.lock.fetch_or(BEING_UPGRADED, Acquire); self.inner.lock.fetch_or(BEING_UPGRADED, Acquire);
loop { loop {
self = match self.try_upgrade() { self = match self.try_upgrade() {
@ -698,7 +517,7 @@ impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> RwLockUpgradeableGuard_<T,
/// Attempts to upgrade this upread guard to a write guard atomically. /// Attempts to upgrade this upread guard to a write guard atomically.
/// ///
/// This function will never spin-wait and will return immediately. /// This function will never spin-wait and will return immediately.
pub fn try_upgrade(mut self) -> Result<RwLockWriteGuard_<T, R>, Self> { pub fn try_upgrade(mut self) -> Result<RwLockWriteGuard_<T, R, G>, Self> {
let res = self.inner.lock.compare_exchange( let res = self.inner.lock.compare_exchange(
UPGRADEABLE_READER | BEING_UPGRADED, UPGRADEABLE_READER | BEING_UPGRADED,
WRITER | UPGRADEABLE_READER, WRITER | UPGRADEABLE_READER,
@ -707,16 +526,18 @@ impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> RwLockUpgradeableGuard_<T,
); );
if res.is_ok() { if res.is_ok() {
let inner = self.inner.clone(); let inner = self.inner.clone();
let inner_guard = self.inner_guard.transfer_to(); let guard = self.guard.transfer_to();
drop(self); drop(self);
Ok(RwLockWriteGuard_ { inner, inner_guard }) Ok(RwLockWriteGuard_ { inner, guard })
} else { } else {
Err(self) Err(self)
} }
} }
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Deref for RwLockUpgradeableGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> Deref
for RwLockUpgradeableGuard_<T, R, G>
{
type Target = T; type Target = T;
fn deref(&self) -> &T { fn deref(&self) -> &T {
@ -724,14 +545,16 @@ impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Deref for RwLockUpgradeabl
} }
} }
impl<T: ?Sized, R: Deref<Target = RwLock<T>> + Clone> Drop for RwLockUpgradeableGuard_<T, R> { impl<T: ?Sized, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> Drop
for RwLockUpgradeableGuard_<T, R, G>
{
fn drop(&mut self) { fn drop(&mut self) {
self.inner.lock.fetch_sub(UPGRADEABLE_READER, Release); self.inner.lock.fetch_sub(UPGRADEABLE_READER, Release);
} }
} }
impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T>> + Clone> fmt::Debug impl<T: ?Sized + fmt::Debug, R: Deref<Target = RwLock<T, G>> + Clone, G: Guardian> fmt::Debug
for RwLockUpgradeableGuard_<T, R> for RwLockUpgradeableGuard_<T, R, G>
{ {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f) fmt::Debug::fmt(&**self, f)

View File

@ -24,12 +24,15 @@ use crate::{
/// - if `G` is [`PreemptDisabled`], preemption is disabled; /// - if `G` is [`PreemptDisabled`], preemption is disabled;
/// - if `G` is [`LocalIrqDisabled`], local IRQs are disabled. /// - if `G` is [`LocalIrqDisabled`], local IRQs are disabled.
/// ///
/// The `G` can also be provided by other crates other than ostd,
/// if it behaves similar like [`PreemptDisabled`] or [`LocalIrqDisabled`].
///
/// The guard behavior can be temporarily upgraded from [`PreemptDisabled`] to /// The guard behavior can be temporarily upgraded from [`PreemptDisabled`] to
/// [`LocalIrqDisabled`] using the [`disable_irq`] method. /// [`LocalIrqDisabled`] using the [`disable_irq`] method.
/// ///
/// [`disable_irq`]: Self::disable_irq /// [`disable_irq`]: Self::disable_irq
#[repr(transparent)] #[repr(transparent)]
pub struct SpinLock<T: ?Sized, G = PreemptDisabled> { pub struct SpinLock<T: ?Sized, G: Guardian = PreemptDisabled> {
phantom: PhantomData<G>, phantom: PhantomData<G>,
/// Only the last field of a struct may have a dynamically sized type. /// Only the last field of a struct may have a dynamically sized type.
/// That's why SpinLockInner is put in the last field. /// That's why SpinLockInner is put in the last field.
@ -44,12 +47,23 @@ struct SpinLockInner<T: ?Sized> {
/// A guardian that denotes the guard behavior for holding the spin lock. /// A guardian that denotes the guard behavior for holding the spin lock.
pub trait Guardian { pub trait Guardian {
/// The guard type. /// The guard type.
type Guard; type Guard: GuardTransfer;
/// Creates a new guard. /// Creates a new guard.
fn guard() -> Self::Guard; fn guard() -> Self::Guard;
} }
/// The Guard can be transferred atomically.
pub trait GuardTransfer {
/// Atomically transfers the current guard to a new instance.
///
/// This function ensures that there are no 'gaps' between the destruction of the old guard and
/// the creation of the new guard, thereby maintaining the atomicity of guard transitions.
///
/// The original guard must be dropped immediately after calling this method.
fn transfer_to(&mut self) -> Self;
}
/// A guardian that disables preemption while holding the spin lock. /// A guardian that disables preemption while holding the spin lock.
pub struct PreemptDisabled; pub struct PreemptDisabled;
@ -165,15 +179,15 @@ impl<T: ?Sized, G: Guardian> SpinLock<T, G> {
} }
} }
impl<T: ?Sized + fmt::Debug, G> fmt::Debug for SpinLock<T, G> { impl<T: ?Sized + fmt::Debug, G: Guardian> fmt::Debug for SpinLock<T, G> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.inner.val, f) fmt::Debug::fmt(&self.inner.val, f)
} }
} }
// SAFETY: Only a single lock holder is permitted to access the inner data of Spinlock. // SAFETY: Only a single lock holder is permitted to access the inner data of Spinlock.
unsafe impl<T: ?Sized + Send, G> Send for SpinLock<T, G> {} unsafe impl<T: ?Sized + Send, G: Guardian> Send for SpinLock<T, G> {}
unsafe impl<T: ?Sized + Send, G> Sync for SpinLock<T, G> {} unsafe impl<T: ?Sized + Send, G: Guardian> Sync for SpinLock<T, G> {}
/// A guard that provides exclusive access to the data protected by a [`SpinLock`]. /// A guard that provides exclusive access to the data protected by a [`SpinLock`].
pub type SpinLockGuard<'a, T, G> = SpinLockGuard_<T, &'a SpinLock<T, G>, G>; pub type SpinLockGuard<'a, T, G> = SpinLockGuard_<T, &'a SpinLock<T, G>, G>;

View File

@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use crate::sync::GuardTransfer;
/// A guard for disable preempt. /// A guard for disable preempt.
#[clippy::has_significant_drop] #[clippy::has_significant_drop]
#[must_use] #[must_use]
@ -16,10 +18,10 @@ impl DisabledPreemptGuard {
super::cpu_local::inc_guard_count(); super::cpu_local::inc_guard_count();
Self { _private: () } Self { _private: () }
} }
}
/// Transfer this guard to a new guard. impl GuardTransfer for DisabledPreemptGuard {
/// This guard must be dropped after this function. fn transfer_to(&mut self) -> Self {
pub fn transfer_to(&self) -> Self {
disable_preempt() disable_preempt()
} }
} }

View File

@ -7,6 +7,7 @@ use core::fmt::Debug;
use crate::{ use crate::{
arch::irq::{self, IrqCallbackHandle, IRQ_ALLOCATOR}, arch::irq::{self, IrqCallbackHandle, IRQ_ALLOCATOR},
prelude::*, prelude::*,
sync::GuardTransfer,
trap::TrapFrame, trap::TrapFrame,
Error, Error,
}; };
@ -149,10 +150,10 @@ impl DisabledLocalIrqGuard {
} }
Self { was_enabled } Self { was_enabled }
} }
}
/// Transfers the saved IRQ status of this guard to a new guard. impl GuardTransfer for DisabledLocalIrqGuard {
/// The saved IRQ status of this guard is cleared. fn transfer_to(&mut self) -> Self {
pub fn transfer_to(&mut self) -> Self {
let was_enabled = self.was_enabled; let was_enabled = self.was_enabled;
self.was_enabled = false; self.was_enabled = false;
Self { was_enabled } Self { was_enabled }