diff --git a/Cargo.lock b/Cargo.lock index 908efea7c..c4c0c7b23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -176,9 +176,10 @@ dependencies = [ "aster-rights", "aster-util", "bitflags 1.3.2", - "bytes", + "bitvec", "component", "int-to-c-enum", + "ktest", "log", "pod", "ringbuf", diff --git a/framework/aster-frame/src/task/task.rs b/framework/aster-frame/src/task/task.rs index 8e7891068..90ceef391 100644 --- a/framework/aster-frame/src/task/task.rs +++ b/framework/aster-frame/src/task/task.rs @@ -51,9 +51,7 @@ pub struct KernelStack { impl KernelStack { pub fn new() -> Result { Ok(Self { - segment: VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE) - .is_contiguous(true) - .alloc_contiguous()?, + segment: VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE).alloc_contiguous()?, old_guard_page_flag: None, }) } @@ -61,9 +59,8 @@ impl KernelStack { /// Generate a kernel stack with a guard page. /// An additional page is allocated and be regarded as a guard page, which should not be accessed. pub fn new_with_guard_page() -> Result { - let stack_segment = VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE + 1) - .is_contiguous(true) - .alloc_contiguous()?; + let stack_segment = + VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE + 1).alloc_contiguous()?; let unpresent_flag = PageTableFlags::empty(); let old_guard_page_flag = Self::protect_guard_page(&stack_segment, unpresent_flag); Ok(Self { diff --git a/framework/aster-frame/src/vm/options.rs b/framework/aster-frame/src/vm/options.rs index 40d6325f6..898d58040 100644 --- a/framework/aster-frame/src/vm/options.rs +++ b/framework/aster-frame/src/vm/options.rs @@ -84,7 +84,8 @@ impl VmAllocOptions { /// /// The returned `VmSegment` contains at least one page frame. pub fn alloc_contiguous(&self) -> Result { - if !self.is_contiguous || self.nframes == 0 { + // It's no use to checking `self.is_contiguous` here. + if self.nframes == 0 { return Err(Error::InvalidArgs); } diff --git a/kernel/aster-nix/src/device/tty/driver.rs b/kernel/aster-nix/src/device/tty/driver.rs index 951dd1084..d4ececf35 100644 --- a/kernel/aster-nix/src/device/tty/driver.rs +++ b/kernel/aster-nix/src/device/tty/driver.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 -pub use aster_frame::arch::console::register_console_input_callback; +pub use aster_frame::arch::console; +use aster_frame::vm::VmReader; use spin::Once; use crate::{ @@ -62,24 +63,25 @@ impl TtyDriver { Ok(()) } - pub fn receive_char(&self, item: u8) { + pub fn push_char(&self, ch: u8) { // FIXME: should the char send to all ttys? for tty in &*self.ttys.lock_irq_disabled() { - tty.receive_char(item); + tty.push_char(ch); } } } -fn console_input_callback(items: &[u8]) { +fn console_input_callback(mut reader: VmReader) { let tty_driver = get_tty_driver(); - for item in items { - tty_driver.receive_char(*item); + while reader.remain() > 0 { + let ch = reader.read_val(); + tty_driver.push_char(ch); } } fn serial_input_callback(item: u8) { let tty_driver = get_tty_driver(); - tty_driver.receive_char(item); + tty_driver.push_char(item); } fn get_tty_driver() -> &'static TtyDriver { diff --git a/kernel/aster-nix/src/device/tty/mod.rs b/kernel/aster-nix/src/device/tty/mod.rs index 9971ff888..eda0c574d 100644 --- a/kernel/aster-nix/src/device/tty/mod.rs +++ b/kernel/aster-nix/src/device/tty/mod.rs @@ -1,5 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 +use aster_frame::early_print; use spin::Once; use self::{driver::TtyDriver, line_discipline::LineDiscipline}; @@ -61,8 +62,11 @@ impl Tty { *self.driver.lock_irq_disabled() = driver; } - pub fn receive_char(&self, ch: u8) { - self.ldisc.push_char(ch, |content| print!("{}", content)); + pub fn push_char(&self, ch: u8) { + // FIXME: Use `early_print` to avoid calling virtio-console. + // This is only a workaround + self.ldisc + .push_char(ch, |content| early_print!("{}", content)) } } diff --git a/kernel/aster-nix/src/fs/ext2/fs.rs b/kernel/aster-nix/src/fs/ext2/fs.rs index e93aaeb8a..f34a9d598 100644 --- a/kernel/aster-nix/src/fs/ext2/fs.rs +++ b/kernel/aster-nix/src/fs/ext2/fs.rs @@ -42,7 +42,6 @@ impl Ext2 { .div_ceil(BLOCK_SIZE); let segment = VmAllocOptions::new(npages) .uninit(true) - .is_contiguous(true) .alloc_contiguous()?; match block_device.read_blocks_sync(super_block.group_descriptors_bid(0), &segment)? { BioStatus::Complete => (), diff --git a/kernel/aster-nix/src/lib.rs b/kernel/aster-nix/src/lib.rs index 85fe56061..8fb1cee0b 100644 --- a/kernel/aster-nix/src/lib.rs +++ b/kernel/aster-nix/src/lib.rs @@ -80,6 +80,9 @@ fn init_thread() { "[kernel] Spawn init thread, tid = {}", current_thread!().tid() ); + // Work queue should be initialized before interrupt is enabled, + // in case any irq handler uses work queue as bottom half + thread::work_queue::init(); // FIXME: Remove this if we move the step of mounting // the filesystems to be done within the init process. aster_frame::trap::enable_local(); @@ -97,7 +100,6 @@ fn init_thread() { "[aster-nix/lib.rs] spawn kernel thread, tid = {}", thread.tid() ); - thread::work_queue::init(); print_banner(); diff --git a/kernel/comps/block/src/impl_block_device.rs b/kernel/comps/block/src/impl_block_device.rs index 6f9477168..1f165b78a 100644 --- a/kernel/comps/block/src/impl_block_device.rs +++ b/kernel/comps/block/src/impl_block_device.rs @@ -99,7 +99,6 @@ impl VmIo for dyn BlockDevice { }; let segment = VmAllocOptions::new(num_blocks as usize) .uninit(true) - .is_contiguous(true) .alloc_contiguous()?; let bio_segment = BioSegment::from_segment(segment, offset % BLOCK_SIZE, buf.len()); @@ -141,7 +140,6 @@ impl VmIo for dyn BlockDevice { }; let segment = VmAllocOptions::new(num_blocks as usize) .uninit(true) - .is_contiguous(true) .alloc_contiguous()?; segment.write_bytes(offset % BLOCK_SIZE, buf)?; let len = segment @@ -183,7 +181,6 @@ impl dyn BlockDevice { }; let segment = VmAllocOptions::new(num_blocks as usize) .uninit(true) - .is_contiguous(true) .alloc_contiguous()?; segment.write_bytes(offset % BLOCK_SIZE, buf)?; let len = segment diff --git a/kernel/comps/console/src/lib.rs b/kernel/comps/console/src/lib.rs index b81e817dd..697ea8e3e 100644 --- a/kernel/comps/console/src/lib.rs +++ b/kernel/comps/console/src/lib.rs @@ -10,17 +10,20 @@ extern crate alloc; use alloc::{collections::BTreeMap, fmt::Debug, string::String, sync::Arc, vec::Vec}; use core::any::Any; -use aster_frame::sync::SpinLock; +use aster_frame::{sync::SpinLock, vm::VmReader}; use component::{init_component, ComponentInitError}; use spin::Once; -pub type ConsoleCallback = dyn Fn(&[u8]) + Send + Sync; +pub type ConsoleCallback = dyn Fn(VmReader) + Send + Sync; pub trait AnyConsoleDevice: Send + Sync + Any + Debug { fn send(&self, buf: &[u8]); - fn recv(&self, buf: &mut [u8]) -> Option; + /// Registers callback to the console device. + /// The callback will be called once the console device receive data. + /// + /// Since the callback will be called in interrupt context, + /// the callback should NEVER sleep. fn register_callback(&self, callback: &'static ConsoleCallback); - fn handle_irq(&self); } pub fn register_device(name: String, device: Arc) { @@ -32,16 +35,6 @@ pub fn register_device(name: String, device: Arc) { .insert(name, device); } -pub fn get_device(str: &str) -> Option> { - COMPONENT - .get() - .unwrap() - .console_device_table - .lock_irq_disabled() - .get(str) - .cloned() -} - pub fn all_devices() -> Vec<(String, Arc)> { let console_devs = COMPONENT .get() diff --git a/kernel/comps/input/src/lib.rs b/kernel/comps/input/src/lib.rs index 2bfd95cc7..c6e4d4b51 100644 --- a/kernel/comps/input/src/lib.rs +++ b/kernel/comps/input/src/lib.rs @@ -23,7 +23,6 @@ pub enum InputEvent { } pub trait InputDevice: Send + Sync + Any + Debug { - fn handle_irq(&self) -> Option<()>; fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)); } diff --git a/kernel/comps/network/Cargo.toml b/kernel/comps/network/Cargo.toml index 60562e894..9825080cf 100644 --- a/kernel/comps/network/Cargo.toml +++ b/kernel/comps/network/Cargo.toml @@ -6,16 +6,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -component = { path = "../../libs/comp-sys/component" } +align_ext = { path = "../../../framework/libs/align_ext" } aster-frame = { path = "../../../framework/aster-frame" } aster-util = { path = "../../libs/aster-util" } aster-rights = { path = "../../libs/aster-rights" } -align_ext = { path = "../../../framework/libs/align_ext" } -int-to-c-enum = { path = "../../libs/int-to-c-enum" } -bytes = { version = "1.4.0", default-features = false } -pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" } bitflags = "1.3" -spin = "0.9.4" -ringbuf = { version = "0.3.2", default-features = false, features = ["alloc"] } +bitvec = { version = "1.0.1", default-features = false, features = ["alloc"]} +component = { path = "../../libs/comp-sys/component" } +int-to-c-enum = { path = "../../libs/int-to-c-enum" } +ktest = { path = "../../../framework/libs/ktest" } log = "0.4" -smoltcp = { version = "0.9.1", default-features = false, features = ["alloc", "log", "medium-ethernet", "medium-ip", "proto-dhcpv4", "proto-ipv4", "proto-igmp", "socket-icmp", "socket-udp", "socket-tcp", "socket-raw", "socket-dhcpv4"] } \ No newline at end of file +pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" } +ringbuf = { version = "0.3.2", default-features = false, features = ["alloc"] } +smoltcp = { version = "0.9.1", default-features = false, features = ["alloc", "log", "medium-ethernet", "medium-ip", "proto-dhcpv4", "proto-ipv4", "proto-igmp", "socket-icmp", "socket-udp", "socket-tcp", "socket-raw", "socket-dhcpv4"] } +spin = "0.9.4" \ No newline at end of file diff --git a/kernel/comps/network/src/buffer.rs b/kernel/comps/network/src/buffer.rs index e0322c273..a2ecaad92 100644 --- a/kernel/comps/network/src/buffer.rs +++ b/kernel/comps/network/src/buffer.rs @@ -1,30 +1,89 @@ // SPDX-License-Identifier: MPL-2.0 -use core::mem::size_of; +use alloc::{collections::LinkedList, sync::Arc}; use align_ext::AlignExt; -use bytes::BytesMut; +use aster_frame::{ + sync::SpinLock, + vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmReader, VmWriter, PAGE_SIZE}, +}; use pod::Pod; +use spin::Once; + +use crate::dma_pool::{DmaPool, DmaSegment}; + +pub struct TxBuffer { + dma_stream: DmaStream, + nbytes: usize, +} + +impl TxBuffer { + pub fn new(header: &H, packet: &[u8]) -> Self { + let header = header.as_bytes(); + let nbytes = header.len() + packet.len(); + + let dma_stream = if let Some(stream) = get_tx_stream_from_pool(nbytes) { + stream + } else { + let segment = { + let nframes = (nbytes.align_up(PAGE_SIZE)) / PAGE_SIZE; + VmAllocOptions::new(nframes).alloc_contiguous().unwrap() + }; + DmaStream::map(segment, DmaDirection::ToDevice, false).unwrap() + }; + + let mut writer = dma_stream.writer().unwrap(); + writer.write(&mut VmReader::from(header)); + writer.write(&mut VmReader::from(packet)); + + let tx_buffer = Self { dma_stream, nbytes }; + tx_buffer.sync(); + tx_buffer + } + + pub fn writer(&self) -> VmWriter<'_> { + self.dma_stream.writer().unwrap().limit(self.nbytes) + } + + fn sync(&self) { + self.dma_stream.sync(0..self.nbytes).unwrap(); + } + + pub fn nbytes(&self) -> usize { + self.nbytes + } +} + +impl HasDaddr for TxBuffer { + fn daddr(&self) -> Daddr { + self.dma_stream.daddr() + } +} + +impl Drop for TxBuffer { + fn drop(&mut self) { + TX_BUFFER_POOL + .get() + .unwrap() + .lock_irq_disabled() + .push_back(self.dma_stream.clone()); + } +} -/// Buffer for receive packet -#[derive(Debug)] pub struct RxBuffer { - /// Packet Buffer, length align 8. - buf: BytesMut, - /// Header len + segment: DmaSegment, header_len: usize, - /// Packet len packet_len: usize, } impl RxBuffer { - pub fn new(len: usize, header_len: usize) -> Self { - let len = len.align_up(8); - let buf = BytesMut::zeroed(len); + pub fn new(header_len: usize) -> Self { + assert!(header_len <= RX_BUFFER_LEN); + let segment = RX_BUFFER_POOL.get().unwrap().alloc_segment().unwrap(); Self { - buf, - packet_len: 0, + segment, header_len, + packet_len: 0, } } @@ -33,59 +92,60 @@ impl RxBuffer { } pub fn set_packet_len(&mut self, packet_len: usize) { + assert!(self.header_len + packet_len <= RX_BUFFER_LEN); self.packet_len = packet_len; } - pub fn buf(&self) -> &[u8] { - &self.buf + pub fn packet(&self) -> VmReader<'_> { + self.segment + .sync(self.header_len..self.header_len + self.packet_len) + .unwrap(); + self.segment + .reader() + .unwrap() + .skip(self.header_len) + .limit(self.packet_len) } - pub fn buf_mut(&mut self) -> &mut [u8] { - &mut self.buf - } - - /// Packet payload slice, which is inner buffer excluding VirtioNetHdr. - pub fn packet(&self) -> &[u8] { - debug_assert!(self.header_len + self.packet_len <= self.buf.len()); - &self.buf[self.header_len..self.header_len + self.packet_len] - } - - /// Mutable packet payload slice. - pub fn packet_mut(&mut self) -> &mut [u8] { - debug_assert!(self.header_len + self.packet_len <= self.buf.len()); - &mut self.buf[self.header_len..self.header_len + self.packet_len] - } - - pub fn header(&self) -> H { - debug_assert_eq!(size_of::(), self.header_len); - H::from_bytes(&self.buf[..size_of::()]) + pub const fn buf_len(&self) -> usize { + self.segment.size() } } -/// Buffer for transmit packet -#[derive(Debug)] -pub struct TxBuffer { - buf: BytesMut, +impl HasDaddr for RxBuffer { + fn daddr(&self) -> Daddr { + self.segment.daddr() + } } -impl TxBuffer { - pub fn with_len(buf_len: usize) -> Self { - Self { - buf: BytesMut::zeroed(buf_len), +const RX_BUFFER_LEN: usize = 4096; +static RX_BUFFER_POOL: Once> = Once::new(); +static TX_BUFFER_POOL: Once>> = Once::new(); + +fn get_tx_stream_from_pool(nbytes: usize) -> Option { + let mut pool = TX_BUFFER_POOL.get().unwrap().lock_irq_disabled(); + let mut cursor = pool.cursor_front_mut(); + while let Some(current) = cursor.current() { + if current.nbytes() >= nbytes { + return cursor.remove_current(); } + cursor.move_next(); } - pub fn new(buf: &[u8]) -> Self { - Self { - buf: BytesMut::from(buf), - } - } - - pub fn buf(&self) -> &[u8] { - &self.buf - } - - pub fn buf_mut(&mut self) -> &mut [u8] { - &mut self.buf - } + None +} + +pub fn init() { + const POOL_INIT_SIZE: usize = 32; + const POOL_HIGH_WATERMARK: usize = 64; + RX_BUFFER_POOL.call_once(|| { + DmaPool::new( + RX_BUFFER_LEN, + POOL_INIT_SIZE, + POOL_HIGH_WATERMARK, + DmaDirection::FromDevice, + false, + ) + }); + TX_BUFFER_POOL.call_once(|| SpinLock::new(LinkedList::new())); } diff --git a/kernel/comps/network/src/dma_pool.rs b/kernel/comps/network/src/dma_pool.rs new file mode 100644 index 000000000..de6c64862 --- /dev/null +++ b/kernel/comps/network/src/dma_pool.rs @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MPL-2.0 + +#![allow(unused)] + +use alloc::{ + collections::VecDeque, + sync::{Arc, Weak}, +}; +use core::ops::Range; + +use aster_frame::{ + sync::{RwLock, SpinLock}, + vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmReader, VmWriter, PAGE_SIZE}, +}; +use bitvec::{array::BitArray, prelude::Lsb0}; +use ktest::ktest; + +/// `DmaPool` is responsible for allocating small streaming DMA segments +/// (equal to or smaller than PAGE_SIZE), +/// referred to as `DmaSegment`. +/// +/// A `DmaPool` can only allocate `DmaSegment` of a fixed size. +/// Once a `DmaSegment` is dropped, it will be returned to the pool. +/// If the `DmaPool` is dropped before the associated `DmaSegment`, +/// the `drop` method of the `DmaSegment` will panic. +/// +/// Therefore, as a best practice, +/// it is recommended for the `DmaPool` to have a static lifetime. +#[derive(Debug)] +pub struct DmaPool { + segment_size: usize, + direction: DmaDirection, + is_cache_coherent: bool, + high_watermark: usize, + avail_pages: SpinLock>>, + all_pages: SpinLock>>, +} + +impl DmaPool { + /// Constructs a new `DmaPool` with a specified initial capacity and a high watermark. + /// + /// The `DmaPool` starts with `init_size` DMAable pages. + /// As additional DMA blocks are requested beyond the initial capacity, + /// the pool dynamically allocates more DMAable pages. + /// To optimize performance, the pool employs a lazy deallocation strategy: + /// A DMAable page is freed only if it meets the following conditions: + /// 1. The page is currently not in use; + /// 2. The total number of allocated DMAable pages exceeds the specified `high_watermark`. + /// + /// The returned pool can be used to allocate small segments for DMA usage. + /// All allocated segments will have the same DMA direction + /// and will either all be cache coherent or not cache coherent, + /// as specified in the parameters. + pub fn new( + segment_size: usize, + init_size: usize, + high_watermark: usize, + direction: DmaDirection, + is_cache_coherent: bool, + ) -> Arc { + assert!(segment_size.is_power_of_two()); + assert!(segment_size >= 64); + assert!(segment_size <= PAGE_SIZE); + assert!(high_watermark >= init_size); + + Arc::new_cyclic(|pool| { + let mut avail_pages = VecDeque::new(); + let mut all_pages = VecDeque::new(); + + for _ in 0..init_size { + let page = Arc::new( + DmaPage::new( + segment_size, + direction, + is_cache_coherent, + Weak::clone(pool), + ) + .unwrap(), + ); + avail_pages.push_back(page.clone()); + all_pages.push_back(page); + } + + Self { + segment_size, + direction, + is_cache_coherent, + high_watermark, + avail_pages: SpinLock::new(avail_pages), + all_pages: SpinLock::new(all_pages), + } + }) + } + + /// Allocates a `DmaSegment` from the pool + pub fn alloc_segment(self: &Arc) -> Result { + // Lock order: pool.avail_pages -> pool.all_pages + // pool.avail_pages -> page.allocated_segments + let mut avail_pages = self.avail_pages.lock_irq_disabled(); + if avail_pages.is_empty() { + /// Allocate a new page + let new_page = { + let pool = Arc::downgrade(self); + Arc::new(DmaPage::new( + self.segment_size, + self.direction, + self.is_cache_coherent, + pool, + )?) + }; + let mut all_pages = self.all_pages.lock_irq_disabled(); + avail_pages.push_back(new_page.clone()); + all_pages.push_back(new_page); + } + + let first_avail_page = avail_pages.front().unwrap(); + let free_segment = first_avail_page.alloc_segment().unwrap(); + if first_avail_page.is_full() { + avail_pages.pop_front(); + } + Ok(free_segment) + } + + /// Returns the number of pages in pool + fn num_pages(&self) -> usize { + self.all_pages.lock_irq_disabled().len() + } +} + +#[derive(Debug)] +struct DmaPage { + storage: DmaStream, + segment_size: usize, + // `BitArray` is 64 bits, since each `DmaSegment` is bigger than 64 bytes, + // there's no more than `PAGE_SIZE` / 64 = 64 `DmaSegment`s in a `DmaPage`. + allocated_segments: SpinLock, + pool: Weak, +} + +impl DmaPage { + fn new( + segment_size: usize, + direction: DmaDirection, + is_cache_coherent: bool, + pool: Weak, + ) -> Result { + let dma_stream = { + let vm_segment = VmAllocOptions::new(1).alloc_contiguous()?; + + DmaStream::map(vm_segment, direction, is_cache_coherent) + .map_err(|_| aster_frame::Error::AccessDenied)? + }; + + Ok(Self { + storage: dma_stream, + segment_size, + allocated_segments: SpinLock::new(BitArray::ZERO), + pool, + }) + } + + fn alloc_segment(self: &Arc) -> Option { + let mut segments = self.allocated_segments.lock_irq_disabled(); + let free_segment_index = get_next_free_index(&segments, self.nr_blocks_per_page())?; + segments.set(free_segment_index, true); + + let segment = DmaSegment { + size: self.segment_size, + dma_stream: self.storage.clone(), + start_addr: self.storage.daddr() + free_segment_index * self.segment_size, + page: Arc::downgrade(self), + }; + + Some(segment) + } + + fn is_free(&self) -> bool { + *self.allocated_segments.lock() == BitArray::<[usize; 1], Lsb0>::ZERO + } + + const fn nr_blocks_per_page(&self) -> usize { + PAGE_SIZE / self.segment_size + } + + fn is_full(&self) -> bool { + let segments = self.allocated_segments.lock_irq_disabled(); + get_next_free_index(&segments, self.nr_blocks_per_page()).is_none() + } +} + +fn get_next_free_index(segments: &BitArray, nr_blocks_per_page: usize) -> Option { + let free_segment_index = segments.iter_zeros().next()?; + + if free_segment_index >= nr_blocks_per_page { + None + } else { + Some(free_segment_index) + } +} + +impl HasDaddr for DmaPage { + fn daddr(&self) -> Daddr { + self.storage.daddr() + } +} + +/// A small and fixed-size segment of DMA memory. +/// +/// The size of `DmaSegment` ranges from 64 bytes to `PAGE_SIZE` and must be 2^K. +/// Each `DmaSegment`'s daddr must be aligned with its size. +#[derive(Debug)] +pub struct DmaSegment { + dma_stream: DmaStream, + start_addr: Daddr, + size: usize, + page: Weak, +} + +impl HasDaddr for DmaSegment { + fn daddr(&self) -> Daddr { + self.start_addr + } +} + +impl DmaSegment { + pub const fn size(&self) -> usize { + self.size + } + + pub fn reader(&self) -> Result, aster_frame::Error> { + let offset = self.start_addr - self.dma_stream.daddr(); + Ok(self.dma_stream.reader()?.skip(offset).limit(self.size)) + } + + pub fn writer(&self) -> Result, aster_frame::Error> { + let offset = self.start_addr - self.dma_stream.daddr(); + Ok(self.dma_stream.writer()?.skip(offset).limit(self.size)) + } + + pub fn sync(&self, byte_range: Range) -> Result<(), aster_frame::Error> { + let offset = self.daddr() - self.dma_stream.daddr(); + let range = byte_range.start + offset..byte_range.end + offset; + self.dma_stream.sync(range) + } +} + +impl Drop for DmaSegment { + fn drop(&mut self) { + let page = self.page.upgrade().unwrap(); + let pool = page.pool.upgrade().unwrap(); + + // Keep the same lock order as `pool.alloc_segment` + // Lock order: pool.avail_pages -> pool.all_pages -> page.allocated_segments + let mut avail_pages = pool.avail_pages.lock_irq_disabled(); + let mut all_pages = pool.all_pages.lock_irq_disabled(); + + let mut allocated_segments = page.allocated_segments.lock_irq_disabled(); + + let nr_blocks_per_page = PAGE_SIZE / self.size; + let became_avail = get_next_free_index(&allocated_segments, nr_blocks_per_page).is_none(); + + debug_assert!((page.daddr()..page.daddr() + PAGE_SIZE).contains(&self.daddr())); + let segment_idx = (self.daddr() - page.daddr()) / self.size; + allocated_segments.set(segment_idx, false); + + let became_free = allocated_segments.not_any(); + + if became_free && all_pages.len() > pool.high_watermark { + avail_pages.retain(|page_| !Arc::ptr_eq(page_, &page)); + all_pages.retain(|page_| !Arc::ptr_eq(page_, &page)); + return; + } + + if became_avail { + avail_pages.push_back(page.clone()); + } + } +} + +#[cfg(ktest)] +mod test { + use alloc::vec::Vec; + + use super::*; + + #[ktest] + fn alloc_page_size_segment() { + let pool = DmaPool::new(PAGE_SIZE, 0, 100, DmaDirection::ToDevice, false); + let segments1: Vec<_> = (0..100) + .map(|_| { + let segment = pool.alloc_segment().unwrap(); + assert_eq!(segment.size(), PAGE_SIZE); + assert!(segment.reader().is_err()); + assert!(segment.writer().is_ok()); + segment + }) + .collect(); + + assert_eq!(pool.num_pages(), 100); + drop(segments1); + } + + #[ktest] + fn write_to_dma_segment() { + let pool: Arc = DmaPool::new(PAGE_SIZE, 1, 2, DmaDirection::ToDevice, false); + let segment = pool.alloc_segment().unwrap(); + let mut writer = segment.writer().unwrap(); + let data = &[0u8, 1, 2, 3, 4] as &[u8]; + let size = writer.write(&mut VmReader::from(data)); + assert_eq!(size, data.len()); + } + + #[ktest] + fn free_pool_pages() { + let pool: Arc = DmaPool::new(PAGE_SIZE, 10, 50, DmaDirection::ToDevice, false); + let segments1: Vec<_> = (0..100) + .map(|_| { + let segment = pool.alloc_segment().unwrap(); + assert_eq!(segment.size(), PAGE_SIZE); + assert!(segment.reader().is_err()); + assert!(segment.writer().is_ok()); + segment + }) + .collect(); + assert_eq!(pool.num_pages(), 100); + drop(segments1); + assert_eq!(pool.num_pages(), 50); + } + + #[ktest] + fn alloc_small_size_segment() { + const SEGMENT_SIZE: usize = PAGE_SIZE / 4; + let pool: Arc = + DmaPool::new(SEGMENT_SIZE, 0, 10, DmaDirection::Bidirectional, false); + let segments1: Vec<_> = (0..100) + .map(|_| { + let segment = pool.alloc_segment().unwrap(); + assert_eq!(segment.size(), PAGE_SIZE / 4); + assert!(segment.reader().is_ok()); + assert!(segment.writer().is_ok()); + segment + }) + .collect(); + + assert_eq!(pool.num_pages(), 100 / 4); + drop(segments1); + assert_eq!(pool.num_pages(), 10); + } + + #[ktest] + fn read_dma_segments() { + const SEGMENT_SIZE: usize = PAGE_SIZE / 4; + let pool: Arc = + DmaPool::new(SEGMENT_SIZE, 1, 2, DmaDirection::Bidirectional, false); + let segment = pool.alloc_segment().unwrap(); + assert_eq!(pool.num_pages(), 1); + let mut writer = segment.writer().unwrap(); + let data = &[0u8, 1, 2, 3, 4] as &[u8]; + let size = writer.write(&mut VmReader::from(data)); + assert_eq!(size, data.len()); + + let mut read_buf = [0u8; 5]; + let mut reader = segment.reader().unwrap(); + reader.read(&mut VmWriter::from(&mut read_buf as &mut [u8])); + assert_eq!(&read_buf, data); + } +} diff --git a/kernel/comps/network/src/driver.rs b/kernel/comps/network/src/driver.rs index 36ddcb67a..734263af3 100644 --- a/kernel/comps/network/src/driver.rs +++ b/kernel/comps/network/src/driver.rs @@ -2,12 +2,10 @@ use alloc::vec; +use aster_frame::vm::VmWriter; use smoltcp::{phy, time::Instant}; -use crate::{ - buffer::{RxBuffer, TxBuffer}, - AnyNetworkDevice, -}; +use crate::{buffer::RxBuffer, AnyNetworkDevice}; impl phy::Device for dyn AnyNetworkDevice { type RxToken<'a> = RxToken; @@ -37,12 +35,14 @@ impl phy::Device for dyn AnyNetworkDevice { pub struct RxToken(RxBuffer); impl phy::RxToken for RxToken { - fn consume(mut self, f: F) -> R + fn consume(self, f: F) -> R where F: FnOnce(&mut [u8]) -> R, { - let packet_but = self.0.packet_mut(); - f(packet_but) + let mut packet = self.0.packet(); + let mut buffer = vec![0u8; packet.remain()]; + packet.read(&mut VmWriter::from(&mut buffer as &mut [u8])); + f(&mut buffer) } } @@ -55,8 +55,7 @@ impl<'a> phy::TxToken for TxToken<'a> { { let mut buffer = vec![0u8; len]; let res = f(&mut buffer); - let tx_buffer = TxBuffer::new(&buffer); - self.0.send(tx_buffer).expect("Send packet failed"); + self.0.send(&buffer).expect("Send packet failed"); res } } diff --git a/kernel/comps/network/src/lib.rs b/kernel/comps/network/src/lib.rs index adb546780..f65ae99b1 100644 --- a/kernel/comps/network/src/lib.rs +++ b/kernel/comps/network/src/lib.rs @@ -4,9 +4,11 @@ #![forbid(unsafe_code)] #![feature(trait_alias)] #![feature(fn_traits)] +#![feature(linked_list_cursors)] -pub mod buffer; -pub mod driver; +mod buffer; +mod dma_pool; +mod driver; extern crate alloc; @@ -15,8 +17,9 @@ use core::{any::Any, fmt::Debug}; use aster_frame::sync::SpinLock; use aster_util::safe_ptr::Pod; -use buffer::{RxBuffer, TxBuffer}; +pub use buffer::{RxBuffer, TxBuffer}; use component::{init_component, ComponentInitError}; +pub use dma_pool::DmaSegment; use smoltcp::phy; use spin::Once; @@ -45,7 +48,7 @@ pub trait AnyNetworkDevice: Send + Sync + Any + Debug { /// Otherwise, return NotReady error. fn receive(&mut self) -> Result; /// Send a packet to network. Return until the request completes. - fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError>; + fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError>; } pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static; @@ -55,38 +58,57 @@ pub fn register_device(name: String, device: Arc> .get() .unwrap() .network_device_table - .lock() + .lock_irq_disabled() .insert(name, (Arc::new(SpinLock::new(Vec::new())), device)); } pub fn get_device(str: &str) -> Option>> { - let lock = COMPONENT.get().unwrap().network_device_table.lock(); - let (_, device) = lock.get(str)?; + let table = COMPONENT + .get() + .unwrap() + .network_device_table + .lock_irq_disabled(); + let (_, device) = table.get(str)?; Some(device.clone()) } +/// Registers callback which will be called when receiving message. +/// +/// Since the callback will be called in interrupt context, +/// the callback function should NOT sleep. pub fn register_recv_callback(name: &str, callback: impl NetDeviceIrqHandler) { - let lock = COMPONENT.get().unwrap().network_device_table.lock(); - let Some((callbacks, _)) = lock.get(name) else { + let device_table = COMPONENT + .get() + .unwrap() + .network_device_table + .lock_irq_disabled(); + let Some((callbacks, _)) = device_table.get(name) else { return; }; - callbacks.lock().push(Arc::new(callback)); + callbacks.lock_irq_disabled().push(Arc::new(callback)); } pub fn handle_recv_irq(name: &str) { - let lock = COMPONENT.get().unwrap().network_device_table.lock(); - let Some((callbacks, _)) = lock.get(name) else { + let device_table = COMPONENT + .get() + .unwrap() + .network_device_table + .lock_irq_disabled(); + let Some((callbacks, _)) = device_table.get(name) else { return; }; - let callbacks = callbacks.clone(); - let lock = callbacks.lock(); - for callback in lock.iter() { - callback.call(()) + let callbacks = callbacks.lock_irq_disabled(); + for callback in callbacks.iter() { + callback(); } } pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> { - let network_devs = COMPONENT.get().unwrap().network_device_table.lock(); + let network_devs = COMPONENT + .get() + .unwrap() + .network_device_table + .lock_irq_disabled(); network_devs .iter() .map(|(name, (_, device))| (name.clone(), device.clone())) @@ -102,6 +124,7 @@ fn init() -> Result<(), ComponentInitError> { let a = Component::init()?; COMPONENT.call_once(|| a); NETWORK_IRQ_HANDLERS.call_once(|| SpinLock::new(Vec::new())); + buffer::init(); Ok(()) } diff --git a/kernel/comps/virtio/Cargo.toml b/kernel/comps/virtio/Cargo.toml index cece300d5..bbdfeae16 100644 --- a/kernel/comps/virtio/Cargo.toml +++ b/kernel/comps/virtio/Cargo.toml @@ -37,4 +37,3 @@ smoltcp = { version = "0.9.1", default-features = false, features = [ "socket-raw", "socket-dhcpv4", ] } -[features] diff --git a/kernel/comps/virtio/src/device/block/device.rs b/kernel/comps/virtio/src/device/block/device.rs index 892ea1afb..ff311d162 100644 --- a/kernel/comps/virtio/src/device/block/device.rs +++ b/kernel/comps/virtio/src/device/block/device.rs @@ -99,18 +99,12 @@ impl DeviceInner { let queue = VirtQueue::new(0, Self::QUEUE_SIZE, transport.as_mut()) .expect("create virtqueue failed"); let block_requests = { - let vm_segment = VmAllocOptions::new(1) - .is_contiguous(true) - .alloc_contiguous() - .unwrap(); + let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap(); DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap() }; assert!(Self::QUEUE_SIZE as usize * REQ_SIZE <= block_requests.nbytes()); let block_responses = { - let vm_segment = VmAllocOptions::new(1) - .is_contiguous(true) - .alloc_contiguous() - .unwrap(); + let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap(); DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap() }; assert!(Self::QUEUE_SIZE as usize * RESP_SIZE <= block_responses.nbytes()); @@ -222,7 +216,6 @@ impl DeviceInner { const MAX_ID_LENGTH: usize = 20; let device_id_stream = { let segment = VmAllocOptions::new(1) - .is_contiguous(true) .uninit(true) .alloc_contiguous() .unwrap(); diff --git a/kernel/comps/virtio/src/device/console/device.rs b/kernel/comps/virtio/src/device/console/device.rs index b76232b12..af1ea24f4 100644 --- a/kernel/comps/virtio/src/device/console/device.rs +++ b/kernel/comps/virtio/src/device/console/device.rs @@ -4,7 +4,12 @@ use alloc::{boxed::Box, fmt::Debug, string::ToString, sync::Arc, vec::Vec}; use core::hint::spin_loop; use aster_console::{AnyConsoleDevice, ConsoleCallback}; -use aster_frame::{io_mem::IoMem, sync::SpinLock, trap::TrapFrame, vm::PAGE_SIZE}; +use aster_frame::{ + io_mem::IoMem, + sync::SpinLock, + trap::TrapFrame, + vm::{DmaDirection, DmaStream, DmaStreamSlice, VmAllocOptions, VmReader}, +}; use aster_util::safe_ptr::SafePtr; use log::debug; @@ -17,62 +22,39 @@ use crate::{ pub struct ConsoleDevice { config: SafePtr, - transport: Box, + transport: SpinLock>, receive_queue: SpinLock, transmit_queue: SpinLock, - buffer: SpinLock>, + send_buffer: DmaStream, + receive_buffer: DmaStream, callbacks: SpinLock>, } impl AnyConsoleDevice for ConsoleDevice { fn send(&self, value: &[u8]) { let mut transmit_queue = self.transmit_queue.lock_irq_disabled(); - transmit_queue.add_buf(&[value], &[]).unwrap(); - if transmit_queue.should_notify() { - transmit_queue.notify(); + let mut reader = VmReader::from(value); + + while reader.remain() > 0 { + let mut writer = self.send_buffer.writer().unwrap(); + let len = writer.write(&mut reader); + self.send_buffer.sync(0..len).unwrap(); + + let slice = DmaStreamSlice::new(&self.send_buffer, 0, len); + transmit_queue.add_dma_buf(&[&slice], &[]).unwrap(); + + if transmit_queue.should_notify() { + transmit_queue.notify(); + } + while !transmit_queue.can_pop() { + spin_loop(); + } + transmit_queue.pop_used().unwrap(); } - while !transmit_queue.can_pop() { - spin_loop(); - } - transmit_queue.pop_used().unwrap(); } - fn recv(&self, buf: &mut [u8]) -> Option { - let mut receive_queue = self.receive_queue.lock_irq_disabled(); - if !receive_queue.can_pop() { - return None; - } - let (_, len) = receive_queue.pop_used().unwrap(); - - let mut recv_buffer = self.buffer.lock(); - buf.copy_from_slice(&recv_buffer.as_ref()[..len as usize]); - receive_queue.add_buf(&[], &[recv_buffer.as_mut()]).unwrap(); - if receive_queue.should_notify() { - receive_queue.notify(); - } - Some(len as usize) - } - - fn register_callback(&self, callback: &'static (dyn Fn(&[u8]) + Send + Sync)) { - self.callbacks.lock().push(callback); - } - - fn handle_irq(&self) { - let mut receive_queue = self.receive_queue.lock_irq_disabled(); - if !receive_queue.can_pop() { - return; - } - let (_, len) = receive_queue.pop_used().unwrap(); - let mut recv_buffer = self.buffer.lock(); - let buffer = &recv_buffer.as_ref()[..len as usize]; - let lock = self.callbacks.lock(); - for callback in lock.iter() { - callback.call((buffer,)); - } - receive_queue.add_buf(&[], &[recv_buffer.as_mut()]).unwrap(); - if receive_queue.should_notify() { - receive_queue.notify(); - } + fn register_callback(&self, callback: &'static ConsoleCallback) { + self.callbacks.lock_irq_disabled().push(callback); } } @@ -104,41 +86,76 @@ impl ConsoleDevice { let transmit_queue = SpinLock::new(VirtQueue::new(TRANSMIT0_QUEUE_INDEX, 2, transport.as_mut()).unwrap()); - let mut device = Self { - config, - transport, - receive_queue, - transmit_queue, - buffer: SpinLock::new(Box::new([0; PAGE_SIZE])), - callbacks: SpinLock::new(Vec::new()), + let send_buffer = { + let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap(); + DmaStream::map(vm_segment, DmaDirection::ToDevice, false).unwrap() }; - let mut receive_queue = device.receive_queue.lock(); + let receive_buffer = { + let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap(); + DmaStream::map(vm_segment, DmaDirection::FromDevice, false).unwrap() + }; + + let device = Arc::new(Self { + config, + transport: SpinLock::new(transport), + receive_queue, + transmit_queue, + send_buffer, + receive_buffer, + callbacks: SpinLock::new(Vec::new()), + }); + + let mut receive_queue = device.receive_queue.lock_irq_disabled(); receive_queue - .add_buf(&[], &[device.buffer.lock().as_mut()]) + .add_dma_buf(&[], &[&device.receive_buffer]) .unwrap(); if receive_queue.should_notify() { receive_queue.notify(); } drop(receive_queue); - device - .transport + + // Register irq callbacks + let mut transport = device.transport.lock_irq_disabled(); + let handle_console_input = { + let device = device.clone(); + move |_: &TrapFrame| device.handle_recv_irq() + }; + transport .register_queue_callback(RECV0_QUEUE_INDEX, Box::new(handle_console_input), false) .unwrap(); - device - .transport + transport .register_cfg_callback(Box::new(config_space_change)) .unwrap(); - device.transport.finish_init(); + transport.finish_init(); + drop(transport); - aster_console::register_device(DEVICE_NAME.to_string(), Arc::new(device)); + aster_console::register_device(DEVICE_NAME.to_string(), device); Ok(()) } -} -fn handle_console_input(_: &TrapFrame) { - aster_console::get_device(DEVICE_NAME).unwrap().handle_irq(); + fn handle_recv_irq(&self) { + let mut receive_queue = self.receive_queue.lock_irq_disabled(); + if !receive_queue.can_pop() { + return; + } + let (_, len) = receive_queue.pop_used().unwrap(); + self.receive_buffer.sync(0..len as usize).unwrap(); + + let callbacks = self.callbacks.lock_irq_disabled(); + + for callback in callbacks.iter() { + let reader = self.receive_buffer.reader().unwrap().limit(len as usize); + callback(reader); + } + receive_queue + .add_dma_buf(&[], &[&self.receive_buffer]) + .unwrap(); + if receive_queue.should_notify() { + receive_queue.notify(); + } + } } fn config_space_change(_: &TrapFrame) { diff --git a/kernel/comps/virtio/src/device/input/device.rs b/kernel/comps/virtio/src/device/input/device.rs index 7ea7b3f50..d285809b8 100644 --- a/kernel/comps/virtio/src/device/input/device.rs +++ b/kernel/comps/virtio/src/device/input/device.rs @@ -6,9 +6,15 @@ use alloc::{ sync::Arc, vec::Vec, }; -use core::fmt::Debug; +use core::{fmt::Debug, mem}; -use aster_frame::{io_mem::IoMem, offset_of, sync::SpinLock, trap::TrapFrame}; +use aster_frame::{ + io_mem::IoMem, + offset_of, + sync::SpinLock, + trap::TrapFrame, + vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmIo, VmReader, PAGE_SIZE}, +}; use aster_input::{ key::{Key, KeyStatus}, InputEvent, @@ -16,10 +22,11 @@ use aster_input::{ use aster_util::{field_ptr, safe_ptr::SafePtr}; use bitflags::bitflags; use log::{debug, info}; -use pod::Pod; use super::{InputConfigSelect, VirtioInputConfig, VirtioInputEvent, QUEUE_EVENT, QUEUE_STATUS}; -use crate::{device::VirtioDeviceError, queue::VirtQueue, transport::VirtioTransport}; +use crate::{ + device::VirtioDeviceError, dma_buf::DmaBuf, queue::VirtQueue, transport::VirtioTransport, +}; bitflags! { /// The properties of input device. @@ -67,25 +74,25 @@ pub struct InputDevice { config: SafePtr, event_queue: SpinLock, status_queue: VirtQueue, - event_buf: SpinLock>, + event_table: EventTable, #[allow(clippy::type_complexity)] callbacks: SpinLock>>, - transport: Box, + transport: SpinLock>, } impl InputDevice { /// Create a new VirtIO-Input driver. /// msix_vector_left should at least have one element or n elements where n is the virtqueue amount pub fn init(mut transport: Box) -> Result<(), VirtioDeviceError> { - let mut event_buf = Box::new([VirtioInputEvent::default(); QUEUE_SIZE as usize]); let mut event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut()) .expect("create event virtqueue failed"); let status_queue = VirtQueue::new(QUEUE_STATUS, QUEUE_SIZE, transport.as_mut()) .expect("create status virtqueue failed"); - for (i, event) in event_buf.as_mut().iter_mut().enumerate() { - // FIEME: replace slice with a more secure data structure to use dma mapping. - let token = event_queue.add_buf(&[], &[event.as_bytes_mut()]); + let event_table = EventTable::new(QUEUE_SIZE as usize); + for i in 0..event_table.num_events() { + let event_buf = event_table.get(i); + let token = event_queue.add_dma_buf(&[], &[&event_buf]); match token { Ok(value) => { assert_eq!(value, i as u16); @@ -96,14 +103,14 @@ impl InputDevice { } } - let mut device = Self { + let device = Arc::new(Self { config: VirtioInputConfig::new(transport.as_mut()), event_queue: SpinLock::new(event_queue), status_queue, - event_buf: SpinLock::new(event_buf), - transport, + event_table, + transport: SpinLock::new(transport), callbacks: SpinLock::new(Vec::new()), - }; + }); let mut raw_name: [u8; 128] = [0; 128]; device.query_config_select(InputConfigSelect::IdName, 0, &mut raw_name); @@ -115,51 +122,49 @@ impl InputDevice { let input_prop = InputProp::from_bits(prop[0]).unwrap(); debug!("input device prop:{:?}", input_prop); - fn handle_input(_: &TrapFrame) { - debug!("Handle Virtio input interrupt"); - let device = aster_input::get_device(super::DEVICE_NAME).unwrap(); - device.handle_irq().unwrap(); - } - + let mut transport = device.transport.lock_irq_disabled(); fn config_space_change(_: &TrapFrame) { debug!("input device config space change"); } - - device - .transport + transport .register_cfg_callback(Box::new(config_space_change)) .unwrap(); - device - .transport + + let handle_input = { + let device = device.clone(); + move |_: &TrapFrame| device.handle_irq() + }; + transport .register_queue_callback(QUEUE_EVENT, Box::new(handle_input), false) .unwrap(); - device.transport.finish_init(); + transport.finish_init(); + drop(transport); - aster_input::register_device(super::DEVICE_NAME.to_string(), Arc::new(device)); + aster_input::register_device(super::DEVICE_NAME.to_string(), device); Ok(()) } /// Pop the pending event. - pub fn pop_pending_event(&self) -> Option { - let mut lock = self.event_queue.lock(); - if let Ok((token, _)) = lock.pop_used() { - if token >= QUEUE_SIZE { - return None; - } - let event = &mut self.event_buf.lock()[token as usize]; - // requeue - // FIEME: replace slice with a more secure data structure to use dma mapping. - if let Ok(new_token) = lock.add_buf(&[], &[event.as_bytes_mut()]) { - // This only works because nothing happen between `pop_used` and `add` that affects - // the list of free descriptors in the queue, so `add` reuses the descriptor which - // was just freed by `pop_used`. - assert_eq!(new_token, token); - return Some(*event); + fn pop_pending_events(&self, handle_event: &impl Fn(&EventBuf) -> bool) { + let mut event_queue = self.event_queue.lock_irq_disabled(); + + // one interrupt may contain several input events, so it should loop + while let Ok((token, _)) = event_queue.pop_used() { + debug_assert!(token < QUEUE_SIZE); + let event_buf = self.event_table.get(token as usize); + let res = handle_event(&event_buf); + let new_token = event_queue.add_dma_buf(&[], &[&event_buf]).unwrap(); + // This only works because nothing happen between `pop_used` and `add` that affects + // the list of free descriptors in the queue, so `add` reuses the descriptor which + // was just freed by `pop_used`. + assert_eq!(new_token, token); + + if !res { + break; } } - None } /// Query a specific piece of information by `select` and `subsel`, and write @@ -181,6 +186,42 @@ impl InputDevice { size } + fn handle_irq(&self) { + // Returns ture if there may be more events to handle + let handle_event = |event: &EventBuf| -> bool { + let event: VirtioInputEvent = { + let mut reader = event.reader(); + reader.read_val() + }; + + match event.event_type { + 0 => return false, + // Keyboard + 1 => {} + // TODO: Support mouse device. + _ => return true, + } + + let status = match event.value { + 1 => KeyStatus::Pressed, + 0 => KeyStatus::Released, + _ => return false, + }; + + let event = InputEvent::KeyBoard(Key::try_from(event.code).unwrap(), status); + info!("Input Event:{:?}", event); + + let callbacks = self.callbacks.lock(); + for callback in callbacks.iter() { + callback(event); + } + + true + }; + + self.pop_pending_events(&handle_event); + } + /// Negotiate features for the device specified bits 0~23 pub(crate) fn negotiate_features(features: u64) -> u64 { assert_eq!(features, 0); @@ -188,37 +229,85 @@ impl InputDevice { } } -impl aster_input::InputDevice for InputDevice { - fn handle_irq(&self) -> Option<()> { - // one interrupt may contains serval input, so it should loop - loop { - let Some(event) = self.pop_pending_event() else { - return Some(()); - }; - match event.event_type { - 0 => return Some(()), - // Keyboard - 1 => {} - // TODO: Support mouse device. - _ => continue, - } - let status = match event.value { - 1 => KeyStatus::Pressed, - 0 => KeyStatus::Released, - _ => return Some(()), - }; - let event = InputEvent::KeyBoard(Key::try_from(event.code).unwrap(), status); - info!("Input Event:{:?}", event); +/// A event table consists of many event buffers, +/// each of which is large enough to contain a `VirtioInputEvent`. +#[derive(Debug)] +struct EventTable { + stream: DmaStream, + num_events: usize, +} - let callbacks = self.callbacks.lock(); - for callback in callbacks.iter() { - callback.call((event,)); - } +impl EventTable { + fn new(num_events: usize) -> Self { + assert!(num_events * mem::size_of::() <= PAGE_SIZE); + + let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap(); + + let default_event = VirtioInputEvent::default(); + for idx in 0..num_events { + let offset = idx * EVENT_SIZE; + vm_segment.write_val(offset, &default_event).unwrap(); + } + + let stream = DmaStream::map(vm_segment, DmaDirection::FromDevice, false).unwrap(); + Self { stream, num_events } + } + + fn get(&self, idx: usize) -> EventBuf<'_> { + assert!(idx < self.num_events); + + let offset = idx * EVENT_SIZE; + EventBuf { + event_table: self, + offset, + size: EVENT_SIZE, } } + const fn num_events(&self) -> usize { + self.num_events + } +} + +const EVENT_SIZE: usize = core::mem::size_of::(); + +/// A buffer stores exact one `VirtioInputEvent` +struct EventBuf<'a> { + event_table: &'a EventTable, + offset: usize, + size: usize, +} + +impl<'a> HasDaddr for EventBuf<'a> { + fn daddr(&self) -> Daddr { + self.event_table.stream.daddr() + self.offset + } +} + +impl<'a> DmaBuf for EventBuf<'a> { + fn len(&self) -> usize { + self.size + } +} + +impl<'a> EventBuf<'a> { + fn reader(&self) -> VmReader<'a> { + self.event_table + .stream + .sync(self.offset..self.offset + self.size) + .unwrap(); + self.event_table + .stream + .reader() + .unwrap() + .skip(self.offset) + .limit(self.size) + } +} + +impl aster_input::InputDevice for InputDevice { fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) { - self.callbacks.lock().push(Arc::new(function)) + self.callbacks.lock_irq_disabled().push(Arc::new(function)) } } @@ -228,7 +317,7 @@ impl Debug for InputDevice { .field("config", &self.config) .field("event_queue", &self.event_queue) .field("status_queue", &self.status_queue) - .field("event_buf", &self.event_buf) + .field("event_buf", &self.event_table) .field("transport", &self.transport) .finish() } diff --git a/kernel/comps/virtio/src/device/network/device.rs b/kernel/comps/virtio/src/device/network/device.rs index 91d53e9a3..edda9a4a4 100644 --- a/kernel/comps/virtio/src/device/network/device.rs +++ b/kernel/comps/virtio/src/device/network/device.rs @@ -5,12 +5,10 @@ use core::{fmt::Debug, hint::spin_loop, mem::size_of}; use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame}; use aster_network::{ - buffer::{RxBuffer, TxBuffer}, - AnyNetworkDevice, EthernetAddr, NetDeviceIrqHandler, VirtioNetError, + AnyNetworkDevice, EthernetAddr, NetDeviceIrqHandler, RxBuffer, TxBuffer, VirtioNetError, }; use aster_util::{field_ptr, slot_vec::SlotVec}; use log::debug; -use pod::Pod; use smoltcp::phy::{DeviceCapabilities, Medium}; use super::{config::VirtioNetConfig, header::VirtioNetHdr}; @@ -60,9 +58,9 @@ impl NetworkDevice { let mut rx_buffers = SlotVec::new(); for i in 0..QUEUE_SIZE { - let mut rx_buffer = RxBuffer::new(RX_BUFFER_LEN, size_of::()); + let rx_buffer = RxBuffer::new(size_of::()); // FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping. - let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?; + let token = recv_queue.add_dma_buf(&[], &[&rx_buffer])?; assert_eq!(i, token); assert_eq!(rx_buffers.put(rx_buffer) as u16, i); } @@ -80,7 +78,7 @@ impl NetworkDevice { transport, callbacks: Vec::new(), }; - device.transport.finish_init(); + /// Interrupt handler if network device config space changes fn config_space_change(_: &TrapFrame) { debug!("network device config space change"); @@ -99,6 +97,7 @@ impl NetworkDevice { .transport .register_queue_callback(QUEUE_RECV, Box::new(handle_network_event), false) .unwrap(); + device.transport.finish_init(); aster_network::register_device( super::DEVICE_NAME.to_string(), @@ -109,10 +108,10 @@ impl NetworkDevice { /// Add a rx buffer to recv queue /// FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping. - fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer) -> Result<(), VirtioNetError> { + fn add_rx_buffer(&mut self, rx_buffer: RxBuffer) -> Result<(), VirtioNetError> { let token = self .recv_queue - .add_buf(&[], &[rx_buffer.buf_mut()]) + .add_dma_buf(&[], &[&rx_buffer]) .map_err(queue_to_network_error)?; assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none()); if self.recv_queue.should_notify() { @@ -133,18 +132,20 @@ impl NetworkDevice { rx_buffer.set_packet_len(len as usize); // FIXME: Ideally, we can reuse the returned buffer without creating new buffer. // But this requires locking device to be compatible with smoltcp interface. - let new_rx_buffer = RxBuffer::new(RX_BUFFER_LEN, size_of::()); + let new_rx_buffer = RxBuffer::new(size_of::()); self.add_rx_buffer(new_rx_buffer)?; Ok(rx_buffer) } /// Send a packet to network. Return until the request completes. /// FIEME: Replace tx_buffer with VM segment-based data structure to use dma mapping. - fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError> { + fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError> { let header = VirtioNetHdr::default(); + let tx_buffer = TxBuffer::new(&header, packet); + let token = self .send_queue - .add_buf(&[header.as_bytes(), tx_buffer.buf()], &[]) + .add_dma_buf(&[&tx_buffer], &[]) .map_err(queue_to_network_error)?; if self.send_queue.should_notify() { @@ -198,8 +199,8 @@ impl AnyNetworkDevice for NetworkDevice { self.receive() } - fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError> { - self.send(tx_buffer) + fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError> { + self.send(packet) } } diff --git a/kernel/comps/virtio/src/dma_buf.rs b/kernel/comps/virtio/src/dma_buf.rs index ad517e443..777527491 100644 --- a/kernel/comps/virtio/src/dma_buf.rs +++ b/kernel/comps/virtio/src/dma_buf.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use aster_frame::vm::{DmaCoherent, DmaStream, DmaStreamSlice, HasDaddr}; +use aster_network::{DmaSegment, RxBuffer, TxBuffer}; /// A DMA-capable buffer. /// @@ -29,3 +30,21 @@ impl DmaBuf for DmaCoherent { self.nbytes() } } + +impl DmaBuf for DmaSegment { + fn len(&self) -> usize { + self.size() + } +} + +impl DmaBuf for TxBuffer { + fn len(&self) -> usize { + self.nbytes() + } +} + +impl DmaBuf for RxBuffer { + fn len(&self) -> usize { + self.buf_len() + } +} diff --git a/kernel/comps/virtio/src/queue.rs b/kernel/comps/virtio/src/queue.rs index d8ba64980..23695c725 100644 --- a/kernel/comps/virtio/src/queue.rs +++ b/kernel/comps/virtio/src/queue.rs @@ -82,10 +82,7 @@ impl VirtQueue { let desc_size = size_of::() * size as usize; let (seg1, seg2) = { - let continue_segment = VmAllocOptions::new(2) - .is_contiguous(true) - .alloc_contiguous() - .unwrap(); + let continue_segment = VmAllocOptions::new(2).alloc_contiguous().unwrap(); let seg1 = continue_segment.range(0..1); let seg2 = continue_segment.range(1..2); (seg1, seg2) @@ -104,36 +101,18 @@ impl VirtQueue { } ( SafePtr::new( - DmaCoherent::map( - VmAllocOptions::new(1) - .is_contiguous(true) - .alloc_contiguous() - .unwrap(), - true, - ) - .unwrap(), + DmaCoherent::map(VmAllocOptions::new(1).alloc_contiguous().unwrap(), true) + .unwrap(), 0, ), SafePtr::new( - DmaCoherent::map( - VmAllocOptions::new(1) - .is_contiguous(true) - .alloc_contiguous() - .unwrap(), - true, - ) - .unwrap(), + DmaCoherent::map(VmAllocOptions::new(1).alloc_contiguous().unwrap(), true) + .unwrap(), 0, ), SafePtr::new( - DmaCoherent::map( - VmAllocOptions::new(1) - .is_contiguous(true) - .alloc_contiguous() - .unwrap(), - true, - ) - .unwrap(), + DmaCoherent::map(VmAllocOptions::new(1).alloc_contiguous().unwrap(), true) + .unwrap(), 0, ), )