Refactor Rwlock to take type parameter

This commit is contained in:
jiangjianfeng
2024-11-19 10:35:20 +00:00
committed by Tate, Hongliang Tian
parent ac1a6be05d
commit 495c93c2ad
20 changed files with 205 additions and 363 deletions

View File

@ -14,7 +14,7 @@ use alloc::sync::Arc;
use core::{cmp::max, ops::Add, time::Duration};
use aster_util::coeff::Coeff;
use ostd::sync::RwLock;
use ostd::sync::{LocalIrqDisabled, RwLock};
use crate::NANOS_PER_SECOND;
@ -55,7 +55,7 @@ pub struct ClockSource {
base: ClockSourceBase,
coeff: Coeff,
/// A record to an `Instant` and the corresponding cycles of this `ClockSource`.
last_record: RwLock<(Instant, u64)>,
last_record: RwLock<(Instant, u64), LocalIrqDisabled>,
}
impl ClockSource {
@ -91,7 +91,7 @@ impl ClockSource {
/// Returns the calculated instant and instant cycles.
fn calculate_instant(&self) -> (Instant, u64) {
let (instant_cycles, last_instant, last_cycles) = {
let last_record = self.last_record.read_irq_disabled();
let last_record = self.last_record.read();
let (last_instant, last_cycles) = *last_record;
(self.read_cycles(), last_instant, last_cycles)
};
@ -121,7 +121,7 @@ impl ClockSource {
/// Uses an input instant and cycles to update the `last_record` in the `ClockSource`.
fn update_last_record(&self, record: (Instant, u64)) {
*self.last_record.write_irq_disabled() = record;
*self.last_record.write() = record;
}
/// Reads current cycles of the `ClockSource`.
@ -131,7 +131,7 @@ impl ClockSource {
/// Returns the last instant and last cycles recorded in the `ClockSource`.
pub fn last_record(&self) -> (Instant, u64) {
return *self.last_record.read_irq_disabled();
return *self.last_record.read();
}
/// Returns the maximum delay seconds for updating of the `ClockSource`.

View File

@ -9,7 +9,7 @@ use log::debug;
use ostd::{
io_mem::IoMem,
mm::{DmaDirection, DmaStream, DmaStreamSlice, FrameAllocOptions, VmReader},
sync::{RwLock, SpinLock},
sync::{LocalIrqDisabled, RwLock, SpinLock},
trap::TrapFrame,
};
@ -27,7 +27,7 @@ pub struct ConsoleDevice {
transmit_queue: SpinLock<VirtQueue>,
send_buffer: DmaStream,
receive_buffer: DmaStream,
callbacks: RwLock<Vec<&'static ConsoleCallback>>,
callbacks: RwLock<Vec<&'static ConsoleCallback>, LocalIrqDisabled>,
}
impl AnyConsoleDevice for ConsoleDevice {
@ -54,7 +54,7 @@ impl AnyConsoleDevice for ConsoleDevice {
}
fn register_callback(&self, callback: &'static ConsoleCallback) {
self.callbacks.write_irq_disabled().push(callback);
self.callbacks.write().push(callback);
}
}
@ -136,7 +136,7 @@ impl ConsoleDevice {
};
self.receive_buffer.sync(0..len as usize).unwrap();
let callbacks = self.callbacks.read_irq_disabled();
let callbacks = self.callbacks.read();
for callback in callbacks.iter() {
let reader = self.receive_buffer.reader().unwrap().limit(len as usize);
callback(reader);

View File

@ -19,7 +19,7 @@ use ostd::{
io_mem::IoMem,
mm::{DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, VmIo, PAGE_SIZE},
offset_of,
sync::{RwLock, SpinLock},
sync::{LocalIrqDisabled, RwLock, SpinLock},
trap::TrapFrame,
};
@ -76,7 +76,7 @@ pub struct InputDevice {
status_queue: VirtQueue,
event_table: EventTable,
#[allow(clippy::type_complexity)]
callbacks: RwLock<Vec<Arc<dyn Fn(InputEvent) + Send + Sync + 'static>>>,
callbacks: RwLock<Vec<Arc<dyn Fn(InputEvent) + Send + Sync + 'static>>, LocalIrqDisabled>,
transport: SpinLock<Box<dyn VirtioTransport>>,
}
@ -209,7 +209,7 @@ impl InputDevice {
}
fn handle_irq(&self) {
let callbacks = self.callbacks.read_irq_disabled();
let callbacks = self.callbacks.read();
// Returns true if there may be more events to handle
let handle_event = |event: &EventBuf| -> bool {
event.sync().unwrap();
@ -295,7 +295,7 @@ impl<T, M: HasDaddr> DmaBuf for SafePtr<T, M> {
impl aster_input::InputDevice for InputDevice {
fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) {
self.callbacks.write_irq_disabled().push(Arc::new(function))
self.callbacks.write().push(Arc::new(function))
}
}

View File

@ -46,7 +46,7 @@ pub struct BoundSocketInner<T, E> {
iface: Arc<dyn Iface<E>>,
port: u16,
socket: T,
observer: RwLock<Weak<dyn SocketEventObserver>>,
observer: RwLock<Weak<dyn SocketEventObserver>, LocalIrqDisabled>,
events: AtomicU8,
next_poll_at_ms: AtomicU64,
}
@ -223,7 +223,7 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
/// that the old observer will never be called after the setting. Users should be aware of this
/// and proactively handle the race conditions if necessary.
pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) {
*self.0.observer.write_irq_disabled() = new_observer;
*self.0.observer.write() = new_observer;
self.0.on_events();
}

View File

@ -12,7 +12,7 @@ use aster_util::slot_vec::SlotVec;
use hashbrown::HashMap;
use ostd::{
mm::{Frame, VmIo},
sync::RwLockWriteGuard,
sync::{PreemptDisabled, RwLockWriteGuard},
};
use super::*;
@ -1195,8 +1195,8 @@ fn write_lock_two_direntries_by_ino<'a>(
this: (u64, &'a RwLock<DirEntry>),
other: (u64, &'a RwLock<DirEntry>),
) -> (
RwLockWriteGuard<'a, DirEntry>,
RwLockWriteGuard<'a, DirEntry>,
RwLockWriteGuard<'a, DirEntry, PreemptDisabled>,
RwLockWriteGuard<'a, DirEntry, PreemptDisabled>,
) {
if this.0 < other.0 {
let this = this.1.write();

View File

@ -290,11 +290,11 @@ pub fn create_sem_set(nsems: usize, mode: u16, credentials: Credentials<ReadOp>)
Ok(id)
}
pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap<key_t, SemaphoreSet>> {
pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap<key_t, SemaphoreSet>, PreemptDisabled> {
SEMAPHORE_SETS.read()
}
pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap<key_t, SemaphoreSet>> {
pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap<key_t, SemaphoreSet>, PreemptDisabled> {
SEMAPHORE_SETS.write()
}

View File

@ -6,6 +6,7 @@ use aster_bigtcp::{
socket::{SocketEventObserver, SocketEvents},
wire::IpEndpoint,
};
use ostd::sync::LocalIrqDisabled;
use takeable::Takeable;
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
@ -51,7 +52,7 @@ impl OptionSet {
pub struct DatagramSocket {
options: RwLock<OptionSet>,
inner: RwLock<Takeable<Inner>>,
inner: RwLock<Takeable<Inner>, LocalIrqDisabled>,
nonblocking: AtomicBool,
pollee: Pollee,
}
@ -134,7 +135,7 @@ impl DatagramSocket {
}
// Slow path
let mut inner = self.inner.write_irq_disabled();
let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram,
@ -277,7 +278,7 @@ impl Socket for DatagramSocket {
let endpoint = socket_addr.try_into()?;
let can_reuse = self.options.read().socket.reuse_addr();
let mut inner = self.inner.write_irq_disabled();
let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) {
Ok(bound_datagram) => bound_datagram,
@ -294,7 +295,7 @@ impl Socket for DatagramSocket {
self.try_bind_ephemeral(&endpoint)?;
let mut inner = self.inner.write_irq_disabled();
let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = inner.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
};

View File

@ -3,6 +3,7 @@
use aster_bigtcp::{
errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint,
};
use ostd::sync::LocalIrqDisabled;
use super::connected::ConnectedStream;
use crate::{
@ -16,7 +17,7 @@ pub struct ListenStream {
/// A bound socket held to ensure the TCP port cannot be released
bound_socket: BoundTcpSocket,
/// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>,
backlog_sockets: RwLock<Vec<BacklogSocket>, LocalIrqDisabled>,
}
impl ListenStream {
@ -40,7 +41,7 @@ impl ListenStream {
/// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> {
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
let mut backlog_sockets = self.backlog_sockets.write();
let backlog = self.backlog;
let current_backlog_len = backlog_sockets.len();
@ -58,7 +59,7 @@ impl ListenStream {
}
pub fn try_accept(&self) -> Result<ConnectedStream> {
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets
.iter()

View File

@ -11,7 +11,7 @@ use connecting::{ConnResult, ConnectingStream};
use init::InitStream;
use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use takeable::Takeable;
use util::TcpOptionSet;
@ -50,7 +50,7 @@ pub use self::util::CongestionControl;
pub struct StreamSocket {
options: RwLock<OptionSet>,
state: RwLock<Takeable<State>>,
state: RwLock<Takeable<State>, LocalIrqDisabled>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
@ -116,7 +116,7 @@ impl StreamSocket {
/// Ensures that the socket state is up to date and obtains a read lock on it.
///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>> {
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>, LocalIrqDisabled> {
loop {
let state = self.state.read();
match state.as_ref() {
@ -132,7 +132,7 @@ impl StreamSocket {
/// Ensures that the socket state is up to date and obtains a write lock on it.
///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>> {
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>, LocalIrqDisabled> {
self.update_connecting().1
}
@ -148,12 +148,12 @@ impl StreamSocket {
fn update_connecting(
&self,
) -> (
RwLockWriteGuard<OptionSet>,
RwLockWriteGuard<Takeable<State>>,
RwLockWriteGuard<OptionSet, PreemptDisabled>,
RwLockWriteGuard<Takeable<State>, LocalIrqDisabled>,
) {
// Hold the lock in advance to avoid race conditions.
let mut options = self.options.write();
let mut state = self.state.write_irq_disabled();
let mut state = self.state.write();
match state.as_ref() {
State::Connecting(connection_stream) if connection_stream.has_result() => (),

View File

@ -7,6 +7,7 @@ use aster_virtio::device::socket::{
device::SocketDevice,
error::SocketError,
};
use ostd::sync::LocalIrqDisabled;
use super::{
addr::VsockSocketAddr,
@ -26,7 +27,7 @@ pub struct VsockSpace {
// (key, value) = (local_addr, listen)
listen_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Listen>>>,
// (key, value) = (id(local_addr,peer_addr), connected)
connected_sockets: RwLock<BTreeMap<ConnectionID, Arc<Connected>>>,
connected_sockets: RwLock<BTreeMap<ConnectionID, Arc<Connected>>, LocalIrqDisabled>,
// Used ports
used_ports: SpinLock<BTreeSet<u32>>,
}
@ -54,10 +55,7 @@ impl VsockSpace {
.disable_irq()
.lock()
.contains_key(&event.destination.into())
|| self
.connected_sockets
.read_irq_disabled()
.contains_key(&(*event).into())
|| self.connected_sockets.read().contains_key(&(*event).into())
}
/// Alloc an unused port range
@ -91,13 +89,13 @@ impl VsockSpace {
id: ConnectionID,
connected: Arc<Connected>,
) -> Option<Arc<Connected>> {
let mut connected_sockets = self.connected_sockets.write_irq_disabled();
let mut connected_sockets = self.connected_sockets.write();
connected_sockets.insert(id, connected)
}
/// Remove a connected socket
pub fn remove_connected_socket(&self, id: &ConnectionID) -> Option<Arc<Connected>> {
let mut connected_sockets = self.connected_sockets.write_irq_disabled();
let mut connected_sockets = self.connected_sockets.write();
connected_sockets.remove(id)
}
@ -214,11 +212,7 @@ impl VsockSpace {
debug!("vsock receive event: {:?}", event);
// The socket must be stored in the VsockSpace.
if let Some(connected) = self
.connected_sockets
.read_irq_disabled()
.get(&event.into())
{
if let Some(connected) = self.connected_sockets.read().get(&event.into()) {
connected.update_info(&event);
}
@ -255,7 +249,7 @@ impl VsockSpace {
connecting.set_connected();
}
VsockEventType::Disconnected { .. } => {
let connected_sockets = self.connected_sockets.read_irq_disabled();
let connected_sockets = self.connected_sockets.read();
let Some(connected) = connected_sockets.get(&event.into()) else {
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
};
@ -263,7 +257,7 @@ impl VsockSpace {
}
VsockEventType::Received { .. } => {}
VsockEventType::CreditRequest => {
let connected_sockets = self.connected_sockets.read_irq_disabled();
let connected_sockets = self.connected_sockets.read();
let Some(connected) = connected_sockets.get(&event.into()) else {
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
};
@ -272,7 +266,7 @@ impl VsockSpace {
})?;
}
VsockEventType::CreditUpdate => {
let connected_sockets = self.connected_sockets.read_irq_disabled();
let connected_sockets = self.connected_sockets.read();
let Some(connected) = connected_sockets.get(&event.into()) else {
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
};
@ -289,7 +283,7 @@ impl VsockSpace {
// Deal with Received before the buffer are recycled.
if let VsockEventType::Received { .. } = event.event_type {
// Only consider the connected socket and copy body to buffer
let connected_sockets = self.connected_sockets.read_irq_disabled();
let connected_sockets = self.connected_sockets.read();
let connected = connected_sockets.get(&event.into()).unwrap();
debug!("Rw matches a connection with id {:?}", connected.id());
if !connected.add_connection_buffer(body) {

View File

@ -2,7 +2,7 @@
use core::sync::atomic::Ordering;
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use super::{group::AtomicGid, user::AtomicUid, Gid, Uid};
use crate::{
@ -387,11 +387,11 @@ impl Credentials_ {
// ******* Supplementary groups methods *******
pub(super) fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>> {
pub(super) fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>, PreemptDisabled> {
self.supplementary_gids.read()
}
pub(super) fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>> {
pub(super) fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>, PreemptDisabled> {
self.supplementary_gids.write()
}

View File

@ -4,7 +4,7 @@
use aster_rights::{Dup, Read, TRights, Write};
use aster_rights_proc::require;
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use super::{capabilities::CapSet, credentials_::Credentials_, Credentials, Gid, Uid};
use crate::prelude::*;
@ -239,7 +239,7 @@ impl<R: TRights> Credentials<R> {
///
/// This method requires the `Read` right.
#[require(R > Read)]
pub fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>> {
pub fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>, PreemptDisabled> {
self.0.groups()
}
@ -247,7 +247,7 @@ impl<R: TRights> Credentials<R> {
///
/// This method requires the `Write` right.
#[require(R > Write)]
pub fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>> {
pub fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>, PreemptDisabled> {
self.0.groups_mut()
}

View File

@ -3,9 +3,13 @@
use alloc::{boxed::Box, vec::Vec};
use aster_softirq::{softirq_id::TIMER_SOFTIRQ_ID, SoftIrqLine};
use ostd::{sync::RwLock, timer};
use ostd::{
sync::{LocalIrqDisabled, RwLock},
timer,
};
static TIMER_SOFTIRQ_CALLBACKS: RwLock<Vec<Box<dyn Fn() + Sync + Send>>> = RwLock::new(Vec::new());
static TIMER_SOFTIRQ_CALLBACKS: RwLock<Vec<Box<dyn Fn() + Sync + Send>>, LocalIrqDisabled> =
RwLock::new(Vec::new());
pub(super) fn init() {
SoftIrqLine::get(TIMER_SOFTIRQ_ID).enable(timer_softirq_handler);
@ -20,13 +24,11 @@ pub(super) fn register_callback<F>(func: F)
where
F: Fn() + Sync + Send + 'static,
{
TIMER_SOFTIRQ_CALLBACKS
.write_irq_disabled()
.push(Box::new(func));
TIMER_SOFTIRQ_CALLBACKS.write().push(Box::new(func));
}
fn timer_softirq_handler() {
let callbacks = TIMER_SOFTIRQ_CALLBACKS.read_irq_disabled();
let callbacks = TIMER_SOFTIRQ_CALLBACKS.read();
for callback in callbacks.iter() {
(callback)();
}