diff --git a/framework/aster-frame/src/vm/dma/dma_stream.rs b/framework/aster-frame/src/vm/dma/dma_stream.rs index b348847b..1301b168 100644 --- a/framework/aster-frame/src/vm/dma/dma_stream.rs +++ b/framework/aster-frame/src/vm/dma/dma_stream.rs @@ -222,6 +222,73 @@ impl HasPaddr for DmaStream { } } +/// A slice of streaming DMA mapping. +#[derive(Debug)] +pub struct DmaStreamSlice<'a> { + stream: &'a DmaStream, + offset: usize, + len: usize, +} + +impl<'a> DmaStreamSlice<'a> { + /// Constructs a `DmaStreamSlice` from the `DmaStream`. + /// + /// # Panic + /// + /// If the `offset` is greater than or equal to the length of the stream, + /// this method will panic. + /// If the `offset + len` is greater than the length of the stream, + /// this method will panic. + pub fn new(stream: &'a DmaStream, offset: usize, len: usize) -> Self { + assert!(offset < stream.nbytes()); + assert!(offset + len <= stream.nbytes()); + + Self { + stream, + offset, + len, + } + } + + /// Returns the number of bytes. + pub fn nbytes(&self) -> usize { + self.len + } + + /// Synchronizes the slice of streaming DMA mapping with the device. + pub fn sync(&self) -> Result<(), Error> { + self.stream.sync(self.offset..self.offset + self.len) + } +} + +impl VmIo for DmaStreamSlice<'_> { + fn read_bytes(&self, offset: usize, buf: &mut [u8]) -> Result<(), Error> { + if buf.len() + offset > self.len { + return Err(Error::InvalidArgs); + } + self.stream.read_bytes(self.offset + offset, buf) + } + + fn write_bytes(&self, offset: usize, buf: &[u8]) -> Result<(), Error> { + if buf.len() + offset > self.len { + return Err(Error::InvalidArgs); + } + self.stream.write_bytes(self.offset + offset, buf) + } +} + +impl HasDaddr for DmaStreamSlice<'_> { + fn daddr(&self) -> Daddr { + self.stream.daddr() + self.offset + } +} + +impl HasPaddr for DmaStreamSlice<'_> { + fn paddr(&self) -> Paddr { + self.stream.paddr() + self.offset + } +} + #[cfg(ktest)] mod test { use alloc::vec; diff --git a/framework/aster-frame/src/vm/dma/mod.rs b/framework/aster-frame/src/vm/dma/mod.rs index ba456c09..f38fd7c0 100644 --- a/framework/aster-frame/src/vm/dma/mod.rs +++ b/framework/aster-frame/src/vm/dma/mod.rs @@ -6,7 +6,7 @@ mod dma_stream; use alloc::collections::BTreeSet; pub use dma_coherent::DmaCoherent; -pub use dma_stream::{DmaDirection, DmaStream}; +pub use dma_stream::{DmaDirection, DmaStream, DmaStreamSlice}; use spin::Once; use super::Paddr; diff --git a/framework/aster-frame/src/vm/mod.rs b/framework/aster-frame/src/vm/mod.rs index c3ef6455..2e7b7d1a 100644 --- a/framework/aster-frame/src/vm/mod.rs +++ b/framework/aster-frame/src/vm/mod.rs @@ -25,7 +25,7 @@ use core::ops::Range; use spin::Once; pub use self::{ - dma::{Daddr, DmaCoherent, DmaDirection, DmaStream, HasDaddr}, + dma::{Daddr, DmaCoherent, DmaDirection, DmaStream, DmaStreamSlice, HasDaddr}, frame::{VmFrame, VmFrameVec, VmFrameVecIter, VmReader, VmSegment, VmWriter}, io::VmIo, memory_set::{MapArea, MemorySet}, diff --git a/kernel/comps/virtio/src/device/block/device.rs b/kernel/comps/virtio/src/device/block/device.rs index d442887e..02bb1496 100644 --- a/kernel/comps/virtio/src/device/block/device.rs +++ b/kernel/comps/virtio/src/device/block/device.rs @@ -12,7 +12,7 @@ use aster_frame::{ io_mem::IoMem, sync::SpinLock, trap::TrapFrame, - vm::{VmAllocOptions, VmFrame, VmIo, VmReader, VmWriter}, + vm::{DmaDirection, DmaStream, DmaStreamSlice, VmAllocOptions, VmIo}, }; use aster_util::safe_ptr::SafePtr; use log::info; @@ -67,42 +67,60 @@ impl BlockDevice { fn do_read(&self, request: &BioRequest) { let start_sid = request.sid_range().start; + let dma_streams: Vec<(DmaStream, usize, usize)> = request + .bios() + .flat_map(|bio| { + bio.segments().iter().map(|segment| { + let dma_stream = + DmaStream::map(segment.pages().clone(), DmaDirection::FromDevice, false) + .unwrap(); + (dma_stream, segment.offset(), segment.nbytes()) + }) + }) + .collect(); + let dma_slices: Vec = dma_streams + .iter() + .map(|(stream, offset, len)| DmaStreamSlice::new(stream, *offset, *len)) + .collect(); - let writers = { - let mut writers = Vec::new(); - for bio in request.bios() { - for segment in bio.segments() { - writers.push(segment.writer()); - } - } - writers - }; + self.device.read(start_sid, &dma_slices); - self.device.read(start_sid, writers.as_slice()); + dma_slices.iter().for_each(|dma_slice| { + dma_slice.sync().unwrap(); + }); + drop(dma_slices); + drop(dma_streams); - for bio in request.bios() { + request.bios().for_each(|bio| { bio.complete(BioStatus::Complete); - } + }); } fn do_write(&self, request: &BioRequest) { let start_sid = request.sid_range().start; + let dma_streams: Vec<(DmaStream, usize, usize)> = request + .bios() + .flat_map(|bio| { + bio.segments().iter().map(|segment| { + let dma_stream = + DmaStream::map(segment.pages().clone(), DmaDirection::ToDevice, false) + .unwrap(); + (dma_stream, segment.offset(), segment.nbytes()) + }) + }) + .collect(); + let dma_slices: Vec = dma_streams + .iter() + .map(|(stream, offset, len)| DmaStreamSlice::new(stream, *offset, *len)) + .collect(); - let readers = { - let mut readers = Vec::new(); - for bio in request.bios() { - for segment in bio.segments() { - readers.push(segment.reader()); - } - } - readers - }; + self.device.write(start_sid, &dma_slices); + drop(dma_slices); + drop(dma_streams); - self.device.write(start_sid, readers.as_slice()); - - for bio in request.bios() { + request.bios().for_each(|bio| { bio.complete(BioStatus::Complete); - } + }); } /// Negotiate features for the device specified bits 0~23 @@ -128,12 +146,8 @@ struct DeviceInner { config: SafePtr, queue: SpinLock, transport: Box, - /// Block requests, we use VmFrame to store the requests so that - /// it can pass to the `add_vm` function - block_requests: VmFrame, - /// Block responses, we use VmFrame to store the requests so that - /// it can pass to the `add_vm` function - block_responses: VmFrame, + block_requests: DmaStream, + block_responses: DmaStream, id_allocator: SpinLock>, device_id: Option, } @@ -148,12 +162,26 @@ impl DeviceInner { } let queue = VirtQueue::new(0, 64, transport.as_mut()).expect("create virtqueue failed"); + let block_requests = { + let vm_segment = VmAllocOptions::new(1) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap() + }; + let block_responses = { + let vm_segment = VmAllocOptions::new(1) + .is_contiguous(true) + .alloc_contiguous() + .unwrap(); + DmaStream::map(vm_segment, DmaDirection::Bidirectional, false).unwrap() + }; let mut device = Self { config, queue: SpinLock::new(queue), transport, - block_requests: VmAllocOptions::new(1).alloc_single().unwrap(), - block_responses: VmAllocOptions::new(1).alloc_single().unwrap(), + block_requests, + block_responses, id_allocator: SpinLock::new((0..64).collect()), device_id: None, }; @@ -190,30 +218,38 @@ impl DeviceInner { // TODO: Should return an Err instead of panic if the device fails. fn get_id(&self) -> String { let id = self.id_allocator.lock().pop().unwrap() as usize; - let req = BlockReq { - type_: ReqType::GetId as _, - reserved: 0, - sector: 0, + let req_slice = { + let req_slice = DmaStreamSlice::new(&self.block_requests, id * REQ_SIZE, REQ_SIZE); + let req = BlockReq { + type_: ReqType::GetId as _, + reserved: 0, + sector: 0, + }; + req_slice.write_val(0, &req).unwrap(); + req_slice.sync().unwrap(); + req_slice }; - self.block_requests - .write_val(id * size_of::(), &req) - .unwrap(); - - let req_reader = self - .block_requests - .reader() - .skip(id * size_of::()) - .limit(size_of::()); - + let resp_slice = { + let resp_slice = DmaStreamSlice::new(&self.block_responses, id * RESP_SIZE, RESP_SIZE); + resp_slice.write_val(0, &BlockResp::default()).unwrap(); + resp_slice + }; const MAX_ID_LENGTH: usize = 20; - - let page = VmAllocOptions::new(1).uninit(true).alloc_single().unwrap(); - let writer = page.writer().limit(MAX_ID_LENGTH); + let device_id_stream = { + let segment = VmAllocOptions::new(1) + .is_contiguous(true) + .uninit(true) + .alloc_contiguous() + .unwrap(); + DmaStream::map(segment, DmaDirection::FromDevice, false).unwrap() + }; + let device_id_slice = DmaStreamSlice::new(&device_id_stream, 0, MAX_ID_LENGTH); + let outputs = vec![&device_id_slice, &resp_slice]; let mut queue = self.queue.lock_irq_disabled(); let token = queue - .add_vm(&[&req_reader], &[&writer]) + .add_dma_buf(&[&req_slice], outputs.as_slice()) .expect("add queue failed"); queue.notify(); while !queue.can_pop() { @@ -221,61 +257,69 @@ impl DeviceInner { } queue.pop_used_with_token(token).expect("pop used failed"); + resp_slice.sync().unwrap(); self.id_allocator.lock().push(id as u8); + let resp: BlockResp = resp_slice.read_val(0).unwrap(); + match RespStatus::try_from(resp.status).unwrap() { + RespStatus::Ok => {} + _ => panic!("io error in block device"), + }; - //Add an extra 0, so that the array must end with 0. - let mut device_id = vec![0; MAX_ID_LENGTH + 1]; - let _ = page.read_bytes(0, &mut device_id); - - device_id.resize(device_id.iter().position(|&x| x == 0).unwrap(), 0); + let device_id = { + device_id_slice.sync().unwrap(); + let mut device_id = vec![0u8; MAX_ID_LENGTH]; + let _ = device_id_slice.read_bytes(0, &mut device_id); + let len = device_id + .iter() + .position(|&b| b == 0) + .unwrap_or(MAX_ID_LENGTH); + device_id.truncate(len); + device_id + }; String::from_utf8(device_id).unwrap() - - //The device is not initialized yet, so the response must be not_ready. } /// Reads data from the block device, this function is blocking. - /// FIEME: replace slice with a more secure data structure to use dma mapping. - pub fn read(&self, sector_id: Sid, buf: &[VmWriter]) { + pub fn read(&self, sector_id: Sid, bufs: &[DmaStreamSlice]) { // FIXME: Handling cases without id. let id = self.id_allocator.lock().pop().unwrap() as usize; - let req = BlockReq { - type_: ReqType::In as _, - reserved: 0, - sector: sector_id.to_raw(), - }; - let resp = BlockResp::default(); - self.block_requests - .write_val(id * size_of::(), &req) - .unwrap(); - self.block_responses - .write_val(id * size_of::(), &resp) - .unwrap(); - let req_reader = self - .block_requests - .reader() - .skip(id * size_of::()) - .limit(size_of::()); - let resp_writer = self - .block_responses - .writer() - .skip(id * size_of::()) - .limit(size_of::()); - let mut outputs: Vec<&VmWriter<'_>> = buf.iter().collect(); - outputs.push(&resp_writer); + let req_slice = { + let req_slice = DmaStreamSlice::new(&self.block_requests, id * REQ_SIZE, REQ_SIZE); + let req = BlockReq { + type_: ReqType::In as _, + reserved: 0, + sector: sector_id.to_raw(), + }; + req_slice.write_val(0, &req).unwrap(); + req_slice.sync().unwrap(); + req_slice + }; + + let resp_slice = { + let resp_slice = DmaStreamSlice::new(&self.block_responses, id * RESP_SIZE, RESP_SIZE); + resp_slice.write_val(0, &BlockResp::default()).unwrap(); + resp_slice + }; + + let outputs = { + let mut outputs: Vec<&DmaStreamSlice> = bufs.iter().collect(); + outputs.push(&resp_slice); + outputs + }; + let mut queue = self.queue.lock_irq_disabled(); let token = queue - .add_vm(&[&req_reader], outputs.as_slice()) + .add_dma_buf(&[&req_slice], outputs.as_slice()) .expect("add queue failed"); queue.notify(); while !queue.can_pop() { spin_loop(); } queue.pop_used_with_token(token).expect("pop used failed"); - let resp: BlockResp = self - .block_responses - .read_val(id * size_of::()) - .unwrap(); + + resp_slice.sync().unwrap(); + let resp: BlockResp = resp_slice.read_val(0).unwrap(); self.id_allocator.lock().push(id as u8); match RespStatus::try_from(resp.status).unwrap() { RespStatus::Ok => {} @@ -284,48 +328,46 @@ impl DeviceInner { } /// Writes data to the block device, this function is blocking. - /// FIEME: replace slice with a more secure data structure to use dma mapping. - pub fn write(&self, sector_id: Sid, buf: &[VmReader]) { + pub fn write(&self, sector_id: Sid, bufs: &[DmaStreamSlice]) { // FIXME: Handling cases without id. let id = self.id_allocator.lock().pop().unwrap() as usize; - let req = BlockReq { - type_: ReqType::Out as _, - reserved: 0, - sector: sector_id.to_raw(), + + let req_slice = { + let req_slice = DmaStreamSlice::new(&self.block_requests, id * REQ_SIZE, REQ_SIZE); + let req = BlockReq { + type_: ReqType::Out as _, + reserved: 0, + sector: sector_id.to_raw(), + }; + req_slice.write_val(0, &req).unwrap(); + req_slice.sync().unwrap(); + req_slice + }; + + let resp_slice = { + let resp_slice = DmaStreamSlice::new(&self.block_responses, id * RESP_SIZE, RESP_SIZE); + resp_slice.write_val(0, &BlockResp::default()).unwrap(); + resp_slice + }; + + let inputs = { + let mut inputs: Vec<&DmaStreamSlice> = bufs.iter().collect(); + inputs.insert(0, &req_slice); + inputs }; - let resp = BlockResp::default(); - self.block_requests - .write_val(id * size_of::(), &req) - .unwrap(); - self.block_responses - .write_val(id * size_of::(), &resp) - .unwrap(); - let req_reader = self - .block_requests - .reader() - .skip(id * size_of::()) - .limit(size_of::()); - let resp_writer = self - .block_responses - .writer() - .skip(id * size_of::()) - .limit(size_of::()); let mut queue = self.queue.lock_irq_disabled(); - let mut inputs: Vec<&VmReader<'_>> = buf.iter().collect(); - inputs.insert(0, &req_reader); let token = queue - .add_vm(inputs.as_slice(), &[&resp_writer]) + .add_dma_buf(inputs.as_slice(), &[&resp_slice]) .expect("add queue failed"); queue.notify(); while !queue.can_pop() { spin_loop(); } queue.pop_used_with_token(token).expect("pop used failed"); - let resp: BlockResp = self - .block_responses - .read_val(id * size_of::()) - .unwrap(); + + resp_slice.sync().unwrap(); + let resp: BlockResp = resp_slice.read_val(0).unwrap(); self.id_allocator.lock().push(id as u8); match RespStatus::try_from(resp.status).unwrap() { RespStatus::Ok => {} @@ -342,6 +384,8 @@ struct BlockReq { pub sector: u64, } +const REQ_SIZE: usize = size_of::(); + /// Response of a VirtIOBlock request. #[repr(C)] #[derive(Debug, Copy, Clone, Pod)] @@ -349,6 +393,8 @@ struct BlockResp { pub status: u8, } +const RESP_SIZE: usize = size_of::(); + impl Default for BlockResp { fn default() -> Self { Self { diff --git a/kernel/comps/virtio/src/dma_buf.rs b/kernel/comps/virtio/src/dma_buf.rs index 0395cd90..ad517e44 100644 --- a/kernel/comps/virtio/src/dma_buf.rs +++ b/kernel/comps/virtio/src/dma_buf.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_frame::vm::{DmaCoherent, DmaStream, HasDaddr}; +use aster_frame::vm::{DmaCoherent, DmaStream, DmaStreamSlice, HasDaddr}; /// A DMA-capable buffer. /// @@ -18,6 +18,12 @@ impl DmaBuf for DmaStream { } } +impl DmaBuf for DmaStreamSlice<'_> { + fn len(&self) -> usize { + self.nbytes() + } +} + impl DmaBuf for DmaCoherent { fn len(&self) -> usize { self.nbytes() diff --git a/kernel/comps/virtio/src/queue.rs b/kernel/comps/virtio/src/queue.rs index 40e34496..3945d415 100644 --- a/kernel/comps/virtio/src/queue.rs +++ b/kernel/comps/virtio/src/queue.rs @@ -11,7 +11,7 @@ use core::{ use aster_frame::{ io_mem::IoMem, offset_of, - vm::{DmaCoherent, VmAllocOptions, VmReader, VmWriter}, + vm::{DmaCoherent, VmAllocOptions}, }; use aster_rights::{Dup, TRightSet, TRights, Write}; use aster_util::{field_ptr, safe_ptr::SafePtr}; @@ -317,73 +317,6 @@ impl VirtQueue { Ok(head) } - /// Add VmReader/VmWriter to the virtqueue, return a token. - /// - /// Ref: linux virtio_ring.c virtqueue_add - pub fn add_vm( - &mut self, - inputs: &[&VmReader], - outputs: &[&VmWriter], - ) -> Result { - if inputs.is_empty() && outputs.is_empty() { - return Err(QueueError::InvalidArgs); - } - if inputs.len() + outputs.len() + self.num_used as usize > self.queue_size as usize { - return Err(QueueError::BufferTooSmall); - } - - // allocate descriptors from free list - let head = self.free_head; - let mut last = self.free_head; - for input in inputs.iter() { - let desc = &self.descs[self.free_head as usize]; - set_buf_reader(&desc.borrow_vm().restrict::(), input); - field_ptr!(desc, Descriptor, flags) - .write(&DescFlags::NEXT) - .unwrap(); - last = self.free_head; - self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); - } - for output in outputs.iter() { - let desc = &mut self.descs[self.free_head as usize]; - set_buf_writer(&desc.borrow_vm().restrict::(), output); - field_ptr!(desc, Descriptor, flags) - .write(&(DescFlags::NEXT | DescFlags::WRITE)) - .unwrap(); - last = self.free_head; - self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); - } - // set last_elem.next = NULL - { - let desc = &mut self.descs[last as usize]; - let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap(); - flags.remove(DescFlags::NEXT); - field_ptr!(desc, Descriptor, flags).write(&flags).unwrap(); - } - self.num_used += (inputs.len() + outputs.len()) as u16; - - let avail_slot = self.avail_idx & (self.queue_size - 1); - - { - let ring_ptr: SafePtr<[u16; 64], &DmaCoherent> = - field_ptr!(&self.avail, AvailRing, ring); - let mut ring_slot_ptr = ring_ptr.cast::(); - ring_slot_ptr.add(avail_slot as usize); - ring_slot_ptr.write(&head).unwrap(); - } - // write barrier - fence(Ordering::SeqCst); - - // increase head of avail ring - self.avail_idx = self.avail_idx.wrapping_add(1); - field_ptr!(&self.avail, AvailRing, idx) - .write(&self.avail_idx) - .unwrap(); - - fence(Ordering::SeqCst); - Ok(head) - } - /// Whether there is a used element that can pop. pub fn can_pop(&self) -> bool { self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read().unwrap() @@ -535,32 +468,6 @@ fn set_buf_slice(desc_ptr: &DescriptorPtr, buf: &[u8]) { .unwrap(); } -#[inline] -#[allow(clippy::type_complexity)] -fn set_buf_reader(desc_ptr: &DescriptorPtr, reader: &VmReader) { - let va = reader.cursor() as usize; - let pa = aster_frame::vm::vaddr_to_paddr(va).unwrap(); - field_ptr!(desc_ptr, Descriptor, addr) - .write(&(pa as u64)) - .unwrap(); - field_ptr!(desc_ptr, Descriptor, len) - .write(&(reader.remain() as u32)) - .unwrap(); -} - -#[inline] -#[allow(clippy::type_complexity)] -fn set_buf_writer(desc_ptr: &DescriptorPtr, writer: &VmWriter) { - let va = writer.cursor() as usize; - let pa = aster_frame::vm::vaddr_to_paddr(va).unwrap(); - field_ptr!(desc_ptr, Descriptor, addr) - .write(&(pa as u64)) - .unwrap(); - field_ptr!(desc_ptr, Descriptor, len) - .write(&(writer.avail() as u32)) - .unwrap(); -} - bitflags! { /// Descriptor flags #[derive(Pod, Default)]