Refactor virtio drivers with DMA APIs

This commit is contained in:
Jianfeng Jiang
2024-03-12 11:01:50 +00:00
committed by Tate, Hongliang Tian
parent 5e127b2da0
commit cd1575bc6d
22 changed files with 853 additions and 311 deletions

3
Cargo.lock generated
View File

@ -176,9 +176,10 @@ dependencies = [
"aster-rights", "aster-rights",
"aster-util", "aster-util",
"bitflags 1.3.2", "bitflags 1.3.2",
"bytes", "bitvec",
"component", "component",
"int-to-c-enum", "int-to-c-enum",
"ktest",
"log", "log",
"pod", "pod",
"ringbuf", "ringbuf",

View File

@ -51,9 +51,7 @@ pub struct KernelStack {
impl KernelStack { impl KernelStack {
pub fn new() -> Result<Self> { pub fn new() -> Result<Self> {
Ok(Self { Ok(Self {
segment: VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE) segment: VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE).alloc_contiguous()?,
.is_contiguous(true)
.alloc_contiguous()?,
old_guard_page_flag: None, old_guard_page_flag: None,
}) })
} }
@ -61,9 +59,8 @@ impl KernelStack {
/// Generate a kernel stack with a guard page. /// Generate a kernel stack with a guard page.
/// An additional page is allocated and be regarded as a guard page, which should not be accessed. /// An additional page is allocated and be regarded as a guard page, which should not be accessed.
pub fn new_with_guard_page() -> Result<Self> { pub fn new_with_guard_page() -> Result<Self> {
let stack_segment = VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE + 1) let stack_segment =
.is_contiguous(true) VmAllocOptions::new(KERNEL_STACK_SIZE / PAGE_SIZE + 1).alloc_contiguous()?;
.alloc_contiguous()?;
let unpresent_flag = PageTableFlags::empty(); let unpresent_flag = PageTableFlags::empty();
let old_guard_page_flag = Self::protect_guard_page(&stack_segment, unpresent_flag); let old_guard_page_flag = Self::protect_guard_page(&stack_segment, unpresent_flag);
Ok(Self { Ok(Self {

View File

@ -84,7 +84,8 @@ impl VmAllocOptions {
/// ///
/// The returned `VmSegment` contains at least one page frame. /// The returned `VmSegment` contains at least one page frame.
pub fn alloc_contiguous(&self) -> Result<VmSegment> { pub fn alloc_contiguous(&self) -> Result<VmSegment> {
if !self.is_contiguous || self.nframes == 0 { // It's no use to checking `self.is_contiguous` here.
if self.nframes == 0 {
return Err(Error::InvalidArgs); return Err(Error::InvalidArgs);
} }

View File

@ -1,6 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
pub use aster_frame::arch::console::register_console_input_callback; pub use aster_frame::arch::console;
use aster_frame::vm::VmReader;
use spin::Once; use spin::Once;
use crate::{ use crate::{
@ -62,24 +63,25 @@ impl TtyDriver {
Ok(()) Ok(())
} }
pub fn receive_char(&self, item: u8) { pub fn push_char(&self, ch: u8) {
// FIXME: should the char send to all ttys? // FIXME: should the char send to all ttys?
for tty in &*self.ttys.lock_irq_disabled() { for tty in &*self.ttys.lock_irq_disabled() {
tty.receive_char(item); tty.push_char(ch);
} }
} }
} }
fn console_input_callback(items: &[u8]) { fn console_input_callback(mut reader: VmReader) {
let tty_driver = get_tty_driver(); let tty_driver = get_tty_driver();
for item in items { while reader.remain() > 0 {
tty_driver.receive_char(*item); let ch = reader.read_val();
tty_driver.push_char(ch);
} }
} }
fn serial_input_callback(item: u8) { fn serial_input_callback(item: u8) {
let tty_driver = get_tty_driver(); let tty_driver = get_tty_driver();
tty_driver.receive_char(item); tty_driver.push_char(item);
} }
fn get_tty_driver() -> &'static TtyDriver { fn get_tty_driver() -> &'static TtyDriver {

View File

@ -1,5 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_frame::early_print;
use spin::Once; use spin::Once;
use self::{driver::TtyDriver, line_discipline::LineDiscipline}; use self::{driver::TtyDriver, line_discipline::LineDiscipline};
@ -61,8 +62,11 @@ impl Tty {
*self.driver.lock_irq_disabled() = driver; *self.driver.lock_irq_disabled() = driver;
} }
pub fn receive_char(&self, ch: u8) { pub fn push_char(&self, ch: u8) {
self.ldisc.push_char(ch, |content| print!("{}", content)); // FIXME: Use `early_print` to avoid calling virtio-console.
// This is only a workaround
self.ldisc
.push_char(ch, |content| early_print!("{}", content))
} }
} }

View File

@ -42,7 +42,6 @@ impl Ext2 {
.div_ceil(BLOCK_SIZE); .div_ceil(BLOCK_SIZE);
let segment = VmAllocOptions::new(npages) let segment = VmAllocOptions::new(npages)
.uninit(true) .uninit(true)
.is_contiguous(true)
.alloc_contiguous()?; .alloc_contiguous()?;
match block_device.read_blocks_sync(super_block.group_descriptors_bid(0), &segment)? { match block_device.read_blocks_sync(super_block.group_descriptors_bid(0), &segment)? {
BioStatus::Complete => (), BioStatus::Complete => (),

View File

@ -80,6 +80,9 @@ fn init_thread() {
"[kernel] Spawn init thread, tid = {}", "[kernel] Spawn init thread, tid = {}",
current_thread!().tid() current_thread!().tid()
); );
// Work queue should be initialized before interrupt is enabled,
// in case any irq handler uses work queue as bottom half
thread::work_queue::init();
// FIXME: Remove this if we move the step of mounting // FIXME: Remove this if we move the step of mounting
// the filesystems to be done within the init process. // the filesystems to be done within the init process.
aster_frame::trap::enable_local(); aster_frame::trap::enable_local();
@ -97,7 +100,6 @@ fn init_thread() {
"[aster-nix/lib.rs] spawn kernel thread, tid = {}", "[aster-nix/lib.rs] spawn kernel thread, tid = {}",
thread.tid() thread.tid()
); );
thread::work_queue::init();
print_banner(); print_banner();

View File

@ -99,7 +99,6 @@ impl VmIo for dyn BlockDevice {
}; };
let segment = VmAllocOptions::new(num_blocks as usize) let segment = VmAllocOptions::new(num_blocks as usize)
.uninit(true) .uninit(true)
.is_contiguous(true)
.alloc_contiguous()?; .alloc_contiguous()?;
let bio_segment = BioSegment::from_segment(segment, offset % BLOCK_SIZE, buf.len()); let bio_segment = BioSegment::from_segment(segment, offset % BLOCK_SIZE, buf.len());
@ -141,7 +140,6 @@ impl VmIo for dyn BlockDevice {
}; };
let segment = VmAllocOptions::new(num_blocks as usize) let segment = VmAllocOptions::new(num_blocks as usize)
.uninit(true) .uninit(true)
.is_contiguous(true)
.alloc_contiguous()?; .alloc_contiguous()?;
segment.write_bytes(offset % BLOCK_SIZE, buf)?; segment.write_bytes(offset % BLOCK_SIZE, buf)?;
let len = segment let len = segment
@ -183,7 +181,6 @@ impl dyn BlockDevice {
}; };
let segment = VmAllocOptions::new(num_blocks as usize) let segment = VmAllocOptions::new(num_blocks as usize)
.uninit(true) .uninit(true)
.is_contiguous(true)
.alloc_contiguous()?; .alloc_contiguous()?;
segment.write_bytes(offset % BLOCK_SIZE, buf)?; segment.write_bytes(offset % BLOCK_SIZE, buf)?;
let len = segment let len = segment

View File

@ -10,17 +10,20 @@ extern crate alloc;
use alloc::{collections::BTreeMap, fmt::Debug, string::String, sync::Arc, vec::Vec}; use alloc::{collections::BTreeMap, fmt::Debug, string::String, sync::Arc, vec::Vec};
use core::any::Any; use core::any::Any;
use aster_frame::sync::SpinLock; use aster_frame::{sync::SpinLock, vm::VmReader};
use component::{init_component, ComponentInitError}; use component::{init_component, ComponentInitError};
use spin::Once; use spin::Once;
pub type ConsoleCallback = dyn Fn(&[u8]) + Send + Sync; pub type ConsoleCallback = dyn Fn(VmReader) + Send + Sync;
pub trait AnyConsoleDevice: Send + Sync + Any + Debug { pub trait AnyConsoleDevice: Send + Sync + Any + Debug {
fn send(&self, buf: &[u8]); fn send(&self, buf: &[u8]);
fn recv(&self, buf: &mut [u8]) -> Option<usize>; /// Registers callback to the console device.
/// The callback will be called once the console device receive data.
///
/// Since the callback will be called in interrupt context,
/// the callback should NEVER sleep.
fn register_callback(&self, callback: &'static ConsoleCallback); fn register_callback(&self, callback: &'static ConsoleCallback);
fn handle_irq(&self);
} }
pub fn register_device(name: String, device: Arc<dyn AnyConsoleDevice>) { pub fn register_device(name: String, device: Arc<dyn AnyConsoleDevice>) {
@ -32,16 +35,6 @@ pub fn register_device(name: String, device: Arc<dyn AnyConsoleDevice>) {
.insert(name, device); .insert(name, device);
} }
pub fn get_device(str: &str) -> Option<Arc<dyn AnyConsoleDevice>> {
COMPONENT
.get()
.unwrap()
.console_device_table
.lock_irq_disabled()
.get(str)
.cloned()
}
pub fn all_devices() -> Vec<(String, Arc<dyn AnyConsoleDevice>)> { pub fn all_devices() -> Vec<(String, Arc<dyn AnyConsoleDevice>)> {
let console_devs = COMPONENT let console_devs = COMPONENT
.get() .get()

View File

@ -23,7 +23,6 @@ pub enum InputEvent {
} }
pub trait InputDevice: Send + Sync + Any + Debug { pub trait InputDevice: Send + Sync + Any + Debug {
fn handle_irq(&self) -> Option<()>;
fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)); fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync));
} }

View File

@ -6,16 +6,17 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
component = { path = "../../libs/comp-sys/component" } align_ext = { path = "../../../framework/libs/align_ext" }
aster-frame = { path = "../../../framework/aster-frame" } aster-frame = { path = "../../../framework/aster-frame" }
aster-util = { path = "../../libs/aster-util" } aster-util = { path = "../../libs/aster-util" }
aster-rights = { path = "../../libs/aster-rights" } aster-rights = { path = "../../libs/aster-rights" }
align_ext = { path = "../../../framework/libs/align_ext" }
int-to-c-enum = { path = "../../libs/int-to-c-enum" }
bytes = { version = "1.4.0", default-features = false }
pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" }
bitflags = "1.3" bitflags = "1.3"
spin = "0.9.4" bitvec = { version = "1.0.1", default-features = false, features = ["alloc"]}
ringbuf = { version = "0.3.2", default-features = false, features = ["alloc"] } component = { path = "../../libs/comp-sys/component" }
int-to-c-enum = { path = "../../libs/int-to-c-enum" }
ktest = { path = "../../../framework/libs/ktest" }
log = "0.4" log = "0.4"
pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" }
ringbuf = { version = "0.3.2", default-features = false, features = ["alloc"] }
smoltcp = { version = "0.9.1", default-features = false, features = ["alloc", "log", "medium-ethernet", "medium-ip", "proto-dhcpv4", "proto-ipv4", "proto-igmp", "socket-icmp", "socket-udp", "socket-tcp", "socket-raw", "socket-dhcpv4"] } smoltcp = { version = "0.9.1", default-features = false, features = ["alloc", "log", "medium-ethernet", "medium-ip", "proto-dhcpv4", "proto-ipv4", "proto-igmp", "socket-icmp", "socket-udp", "socket-tcp", "socket-raw", "socket-dhcpv4"] }
spin = "0.9.4"

View File

@ -1,30 +1,89 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::mem::size_of; use alloc::{collections::LinkedList, sync::Arc};
use align_ext::AlignExt; use align_ext::AlignExt;
use bytes::BytesMut; use aster_frame::{
sync::SpinLock,
vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmReader, VmWriter, PAGE_SIZE},
};
use pod::Pod; use pod::Pod;
use spin::Once;
use crate::dma_pool::{DmaPool, DmaSegment};
pub struct TxBuffer {
dma_stream: DmaStream,
nbytes: usize,
}
impl TxBuffer {
pub fn new<H: Pod>(header: &H, packet: &[u8]) -> Self {
let header = header.as_bytes();
let nbytes = header.len() + packet.len();
let dma_stream = if let Some(stream) = get_tx_stream_from_pool(nbytes) {
stream
} else {
let segment = {
let nframes = (nbytes.align_up(PAGE_SIZE)) / PAGE_SIZE;
VmAllocOptions::new(nframes).alloc_contiguous().unwrap()
};
DmaStream::map(segment, DmaDirection::ToDevice, false).unwrap()
};
let mut writer = dma_stream.writer().unwrap();
writer.write(&mut VmReader::from(header));
writer.write(&mut VmReader::from(packet));
let tx_buffer = Self { dma_stream, nbytes };
tx_buffer.sync();
tx_buffer
}
pub fn writer(&self) -> VmWriter<'_> {
self.dma_stream.writer().unwrap().limit(self.nbytes)
}
fn sync(&self) {
self.dma_stream.sync(0..self.nbytes).unwrap();
}
pub fn nbytes(&self) -> usize {
self.nbytes
}
}
impl HasDaddr for TxBuffer {
fn daddr(&self) -> Daddr {
self.dma_stream.daddr()
}
}
impl Drop for TxBuffer {
fn drop(&mut self) {
TX_BUFFER_POOL
.get()
.unwrap()
.lock_irq_disabled()
.push_back(self.dma_stream.clone());
}
}
/// Buffer for receive packet
#[derive(Debug)]
pub struct RxBuffer { pub struct RxBuffer {
/// Packet Buffer, length align 8. segment: DmaSegment,
buf: BytesMut,
/// Header len
header_len: usize, header_len: usize,
/// Packet len
packet_len: usize, packet_len: usize,
} }
impl RxBuffer { impl RxBuffer {
pub fn new(len: usize, header_len: usize) -> Self { pub fn new(header_len: usize) -> Self {
let len = len.align_up(8); assert!(header_len <= RX_BUFFER_LEN);
let buf = BytesMut::zeroed(len); let segment = RX_BUFFER_POOL.get().unwrap().alloc_segment().unwrap();
Self { Self {
buf, segment,
packet_len: 0,
header_len, header_len,
packet_len: 0,
} }
} }
@ -33,59 +92,60 @@ impl RxBuffer {
} }
pub fn set_packet_len(&mut self, packet_len: usize) { pub fn set_packet_len(&mut self, packet_len: usize) {
assert!(self.header_len + packet_len <= RX_BUFFER_LEN);
self.packet_len = packet_len; self.packet_len = packet_len;
} }
pub fn buf(&self) -> &[u8] { pub fn packet(&self) -> VmReader<'_> {
&self.buf self.segment
.sync(self.header_len..self.header_len + self.packet_len)
.unwrap();
self.segment
.reader()
.unwrap()
.skip(self.header_len)
.limit(self.packet_len)
} }
pub fn buf_mut(&mut self) -> &mut [u8] { pub const fn buf_len(&self) -> usize {
&mut self.buf self.segment.size()
}
/// Packet payload slice, which is inner buffer excluding VirtioNetHdr.
pub fn packet(&self) -> &[u8] {
debug_assert!(self.header_len + self.packet_len <= self.buf.len());
&self.buf[self.header_len..self.header_len + self.packet_len]
}
/// Mutable packet payload slice.
pub fn packet_mut(&mut self) -> &mut [u8] {
debug_assert!(self.header_len + self.packet_len <= self.buf.len());
&mut self.buf[self.header_len..self.header_len + self.packet_len]
}
pub fn header<H: Pod>(&self) -> H {
debug_assert_eq!(size_of::<H>(), self.header_len);
H::from_bytes(&self.buf[..size_of::<H>()])
} }
} }
/// Buffer for transmit packet impl HasDaddr for RxBuffer {
#[derive(Debug)] fn daddr(&self) -> Daddr {
pub struct TxBuffer { self.segment.daddr()
buf: BytesMut,
}
impl TxBuffer {
pub fn with_len(buf_len: usize) -> Self {
Self {
buf: BytesMut::zeroed(buf_len),
}
}
pub fn new(buf: &[u8]) -> Self {
Self {
buf: BytesMut::from(buf),
}
}
pub fn buf(&self) -> &[u8] {
&self.buf
}
pub fn buf_mut(&mut self) -> &mut [u8] {
&mut self.buf
} }
} }
const RX_BUFFER_LEN: usize = 4096;
static RX_BUFFER_POOL: Once<Arc<DmaPool>> = Once::new();
static TX_BUFFER_POOL: Once<SpinLock<LinkedList<DmaStream>>> = Once::new();
fn get_tx_stream_from_pool(nbytes: usize) -> Option<DmaStream> {
let mut pool = TX_BUFFER_POOL.get().unwrap().lock_irq_disabled();
let mut cursor = pool.cursor_front_mut();
while let Some(current) = cursor.current() {
if current.nbytes() >= nbytes {
return cursor.remove_current();
}
cursor.move_next();
}
None
}
pub fn init() {
const POOL_INIT_SIZE: usize = 32;
const POOL_HIGH_WATERMARK: usize = 64;
RX_BUFFER_POOL.call_once(|| {
DmaPool::new(
RX_BUFFER_LEN,
POOL_INIT_SIZE,
POOL_HIGH_WATERMARK,
DmaDirection::FromDevice,
false,
)
});
TX_BUFFER_POOL.call_once(|| SpinLock::new(LinkedList::new()));
}

