diff --git a/framework/aster-frame/src/vm/dma/dma_stream.rs b/framework/aster-frame/src/vm/dma/dma_stream.rs index cee09336..b348847b 100644 --- a/framework/aster-frame/src/vm/dma/dma_stream.rs +++ b/framework/aster-frame/src/vm/dma/dma_stream.rs @@ -38,7 +38,7 @@ struct DmaStreamInner { /// `DmaDirection` limits the data flow direction of `DmaStream` and /// prevents users from reading and writing to `DmaStream` unexpectedly. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum DmaDirection { ToDevice, FromDevice, @@ -108,6 +108,10 @@ impl DmaStream { &self.inner.vm_segment } + pub fn nframes(&self) -> usize { + self.inner.vm_segment.nframes() + } + pub fn nbytes(&self) -> usize { self.inner.vm_segment.nbytes() } diff --git a/framework/aster-frame/src/vm/frame.rs b/framework/aster-frame/src/vm/frame.rs index 23f10c0d..983e148f 100644 --- a/framework/aster-frame/src/vm/frame.rs +++ b/framework/aster-frame/src/vm/frame.rs @@ -595,6 +595,21 @@ impl<'a> VmReader<'a> { } copy_len } + + /// Read a value of `Pod` type. + /// + /// # Panic + /// + /// If the length of the `Pod` type exceeds `self.remain()`, then this method will panic. + pub fn read_val(&mut self) -> T { + assert!(self.remain() >= core::mem::size_of::()); + + let mut val = T::new_uninit(); + let mut writer = VmWriter::from(val.as_bytes_mut()); + let read_len = self.read(&mut writer); + + val + } } impl<'a> From<&'a [u8]> for VmReader<'a> { diff --git a/framework/aster-frame/src/vm/mod.rs b/framework/aster-frame/src/vm/mod.rs index 2c14639c..c3ef6455 100644 --- a/framework/aster-frame/src/vm/mod.rs +++ b/framework/aster-frame/src/vm/mod.rs @@ -25,7 +25,7 @@ use core::ops::Range; use spin::Once; pub use self::{ - dma::{DmaCoherent, DmaDirection, DmaStream, HasDaddr}, + dma::{Daddr, DmaCoherent, DmaDirection, DmaStream, HasDaddr}, frame::{VmFrame, VmFrameVec, VmFrameVecIter, VmReader, VmSegment, VmWriter}, io::VmIo, memory_set::{MapArea, MemorySet}, diff --git a/kernel/comps/virtio/src/dma_buf.rs b/kernel/comps/virtio/src/dma_buf.rs new file mode 100644 index 00000000..0395cd90 --- /dev/null +++ b/kernel/comps/virtio/src/dma_buf.rs @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MPL-2.0 + +use aster_frame::vm::{DmaCoherent, DmaStream, HasDaddr}; + +/// A DMA-capable buffer. +/// +/// Any type implements this trait should also implements `HasDaddr` trait, +/// and provides the exact length of DMA area. +#[allow(clippy::len_without_is_empty)] +pub trait DmaBuf: HasDaddr { + /// The length of Dma area, in bytes + fn len(&self) -> usize; +} + +impl DmaBuf for DmaStream { + fn len(&self) -> usize { + self.nbytes() + } +} + +impl DmaBuf for DmaCoherent { + fn len(&self) -> usize { + self.nbytes() + } +} diff --git a/kernel/comps/virtio/src/lib.rs b/kernel/comps/virtio/src/lib.rs index 09fc483b..b2faa7b6 100644 --- a/kernel/comps/virtio/src/lib.rs +++ b/kernel/comps/virtio/src/lib.rs @@ -22,6 +22,7 @@ use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus}; use crate::transport::VirtioTransport; pub mod device; +mod dma_buf; pub mod queue; mod transport; diff --git a/kernel/comps/virtio/src/queue.rs b/kernel/comps/virtio/src/queue.rs index eef4d14e..40e34496 100644 --- a/kernel/comps/virtio/src/queue.rs +++ b/kernel/comps/virtio/src/queue.rs @@ -19,7 +19,7 @@ use bitflags::bitflags; use log::debug; use pod::Pod; -use crate::transport::VirtioTransport; +use crate::{dma_buf::DmaBuf, transport::VirtioTransport}; #[derive(Debug)] pub enum QueueError { @@ -176,6 +176,76 @@ impl VirtQueue { }) } + /// Add dma buffers to the virtqueue, return a token. + /// + /// Ref: linux virtio_ring.c virtqueue_add + pub fn add_dma_buf( + &mut self, + inputs: &[&T], + outputs: &[&T], + ) -> Result { + if inputs.is_empty() && outputs.is_empty() { + return Err(QueueError::InvalidArgs); + } + if inputs.len() + outputs.len() + self.num_used as usize > self.queue_size as usize { + return Err(QueueError::BufferTooSmall); + } + + // allocate descriptors from free list + let head = self.free_head; + let mut last = self.free_head; + for input in inputs.iter() { + let desc = &self.descs[self.free_head as usize]; + set_dma_buf(&desc.borrow_vm().restrict::(), *input); + field_ptr!(desc, Descriptor, flags) + .write(&DescFlags::NEXT) + .unwrap(); + last = self.free_head; + self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); + } + for output in outputs.iter() { + let desc = &mut self.descs[self.free_head as usize]; + set_dma_buf( + &desc.borrow_vm().restrict::(), + *output, + ); + field_ptr!(desc, Descriptor, flags) + .write(&(DescFlags::NEXT | DescFlags::WRITE)) + .unwrap(); + last = self.free_head; + self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); + } + // set last_elem.next = NULL + { + let desc = &mut self.descs[last as usize]; + let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap(); + flags.remove(DescFlags::NEXT); + field_ptr!(desc, Descriptor, flags).write(&flags).unwrap(); + } + self.num_used += (inputs.len() + outputs.len()) as u16; + + let avail_slot = self.avail_idx & (self.queue_size - 1); + + { + let ring_ptr: SafePtr<[u16; 64], &DmaCoherent> = + field_ptr!(&self.avail, AvailRing, ring); + let mut ring_slot_ptr = ring_ptr.cast::(); + ring_slot_ptr.add(avail_slot as usize); + ring_slot_ptr.write(&head).unwrap(); + } + // write barrier + fence(Ordering::SeqCst); + + // increase head of avail ring + self.avail_idx = self.avail_idx.wrapping_add(1); + field_ptr!(&self.avail, AvailRing, idx) + .write(&self.avail_idx) + .unwrap(); + + fence(Ordering::SeqCst); + Ok(head) + } + /// Add buffers to the virtqueue, return a token. **This function will be removed in the future.** /// /// Ref: linux virtio_ring.c virtqueue_add @@ -439,6 +509,17 @@ pub struct Descriptor { type DescriptorPtr<'a> = SafePtr>; +#[inline] +fn set_dma_buf(desc_ptr: &DescriptorPtr, buf: &T) { + let daddr = buf.daddr(); + field_ptr!(desc_ptr, Descriptor, addr) + .write(&(daddr as u64)) + .unwrap(); + field_ptr!(desc_ptr, Descriptor, len) + .write(&(buf.len() as u32)) + .unwrap(); +} + #[inline] #[allow(clippy::type_complexity)] fn set_buf_slice(desc_ptr: &DescriptorPtr, buf: &[u8]) {