mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-24 01:43:22 +00:00
Refactor Rwlock to take type parameter
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
ac1a6be05d
commit
495c93c2ad
@ -12,7 +12,7 @@ use aster_util::slot_vec::SlotVec;
|
||||
use hashbrown::HashMap;
|
||||
use ostd::{
|
||||
mm::{Frame, VmIo},
|
||||
sync::RwLockWriteGuard,
|
||||
sync::{PreemptDisabled, RwLockWriteGuard},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@ -1195,8 +1195,8 @@ fn write_lock_two_direntries_by_ino<'a>(
|
||||
this: (u64, &'a RwLock<DirEntry>),
|
||||
other: (u64, &'a RwLock<DirEntry>),
|
||||
) -> (
|
||||
RwLockWriteGuard<'a, DirEntry>,
|
||||
RwLockWriteGuard<'a, DirEntry>,
|
||||
RwLockWriteGuard<'a, DirEntry, PreemptDisabled>,
|
||||
RwLockWriteGuard<'a, DirEntry, PreemptDisabled>,
|
||||
) {
|
||||
if this.0 < other.0 {
|
||||
let this = this.1.write();
|
||||
|
@ -290,11 +290,11 @@ pub fn create_sem_set(nsems: usize, mode: u16, credentials: Credentials<ReadOp>)
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap<key_t, SemaphoreSet>> {
|
||||
pub fn sem_sets<'a>() -> RwLockReadGuard<'a, BTreeMap<key_t, SemaphoreSet>, PreemptDisabled> {
|
||||
SEMAPHORE_SETS.read()
|
||||
}
|
||||
|
||||
pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap<key_t, SemaphoreSet>> {
|
||||
pub fn sem_sets_mut<'a>() -> RwLockWriteGuard<'a, BTreeMap<key_t, SemaphoreSet>, PreemptDisabled> {
|
||||
SEMAPHORE_SETS.write()
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@ use aster_bigtcp::{
|
||||
socket::{SocketEventObserver, SocketEvents},
|
||||
wire::IpEndpoint,
|
||||
};
|
||||
use ostd::sync::LocalIrqDisabled;
|
||||
use takeable::Takeable;
|
||||
|
||||
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
|
||||
@ -51,7 +52,7 @@ impl OptionSet {
|
||||
|
||||
pub struct DatagramSocket {
|
||||
options: RwLock<OptionSet>,
|
||||
inner: RwLock<Takeable<Inner>>,
|
||||
inner: RwLock<Takeable<Inner>, LocalIrqDisabled>,
|
||||
nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
@ -134,7 +135,7 @@ impl DatagramSocket {
|
||||
}
|
||||
|
||||
// Slow path
|
||||
let mut inner = self.inner.write_irq_disabled();
|
||||
let mut inner = self.inner.write();
|
||||
inner.borrow_result(|owned_inner| {
|
||||
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
@ -277,7 +278,7 @@ impl Socket for DatagramSocket {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
|
||||
let can_reuse = self.options.read().socket.reuse_addr();
|
||||
let mut inner = self.inner.write_irq_disabled();
|
||||
let mut inner = self.inner.write();
|
||||
inner.borrow_result(|owned_inner| {
|
||||
let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
@ -294,7 +295,7 @@ impl Socket for DatagramSocket {
|
||||
|
||||
self.try_bind_ephemeral(&endpoint)?;
|
||||
|
||||
let mut inner = self.inner.write_irq_disabled();
|
||||
let mut inner = self.inner.write();
|
||||
let Inner::Bound(bound_datagram) = inner.as_mut() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
|
||||
};
|
||||
|
@ -3,6 +3,7 @@
|
||||
use aster_bigtcp::{
|
||||
errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint,
|
||||
};
|
||||
use ostd::sync::LocalIrqDisabled;
|
||||
|
||||
use super::connected::ConnectedStream;
|
||||
use crate::{
|
||||
@ -16,7 +17,7 @@ pub struct ListenStream {
|
||||
/// A bound socket held to ensure the TCP port cannot be released
|
||||
bound_socket: BoundTcpSocket,
|
||||
/// Backlog sockets listening at the local endpoint
|
||||
backlog_sockets: RwLock<Vec<BacklogSocket>>,
|
||||
backlog_sockets: RwLock<Vec<BacklogSocket>, LocalIrqDisabled>,
|
||||
}
|
||||
|
||||
impl ListenStream {
|
||||
@ -40,7 +41,7 @@ impl ListenStream {
|
||||
|
||||
/// Append sockets listening at LocalEndPoint to support backlog
|
||||
fn fill_backlog_sockets(&self) -> Result<()> {
|
||||
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
|
||||
let backlog = self.backlog;
|
||||
let current_backlog_len = backlog_sockets.len();
|
||||
@ -58,7 +59,7 @@ impl ListenStream {
|
||||
}
|
||||
|
||||
pub fn try_accept(&self) -> Result<ConnectedStream> {
|
||||
let mut backlog_sockets = self.backlog_sockets.write_irq_disabled();
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
|
||||
let index = backlog_sockets
|
||||
.iter()
|
||||
|
@ -11,7 +11,7 @@ use connecting::{ConnResult, ConnectingStream};
|
||||
use init::InitStream;
|
||||
use listen::ListenStream;
|
||||
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
|
||||
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
|
||||
use takeable::Takeable;
|
||||
use util::TcpOptionSet;
|
||||
|
||||
@ -50,7 +50,7 @@ pub use self::util::CongestionControl;
|
||||
|
||||
pub struct StreamSocket {
|
||||
options: RwLock<OptionSet>,
|
||||
state: RwLock<Takeable<State>>,
|
||||
state: RwLock<Takeable<State>, LocalIrqDisabled>,
|
||||
is_nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
@ -116,7 +116,7 @@ impl StreamSocket {
|
||||
/// Ensures that the socket state is up to date and obtains a read lock on it.
|
||||
///
|
||||
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
|
||||
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>> {
|
||||
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>, LocalIrqDisabled> {
|
||||
loop {
|
||||
let state = self.state.read();
|
||||
match state.as_ref() {
|
||||
@ -132,7 +132,7 @@ impl StreamSocket {
|
||||
/// Ensures that the socket state is up to date and obtains a write lock on it.
|
||||
///
|
||||
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
|
||||
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>> {
|
||||
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>, LocalIrqDisabled> {
|
||||
self.update_connecting().1
|
||||
}
|
||||
|
||||
@ -148,12 +148,12 @@ impl StreamSocket {
|
||||
fn update_connecting(
|
||||
&self,
|
||||
) -> (
|
||||
RwLockWriteGuard<OptionSet>,
|
||||
RwLockWriteGuard<Takeable<State>>,
|
||||
RwLockWriteGuard<OptionSet, PreemptDisabled>,
|
||||
RwLockWriteGuard<Takeable<State>, LocalIrqDisabled>,
|
||||
) {
|
||||
// Hold the lock in advance to avoid race conditions.
|
||||
let mut options = self.options.write();
|
||||
let mut state = self.state.write_irq_disabled();
|
||||
let mut state = self.state.write();
|
||||
|
||||
match state.as_ref() {
|
||||
State::Connecting(connection_stream) if connection_stream.has_result() => (),
|
||||
|
@ -7,6 +7,7 @@ use aster_virtio::device::socket::{
|
||||
device::SocketDevice,
|
||||
error::SocketError,
|
||||
};
|
||||
use ostd::sync::LocalIrqDisabled;
|
||||
|
||||
use super::{
|
||||
addr::VsockSocketAddr,
|
||||
@ -26,7 +27,7 @@ pub struct VsockSpace {
|
||||
// (key, value) = (local_addr, listen)
|
||||
listen_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Listen>>>,
|
||||
// (key, value) = (id(local_addr,peer_addr), connected)
|
||||
connected_sockets: RwLock<BTreeMap<ConnectionID, Arc<Connected>>>,
|
||||
connected_sockets: RwLock<BTreeMap<ConnectionID, Arc<Connected>>, LocalIrqDisabled>,
|
||||
// Used ports
|
||||
used_ports: SpinLock<BTreeSet<u32>>,
|
||||
}
|
||||
@ -54,10 +55,7 @@ impl VsockSpace {
|
||||
.disable_irq()
|
||||
.lock()
|
||||
.contains_key(&event.destination.into())
|
||||
|| self
|
||||
.connected_sockets
|
||||
.read_irq_disabled()
|
||||
.contains_key(&(*event).into())
|
||||
|| self.connected_sockets.read().contains_key(&(*event).into())
|
||||
}
|
||||
|
||||
/// Alloc an unused port range
|
||||
@ -91,13 +89,13 @@ impl VsockSpace {
|
||||
id: ConnectionID,
|
||||
connected: Arc<Connected>,
|
||||
) -> Option<Arc<Connected>> {
|
||||
let mut connected_sockets = self.connected_sockets.write_irq_disabled();
|
||||
let mut connected_sockets = self.connected_sockets.write();
|
||||
connected_sockets.insert(id, connected)
|
||||
}
|
||||
|
||||
/// Remove a connected socket
|
||||
pub fn remove_connected_socket(&self, id: &ConnectionID) -> Option<Arc<Connected>> {
|
||||
let mut connected_sockets = self.connected_sockets.write_irq_disabled();
|
||||
let mut connected_sockets = self.connected_sockets.write();
|
||||
connected_sockets.remove(id)
|
||||
}
|
||||
|
||||
@ -214,11 +212,7 @@ impl VsockSpace {
|
||||
|
||||
debug!("vsock receive event: {:?}", event);
|
||||
// The socket must be stored in the VsockSpace.
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.read_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
if let Some(connected) = self.connected_sockets.read().get(&event.into()) {
|
||||
connected.update_info(&event);
|
||||
}
|
||||
|
||||
@ -255,7 +249,7 @@ impl VsockSpace {
|
||||
connecting.set_connected();
|
||||
}
|
||||
VsockEventType::Disconnected { .. } => {
|
||||
let connected_sockets = self.connected_sockets.read_irq_disabled();
|
||||
let connected_sockets = self.connected_sockets.read();
|
||||
let Some(connected) = connected_sockets.get(&event.into()) else {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
|
||||
};
|
||||
@ -263,7 +257,7 @@ impl VsockSpace {
|
||||
}
|
||||
VsockEventType::Received { .. } => {}
|
||||
VsockEventType::CreditRequest => {
|
||||
let connected_sockets = self.connected_sockets.read_irq_disabled();
|
||||
let connected_sockets = self.connected_sockets.read();
|
||||
let Some(connected) = connected_sockets.get(&event.into()) else {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
|
||||
};
|
||||
@ -272,7 +266,7 @@ impl VsockSpace {
|
||||
})?;
|
||||
}
|
||||
VsockEventType::CreditUpdate => {
|
||||
let connected_sockets = self.connected_sockets.read_irq_disabled();
|
||||
let connected_sockets = self.connected_sockets.read();
|
||||
let Some(connected) = connected_sockets.get(&event.into()) else {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected");
|
||||
};
|
||||
@ -289,7 +283,7 @@ impl VsockSpace {
|
||||
// Deal with Received before the buffer are recycled.
|
||||
if let VsockEventType::Received { .. } = event.event_type {
|
||||
// Only consider the connected socket and copy body to buffer
|
||||
let connected_sockets = self.connected_sockets.read_irq_disabled();
|
||||
let connected_sockets = self.connected_sockets.read();
|
||||
let connected = connected_sockets.get(&event.into()).unwrap();
|
||||
debug!("Rw matches a connection with id {:?}", connected.id());
|
||||
if !connected.add_connection_buffer(body) {
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
use core::sync::atomic::Ordering;
|
||||
|
||||
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
|
||||
use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
use super::{group::AtomicGid, user::AtomicUid, Gid, Uid};
|
||||
use crate::{
|
||||
@ -387,11 +387,11 @@ impl Credentials_ {
|
||||
|
||||
// ******* Supplementary groups methods *******
|
||||
|
||||
pub(super) fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>> {
|
||||
pub(super) fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>, PreemptDisabled> {
|
||||
self.supplementary_gids.read()
|
||||
}
|
||||
|
||||
pub(super) fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>> {
|
||||
pub(super) fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>, PreemptDisabled> {
|
||||
self.supplementary_gids.write()
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
use aster_rights::{Dup, Read, TRights, Write};
|
||||
use aster_rights_proc::require;
|
||||
use ostd::sync::{RwLockReadGuard, RwLockWriteGuard};
|
||||
use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
use super::{capabilities::CapSet, credentials_::Credentials_, Credentials, Gid, Uid};
|
||||
use crate::prelude::*;
|
||||
@ -239,7 +239,7 @@ impl<R: TRights> Credentials<R> {
|
||||
///
|
||||
/// This method requires the `Read` right.
|
||||
#[require(R > Read)]
|
||||
pub fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>> {
|
||||
pub fn groups(&self) -> RwLockReadGuard<BTreeSet<Gid>, PreemptDisabled> {
|
||||
self.0.groups()
|
||||
}
|
||||
|
||||
@ -247,7 +247,7 @@ impl<R: TRights> Credentials<R> {
|
||||
///
|
||||
/// This method requires the `Write` right.
|
||||
#[require(R > Write)]
|
||||
pub fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>> {
|
||||
pub fn groups_mut(&self) -> RwLockWriteGuard<BTreeSet<Gid>, PreemptDisabled> {
|
||||
self.0.groups_mut()
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,13 @@
|
||||
use alloc::{boxed::Box, vec::Vec};
|
||||
|
||||
use aster_softirq::{softirq_id::TIMER_SOFTIRQ_ID, SoftIrqLine};
|
||||
use ostd::{sync::RwLock, timer};
|
||||
use ostd::{
|
||||
sync::{LocalIrqDisabled, RwLock},
|
||||
timer,
|
||||
};
|
||||
|
||||
static TIMER_SOFTIRQ_CALLBACKS: RwLock<Vec<Box<dyn Fn() + Sync + Send>>> = RwLock::new(Vec::new());
|
||||
static TIMER_SOFTIRQ_CALLBACKS: RwLock<Vec<Box<dyn Fn() + Sync + Send>>, LocalIrqDisabled> =
|
||||
RwLock::new(Vec::new());
|
||||
|
||||
pub(super) fn init() {
|
||||
SoftIrqLine::get(TIMER_SOFTIRQ_ID).enable(timer_softirq_handler);
|
||||
@ -20,13 +24,11 @@ pub(super) fn register_callback<F>(func: F)
|
||||
where
|
||||
F: Fn() + Sync + Send + 'static,
|
||||
{
|
||||
TIMER_SOFTIRQ_CALLBACKS
|
||||
.write_irq_disabled()
|
||||
.push(Box::new(func));
|
||||
TIMER_SOFTIRQ_CALLBACKS.write().push(Box::new(func));
|
||||
}
|
||||
|
||||
fn timer_softirq_handler() {
|
||||
let callbacks = TIMER_SOFTIRQ_CALLBACKS.read_irq_disabled();
|
||||
let callbacks = TIMER_SOFTIRQ_CALLBACKS.read();
|
||||
for callback in callbacks.iter() {
|
||||
(callback)();
|
||||
}
|
||||
|
Reference in New Issue
Block a user