View File

@ -0,0 +1,367 @@
// SPDX-License-Identifier: MPL-2.0
#![allow(unused)]
use alloc::{
collections::VecDeque,
sync::{Arc, Weak},
};
use core::ops::Range;
use aster_frame::{
sync::{RwLock, SpinLock},
vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmReader, VmWriter, PAGE_SIZE},
};
use bitvec::{array::BitArray, prelude::Lsb0};
use ktest::ktest;
/// `DmaPool` is responsible for allocating small streaming DMA segments
/// (equal to or smaller than PAGE_SIZE),
/// referred to as `DmaSegment`.
///
/// A `DmaPool` can only allocate `DmaSegment` of a fixed size.
/// Once a `DmaSegment` is dropped, it will be returned to the pool.
/// If the `DmaPool` is dropped before the associated `DmaSegment`,
/// the `drop` method of the `DmaSegment` will panic.
///
/// Therefore, as a best practice,
/// it is recommended for the `DmaPool` to have a static lifetime.
#[derive(Debug)]
pub struct DmaPool {
segment_size: usize,
direction: DmaDirection,
is_cache_coherent: bool,
high_watermark: usize,
avail_pages: SpinLock<VecDeque<Arc<DmaPage>>>,
all_pages: SpinLock<VecDeque<Arc<DmaPage>>>,
}
impl DmaPool {
/// Constructs a new `DmaPool` with a specified initial capacity and a high watermark.
///
/// The `DmaPool` starts with `init_size` DMAable pages.
/// As additional DMA blocks are requested beyond the initial capacity,
/// the pool dynamically allocates more DMAable pages.
/// To optimize performance, the pool employs a lazy deallocation strategy:
/// A DMAable page is freed only if it meets the following conditions:
/// 1. The page is currently not in use;
/// 2. The total number of allocated DMAable pages exceeds the specified `high_watermark`.
///
/// The returned pool can be used to allocate small segments for DMA usage.
/// All allocated segments will have the same DMA direction
/// and will either all be cache coherent or not cache coherent,
/// as specified in the parameters.
pub fn new(
segment_size: usize,
init_size: usize,
high_watermark: usize,
direction: DmaDirection,
is_cache_coherent: bool,
) -> Arc<Self> {
assert!(segment_size.is_power_of_two());
assert!(segment_size >= 64);
assert!(segment_size <= PAGE_SIZE);
assert!(high_watermark >= init_size);
Arc::new_cyclic(|pool| {
let mut avail_pages = VecDeque::new();
let mut all_pages = VecDeque::new();
for _ in 0..init_size {
let page = Arc::new(
DmaPage::new(
segment_size,
direction,
is_cache_coherent,
Weak::clone(pool),
)
.unwrap(),
);
avail_pages.push_back(page.clone());
all_pages.push_back(page);
}
Self {
segment_size,
direction,
is_cache_coherent,
high_watermark,
avail_pages: SpinLock::new(avail_pages),
all_pages: SpinLock::new(all_pages),
}
})
}
/// Allocates a `DmaSegment` from the pool
pub fn alloc_segment(self: &Arc<Self>) -> Result<DmaSegment, aster_frame::Error> {
// Lock order: pool.avail_pages -> pool.all_pages
// pool.avail_pages -> page.allocated_segments
let mut avail_pages = self.avail_pages.lock_irq_disabled();
if avail_pages.is_empty() {
/// Allocate a new page
let new_page = {
let pool = Arc::downgrade(self);
Arc::new(DmaPage::new(
self.segment_size,
self.direction,
self.is_cache_coherent,
pool,
)?)
};
let mut all_pages = self.all_pages.lock_irq_disabled();
avail_pages.push_back(new_page.clone());
all_pages.push_back(new_page);
}
let first_avail_page = avail_pages.front().unwrap();
let free_segment = first_avail_page.alloc_segment().unwrap();
if first_avail_page.is_full() {
avail_pages.pop_front();
}
Ok(free_segment)
}
/// Returns the number of pages in pool
fn num_pages(&self) -> usize {
self.all_pages.lock_irq_disabled().len()
}
}
#[derive(Debug)]
struct DmaPage {
storage: DmaStream,
segment_size: usize,
// `BitArray` is 64 bits, since each `DmaSegment` is bigger than 64 bytes,
// there's no more than `PAGE_SIZE` / 64 = 64 `DmaSegment`s in a `DmaPage`.
allocated_segments: SpinLock<BitArray>,
pool: Weak<DmaPool>,
}
impl DmaPage {
fn new(
segment_size: usize,
direction: DmaDirection,
is_cache_coherent: bool,
pool: Weak<DmaPool>,
) -> Result<Self, aster_frame::Error> {
let dma_stream = {
let vm_segment = VmAllocOptions::new(1).alloc_contiguous()?;
DmaStream::map(vm_segment, direction, is_cache_coherent)
.map_err(|_| aster_frame::Error::AccessDenied)?
};
Ok(Self {
storage: dma_stream,
segment_size,
allocated_segments: SpinLock::new(BitArray::ZERO),
pool,
})
}
fn alloc_segment(self: &Arc<Self>) -> Option<DmaSegment> {
let mut segments = self.allocated_segments.lock_irq_disabled();
let free_segment_index = get_next_free_index(&segments, self.nr_blocks_per_page())?;
segments.set(free_segment_index, true);
let segment = DmaSegment {
size: self.segment_size,
dma_stream: self.storage.clone(),
start_addr: self.storage.daddr() + free_segment_index * self.segment_size,
page: Arc::downgrade(self),
};
Some(segment)
}
fn is_free(&self) -> bool {
*self.allocated_segments.lock() == BitArray::<[usize; 1], Lsb0>::ZERO
}
const fn nr_blocks_per_page(&self) -> usize {
PAGE_SIZE / self.segment_size
}
fn is_full(&self) -> bool {
let segments = self.allocated_segments.lock_irq_disabled();
get_next_free_index(&segments, self.nr_blocks_per_page()).is_none()
}
}
fn get_next_free_index(segments: &BitArray, nr_blocks_per_page: usize) -> Option<usize> {
let free_segment_index = segments.iter_zeros().next()?;
if free_segment_index >= nr_blocks_per_page {
None
} else {
Some(free_segment_index)
}
}
impl HasDaddr for DmaPage {
fn daddr(&self) -> Daddr {
self.storage.daddr()
}
}
/// A small and fixed-size segment of DMA memory.
///
/// The size of `DmaSegment` ranges from 64 bytes to `PAGE_SIZE` and must be 2^K.
/// Each `DmaSegment`'s daddr must be aligned with its size.
#[derive(Debug)]
pub struct DmaSegment {
dma_stream: DmaStream,
start_addr: Daddr,
size: usize,
page: Weak<DmaPage>,
}
impl HasDaddr for DmaSegment {
fn daddr(&self) -> Daddr {
self.start_addr
}
}
impl DmaSegment {
pub const fn size(&self) -> usize {
self.size
}
pub fn reader(&self) -> Result<VmReader<'_>, aster_frame::Error> {
let offset = self.start_addr - self.dma_stream.daddr();
Ok(self.dma_stream.reader()?.skip(offset).limit(self.size))
}
pub fn writer(&self) -> Result<VmWriter<'_>, aster_frame::Error> {
let offset = self.start_addr - self.dma_stream.daddr();
Ok(self.dma_stream.writer()?.skip(offset).limit(self.size))
}
pub fn sync(&self, byte_range: Range<usize>) -> Result<(), aster_frame::Error> {
let offset = self.daddr() - self.dma_stream.daddr();
let range = byte_range.start + offset..byte_range.end + offset;
self.dma_stream.sync(range)
}
}
impl Drop for DmaSegment {
fn drop(&mut self) {
let page = self.page.upgrade().unwrap();
let pool = page.pool.upgrade().unwrap();
// Keep the same lock order as `pool.alloc_segment`
// Lock order: pool.avail_pages -> pool.all_pages -> page.allocated_segments
let mut avail_pages = pool.avail_pages.lock_irq_disabled();
let mut all_pages = pool.all_pages.lock_irq_disabled();
let mut allocated_segments = page.allocated_segments.lock_irq_disabled();
let nr_blocks_per_page = PAGE_SIZE / self.size;
let became_avail = get_next_free_index(&allocated_segments, nr_blocks_per_page).is_none();
debug_assert!((page.daddr()..page.daddr() + PAGE_SIZE).contains(&self.daddr()));
let segment_idx = (self.daddr() - page.daddr()) / self.size;
allocated_segments.set(segment_idx, false);
let became_free = allocated_segments.not_any();
if became_free && all_pages.len() > pool.high_watermark {
avail_pages.retain(|page_| !Arc::ptr_eq(page_, &page));
all_pages.retain(|page_| !Arc::ptr_eq(page_, &page));
return;
}
if became_avail {
avail_pages.push_back(page.clone());
}
}
}
#[cfg(ktest)]
mod test {
use alloc::vec::Vec;
use super::*;
#[ktest]
fn alloc_page_size_segment() {
let pool = DmaPool::new(PAGE_SIZE, 0, 100, DmaDirection::ToDevice, false);
let segments1: Vec<_> = (0..100)
.map(|_| {
let segment = pool.alloc_segment().unwrap();
assert_eq!(segment.size(), PAGE_SIZE);
assert!(segment.reader().is_err());
assert!(segment.writer().is_ok());
segment
})
.collect();
assert_eq!(pool.num_pages(), 100);
drop(segments1);
}
#[ktest]
fn write_to_dma_segment() {
let pool: Arc<DmaPool> = DmaPool::new(PAGE_SIZE, 1, 2, DmaDirection::ToDevice, false);
let segment = pool.alloc_segment().unwrap();
let mut writer = segment.writer().unwrap();
let data = &[0u8, 1, 2, 3, 4] as &[u8];
let size = writer.write(&mut VmReader::from(data));
assert_eq!(size, data.len());
}
#[ktest]
fn free_pool_pages() {
let pool: Arc<DmaPool> = DmaPool::new(PAGE_SIZE, 10, 50, DmaDirection::ToDevice, false);
let segments1: Vec<_> = (0..100)
.map(|_| {
let segment = pool.alloc_segment().unwrap();
assert_eq!(segment.size(), PAGE_SIZE);
assert!(segment.reader().is_err());
assert!(segment.writer().is_ok());
segment
})
.collect();
assert_eq!(pool.num_pages(), 100);
drop(segments1);
assert_eq!(pool.num_pages(), 50);
}
#[ktest]
fn alloc_small_size_segment() {
const SEGMENT_SIZE: usize = PAGE_SIZE / 4;
let pool: Arc<DmaPool> =
DmaPool::new(SEGMENT_SIZE, 0, 10, DmaDirection::Bidirectional, false);
let segments1: Vec<_> = (0..100)
.map(|_| {
let segment = pool.alloc_segment().unwrap();
assert_eq!(segment.size(), PAGE_SIZE / 4);
assert!(segment.reader().is_ok());
assert!(segment.writer().is_ok());
segment
})
.collect();
assert_eq!(pool.num_pages(), 100 / 4);
drop(segments1);
assert_eq!(pool.num_pages(), 10);
}
#[ktest]
fn read_dma_segments() {
const SEGMENT_SIZE: usize = PAGE_SIZE / 4;
let pool: Arc<DmaPool> =
DmaPool::new(SEGMENT_SIZE, 1, 2, DmaDirection::Bidirectional, false);
let segment = pool.alloc_segment().unwrap();
assert_eq!(pool.num_pages(), 1);
let mut writer = segment.writer().unwrap();
let data = &[0u8, 1, 2, 3, 4] as &[u8];
let size = writer.write(&mut VmReader::from(data));
assert_eq!(size, data.len());
let mut read_buf = [0u8; 5];
let mut reader = segment.reader().unwrap();
reader.read(&mut VmWriter::from(&mut read_buf as &mut [u8]));
assert_eq!(&read_buf, data);
}
}

