Make use of new SpinLock APIs

This commit is contained in:
Ruihan Li 2024-09-12 09:52:40 +08:00 committed by Tate, Hongliang Tian
parent 67d3682116
commit 1b9b76d27c
3 changed files with 36 additions and 52 deletions

View File

@ -19,7 +19,7 @@ pub use buffer::{RxBuffer, TxBuffer, RX_BUFFER_POOL, TX_BUFFER_POOL};
use component::{init_component, ComponentInitError}; use component::{init_component, ComponentInitError};
pub use dma_pool::DmaSegment; pub use dma_pool::DmaSegment;
use ostd::{ use ostd::{
sync::{PreemptDisabled, SpinLock}, sync::{LocalIrqDisabled, SpinLock},
Pod, Pod,
}; };
use smoltcp::phy; use smoltcp::phy;
@ -55,23 +55,20 @@ pub trait AnyNetworkDevice: Send + Sync + Any + Debug {
pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static; pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static;
pub fn register_device(name: String, device: Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>) { pub fn register_device(
name: String,
device: Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>,
) {
COMPONENT COMPONENT
.get() .get()
.unwrap() .unwrap()
.network_device_table .network_device_table
.disable_irq()
.lock() .lock()
.insert(name, (Arc::new(SpinLock::new(Vec::new())), device)); .insert(name, (Arc::new(SpinLock::new(Vec::new())), device));
} }
pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>> { pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>> {
let table = COMPONENT let table = COMPONENT.get().unwrap().network_device_table.lock();
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
let (_, device) = table.get(str)?; let (_, device) = table.get(str)?;
Some(device.clone()) Some(device.clone())
} }
@ -81,41 +78,26 @@ pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice, Preemp
/// Since the callback will be called in interrupt context, /// Since the callback will be called in interrupt context,
/// the callback function should NOT sleep. /// the callback function should NOT sleep.
pub fn register_recv_callback(name: &str, callback: impl NetDeviceIrqHandler) { pub fn register_recv_callback(name: &str, callback: impl NetDeviceIrqHandler) {
let device_table = COMPONENT let device_table = COMPONENT.get().unwrap().network_device_table.lock();
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
let Some((callbacks, _)) = device_table.get(name) else { let Some((callbacks, _)) = device_table.get(name) else {
return; return;
}; };
callbacks.disable_irq().lock().push(Arc::new(callback)); callbacks.lock().push(Arc::new(callback));
} }
pub fn handle_recv_irq(name: &str) { pub fn handle_recv_irq(name: &str) {
let device_table = COMPONENT let device_table = COMPONENT.get().unwrap().network_device_table.lock();
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
let Some((callbacks, _)) = device_table.get(name) else { let Some((callbacks, _)) = device_table.get(name) else {
return; return;
}; };
let callbacks = callbacks.disable_irq().lock(); let callbacks = callbacks.lock();
for callback in callbacks.iter() { for callback in callbacks.iter() {
callback(); callback();
} }
} }
pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> { pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> {
let network_devs = COMPONENT let network_devs = COMPONENT.get().unwrap().network_device_table.lock();
.get()
.unwrap()
.network_device_table
.disable_irq()
.lock();
network_devs network_devs
.iter() .iter()
.map(|(name, (_, device))| (name.clone(), device.clone())) .map(|(name, (_, device))| (name.clone(), device.clone()))
@ -124,7 +106,7 @@ pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> {
static COMPONENT: Once<Component> = Once::new(); static COMPONENT: Once<Component> = Once::new();
pub(crate) static NETWORK_IRQ_HANDLERS: Once< pub(crate) static NETWORK_IRQ_HANDLERS: Once<
SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, PreemptDisabled>, SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, LocalIrqDisabled>,
> = Once::new(); > = Once::new();
#[init_component] #[init_component]
@ -136,13 +118,16 @@ fn init() -> Result<(), ComponentInitError> {
Ok(()) Ok(())
} }
type NetDeviceIrqHandlerListRef = Arc<SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, PreemptDisabled>>; type NetDeviceIrqHandlerListRef =
type NetworkDeviceRef = Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>; Arc<SpinLock<Vec<Arc<dyn NetDeviceIrqHandler>>, LocalIrqDisabled>>;
type NetworkDeviceRef = Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>;
struct Component { struct Component {
/// Device list, the key is device name, value is (callbacks, device); /// Device list, the key is device name, value is (callbacks, device);
network_device_table: network_device_table: SpinLock<
SpinLock<BTreeMap<String, (NetDeviceIrqHandlerListRef, NetworkDeviceRef)>, PreemptDisabled>, BTreeMap<String, (NetDeviceIrqHandlerListRef, NetworkDeviceRef)>,
LocalIrqDisabled,
>,
} }
impl Component { impl Component {

View File

@ -25,11 +25,11 @@ use crate::{
}; };
pub struct IfaceCommon<E> { pub struct IfaceCommon<E> {
interface: SpinLock<smoltcp::iface::Interface>, interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
sockets: SpinLock<SocketSet<'static>>, sockets: SpinLock<SocketSet<'static>, LocalIrqDisabled>,
used_ports: RwLock<BTreeMap<u16, usize>>, used_ports: RwLock<BTreeMap<u16, usize>>,
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>, bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>, closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>, LocalIrqDisabled>,
ext: E, ext: E,
} }
@ -51,7 +51,7 @@ impl<E> IfaceCommon<E> {
/// ///
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second. /// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> { pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
self.interface.disable_irq().lock() self.interface.lock()
} }
/// Acuqires the lock to the sockets. /// Acuqires the lock to the sockets.
@ -60,11 +60,11 @@ impl<E> IfaceCommon<E> {
pub(crate) fn sockets( pub(crate) fn sockets(
&self, &self,
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> { ) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
self.sockets.disable_irq().lock() self.sockets.lock()
} }
pub(super) fn ipv4_addr(&self) -> Option<Ipv4Address> { pub(super) fn ipv4_addr(&self) -> Option<Ipv4Address> {
self.interface.disable_irq().lock().ipv4_addr() self.interface.lock().ipv4_addr()
} }
/// Alloc an unused port range from 49152 ~ 65535 (According to smoltcp docs) /// Alloc an unused port range from 49152 ~ 65535 (According to smoltcp docs)
@ -125,12 +125,12 @@ impl<E> IfaceCommon<E> {
let (handle, socket_family, observer) = match socket.into_raw() { let (handle, socket_family, observer) = match socket.into_raw() {
(AnyRawSocket::Tcp(tcp_socket), observer) => ( (AnyRawSocket::Tcp(tcp_socket), observer) => (
self.sockets.disable_irq().lock().add(tcp_socket), self.sockets.lock().add(tcp_socket),
SocketFamily::Tcp, SocketFamily::Tcp,
observer, observer,
), ),
(AnyRawSocket::Udp(udp_socket), observer) => ( (AnyRawSocket::Udp(udp_socket), observer) => (
self.sockets.disable_irq().lock().add(udp_socket), self.sockets.lock().add(udp_socket),
SocketFamily::Udp, SocketFamily::Udp,
observer, observer,
), ),
@ -143,13 +143,13 @@ impl<E> IfaceCommon<E> {
/// Remove a socket from the interface /// Remove a socket from the interface
pub(crate) fn remove_socket(&self, handle: SocketHandle) { pub(crate) fn remove_socket(&self, handle: SocketHandle) {
self.sockets.disable_irq().lock().remove(handle); self.sockets.lock().remove(handle);
} }
#[must_use] #[must_use]
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> { pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> {
let mut sockets = self.sockets.disable_irq().lock(); let mut sockets = self.sockets.lock();
let mut interface = self.interface.disable_irq().lock(); let mut interface = self.interface.lock();
let timestamp = get_network_timestamp(); let timestamp = get_network_timestamp();
let (has_events, poll_at) = { let (has_events, poll_at) = {
@ -192,7 +192,6 @@ impl<E> IfaceCommon<E> {
let closed_sockets = self let closed_sockets = self
.closing_sockets .closing_sockets
.disable_irq()
.lock() .lock()
.extract_if(|closing_socket| closing_socket.is_closed()) .extract_if(|closing_socket| closing_socket.is_closed())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -235,7 +234,7 @@ impl<E> IfaceCommon<E> {
.remove(&keyable_socket); .remove(&keyable_socket);
assert!(removed); assert!(removed);
let mut closing_sockets = self.closing_sockets.disable_irq().lock(); let mut closing_sockets = self.closing_sockets.lock();
// Check `is_closed` after holding the lock to avoid race conditions. // Check `is_closed` after holding the lock to avoid race conditions.
if keyable_socket.is_closed() { if keyable_socket.is_closed() {

View File

@ -3,7 +3,7 @@
use alloc::{borrow::ToOwned, sync::Arc}; use alloc::{borrow::ToOwned, sync::Arc};
use aster_bigtcp::device::WithDevice; use aster_bigtcp::device::WithDevice;
use ostd::sync::PreemptDisabled; use ostd::sync::LocalIrqDisabled;
use spin::Once; use spin::Once;
use super::{poll_ifaces, Iface}; use super::{poll_ifaces, Iface};
@ -39,9 +39,9 @@ fn new_virtio() -> Arc<Iface> {
let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap(); let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap();
let ether_addr = virtio_net.disable_irq().lock().mac_addr().0; let ether_addr = virtio_net.lock().mac_addr().0;
struct Wrapper(Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>); struct Wrapper(Arc<SpinLock<dyn AnyNetworkDevice, LocalIrqDisabled>>);
impl WithDevice for Wrapper { impl WithDevice for Wrapper {
type Device = dyn AnyNetworkDevice; type Device = dyn AnyNetworkDevice;
@ -50,7 +50,7 @@ fn new_virtio() -> Arc<Iface> {
where where
F: FnOnce(&mut Self::Device) -> R, F: FnOnce(&mut Self::Device) -> R,
{ {
let mut device = self.0.disable_irq().lock(); let mut device = self.0.lock();
f(&mut *device) f(&mut *device)
} }
} }