From f7932595125a0bba8230b5f8d3b110c687d6f3b2 Mon Sep 17 00:00:00 2001 From: jiangjianfeng Date: Wed, 25 Sep 2024 11:15:58 +0000 Subject: [PATCH] Avoiding busy loop in sending packet and optimize device caps --- kernel/comps/network/src/buffer.rs | 68 ++++------ kernel/comps/network/src/driver.rs | 2 +- kernel/comps/network/src/lib.rs | 72 ++++++++-- .../comps/virtio/src/device/network/config.rs | 2 +- .../comps/virtio/src/device/network/device.rs | 128 +++++++++++++----- .../comps/virtio/src/device/socket/buffer.rs | 7 +- .../comps/virtio/src/device/socket/device.rs | 6 +- kernel/libs/aster-bigtcp/src/device.rs | 4 +- kernel/src/net/iface/init.rs | 6 +- 9 files changed, 198 insertions(+), 97 deletions(-) diff --git a/kernel/comps/network/src/buffer.rs b/kernel/comps/network/src/buffer.rs index 7d26b4e42..b2ebea896 100644 --- a/kernel/comps/network/src/buffer.rs +++ b/kernel/comps/network/src/buffer.rs @@ -1,14 +1,13 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{collections::LinkedList, sync::Arc}; +use alloc::{collections::linked_list::LinkedList, sync::Arc}; -use align_ext::AlignExt; use ostd::{ mm::{ Daddr, DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, Infallible, VmReader, VmWriter, PAGE_SIZE, }, - sync::SpinLock, + sync::{LocalIrqDisabled, SpinLock}, Pod, }; use spin::Once; @@ -18,37 +17,40 @@ use crate::dma_pool::{DmaPool, DmaSegment}; pub struct TxBuffer { dma_stream: DmaStream, nbytes: usize, - pool: &'static SpinLock>, + pool: &'static SpinLock, LocalIrqDisabled>, } impl TxBuffer { pub fn new( header: &H, packet: &[u8], - pool: &'static SpinLock>, + pool: &'static SpinLock, LocalIrqDisabled>, ) -> Self { let header = header.as_bytes(); let nbytes = header.len() + packet.len(); - let dma_stream = if let Some(stream) = get_tx_stream_from_pool(nbytes, pool) { + assert!(nbytes <= TX_BUFFER_LEN); + + let dma_stream = if let Some(stream) = pool.lock().pop_front() { stream } else { - let segment = { - let nframes = (nbytes.align_up(PAGE_SIZE)) / PAGE_SIZE; - FrameAllocOptions::new(nframes).alloc_contiguous().unwrap() - }; + let segment = FrameAllocOptions::new(TX_BUFFER_LEN / PAGE_SIZE) + .alloc_contiguous() + .unwrap(); DmaStream::map(segment, DmaDirection::ToDevice, false).unwrap() }; - let mut writer = dma_stream.writer().unwrap(); - writer.write(&mut VmReader::from(header)); - writer.write(&mut VmReader::from(packet)); - - let tx_buffer = Self { - dma_stream, - nbytes, - pool, + let tx_buffer = { + let mut writer = dma_stream.writer().unwrap(); + writer.write(&mut VmReader::from(header)); + writer.write(&mut VmReader::from(packet)); + Self { + dma_stream, + nbytes, + pool, + } }; + tx_buffer.sync(); tx_buffer } @@ -74,10 +76,7 @@ impl HasDaddr for TxBuffer { impl Drop for TxBuffer { fn drop(&mut self) { - self.pool - .disable_irq() - .lock() - .push_back(self.dma_stream.clone()); + self.pool.lock().push_back(self.dma_stream.clone()); } } @@ -139,29 +138,13 @@ impl HasDaddr for RxBuffer { } } -const RX_BUFFER_LEN: usize = 4096; +pub const RX_BUFFER_LEN: usize = 4096; +pub const TX_BUFFER_LEN: usize = 4096; pub static RX_BUFFER_POOL: Once> = Once::new(); -pub static TX_BUFFER_POOL: Once>> = Once::new(); - -fn get_tx_stream_from_pool( - nbytes: usize, - tx_buffer_pool: &'static SpinLock>, -) -> Option { - let mut pool = tx_buffer_pool.disable_irq().lock(); - let mut cursor = pool.cursor_front_mut(); - while let Some(current) = cursor.current() { - if current.nbytes() >= nbytes { - return cursor.remove_current(); - } - cursor.move_next(); - } - - None -} pub fn init() { - const POOL_INIT_SIZE: usize = 32; - const POOL_HIGH_WATERMARK: usize = 64; + const POOL_INIT_SIZE: usize = 64; + const POOL_HIGH_WATERMARK: usize = 128; RX_BUFFER_POOL.call_once(|| { DmaPool::new( RX_BUFFER_LEN, @@ -171,5 +154,4 @@ pub fn init() { false, ) }); - TX_BUFFER_POOL.call_once(|| SpinLock::new(LinkedList::new())); } diff --git a/kernel/comps/network/src/driver.rs b/kernel/comps/network/src/driver.rs index 5dcee6049..e2684b162 100644 --- a/kernel/comps/network/src/driver.rs +++ b/kernel/comps/network/src/driver.rs @@ -12,7 +12,7 @@ impl device::Device for dyn AnyNetworkDevice { type TxToken<'a> = TxToken<'a>; fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - if self.can_receive() { + if self.can_receive() && self.can_send() { let rx_buffer = self.receive().unwrap(); Some((RxToken(rx_buffer), TxToken(self))) } else { diff --git a/kernel/comps/network/src/lib.rs b/kernel/comps/network/src/lib.rs index b4d3855c7..d88ff5798 100644 --- a/kernel/comps/network/src/lib.rs +++ b/kernel/comps/network/src/lib.rs @@ -16,7 +16,7 @@ use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; use core::{any::Any, fmt::Debug}; use aster_bigtcp::device::DeviceCapabilities; -pub use buffer::{RxBuffer, TxBuffer, RX_BUFFER_POOL, TX_BUFFER_POOL}; +pub use buffer::{RxBuffer, TxBuffer, RX_BUFFER_POOL, TX_BUFFER_LEN}; use component::{init_component, ComponentInitError}; pub use dma_pool::DmaSegment; use ostd::{ @@ -33,6 +33,7 @@ pub struct EthernetAddr(pub [u8; 6]); pub enum VirtioNetError { NotReady, WrongToken, + Busy, Unknown, } @@ -51,6 +52,7 @@ pub trait AnyNetworkDevice: Send + Sync + Any + Debug { fn receive(&mut self) -> Result; /// Send a packet to network. Return until the request completes. fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError>; + fn free_processed_tx_buffers(&mut self); } pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static; @@ -64,13 +66,13 @@ pub fn register_device( .unwrap() .network_device_table .lock() - .insert(name, (Arc::new(SpinLock::new(Vec::new())), device)); + .insert(name, NetworkDeviceIrqCallbackSet::new(device)); } pub fn get_device(str: &str) -> Option>> { let table = COMPONENT.get().unwrap().network_device_table.lock(); - let (_, device) = table.get(str)?; - Some(device.clone()) + let callbacks = table.get(str)?; + Some(callbacks.device.clone()) } /// Registers callback which will be called when receiving message. @@ -79,18 +81,48 @@ pub fn get_device(str: &str) -> Option Vec<(String, NetworkDeviceRef)> { let network_devs = COMPONENT.get().unwrap().network_device_table.lock(); network_devs .iter() - .map(|(name, (_, device))| (name.clone(), device.clone())) + .map(|(name, callbacks)| (name.clone(), callbacks.device.clone())) .collect() } @@ -124,10 +156,24 @@ type NetworkDeviceRef = Arc>; struct Component { /// Device list, the key is device name, value is (callbacks, device); - network_device_table: SpinLock< - BTreeMap, - LocalIrqDisabled, - >, + network_device_table: SpinLock, LocalIrqDisabled>, +} + +/// The send callbacks and recv callbacks for a network device +struct NetworkDeviceIrqCallbackSet { + device: NetworkDeviceRef, + recv_callbacks: NetDeviceIrqHandlerListRef, + send_callbacks: NetDeviceIrqHandlerListRef, +} + +impl NetworkDeviceIrqCallbackSet { + fn new(device: NetworkDeviceRef) -> Self { + Self { + device, + recv_callbacks: Arc::new(SpinLock::new(Vec::new())), + send_callbacks: Arc::new(SpinLock::new(Vec::new())), + } + } } impl Component { diff --git a/kernel/comps/virtio/src/device/network/config.rs b/kernel/comps/virtio/src/device/network/config.rs index f29794963..d09d49847 100644 --- a/kernel/comps/virtio/src/device/network/config.rs +++ b/kernel/comps/virtio/src/device/network/config.rs @@ -63,7 +63,7 @@ pub struct VirtioNetConfig { pub mac: EthernetAddr, pub status: Status, max_virtqueue_pairs: u16, - mtu: u16, + pub mtu: u16, speed: u32, duplex: u8, rss_max_key_size: u8, diff --git a/kernel/comps/virtio/src/device/network/device.rs b/kernel/comps/virtio/src/device/network/device.rs index 75f9df9fb..890ce3021 100644 --- a/kernel/comps/virtio/src/device/network/device.rs +++ b/kernel/comps/virtio/src/device/network/device.rs @@ -1,16 +1,21 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, string::ToString, sync::Arc}; -use core::{fmt::Debug, hint::spin_loop, mem::size_of}; +use alloc::{ + boxed::Box, collections::linked_list::LinkedList, string::ToString, sync::Arc, vec::Vec, +}; +use core::{fmt::Debug, mem::size_of}; -use aster_bigtcp::device::{DeviceCapabilities, Medium}; +use aster_bigtcp::device::{Checksum, DeviceCapabilities, Medium}; use aster_network::{ AnyNetworkDevice, EthernetAddr, RxBuffer, TxBuffer, VirtioNetError, RX_BUFFER_POOL, - TX_BUFFER_POOL, }; use aster_util::slot_vec::SlotVec; use log::debug; -use ostd::{sync::SpinLock, trap::TrapFrame}; +use ostd::{ + mm::DmaStream, + sync::{LocalIrqDisabled, SpinLock}, + trap::TrapFrame, +}; use super::{config::VirtioNetConfig, header::VirtioNetHdr}; use crate::{ @@ -21,9 +26,15 @@ use crate::{ pub struct NetworkDevice { config: VirtioNetConfig, + // For smoltcp use + caps: DeviceCapabilities, mac_addr: EthernetAddr, send_queue: VirtQueue, recv_queue: VirtQueue, + // Since the virtio net header remains consistent for each sending packet, + // we store it to avoid recreating the header repeatedly. + header: VirtioNetHdr, + tx_buffers: Vec>, rx_buffers: SlotVec, transport: Box, } @@ -48,11 +59,15 @@ impl NetworkDevice { let config = VirtioNetConfig::read(&virtio_net_config).unwrap(); let mac_addr = config.mac; debug!("mac addr = {:x?}, status = {:?}", mac_addr, config.status); + let caps = init_caps(&features, &config); + + let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut()) + .expect("create send queue fails"); let mut recv_queue = VirtQueue::new(QUEUE_RECV, QUEUE_SIZE, transport.as_mut()) .expect("creating recv queue fails"); - let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut()) - .expect("create send queue fails"); + + let tx_buffers = (0..QUEUE_SIZE).map(|_| None).collect(); let mut rx_buffers = SlotVec::new(); for i in 0..QUEUE_SIZE { @@ -68,11 +83,15 @@ impl NetworkDevice { debug!("notify receive queue"); recv_queue.notify(); } + let mut device = Self { config, + caps, mac_addr, send_queue, recv_queue, + header: VirtioNetHdr::default(), + tx_buffers, rx_buffers, transport, }; @@ -82,8 +101,11 @@ impl NetworkDevice { debug!("network device config space change"); } - /// Interrupt handler if network device receives some packet - fn handle_network_event(_: &TrapFrame) { + /// Interrupt handlers if network device receives/sends some packet + fn handle_send_event(_: &TrapFrame) { + aster_network::handle_send_irq(super::DEVICE_NAME); + } + fn handle_recv_event(_: &TrapFrame) { aster_network::handle_recv_irq(super::DEVICE_NAME); } @@ -93,8 +115,13 @@ impl NetworkDevice { .unwrap(); device .transport - .register_queue_callback(QUEUE_RECV, Box::new(handle_network_event), false) + .register_queue_callback(QUEUE_SEND, Box::new(handle_send_event), false) .unwrap(); + device + .transport + .register_queue_callback(QUEUE_RECV, Box::new(handle_recv_event), false) + .unwrap(); + device.transport.finish_init(); aster_network::register_device( @@ -139,29 +166,23 @@ impl NetworkDevice { /// Send a packet to network. Return until the request completes. /// FIEME: Replace tx_buffer with VM segment-based data structure to use dma mapping. fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError> { - let header = VirtioNetHdr::default(); - let tx_pool = TX_BUFFER_POOL.get().unwrap(); - let tx_buffer = TxBuffer::new(&header, packet, tx_pool); + if !self.can_send() { + return Err(VirtioNetError::Busy); + } + + let tx_buffer = TxBuffer::new(&self.header, packet, &TX_BUFFER_POOL); let token = self .send_queue .add_dma_buf(&[&tx_buffer], &[]) .map_err(queue_to_network_error)?; - if self.send_queue.should_notify() { self.send_queue.notify(); } - // Wait until the buffer is used - while !self.send_queue.can_pop() { - spin_loop(); - } - // Pop out the buffer, so we can reuse the send queue further - let (pop_token, _) = self.send_queue.pop_used().map_err(queue_to_network_error)?; - debug_assert!(pop_token == token); - if pop_token != token { - return Err(VirtioNetError::WrongToken); - } - debug!("send packet succeeds"); + + debug_assert!(self.tx_buffers[token as usize].is_none()); + self.tx_buffers[token as usize] = Some(tx_buffer); + Ok(()) } } @@ -170,21 +191,57 @@ fn queue_to_network_error(err: QueueError) -> VirtioNetError { match err { QueueError::NotReady => VirtioNetError::NotReady, QueueError::WrongToken => VirtioNetError::WrongToken, + QueueError::BufferTooSmall => VirtioNetError::Busy, _ => VirtioNetError::Unknown, } } +fn init_caps(features: &NetworkFeatures, config: &VirtioNetConfig) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + + caps.max_burst_size = None; + caps.medium = Medium::Ethernet; + + if features.contains(NetworkFeatures::VIRTIO_NET_F_MTU) { + // If `VIRTIO_NET_F_MTU` is negotiated, the MTU is decided by the device. + caps.max_transmission_unit = config.mtu as usize; + } else { + // We do not support these features, + // so this asserts that they are _not_ negotiated. + // + // Without these features, the MTU is 1514 bytes per the virtio-net specification + // (see "5.1.6.3 Setting Up Receive Buffers" and "5.1.6.2 Packet Transmission"). + assert!( + !features.contains(NetworkFeatures::VIRTIO_NET_F_GUEST_TSO4) + && !features.contains(NetworkFeatures::VIRTIO_NET_F_GUEST_TSO6) + && !features.contains(NetworkFeatures::VIRTIO_NET_F_GUEST_UFO) + ); + caps.max_transmission_unit = 1514; + } + + // We do not support checksum offloading. + // So the features must not be negotiated, + // and we must deliver fully checksummed packets to the device + // and validate all checksums for packets from the device. + assert!( + !features.contains(NetworkFeatures::VIRTIO_NET_F_CSUM) + && !features.contains(NetworkFeatures::VIRTIO_NET_F_GUEST_CSUM) + ); + caps.checksum.tcp = Checksum::Both; + caps.checksum.udp = Checksum::Both; + caps.checksum.ipv4 = Checksum::Both; + caps.checksum.icmpv4 = Checksum::Both; + + caps +} + impl AnyNetworkDevice for NetworkDevice { fn mac_addr(&self) -> EthernetAddr { self.mac_addr } fn capabilities(&self) -> DeviceCapabilities { - let mut caps = DeviceCapabilities::default(); - caps.max_transmission_unit = 1536; - caps.max_burst_size = Some(1); - caps.medium = Medium::Ethernet; - caps + self.caps.clone() } fn can_receive(&self) -> bool { @@ -192,7 +249,7 @@ impl AnyNetworkDevice for NetworkDevice { } fn can_send(&self) -> bool { - self.send_queue.available_desc() >= 2 + self.send_queue.available_desc() >= 1 } fn receive(&mut self) -> Result { @@ -202,6 +259,12 @@ impl AnyNetworkDevice for NetworkDevice { fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError> { self.send(packet) } + + fn free_processed_tx_buffers(&mut self) { + while let Ok((token, _)) = self.send_queue.pop_used() { + self.tx_buffers[token as usize] = None; + } + } } impl Debug for NetworkDevice { @@ -216,6 +279,9 @@ impl Debug for NetworkDevice { } } +static TX_BUFFER_POOL: SpinLock, LocalIrqDisabled> = + SpinLock::new(LinkedList::new()); + const QUEUE_RECV: u16 = 0; const QUEUE_SEND: u16 = 1; diff --git a/kernel/comps/virtio/src/device/socket/buffer.rs b/kernel/comps/virtio/src/device/socket/buffer.rs index 7463e821e..c52645816 100644 --- a/kernel/comps/virtio/src/device/socket/buffer.rs +++ b/kernel/comps/virtio/src/device/socket/buffer.rs @@ -1,17 +1,18 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{collections::LinkedList, sync::Arc}; +use alloc::{collections::linked_list::LinkedList, sync::Arc}; use aster_network::dma_pool::DmaPool; use ostd::{ mm::{DmaDirection, DmaStream}, - sync::SpinLock, + sync::{LocalIrqDisabled, SpinLock}, }; use spin::Once; const RX_BUFFER_LEN: usize = 4096; +const TX_BUFFER_LEN: usize = 4096; pub static RX_BUFFER_POOL: Once> = Once::new(); -pub static TX_BUFFER_POOL: Once>> = Once::new(); +pub static TX_BUFFER_POOL: Once, LocalIrqDisabled>> = Once::new(); pub fn init() { const POOL_INIT_SIZE: usize = 32; diff --git a/kernel/comps/virtio/src/device/socket/device.rs b/kernel/comps/virtio/src/device/socket/device.rs index 9f11412c3..dab881ec7 100644 --- a/kernel/comps/virtio/src/device/socket/device.rs +++ b/kernel/comps/virtio/src/device/socket/device.rs @@ -190,8 +190,10 @@ impl SocketDevice { ) -> Result<(), SocketError> { debug!("Sent packet {:?}. Op {:?}", header, header.op()); debug!("buffer in send_packet_to_tx_queue: {:?}", buffer); - let tx_pool = TX_BUFFER_POOL.get().unwrap(); - let tx_buffer = TxBuffer::new(header, buffer, tx_pool); + let tx_buffer = { + let pool = TX_BUFFER_POOL.get().unwrap(); + TxBuffer::new(header, buffer, pool) + }; let token = self.send_queue.add_dma_buf(&[&tx_buffer], &[])?; diff --git a/kernel/libs/aster-bigtcp/src/device.rs b/kernel/libs/aster-bigtcp/src/device.rs index f81f5a9eb..5638d1e44 100644 --- a/kernel/libs/aster-bigtcp/src/device.rs +++ b/kernel/libs/aster-bigtcp/src/device.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MPL-2.0 -pub use smoltcp::phy::{Device, DeviceCapabilities, Loopback, Medium, RxToken, TxToken}; +pub use smoltcp::phy::{ + Checksum, ChecksumCapabilities, Device, DeviceCapabilities, Loopback, Medium, RxToken, TxToken, +}; /// A trait that allows to obtain a mutable reference of [`Device`]. /// diff --git a/kernel/src/net/iface/init.rs b/kernel/src/net/iface/init.rs index 246aba5b5..e3bb2f987 100644 --- a/kernel/src/net/iface/init.rs +++ b/kernel/src/net/iface/init.rs @@ -22,11 +22,13 @@ pub fn init() { }); for (name, _) in aster_network::all_devices() { - aster_network::register_recv_callback(&name, || { + let callback = || { // TODO: further check that the irq num is the same as iface's irq num let iface_virtio = &IFACES.get().unwrap()[0]; iface_virtio.poll(); - }) + }; + aster_network::register_recv_callback(&name, callback); + aster_network::register_send_callback(&name, callback); } poll_ifaces();