View File

@ -2,12 +2,10 @@
use alloc::vec; use alloc::vec;
use aster_frame::vm::VmWriter;
use smoltcp::{phy, time::Instant}; use smoltcp::{phy, time::Instant};
use crate::{ use crate::{buffer::RxBuffer, AnyNetworkDevice};
buffer::{RxBuffer, TxBuffer},
AnyNetworkDevice,
};
impl phy::Device for dyn AnyNetworkDevice { impl phy::Device for dyn AnyNetworkDevice {
type RxToken<'a> = RxToken; type RxToken<'a> = RxToken;
@ -37,12 +35,14 @@ impl phy::Device for dyn AnyNetworkDevice {
pub struct RxToken(RxBuffer); pub struct RxToken(RxBuffer);
impl phy::RxToken for RxToken { impl phy::RxToken for RxToken {
fn consume<R, F>(mut self, f: F) -> R fn consume<R, F>(self, f: F) -> R
where where
F: FnOnce(&mut [u8]) -> R, F: FnOnce(&mut [u8]) -> R,
{ {
let packet_but = self.0.packet_mut(); let mut packet = self.0.packet();
f(packet_but) let mut buffer = vec![0u8; packet.remain()];
packet.read(&mut VmWriter::from(&mut buffer as &mut [u8]));
f(&mut buffer)
} }
} }
@ -55,8 +55,7 @@ impl<'a> phy::TxToken for TxToken<'a> {
{ {
let mut buffer = vec![0u8; len]; let mut buffer = vec![0u8; len];
let res = f(&mut buffer); let res = f(&mut buffer);
let tx_buffer = TxBuffer::new(&buffer); self.0.send(&buffer).expect("Send packet failed");
self.0.send(tx_buffer).expect("Send packet failed");
res res
} }
} }

View File

@ -4,9 +4,11 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
#![feature(trait_alias)] #![feature(trait_alias)]
#![feature(fn_traits)] #![feature(fn_traits)]
#![feature(linked_list_cursors)]
pub mod buffer; mod buffer;
pub mod driver; mod dma_pool;
mod driver;
extern crate alloc; extern crate alloc;
@ -15,8 +17,9 @@ use core::{any::Any, fmt::Debug};
use aster_frame::sync::SpinLock; use aster_frame::sync::SpinLock;
use aster_util::safe_ptr::Pod; use aster_util::safe_ptr::Pod;
use buffer::{RxBuffer, TxBuffer}; pub use buffer::{RxBuffer, TxBuffer};
use component::{init_component, ComponentInitError}; use component::{init_component, ComponentInitError};
pub use dma_pool::DmaSegment;
use smoltcp::phy; use smoltcp::phy;
use spin::Once; use spin::Once;
@ -45,7 +48,7 @@ pub trait AnyNetworkDevice: Send + Sync + Any + Debug {
/// Otherwise, return NotReady error. /// Otherwise, return NotReady error.
fn receive(&mut self) -> Result<RxBuffer, VirtioNetError>; fn receive(&mut self) -> Result<RxBuffer, VirtioNetError>;
/// Send a packet to network. Return until the request completes. /// Send a packet to network. Return until the request completes.
fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError>; fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError>;
} }
pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static; pub trait NetDeviceIrqHandler = Fn() + Send + Sync + 'static;
@ -55,38 +58,57 @@ pub fn register_device(name: String, device: Arc<SpinLock<dyn AnyNetworkDevice>>
.get() .get()
.unwrap() .unwrap()
.network_device_table .network_device_table
.lock() .lock_irq_disabled()
.insert(name, (Arc::new(SpinLock::new(Vec::new())), device)); .insert(name, (Arc::new(SpinLock::new(Vec::new())), device));
} }
pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice>>> { pub fn get_device(str: &str) -> Option<Arc<SpinLock<dyn AnyNetworkDevice>>> {
let lock = COMPONENT.get().unwrap().network_device_table.lock(); let table = COMPONENT
let (_, device) = lock.get(str)?; .get()
.unwrap()
.network_device_table
.lock_irq_disabled();
let (_, device) = table.get(str)?;
Some(device.clone()) Some(device.clone())
} }
/// Registers callback which will be called when receiving message.
///
/// Since the callback will be called in interrupt context,
/// the callback function should NOT sleep.
pub fn register_recv_callback(name: &str, callback: impl NetDeviceIrqHandler) { pub fn register_recv_callback(name: &str, callback: impl NetDeviceIrqHandler) {
let lock = COMPONENT.get().unwrap().network_device_table.lock(); let device_table = COMPONENT
let Some((callbacks, _)) = lock.get(name) else { .get()
.unwrap()
.network_device_table
.lock_irq_disabled();
let Some((callbacks, _)) = device_table.get(name) else {
return; return;
}; };
callbacks.lock().push(Arc::new(callback)); callbacks.lock_irq_disabled().push(Arc::new(callback));
} }
pub fn handle_recv_irq(name: &str) { pub fn handle_recv_irq(name: &str) {
let lock = COMPONENT.get().unwrap().network_device_table.lock(); let device_table = COMPONENT
let Some((callbacks, _)) = lock.get(name) else { .get()
.unwrap()
.network_device_table
.lock_irq_disabled();
let Some((callbacks, _)) = device_table.get(name) else {
return; return;
}; };
let callbacks = callbacks.clone(); let callbacks = callbacks.lock_irq_disabled();
let lock = callbacks.lock(); for callback in callbacks.iter() {
for callback in lock.iter() { callback();
callback.call(())
} }
} }
pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> { pub fn all_devices() -> Vec<(String, NetworkDeviceRef)> {
let network_devs = COMPONENT.get().unwrap().network_device_table.lock(); let network_devs = COMPONENT
.get()
.unwrap()
.network_device_table
.lock_irq_disabled();
network_devs network_devs
.iter() .iter()
.map(|(name, (_, device))| (name.clone(), device.clone())) .map(|(name, (_, device))| (name.clone(), device.clone()))
@ -102,6 +124,7 @@ fn init() -> Result<(), ComponentInitError> {
let a = Component::init()?; let a = Component::init()?;
COMPONENT.call_once(|| a); COMPONENT.call_once(|| a);
NETWORK_IRQ_HANDLERS.call_once(|| SpinLock::new(Vec::new())); NETWORK_IRQ_HANDLERS.call_once(|| SpinLock::new(Vec::new()));
buffer::init();
Ok(()) Ok(())
} }

View File

@ -37,4 +37,3 @@ smoltcp = { version = "0.9.1", default-features = false, features = [
"socket-raw", "socket-raw",
"socket-dhcpv4", "socket-dhcpv4",
] } ] }
[features]

