Make use of new SpinLock APIs

This commit is contained in:
Ruihan Li 2024-09-12 09:52:40 +08:00 committed by Tate, Hongliang Tian
parent 67d3682116
commit 1b9b76d27c
3 changed files with 36 additions and 52 deletions

View File

@ -19,7 +19,7 @@ pub use buffer::{RxBuffer, TxBuffer, RX_BUFFER_POOL, TX_BUFFER_POOL};
use component::{init_component, ComponentInitError};
pub use dma_pool::DmaSegment;
use ostd::{
sync::{PreemptDisabled, SpinLock},
sync::{LocalIrqDisabled, SpinLock},
Pod,
};
use smoltcp::phy;
@ -55,23 +55,20 @@ pub trait AnyNetworkDevice: Send + Sync + Any + Debug {
pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static;
pub fn register_device(name: String, device: Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>) {
pub fn register_device(
name: String,
device: Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>,
) {
COMPONENT
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock()
.insert(name, (Arc::new(SpinLock::new(Vec::new())), device));
}
pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>> {
let table = COMPONENT
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>> {
let table = COMPONENT.get().unwrap().network_device_table.lock();
let (_, device) = table.get(str)?;
Some(device.clone())
}
@ -81,41 +78,26 @@ pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice, Preemp
/// Since the callback will be called in interrupt context,
/// the callback function should NOT sleep.
pub fn register_recv_callback(name: &str, callback: impl NetDeviceIrqHandler) {
let device_table = COMPONENT
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
let device_table = COMPONENT.get().unwrap().network_device_table.lock();
let Some((callbacks, _)) = device_table.get(name) else {
return;
};
callbacks.disable_irq().lock().push(Arc::new(callback));
callbacks.lock().push(Arc::new(callback));
}
pub fn handle_recv_irq(name: &str) {
let device_table = COMPONENT
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
let device_table = COMPONENT.get().unwrap().network_device_table.lock();
let Some((callbacks, _)) = device_table.get(name) else {
return;
};
let callbacks = callbacks.disable_irq().lock();
let callbacks = callbacks.lock();
for callback in callbacks.iter() {
callback();
}
}
pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> {
let network_devs = COMPONENT
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
let network_devs = COMPONENT.get().unwrap().network_device_table.lock();
network_devs
.iter()
.map(|(name, (_, device))| (name.clone(), device.clone()))
@ -124,7 +106,7 @@ pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> {
static COMPONENT: Once<Component> = Once::new();
pub(crate) static NETWORK_IRQ_HANDLERS: Once<
SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, PreemptDisabled>,
SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, LocalIrqDisabled>,
> = Once::new();
#[init_component]
@ -136,13 +118,16 @@ fn init() -> Result<(), ComponentInitError> {
Ok(())
}
type NetDeviceIrqHandlerListRef = Arc<SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, PreemptDisabled>>;
type NetworkDeviceRef = Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>;
type NetDeviceIrqHandlerListRef =
Arc<SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, LocalIrqDisabled>>;
type NetworkDeviceRef = Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>;
struct Component {
/// Device list, the key is device name, value is (callbacks, device);
network_device_table:
SpinLock<BTreeMap<String, (NetDeviceIrqHandlerListRef, NetworkDeviceRef)>, PreemptDisabled>,
network_device_table: SpinLock<
BTreeMap<String, (NetDeviceIrqHandlerListRef, NetworkDeviceRef)>,
LocalIrqDisabled,
>,
}
impl Component {

View File

@ -25,11 +25,11 @@ use crate::{
};
pub struct IfaceCommon<E> {
interface: SpinLock<smoltcp::iface::Interface>,
sockets: SpinLock<SocketSet<'static>>,
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
sockets: SpinLock<SocketSet<'static>, LocalIrqDisabled>,
used_ports: RwLock<BTreeMap<u16, usize>>,
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>, LocalIrqDisabled>,
ext: E,
}
@ -51,7 +51,7 @@ impl<E> IfaceCommon<E> {
///
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
self.interface.disable_irq().lock()
self.interface.lock()
}
/// Acuqires the lock to the sockets.
@ -60,11 +60,11 @@ impl<E> IfaceCommon<E> {
pub(crate) fn sockets(
&self,
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
self.sockets.disable_irq().lock()
self.sockets.lock()
}
pub(super) fn ipv4_addr(&self) -> Option<Ipv4Address> {
self.interface.disable_irq().lock().ipv4_addr()
self.interface.lock().ipv4_addr()
}
/// Alloc an unused port range from 49152 ~ 65535 (According to smoltcp docs)
@ -125,12 +125,12 @@ impl<E> IfaceCommon<E> {
let (handle, socket_family, observer) = match socket.into_raw() {
(AnyRawSocket::Tcp(tcp_socket), observer) => (
self.sockets.disable_irq().lock().add(tcp_socket),
self.sockets.lock().add(tcp_socket),
SocketFamily::Tcp,
observer,
),
(AnyRawSocket::Udp(udp_socket), observer) => (
self.sockets.disable_irq().lock().add(udp_socket),
self.sockets.lock().add(udp_socket),
SocketFamily::Udp,
observer,
),
@ -143,13 +143,13 @@ impl<E> IfaceCommon<E> {
/// Remove a socket from the interface
pub(crate) fn remove_socket(&self, handle: SocketHandle) {
self.sockets.disable_irq().lock().remove(handle);
self.sockets.lock().remove(handle);
}
#[must_use]
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> {
let mut sockets = self.sockets.disable_irq().lock();
let mut interface = self.interface.disable_irq().lock();
let mut sockets = self.sockets.lock();
let mut interface = self.interface.lock();
let timestamp = get_network_timestamp();
let (has_events, poll_at) = {
@ -192,7 +192,6 @@ impl<E> IfaceCommon<E> {
let closed_sockets = self
.closing_sockets
.disable_irq()
.lock()
.extract_if(|closing_socket| closing_socket.is_closed())
.collect::<Vec<_>>();
@ -235,7 +234,7 @@ impl<E> IfaceCommon<E> {
.remove(&keyable_socket);
assert!(removed);
let mut closing_sockets = self.closing_sockets.disable_irq().lock();
let mut closing_sockets = self.closing_sockets.lock();
// Check `is_closed` after holding the lock to avoid race conditions.
if keyable_socket.is_closed() {

View File

@ -3,7 +3,7 @@
use alloc::{borrow::ToOwned, sync::Arc};
use aster_bigtcp::device::WithDevice;
use ostd::sync::PreemptDisabled;
use ostd::sync::LocalIrqDisabled;
use spin::Once;
use super::{poll_ifaces, Iface};
@ -39,9 +39,9 @@ fn new_virtio() -> Arc<Iface> {
let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap();
let ether_addr = virtio_net.disable_irq().lock().mac_addr().0;
let ether_addr = virtio_net.lock().mac_addr().0;
struct Wrapper(Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>);
struct Wrapper(Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>);
impl WithDevice for Wrapper {
type Device = dyn AnyNetworkDevice;
@ -50,7 +50,7 @@ fn new_virtio() -> Arc<Iface> {
where
F: FnOnce(&mut Self::Device) -> R,
{
let mut device = self.0.disable_irq().lock();
let mut device = self.0.lock();
f(&mut *device)
}
}