diff --git a/framework/jinux-frame/src/vm/dma/dma_coherent.rs b/framework/jinux-frame/src/vm/dma/dma_coherent.rs index 8c6cf6544..872de9788 100644 --- a/framework/jinux-frame/src/vm/dma/dma_coherent.rs +++ b/framework/jinux-frame/src/vm/dma/dma_coherent.rs @@ -158,3 +158,82 @@ impl HasPaddr for DmaCoherent { self.inner.vm_segment.start_paddr() } } + +#[if_cfg_ktest] +mod test { + use super::*; + use crate::vm::VmAllocOptions; + use alloc::vec; + + #[ktest] + fn map_with_coherent_device() { + let vm_segment = VmAllocOptions::new(1) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_coherent = DmaCoherent::map(vm_segment.clone(), true).unwrap(); + assert!(dma_coherent.paddr() == vm_segment.paddr()); + } + + #[ktest] + fn map_with_incoherent_device() { + let vm_segment = VmAllocOptions::new(1) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_coherent = DmaCoherent::map(vm_segment.clone(), false).unwrap(); + assert!(dma_coherent.paddr() == vm_segment.paddr()); + let mut page_table = KERNEL_PAGE_TABLE.get().unwrap().lock(); + assert!(page_table + .flags(paddr_to_vaddr(vm_segment.paddr())) + .unwrap() + .contains(PageTableFlags::NO_CACHE)) + } + + #[ktest] + fn duplicate_map() { + let vm_segment_parent = VmAllocOptions::new(2) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let vm_segment_child = vm_segment_parent.range(0..1); + let dma_coherent_parent = DmaCoherent::map(vm_segment_parent, false); + let dma_coherent_child = DmaCoherent::map(vm_segment_child, false); + assert!(dma_coherent_child.is_err()); + } + + #[ktest] + fn read_and_write() { + let vm_segment = VmAllocOptions::new(2) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_coherent = DmaCoherent::map(vm_segment, false).unwrap(); + + let buf_write = vec![1u8; 2 * PAGE_SIZE]; + dma_coherent.write_bytes(0, &buf_write).unwrap(); + let mut buf_read = vec![0u8; 2 * PAGE_SIZE]; + dma_coherent.read_bytes(0, &mut buf_read).unwrap(); + assert_eq!(buf_write, buf_read); + } + + #[ktest] + fn reader_and_wirter() { + let vm_segment = VmAllocOptions::new(2) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_coherent = DmaCoherent::map(vm_segment, false).unwrap(); + + let buf_write = vec![1u8; PAGE_SIZE]; + let mut writer = dma_coherent.writer(); + writer.write(&mut buf_write.as_slice().into()); + writer.write(&mut buf_write.as_slice().into()); + + let mut buf_read = vec![0u8; 2 * PAGE_SIZE]; + let buf_write = vec![1u8; 2 * PAGE_SIZE]; + let mut reader = dma_coherent.reader(); + reader.read(&mut buf_read.as_mut_slice().into()); + assert_eq!(buf_read, buf_write); + } +} diff --git a/framework/jinux-frame/src/vm/dma/dma_stream.rs b/framework/jinux-frame/src/vm/dma/dma_stream.rs new file mode 100644 index 000000000..f78aadfa5 --- /dev/null +++ b/framework/jinux-frame/src/vm/dma/dma_stream.rs @@ -0,0 +1,262 @@ +use alloc::sync::Arc; +use core::arch::x86_64::_mm_clflush; +use core::ops::Range; + +use crate::arch::iommu; +use crate::error::Error; +use crate::vm::{ + dma::{dma_type, Daddr, DmaType}, + HasPaddr, Paddr, VmSegment, PAGE_SIZE, +}; +use crate::vm::{VmIo, VmReader, VmWriter}; + +use super::{check_and_insert_dma_mapping, remove_dma_mapping, DmaError, HasDaddr}; + +/// A streaming DMA mapping. Users must synchronize data +/// before reading or after writing to ensure consistency. +/// +/// The mapping is automatically destroyed when this object +/// is dropped. +#[derive(Debug, Clone)] +pub struct DmaStream { + inner: Arc, +} + +#[derive(Debug)] +struct DmaStreamInner { + vm_segment: VmSegment, + start_daddr: Daddr, + is_cache_coherent: bool, + direction: DmaDirection, +} + +/// `DmaDirection` limits the data flow direction of `DmaStream` and +/// prevents users from reading and writing to `DmaStream` unexpectedly. +#[derive(Debug, PartialEq, Clone)] +pub enum DmaDirection { + ToDevice, + FromDevice, + Bidirectional, +} + +impl DmaStream { + /// Establish DMA stream mapping for a given `VmSegment`. + /// + /// The method fails if the segment already belongs to a DMA mapping. + pub fn map( + vm_segment: VmSegment, + direction: DmaDirection, + is_cache_coherent: bool, + ) -> Result { + let frame_count = vm_segment.nframes(); + let start_paddr = vm_segment.start_paddr(); + if !check_and_insert_dma_mapping(start_paddr, frame_count) { + return Err(DmaError::AlreadyMapped); + } + let start_daddr = match dma_type() { + DmaType::Direct => start_paddr as Daddr, + DmaType::Iommu => { + for i in 0..frame_count { + let paddr = start_paddr + (i * PAGE_SIZE); + // Safety: the `paddr` is restricted by the `start_paddr` and `frame_count` of the `vm_segment`. + unsafe { + iommu::map(paddr as Daddr, paddr).unwrap(); + } + } + start_paddr as Daddr + } + DmaType::Tdx => { + todo!() + } + }; + + Ok(Self { + inner: Arc::new(DmaStreamInner { + vm_segment, + start_daddr, + is_cache_coherent, + direction, + }), + }) + } + + /// Get the underlying VM segment. + /// + /// Usually, the CPU side should not access the memory + /// after the DMA mapping is established because + /// there is a chance that the device is updating + /// the memory. Do this at your own risk. + pub fn vm_segment(&self) -> &VmSegment { + &self.inner.vm_segment + } + + pub fn nbytes(&self) -> usize { + self.inner.vm_segment.nbytes() + } + + /// Synchronize the streaming DMA mapping with the device. + /// + /// This method should be called under one of the two conditions: + /// 1. The data of the stream DMA mapping has been updated by the device side. + /// The CPU side needs to call the `sync` method before reading data (e.g., using `read_bytes`). + /// 2. The data of the stream DMA mapping has been updated by the CPU side + /// (e.g., using `write_bytes`). + /// Before the CPU side notifies the device side to read, it must call the `sync` method first. + pub fn sync(&self, byte_range: Range) -> Result<(), Error> { + if byte_range.end > self.nbytes() { + return Err(Error::InvalidArgs); + } + if self.inner.is_cache_coherent { + return Ok(()); + } + if dma_type() == DmaType::Tdx { + // copy pages. + todo!("support dma for tdx") + } else { + let start_va = self.inner.vm_segment.as_ptr(); + // TODO: Query the CPU for the cache line size via CPUID, we use 64 bytes as the cache line size here. + for i in byte_range.step_by(64) { + // Safety: the addresses is limited by a valid `byte_range`. + unsafe { + _mm_clflush(start_va.wrapping_add(i)); + } + } + Ok(()) + } + } +} + +impl HasDaddr for DmaStream { + fn daddr(&self) -> Daddr { + self.inner.start_daddr + } +} + +impl Drop for DmaStreamInner { + fn drop(&mut self) { + let frame_count = self.vm_segment.nframes(); + let start_paddr = self.vm_segment.start_paddr(); + match dma_type() { + DmaType::Direct => {} + DmaType::Iommu => { + for i in 0..frame_count { + let paddr = start_paddr + (i * PAGE_SIZE); + iommu::unmap(paddr).unwrap(); + } + } + DmaType::Tdx => { + todo!(); + } + } + remove_dma_mapping(start_paddr, frame_count); + } +} + +impl VmIo for DmaStream { + /// Read data into the buffer. + fn read_bytes(&self, offset: usize, buf: &mut [u8]) -> Result<(), Error> { + if self.inner.direction == DmaDirection::ToDevice { + return Err(Error::AccessDenied); + } + self.inner.vm_segment.read_bytes(offset, buf) + } + + /// Write data from the buffer. + fn write_bytes(&self, offset: usize, buf: &[u8]) -> Result<(), Error> { + if self.inner.direction == DmaDirection::FromDevice { + return Err(Error::AccessDenied); + } + self.inner.vm_segment.write_bytes(offset, buf) + } +} + +impl<'a> DmaStream { + /// Returns a reader to read data from it. + pub fn reader(&'a self) -> Result, Error> { + if self.inner.direction == DmaDirection::ToDevice { + return Err(Error::AccessDenied); + } + Ok(self.inner.vm_segment.reader()) + } + + /// Returns a writer to write data into it. + pub fn writer(&'a self) -> Result, Error> { + if self.inner.direction == DmaDirection::FromDevice { + return Err(Error::AccessDenied); + } + Ok(self.inner.vm_segment.writer()) + } +} + +impl HasPaddr for DmaStream { + fn paddr(&self) -> Paddr { + self.inner.vm_segment.start_paddr() + } +} + +#[if_cfg_ktest] +mod test { + use super::*; + use crate::vm::VmAllocOptions; + use alloc::vec; + + #[ktest] + fn streaming_map() { + let vm_segment = VmAllocOptions::new(1) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_stream = + DmaStream::map(vm_segment.clone(), DmaDirection::Bidirectional, true).unwrap(); + assert!(dma_stream.paddr() == vm_segment.paddr()); + } + + #[ktest] + fn duplicate_map() { + let vm_segment_parent = VmAllocOptions::new(2) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let vm_segment_child = vm_segment_parent.range(0..1); + let dma_stream_parent = + DmaStream::map(vm_segment_parent, DmaDirection::Bidirectional, false); + let dma_stream_child = DmaStream::map(vm_segment_child, DmaDirection::Bidirectional, false); + assert!(dma_stream_child.is_err()); + } + + #[ktest] + fn read_and_write() { + let vm_segment = VmAllocOptions::new(2) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_stream = DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap(); + + let buf_write = vec![1u8; 2 * PAGE_SIZE]; + dma_stream.write_bytes(0, &buf_write).unwrap(); + dma_stream.sync(0..2 * PAGE_SIZE).unwrap(); + let mut buf_read = vec![0u8; 2 * PAGE_SIZE]; + dma_stream.read_bytes(0, &mut buf_read).unwrap(); + assert_eq!(buf_write, buf_read); + } + + #[ktest] + fn reader_and_wirter() { + let vm_segment = VmAllocOptions::new(2) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + let dma_stream = DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap(); + + let buf_write = vec![1u8; PAGE_SIZE]; + let mut writer = dma_stream.writer().unwrap(); + writer.write(&mut buf_write.as_slice().into()); + writer.write(&mut buf_write.as_slice().into()); + dma_stream.sync(0..2 * PAGE_SIZE).unwrap(); + let mut buf_read = vec![0u8; 2 * PAGE_SIZE]; + let buf_write = vec![1u8; 2 * PAGE_SIZE]; + let mut reader = dma_stream.reader().unwrap(); + reader.read(&mut buf_read.as_mut_slice().into()); + assert_eq!(buf_read, buf_write); + } +} diff --git a/framework/jinux-frame/src/vm/dma/mod.rs b/framework/jinux-frame/src/vm/dma/mod.rs index a7d86fa84..4f13e4055 100644 --- a/framework/jinux-frame/src/vm/dma/mod.rs +++ b/framework/jinux-frame/src/vm/dma/mod.rs @@ -1,4 +1,5 @@ mod dma_coherent; +mod dma_stream; use alloc::collections::BTreeSet; use spin::Once; @@ -8,6 +9,7 @@ use crate::{arch::iommu::has_iommu, config::PAGE_SIZE, sync::SpinLock}; use super::Paddr; pub use dma_coherent::DmaCoherent; +pub use dma_stream::{DmaDirection, DmaStream}; /// If a device performs DMA to read or write system /// memory, the addresses used by the device are device addresses. diff --git a/framework/jinux-frame/src/vm/mod.rs b/framework/jinux-frame/src/vm/mod.rs index baad9f500..4a86befa2 100644 --- a/framework/jinux-frame/src/vm/mod.rs +++ b/framework/jinux-frame/src/vm/mod.rs @@ -19,7 +19,7 @@ mod space; use crate::config::{KERNEL_OFFSET, PAGE_SIZE, PHYS_OFFSET}; -pub use self::dma::{DmaCoherent, HasDaddr}; +pub use self::dma::{DmaCoherent, DmaDirection, DmaStream, HasDaddr}; pub use self::frame::{VmFrame, VmFrameVec, VmFrameVecIter, VmReader, VmSegment, VmWriter}; pub use self::io::VmIo; pub use self::options::VmAllocOptions; diff --git a/services/comps/virtio/src/device/block/device.rs b/services/comps/virtio/src/device/block/device.rs index 9b9773a95..925a71aba 100644 --- a/services/comps/virtio/src/device/block/device.rs +++ b/services/comps/virtio/src/device/block/device.rs @@ -24,6 +24,7 @@ pub struct BlockDevice { impl BlockDevice { /// read data from block device, this function is blocking + /// FIEME: replace slice with a more secure data structure to use dma mapping. pub fn read(&self, block_id: usize, buf: &mut [u8]) { assert_eq!(buf.len(), BLK_SIZE); let req = BlkReq { @@ -47,6 +48,7 @@ impl BlockDevice { }; } /// write data to block device, this function is blocking + /// FIEME: replace slice with a more secure data structure to use dma mapping. pub fn write(&self, block_id: usize, buf: &[u8]) { assert_eq!(buf.len(), BLK_SIZE); let req = BlkReq { diff --git a/services/comps/virtio/src/device/input/device.rs b/services/comps/virtio/src/device/input/device.rs index 70ecb2e59..dc5b48c16 100644 --- a/services/comps/virtio/src/device/input/device.rs +++ b/services/comps/virtio/src/device/input/device.rs @@ -79,6 +79,7 @@ impl InputDevice { .expect("create status virtqueue failed"); for (i, event) in event_buf.as_mut().iter_mut().enumerate() { + // FIEME: replace slice with a more secure data structure to use dma mapping. let token = event_queue.add(&[], &[event.as_bytes_mut()]); match token { Ok(value) => { @@ -144,6 +145,7 @@ impl InputDevice { } let event = &mut self.event_buf.lock()[token as usize]; // requeue + // FIEME: replace slice with a more secure data structure to use dma mapping. if let Ok(new_token) = lock.add(&[], &[event.as_bytes_mut()]) { // This only works because nothing happen between `pop_used` and `add` that affects // the list of free descriptors in the queue, so `add` reuses the descriptor which diff --git a/services/comps/virtio/src/device/network/device.rs b/services/comps/virtio/src/device/network/device.rs index b4b320c8f..3faecadc1 100644 --- a/services/comps/virtio/src/device/network/device.rs +++ b/services/comps/virtio/src/device/network/device.rs @@ -60,6 +60,7 @@ impl NetworkDevice { let mut rx_buffers = SlotVec::new(); for i in 0..QUEUE_SIZE { let mut rx_buffer = RxBuffer::new(RX_BUFFER_LEN, size_of::()); + // FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping. let token = recv_queue.add(&[], &[rx_buffer.buf_mut()])?; assert_eq!(i, token); assert_eq!(rx_buffers.put(rx_buffer) as u16, i); @@ -106,6 +107,7 @@ impl NetworkDevice { } /// Add a rx buffer to recv queue + /// FIEME: Replace rx_buffer with VM segment-based data structure to use dma mapping. fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer) -> Result<(), VirtioNetError> { let token = self .recv_queue @@ -136,6 +138,7 @@ impl NetworkDevice { } /// Send a packet to network. Return until the request completes. + /// FIEME: Replace tx_buffer with VM segment-based data structure to use dma mapping. fn send(&mut self, tx_buffer: TxBuffer) -> Result<(), VirtioNetError> { let header = VirtioNetHdr::default(); let token = self