View File

@ -99,18 +99,12 @@ impl DeviceInner {
let queue = VirtQueue::new(0, Self::QUEUE_SIZE, transport.as_mut()) let queue = VirtQueue::new(0, Self::QUEUE_SIZE, transport.as_mut())
.expect("create virtqueue failed"); .expect("create virtqueue failed");
let block_requests = { let block_requests = {
let vm_segment = VmAllocOptions::new(1) let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap();
.is_contiguous(true)
.alloc_contiguous()
.unwrap();
DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap() DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap()
}; };
assert!(Self::QUEUE_SIZE as usize * REQ_SIZE <= block_requests.nbytes()); assert!(Self::QUEUE_SIZE as usize * REQ_SIZE <= block_requests.nbytes());
let block_responses = { let block_responses = {
let vm_segment = VmAllocOptions::new(1) let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap();
.is_contiguous(true)
.alloc_contiguous()
.unwrap();
DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap() DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap()
}; };
assert!(Self::QUEUE_SIZE as usize * RESP_SIZE <= block_responses.nbytes()); assert!(Self::QUEUE_SIZE as usize * RESP_SIZE <= block_responses.nbytes());
@ -222,7 +216,6 @@ impl DeviceInner {
const MAX_ID_LENGTH: usize = 20; const MAX_ID_LENGTH: usize = 20;
let device_id_stream = { let device_id_stream = {
let segment = VmAllocOptions::new(1) let segment = VmAllocOptions::new(1)
.is_contiguous(true)
.uninit(true) .uninit(true)
.alloc_contiguous() .alloc_contiguous()
.unwrap(); .unwrap();

View File

@ -4,7 +4,12 @@ use alloc::{boxed::Box, fmt::Debug, string::ToString, sync::Arc, vec::Vec};
use core::hint::spin_loop; use core::hint::spin_loop;
use aster_console::{AnyConsoleDevice, ConsoleCallback}; use aster_console::{AnyConsoleDevice, ConsoleCallback};
use aster_frame::{io_mem::IoMem, sync::SpinLock, trap::TrapFrame, vm::PAGE_SIZE}; use aster_frame::{
io_mem::IoMem,
sync::SpinLock,
trap::TrapFrame,
vm::{DmaDirection, DmaStream, DmaStreamSlice, VmAllocOptions, VmReader},
};
use aster_util::safe_ptr::SafePtr; use aster_util::safe_ptr::SafePtr;
use log::debug; use log::debug;
@ -17,17 +22,27 @@ use crate::{
pub struct ConsoleDevice { pub struct ConsoleDevice {
config: SafePtr<VirtioConsoleConfig, IoMem>, config: SafePtr<VirtioConsoleConfig, IoMem>,
transport: Box<dyn VirtioTransport>, transport: SpinLock<Box<dyn VirtioTransport>>,
receive_queue: SpinLock<VirtQueue>, receive_queue: SpinLock<VirtQueue>,
transmit_queue: SpinLock<VirtQueue>, transmit_queue: SpinLock<VirtQueue>,
buffer: SpinLock<Box<[u8; PAGE_SIZE]>>, send_buffer: DmaStream,
receive_buffer: DmaStream,
callbacks: SpinLock<Vec<&'static ConsoleCallback>>, callbacks: SpinLock<Vec<&'static ConsoleCallback>>,
} }
impl AnyConsoleDevice for ConsoleDevice { impl AnyConsoleDevice for ConsoleDevice {
fn send(&self, value: &[u8]) { fn send(&self, value: &[u8]) {
let mut transmit_queue = self.transmit_queue.lock_irq_disabled(); let mut transmit_queue = self.transmit_queue.lock_irq_disabled();
transmit_queue.add_buf(&[value], &[]).unwrap(); let mut reader = VmReader::from(value);
while reader.remain() > 0 {
let mut writer = self.send_buffer.writer().unwrap();
let len = writer.write(&mut reader);
self.send_buffer.sync(0..len).unwrap();
let slice = DmaStreamSlice::new(&self.send_buffer, 0, len);
transmit_queue.add_dma_buf(&[&slice], &[]).unwrap();
if transmit_queue.should_notify() { if transmit_queue.should_notify() {
transmit_queue.notify(); transmit_queue.notify();
} }
@ -36,43 +51,10 @@ impl AnyConsoleDevice for ConsoleDevice {
} }
transmit_queue.pop_used().unwrap(); transmit_queue.pop_used().unwrap();
} }
fn recv(&self, buf: &mut [u8]) -> Option<usize> {
let mut receive_queue = self.receive_queue.lock_irq_disabled();
if !receive_queue.can_pop() {
return None;
}
let (_, len) = receive_queue.pop_used().unwrap();
let mut recv_buffer = self.buffer.lock();
buf.copy_from_slice(&recv_buffer.as_ref()[..len as usize]);
receive_queue.add_buf(&[], &[recv_buffer.as_mut()]).unwrap();
if receive_queue.should_notify() {
receive_queue.notify();
}
Some(len as usize)
} }
fn register_callback(&self, callback: &'static (dyn Fn(&[u8]) + Send + Sync)) { fn register_callback(&self, callback: &'static ConsoleCallback) {
self.callbacks.lock().push(callback); self.callbacks.lock_irq_disabled().push(callback);
}
fn handle_irq(&self) {
let mut receive_queue = self.receive_queue.lock_irq_disabled();
if !receive_queue.can_pop() {
return;
}
let (_, len) = receive_queue.pop_used().unwrap();
let mut recv_buffer = self.buffer.lock();
let buffer = &recv_buffer.as_ref()[..len as usize];
let lock = self.callbacks.lock();
for callback in lock.iter() {
callback.call((buffer,));
}
receive_queue.add_buf(&[], &[recv_buffer.as_mut()]).unwrap();
if receive_queue.should_notify() {
receive_queue.notify();
}
} }
} }
@ -104,41 +86,76 @@ impl ConsoleDevice {
let transmit_queue = let transmit_queue =
SpinLock::new(VirtQueue::new(TRANSMIT0_QUEUE_INDEX, 2, transport.as_mut()).unwrap()); SpinLock::new(VirtQueue::new(TRANSMIT0_QUEUE_INDEX, 2, transport.as_mut()).unwrap());
let mut device = Self { let send_buffer = {
config, let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap();
transport, DmaStream::map(vm_segment, DmaDirection::ToDevice, false).unwrap()
receive_queue,
transmit_queue,
buffer: SpinLock::new(Box::new([0; PAGE_SIZE])),
callbacks: SpinLock::new(Vec::new()),
}; };
let mut receive_queue = device.receive_queue.lock(); let receive_buffer = {
let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap();
DmaStream::map(vm_segment, DmaDirection::FromDevice, false).unwrap()
};
let device = Arc::new(Self {
config,
transport: SpinLock::new(transport),
receive_queue,
transmit_queue,
send_buffer,
receive_buffer,
callbacks: SpinLock::new(Vec::new()),
});
let mut receive_queue = device.receive_queue.lock_irq_disabled();
receive_queue receive_queue
.add_buf(&[], &[device.buffer.lock().as_mut()]) .add_dma_buf(&[], &[&device.receive_buffer])
.unwrap(); .unwrap();
if receive_queue.should_notify() { if receive_queue.should_notify() {
receive_queue.notify(); receive_queue.notify();
} }
drop(receive_queue); drop(receive_queue);
device
.transport // Register irq callbacks
let mut transport = device.transport.lock_irq_disabled();
let handle_console_input = {
let device = device.clone();
move |_: &TrapFrame| device.handle_recv_irq()
};
transport
.register_queue_callback(RECV0_QUEUE_INDEX, Box::new(handle_console_input), false) .register_queue_callback(RECV0_QUEUE_INDEX, Box::new(handle_console_input), false)
.unwrap(); .unwrap();
device transport
.transport
.register_cfg_callback(Box::new(config_space_change)) .register_cfg_callback(Box::new(config_space_change))
.unwrap(); .unwrap();
device.transport.finish_init(); transport.finish_init();
drop(transport);
aster_console::register_device(DEVICE_NAME.to_string(), Arc::new(device)); aster_console::register_device(DEVICE_NAME.to_string(), device);
Ok(()) Ok(())
} }
}
fn handle_console_input(_: &TrapFrame) { fn handle_recv_irq(&self) {
aster_console::get_device(DEVICE_NAME).unwrap().handle_irq(); let mut receive_queue = self.receive_queue.lock_irq_disabled();
if !receive_queue.can_pop() {
return;
}
let (_, len) = receive_queue.pop_used().unwrap();
self.receive_buffer.sync(0..len as usize).unwrap();
let callbacks = self.callbacks.lock_irq_disabled();
for callback in callbacks.iter() {
let reader = self.receive_buffer.reader().unwrap().limit(len as usize);
callback(reader);
}
receive_queue
.add_dma_buf(&[], &[&self.receive_buffer])
.unwrap();
if receive_queue.should_notify() {
receive_queue.notify();
}
}
} }
fn config_space_change(_: &TrapFrame) { fn config_space_change(_: &TrapFrame) {

View File

@ -6,9 +6,15 @@ use alloc::{
sync::Arc, sync::Arc,
vec::Vec, vec::Vec,
}; };
use core::fmt::Debug; use core::{fmt::Debug, mem};
use aster_frame::{io_mem::IoMem, offset_of, sync::SpinLock, trap::TrapFrame}; use aster_frame::{
io_mem::IoMem,
offset_of,
sync::SpinLock,
trap::TrapFrame,
vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmIo, VmReader, PAGE_SIZE},
};
use aster_input::{ use aster_input::{
key::{Key, KeyStatus}, key::{Key, KeyStatus},
InputEvent, InputEvent,
@ -16,10 +22,11 @@ use aster_input::{
use aster_util::{field_ptr, safe_ptr::SafePtr}; use aster_util::{field_ptr, safe_ptr::SafePtr};
use bitflags::bitflags; use bitflags::bitflags;
use log::{debug, info}; use log::{debug, info};
use pod::Pod;
use super::{InputConfigSelect, VirtioInputConfig, VirtioInputEvent, QUEUE_EVENT, QUEUE_STATUS}; use super::{InputConfigSelect, VirtioInputConfig, VirtioInputEvent, QUEUE_EVENT, QUEUE_STATUS};
use crate::{device::VirtioDeviceError, queue::VirtQueue, transport::VirtioTransport}; use crate::{
device::VirtioDeviceError, dma_buf::DmaBuf, queue::VirtQueue, transport::VirtioTransport,
};
bitflags! { bitflags! {
/// The properties of input device. /// The properties of input device.
@ -67,25 +74,25 @@ pub struct InputDevice {
config: SafePtr<VirtioInputConfig, IoMem>, config: SafePtr<VirtioInputConfig, IoMem>,
event_queue: SpinLock<VirtQueue>, event_queue: SpinLock<VirtQueue>,
status_queue: VirtQueue, status_queue: VirtQueue,
event_buf: SpinLock<Box<[VirtioInputEvent; QUEUE_SIZE as usize]>>, event_table: EventTable,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
callbacks: SpinLock<Vec<Arc<dyn Fn(InputEvent) + Send + Sync + 'static>>>, callbacks: SpinLock<Vec<Arc<dyn Fn(InputEvent) + Send + Sync + 'static>>>,
transport: Box<dyn VirtioTransport>, transport: SpinLock<Box<dyn VirtioTransport>>,
} }
impl InputDevice { impl InputDevice {
/// Create a new VirtIO-Input driver. /// Create a new VirtIO-Input driver.
/// msix_vector_left should at least have one element or n elements where n is the virtqueue amount /// msix_vector_left should at least have one element or n elements where n is the virtqueue amount
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> { pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
let mut event_buf = Box::new([VirtioInputEvent::default(); QUEUE_SIZE as usize]);
let mut event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut()) let mut event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut())
.expect("create event virtqueue failed"); .expect("create event virtqueue failed");
let status_queue = VirtQueue::new(QUEUE_STATUS, QUEUE_SIZE, transport.as_mut()) let status_queue = VirtQueue::new(QUEUE_STATUS, QUEUE_SIZE, transport.as_mut())
.expect("create status virtqueue failed"); .expect("create status virtqueue failed");
for (i, event) in event_buf.as_mut().iter_mut().enumerate() { let event_table = EventTable::new(QUEUE_SIZE as usize);
// FIEME: replace slice with a more secure data structure to use dma mapping. for i in 0..event_table.num_events() {
let token = event_queue.add_buf(&[], &[event.as_bytes_mut()]); let event_buf = event_table.get(i);
let token = event_queue.add_dma_buf(&[], &[&event_buf]);
match token { match token {
Ok(value) => { Ok(value) => {
assert_eq!(value, i as u16); assert_eq!(value, i as u16);
@ -96,14 +103,14 @@ impl InputDevice {
} }
} }
let mut device = Self { let device = Arc::new(Self {
config: VirtioInputConfig::new(transport.as_mut()), config: VirtioInputConfig::new(transport.as_mut()),
event_queue: SpinLock::new(event_queue), event_queue: SpinLock::new(event_queue),
status_queue, status_queue,
event_buf: SpinLock::new(event_buf), event_table,
transport, transport: SpinLock::new(transport),
callbacks: SpinLock::new(Vec::new()), callbacks: SpinLock::new(Vec::new()),
}; });
let mut raw_name: [u8; 128] = [0; 128]; let mut raw_name: [u8; 128] = [0; 128];
device.query_config_select(InputConfigSelect::IdName, 0, &mut raw_name); device.query_config_select(InputConfigSelect::IdName, 0, &mut raw_name);
@ -115,51 +122,49 @@ impl InputDevice {
let input_prop = InputProp::from_bits(prop[0]).unwrap(); let input_prop = InputProp::from_bits(prop[0]).unwrap();
debug!("input device prop:{:?}", input_prop); debug!("input device prop:{:?}", input_prop);
fn handle_input(_: &TrapFrame) { let mut transport = device.transport.lock_irq_disabled();
debug!("Handle Virtio input interrupt");
let device = aster_input::get_device(super::DEVICE_NAME).unwrap();
device.handle_irq().unwrap();
}
fn config_space_change(_: &TrapFrame) { fn config_space_change(_: &TrapFrame) {
debug!("input device config space change"); debug!("input device config space change");
} }
transport
device
.transport
.register_cfg_callback(Box::new(config_space_change)) .register_cfg_callback(Box::new(config_space_change))
.unwrap(); .unwrap();
device
.transport let handle_input = {
let device = device.clone();
move |_: &TrapFrame| device.handle_irq()
};
transport
.register_queue_callback(QUEUE_EVENT, Box::new(handle_input), false) .register_queue_callback(QUEUE_EVENT, Box::new(handle_input), false)
.unwrap(); .unwrap();
device.transport.finish_init(); transport.finish_init();
drop(transport);
aster_input::register_device(super::DEVICE_NAME.to_string(), Arc::new(device)); aster_input::register_device(super::DEVICE_NAME.to_string(), device);
Ok(()) Ok(())
} }
/// Pop the pending event. /// Pop the pending event.
pub fn pop_pending_event(&self) -> Option<VirtioInputEvent> { fn pop_pending_events(&self, handle_event: &impl Fn(&EventBuf) -> bool) {
let mut lock = self.event_queue.lock(); let mut event_queue = self.event_queue.lock_irq_disabled();
if let Ok((token, _)) = lock.pop_used() {
if token >= QUEUE_SIZE { // one interrupt may contain several input events, so it should loop
return None; while let Ok((token, _)) = event_queue.pop_used() {
} debug_assert!(token < QUEUE_SIZE);
let event = &mut self.event_buf.lock()[token as usize]; let event_buf = self.event_table.get(token as usize);
// requeue let res = handle_event(&event_buf);
// FIEME: replace slice with a more secure data structure to use dma mapping. let new_token = event_queue.add_dma_buf(&[], &[&event_buf]).unwrap();
if let Ok(new_token) = lock.add_buf(&[], &[event.as_bytes_mut()]) {
// This only works because nothing happen between `pop_used` and `add` that affects // This only works because nothing happen between `pop_used` and `add` that affects
// the list of free descriptors in the queue, so `add` reuses the descriptor which // the list of free descriptors in the queue, so `add` reuses the descriptor which
// was just freed by `pop_used`. // was just freed by `pop_used`.
assert_eq!(new_token, token); assert_eq!(new_token, token);
return Some(*event);
if !res {
break;
} }
} }
None
} }
/// Query a specific piece of information by `select` and `subsel`, and write /// Query a specific piece of information by `select` and `subsel`, and write
@ -181,6 +186,42 @@ impl InputDevice {
size size
} }
fn handle_irq(&self) {
// Returns ture if there may be more events to handle
let handle_event = |event: &EventBuf| -> bool {
let event: VirtioInputEvent = {
let mut reader = event.reader();
reader.read_val()
};
match event.event_type {
0 => return false,
// Keyboard
1 => {}
// TODO: Support mouse device.
_ => return true,
}
let status = match event.value {
1 => KeyStatus::Pressed,
0 => KeyStatus::Released,
_ => return false,
};
let event = InputEvent::KeyBoard(Key::try_from(event.code).unwrap(), status);
info!("Input Event:{:?}", event);
let callbacks = self.callbacks.lock();
for callback in callbacks.iter() {
callback(event);
}
true
};
self.pop_pending_events(&handle_event);
}
/// Negotiate features for the device specified bits 0~23 /// Negotiate features for the device specified bits 0~23
pub(crate) fn negotiate_features(features: u64) -> u64 { pub(crate) fn negotiate_features(features: u64) -> u64 {
assert_eq!(features, 0); assert_eq!(features, 0);
@ -188,37 +229,85 @@ impl InputDevice {
} }
} }
/// A event table consists of many event buffers,
/// each of which is large enough to contain a `VirtioInputEvent`.
#[derive(Debug)]
struct EventTable {
stream: DmaStream,
num_events: usize,
}
impl EventTable {
fn new(num_events: usize) -> Self {
assert!(num_events * mem::size_of::<VirtioInputEvent>() <= PAGE_SIZE);
let vm_segment = VmAllocOptions::new(1).alloc_contiguous().unwrap();
let default_event = VirtioInputEvent::default();
for idx in 0..num_events {
let offset = idx * EVENT_SIZE;
vm_segment.write_val(offset, &default_event).unwrap();
}
let stream = DmaStream::map(vm_segment, DmaDirection::FromDevice, false).unwrap();
Self { stream, num_events }
}
fn get(&self, idx: usize) -> EventBuf<'_> {
assert!(idx < self.num_events);
let offset = idx * EVENT_SIZE;
EventBuf {
event_table: self,
offset,
size: EVENT_SIZE,
}
}
const fn num_events(&self) -> usize {
self.num_events
}
}
const EVENT_SIZE: usize = core::mem::size_of::<VirtioInputEvent>();
/// A buffer stores exact one `VirtioInputEvent`
struct EventBuf<'a> {
event_table: &'a EventTable,
offset: usize,
size: usize,
}
impl<'a> HasDaddr for EventBuf<'a> {
fn daddr(&self) -> Daddr {
self.event_table.stream.daddr() + self.offset
}
}
impl<'a> DmaBuf for EventBuf<'a> {
fn len(&self) -> usize {
self.size
}
}
impl<'a> EventBuf<'a> {
fn reader(&self) -> VmReader<'a> {
self.event_table
.stream
.sync(self.offset..self.offset + self.size)
.unwrap();
self.event_table
.stream
.reader()
.unwrap()
.skip(self.offset)
.limit(self.size)
}
}
impl aster_input::InputDevice for InputDevice { impl aster_input::InputDevice for InputDevice {
fn handle_irq(&self) -> Option<()> {
// one interrupt may contains serval input, so it should loop
loop {
let Some(event) = self.pop_pending_event() else {
return Some(());
};
match event.event_type {
0 => return Some(()),
// Keyboard
1 => {}
// TODO: Support mouse device.
_ => continue,
}
let status = match event.value {
1 => KeyStatus::Pressed,
0 => KeyStatus::Released,
_ => return Some(()),
};
let event = InputEvent::KeyBoard(Key::try_from(event.code).unwrap(), status);
info!("Input Event:{:?}", event);
let callbacks = self.callbacks.lock();
for callback in callbacks.iter() {
callback.call((event,));
}
}
}
fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) { fn register_callbacks(&self, function: &'static (dyn Fn(InputEvent) + Send + Sync)) {
self.callbacks.lock().push(Arc::new(function)) self.callbacks.lock_irq_disabled().push(Arc::new(function))
} }
} }
@ -228,7 +317,7 @@ impl Debug for InputDevice {
.field("config", &self.config) .field("config", &self.config)
.field("event_queue", &self.event_queue) .field("event_queue", &self.event_queue)
.field("status_queue", &self.status_queue) .field("status_queue", &self.status_queue)
.field("event_buf", &self.event_buf) .field("event_buf", &self.event_table)
.field("transport", &self.transport) .field("transport", &self.transport)
.finish() .finish()
} }

View File

@ -5,12 +5,10 @@ use core::{fmt::Debug, hint::spin_loop, mem::size_of};
use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame}; use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame};
use aster_network::{ use aster_network::{
buffer::{RxBuffer, TxBuffer}, AnyNetworkDevice, EthernetAddr, NetDeviceIrqHandler, RxBuffer, TxBuffer, VirtioNetError,
AnyNetworkDevice, EthernetAddr, NetDeviceIrqHandler, VirtioNetError,
}; };
use aster_util::{field_ptr, slot_vec::SlotVec}; use aster_util::{field_ptr, slot_vec::SlotVec};
use log::debug; use log::debug;
use pod::Pod;
use smoltcp::phy::{DeviceCapabilities, Medium}; use smoltcp::phy::{DeviceCapabilities, Medium};
use super::{config::VirtioNetConfig, header::VirtioNetHdr}; use super::{config::VirtioNetConfig, header::VirtioNetHdr};
@ -60,9 +58,9 @@ impl NetworkDevice {
let mut rx_buffers = SlotVec::new(); let mut rx_buffers = SlotVec::new();
for i in 0..QUEUE_SIZE { for i in 0..QUEUE_SIZE {
let mut rx_buffer = RxBuffer::new(RX_BUFFER_LEN, size_of::<VirtioNetHdr>()); let rx_buffer = RxBuffer::new(size_of::<VirtioNetHdr>());
// FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping. // FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping.
let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?; let token = recv_queue.add_dma_buf(&[], &[&rx_buffer])?;
assert_eq!(i, token); assert_eq!(i, token);
assert_eq!(rx_buffers.put(rx_buffer) as u16, i); assert_eq!(rx_buffers.put(rx_buffer) as u16, i);
} }
@ -80,7 +78,7 @@ impl NetworkDevice {
transport, transport,
callbacks: Vec::new(), callbacks: Vec::new(),
}; };
device.transport.finish_init();
/// Interrupt handler if network device config space changes /// Interrupt handler if network device config space changes
fn config_space_change(_: &TrapFrame) { fn config_space_change(_: &TrapFrame) {
debug!("network device config space change"); debug!("network device config space change");
@ -99,6 +97,7 @@ impl NetworkDevice {
.transport .transport
.register_queue_callback(QUEUE_RECV, Box::new(handle_network_event), false) .register_queue_callback(QUEUE_RECV, Box::new(handle_network_event), false)
.unwrap(); .unwrap();
device.transport.finish_init();
aster_network::register_device( aster_network::register_device(
super::DEVICE_NAME.to_string(), super::DEVICE_NAME.to_string(),
@ -109,10 +108,10 @@ impl NetworkDevice {
/// Add a rx buffer to recv queue /// Add a rx buffer to recv queue
/// FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping. /// FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping.
fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer) -> Result<(), VirtioNetError> { fn add_rx_buffer(&mut self, rx_buffer: RxBuffer) -> Result<(), VirtioNetError> {
let token = self let token = self
.recv_queue .recv_queue
.add_buf(&[], &[rx_buffer.buf_mut()]) .add_dma_buf(&[], &[&rx_buffer])
.map_err(queue_to_network_error)?; .map_err(queue_to_network_error)?;
assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none()); assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none());
if self.recv_queue.should_notify() { if self.recv_queue.should_notify() {
@ -133,18 +132,20 @@ impl NetworkDevice {
rx_buffer.set_packet_len(len as usize); rx_buffer.set_packet_len(len as usize);
// FIXME: Ideally, we can reuse the returned buffer without creating new buffer. // FIXME: Ideally, we can reuse the returned buffer without creating new buffer.
// But this requires locking device to be compatible with smoltcp interface. // But this requires locking device to be compatible with smoltcp interface.
let new_rx_buffer = RxBuffer::new(RX_BUFFER_LEN, size_of::<VirtioNetHdr>()); let new_rx_buffer = RxBuffer::new(size_of::<VirtioNetHdr>());
self.add_rx_buffer(new_rx_buffer)?; self.add_rx_buffer(new_rx_buffer)?;
Ok(rx_buffer) Ok(rx_buffer)
} }
/// Send a packet to network. Return until the request completes. /// Send a packet to network. Return until the request completes.
/// FIEME: Replace tx_buffer with VM segment-based data structure to use dma mapping. /// FIEME: Replace tx_buffer with VM segment-based data structure to use dma mapping.
fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError> { fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError> {
let header = VirtioNetHdr::default(); let header = VirtioNetHdr::default();
let tx_buffer = TxBuffer::new(&header, packet);
let token = self let token = self
.send_queue .send_queue
.add_buf(&[header.as_bytes(), tx_buffer.buf()], &[]) .add_dma_buf(&[&tx_buffer], &[])
.map_err(queue_to_network_error)?; .map_err(queue_to_network_error)?;
if self.send_queue.should_notify() { if self.send_queue.should_notify() {
@ -198,8 +199,8 @@ impl AnyNetworkDevice for NetworkDevice {
self.receive() self.receive()
} }
fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError> { fn send(&mut self, packet: &[u8]) -> Result<(), VirtioNetError> {
self.send(tx_buffer) self.send(packet)
} }
} }

View File

@ -1,6 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_frame::vm::{DmaCoherent, DmaStream, DmaStreamSlice, HasDaddr}; use aster_frame::vm::{DmaCoherent, DmaStream, DmaStreamSlice, HasDaddr};
use aster_network::{DmaSegment, RxBuffer, TxBuffer};
/// A DMA-capable buffer. /// A DMA-capable buffer.
/// ///
@ -29,3 +30,21 @@ impl DmaBuf for DmaCoherent {
self.nbytes() self.nbytes()
} }
} }
impl DmaBuf for DmaSegment {
fn len(&self) -> usize {
self.size()
}
}
impl DmaBuf for TxBuffer {
fn len(&self) -> usize {
self.nbytes()
}
}
impl DmaBuf for RxBuffer {
fn len(&self) -> usize {
self.buf_len()
}
}

View File

@ -82,10 +82,7 @@ impl VirtQueue {
let desc_size = size_of::<Descriptor>() * size as usize; let desc_size = size_of::<Descriptor>() * size as usize;
let (seg1, seg2) = { let (seg1, seg2) = {
let continue_segment = VmAllocOptions::new(2) let continue_segment = VmAllocOptions::new(2).alloc_contiguous().unwrap();
.is_contiguous(true)
.alloc_contiguous()
.unwrap();
let seg1 = continue_segment.range(0..1); let seg1 = continue_segment.range(0..1);
let seg2 = continue_segment.range(1..2); let seg2 = continue_segment.range(1..2);
(seg1, seg2) (seg1, seg2)
@ -104,35 +101,17 @@ impl VirtQueue {
} }
( (
SafePtr::new( SafePtr::new(
DmaCoherent::map( DmaCoherent::map(VmAllocOptions::new(1).alloc_contiguous().unwrap(), true)
VmAllocOptions::new(1)
.is_contiguous(true)
.alloc_contiguous()
.unwrap(),
true,
)
.unwrap(), .unwrap(),
0, 0,
), ),
SafePtr::new( SafePtr::new(
DmaCoherent::map( DmaCoherent::map(VmAllocOptions::new(1).alloc_contiguous().unwrap(), true)
VmAllocOptions::new(1)
.is_contiguous(true)
.alloc_contiguous()
.unwrap(),
true,
)
.unwrap(), .unwrap(),
0, 0,
), ),
SafePtr::new( SafePtr::new(
DmaCoherent::map( DmaCoherent::map(VmAllocOptions::new(1).alloc_contiguous().unwrap(), true)
VmAllocOptions::new(1)
.is_contiguous(true)
.alloc_contiguous()
.unwrap(),
true,
)
.unwrap(), .unwrap(),
0, 0,
), ),