diff --git a/kernel/aster-nix/src/context.rs b/kernel/aster-nix/src/context.rs index fb4e84751..aa43ba50b 100644 --- a/kernel/aster-nix/src/context.rs +++ b/kernel/aster-nix/src/context.rs @@ -5,7 +5,7 @@ use core::mem; use ostd::{ - mm::{UserSpace, VmReader, VmSpace, VmWriter}, + mm::{Fallible, Infallible, VmReader, VmSpace, VmWriter}, task::Task, }; @@ -56,14 +56,14 @@ impl CurrentUserSpace { /// Creates a reader to read data from the user space of the current task. /// /// Returns `Err` if the `vaddr` and `len` do not represent a user space memory range. - pub fn reader(&self, vaddr: Vaddr, len: usize) -> Result> { + pub fn reader(&self, vaddr: Vaddr, len: usize) -> Result> { Ok(self.0.reader(vaddr, len)?) } /// Creates a writer to write data into the user space. /// /// Returns `Err` if the `vaddr` and `len` do not represent a user space memory range. - pub fn writer(&self, vaddr: Vaddr, len: usize) -> Result> { + pub fn writer(&self, vaddr: Vaddr, len: usize) -> Result> { Ok(self.0.writer(vaddr, len)?) } @@ -76,7 +76,7 @@ impl CurrentUserSpace { /// If the destination `VmWriter` (`dest`) is empty, this function still /// checks if the current task and user space are available. If they are, /// it returns `Ok`. - pub fn read_bytes(&self, src: Vaddr, dest: &mut VmWriter<'_>) -> Result<()> { + pub fn read_bytes(&self, src: Vaddr, dest: &mut VmWriter<'_, Infallible>) -> Result<()> { let copy_len = dest.avail(); if copy_len > 0 { @@ -107,7 +107,7 @@ impl CurrentUserSpace { /// If the source `VmReader` (`src`) is empty, this function still checks if /// the current task and user space are available. If they are, it returns /// `Ok`. - pub fn write_bytes(&self, dest: Vaddr, src: &mut VmReader<'_>) -> Result<()> { + pub fn write_bytes(&self, dest: Vaddr, src: &mut VmReader<'_, Infallible>) -> Result<()> { let copy_len = src.remain(); if copy_len > 0 { @@ -150,7 +150,7 @@ pub trait ReadCString { fn read_cstring(&mut self) -> Result; } -impl<'a> ReadCString for VmReader<'a, UserSpace> { +impl<'a> ReadCString for VmReader<'a, Fallible> { /// This implementation is inspired by /// the `do_strncpy_from_user` function in Linux kernel. /// The original Linux implementation can be found at: diff --git a/kernel/aster-nix/src/device/tty/driver.rs b/kernel/aster-nix/src/device/tty/driver.rs index b53c6404b..5c510cc35 100644 --- a/kernel/aster-nix/src/device/tty/driver.rs +++ b/kernel/aster-nix/src/device/tty/driver.rs @@ -2,7 +2,7 @@ #![allow(dead_code)] -use ostd::mm::VmReader; +use ostd::mm::{Infallible, VmReader}; use spin::Once; use crate::{ @@ -78,7 +78,7 @@ impl Default for TtyDriver { } } -fn console_input_callback(mut reader: VmReader) { +fn console_input_callback(mut reader: VmReader) { let tty_driver = get_tty_driver(); while reader.remain() > 0 { let ch = reader.read_val().unwrap(); diff --git a/kernel/aster-nix/src/prelude.rs b/kernel/aster-nix/src/prelude.rs index abfa8524d..fc3c22e55 100644 --- a/kernel/aster-nix/src/prelude.rs +++ b/kernel/aster-nix/src/prelude.rs @@ -17,7 +17,7 @@ pub(crate) use bitflags::bitflags; pub(crate) use int_to_c_enum::TryFromInt; pub(crate) use log::{debug, error, info, log_enabled, trace, warn}; pub(crate) use ostd::{ - mm::{Vaddr, VmReader, VmWriter, PAGE_SIZE}, + mm::{FallibleVmRead, FallibleVmWrite, Vaddr, VmReader, VmWriter, PAGE_SIZE}, sync::{Mutex, MutexGuard, RwLock, RwMutex, SpinLock, SpinLockGuard}, Pod, }; diff --git a/kernel/aster-nix/src/vm/vmo/mod.rs b/kernel/aster-nix/src/vm/vmo/mod.rs index 96c296797..4c1ce5c90 100644 --- a/kernel/aster-nix/src/vm/vmo/mod.rs +++ b/kernel/aster-nix/src/vm/vmo/mod.rs @@ -11,7 +11,7 @@ use align_ext::AlignExt; use aster_rights::Rights; use ostd::{ collections::xarray::{CursorMut, XArray}, - mm::{Frame, FrameAllocOptions, VmReader, VmWriter}, + mm::{Frame, FrameAllocOptions, Infallible, VmReader, VmWriter}, }; use crate::prelude::*; @@ -304,7 +304,7 @@ impl Vmo_ { let read_len = buf.len(); let read_range = offset..(offset + read_len); let mut read_offset = offset % PAGE_SIZE; - let mut buf_writer: VmWriter = buf.into(); + let mut buf_writer: VmWriter = buf.into(); let read = move |page: Frame| { page.reader().skip(read_offset).read(&mut buf_writer); @@ -319,7 +319,7 @@ impl Vmo_ { let write_len = buf.len(); let write_range = offset..(offset + write_len); let mut write_offset = offset % PAGE_SIZE; - let mut buf_reader: VmReader = buf.into(); + let mut buf_reader: VmReader = buf.into(); let mut write = move |page: Frame| { page.writer().skip(write_offset).write(&mut buf_reader); diff --git a/kernel/comps/block/src/bio.rs b/kernel/comps/block/src/bio.rs index b17f747e8..57e6e5b3b 100644 --- a/kernel/comps/block/src/bio.rs +++ b/kernel/comps/block/src/bio.rs @@ -3,7 +3,7 @@ use align_ext::AlignExt; use int_to_c_enum::TryFromInt; use ostd::{ - mm::{Frame, Segment, VmReader, VmWriter}, + mm::{Frame, Infallible, Segment, VmReader, VmWriter}, sync::WaitQueue, }; @@ -423,7 +423,7 @@ impl<'a> BioSegment { } /// Returns a reader to read data from it. - pub fn reader(&'a self) -> VmReader<'a> { + pub fn reader(&'a self) -> VmReader<'a, Infallible> { self.pages .reader() .skip(self.offset.value()) @@ -431,7 +431,7 @@ impl<'a> BioSegment { } /// Returns a writer to write data into it. - pub fn writer(&'a self) -> VmWriter<'a> { + pub fn writer(&'a self) -> VmWriter<'a, Infallible> { self.pages .writer() .skip(self.offset.value()) diff --git a/kernel/comps/console/src/lib.rs b/kernel/comps/console/src/lib.rs index f748dd869..11955f6e2 100644 --- a/kernel/comps/console/src/lib.rs +++ b/kernel/comps/console/src/lib.rs @@ -11,10 +11,13 @@ use alloc::{collections::BTreeMap, fmt::Debug, string::String, sync::Arc, vec::V use core::any::Any; use component::{init_component, ComponentInitError}; -use ostd::{mm::VmReader, sync::SpinLock}; +use ostd::{ + mm::{Infallible, VmReader}, + sync::SpinLock, +}; use spin::Once; -pub type ConsoleCallback = dyn Fn(VmReader) + Send + Sync; +pub type ConsoleCallback = dyn Fn(VmReader) + Send + Sync; pub trait AnyConsoleDevice: Send + Sync + Any + Debug { fn send(&self, buf: &[u8]); diff --git a/kernel/comps/network/src/buffer.rs b/kernel/comps/network/src/buffer.rs index 62d8ba8a8..7d26b4e42 100644 --- a/kernel/comps/network/src/buffer.rs +++ b/kernel/comps/network/src/buffer.rs @@ -5,7 +5,8 @@ use alloc::{collections::LinkedList, sync::Arc}; use align_ext::AlignExt; use ostd::{ mm::{ - Daddr, DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, VmReader, VmWriter, PAGE_SIZE, + Daddr, DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, Infallible, VmReader, + VmWriter, PAGE_SIZE, }, sync::SpinLock, Pod, @@ -52,7 +53,7 @@ impl TxBuffer { tx_buffer } - pub fn writer(&self) -> VmWriter<'_> { + pub fn writer(&self) -> VmWriter<'_, Infallible> { self.dma_stream.writer().unwrap().limit(self.nbytes) } @@ -106,7 +107,7 @@ impl RxBuffer { self.packet_len = packet_len; } - pub fn packet(&self) -> VmReader<'_> { + pub fn packet(&self) -> VmReader<'_, Infallible> { self.segment .sync(self.header_len..self.header_len + self.packet_len) .unwrap(); @@ -117,7 +118,7 @@ impl RxBuffer { .limit(self.packet_len) } - pub fn buf(&self) -> VmReader<'_> { + pub fn buf(&self) -> VmReader<'_, Infallible> { self.segment .sync(0..self.header_len + self.packet_len) .unwrap(); diff --git a/kernel/comps/network/src/dma_pool.rs b/kernel/comps/network/src/dma_pool.rs index ee8917286..072abd4bf 100644 --- a/kernel/comps/network/src/dma_pool.rs +++ b/kernel/comps/network/src/dma_pool.rs @@ -11,7 +11,8 @@ use core::ops::Range; use bitvec::{array::BitArray, prelude::Lsb0}; use ostd::{ mm::{ - Daddr, DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, VmReader, VmWriter, PAGE_SIZE, + Daddr, DmaDirection, DmaStream, FrameAllocOptions, HasDaddr, Infallible, VmReader, + VmWriter, PAGE_SIZE, }, sync::{RwLock, SpinLock}, }; @@ -233,12 +234,12 @@ impl DmaSegment { self.size } - pub fn reader(&self) -> Result, ostd::Error> { + pub fn reader(&self) -> Result, ostd::Error> { let offset = self.start_addr - self.dma_stream.daddr(); Ok(self.dma_stream.reader()?.skip(offset).limit(self.size)) } - pub fn writer(&self) -> Result, ostd::Error> { + pub fn writer(&self) -> Result, ostd::Error> { let offset = self.start_addr - self.dma_stream.daddr(); Ok(self.dma_stream.writer()?.skip(offset).limit(self.size)) } diff --git a/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs b/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs index 21f60b1b6..9fbfde511 100644 --- a/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs +++ b/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs @@ -13,8 +13,8 @@ use alloc::vec; use ostd::arch::qemu::{exit_qemu, QemuExitCode}; use ostd::cpu::UserContext; use ostd::mm::{ - CachePolicy, FrameAllocOptions, PageFlags, PageProperty, Vaddr, VmIo, VmSpace, VmWriter, - PAGE_SIZE, + CachePolicy, FallibleVmRead, FallibleVmWrite, FrameAllocOptions, PageFlags, PageProperty, + Vaddr, VmIo, VmSpace, VmWriter, PAGE_SIZE, }; use ostd::prelude::*; use ostd::task::{Task, TaskOptions}; diff --git a/ostd/src/mm/dma/dma_coherent.rs b/ostd/src/mm/dma/dma_coherent.rs index 885c14794..1baaf4394 100644 --- a/ostd/src/mm/dma/dma_coherent.rs +++ b/ostd/src/mm/dma/dma_coherent.rs @@ -13,7 +13,7 @@ use crate::{ io::VmIoOnce, kspace::{paddr_to_vaddr, KERNEL_PAGE_TABLE}, page_prop::CachePolicy, - HasPaddr, Paddr, PodOnce, Segment, VmIo, VmReader, VmWriter, PAGE_SIZE, + HasPaddr, Infallible, Paddr, PodOnce, Segment, VmIo, VmReader, VmWriter, PAGE_SIZE, }, prelude::*, }; @@ -192,12 +192,12 @@ impl VmIoOnce for DmaCoherent { impl<'a> DmaCoherent { /// Returns a reader to read data from it. - pub fn reader(&'a self) -> VmReader<'a> { + pub fn reader(&'a self) -> VmReader<'a, Infallible> { self.inner.vm_segment.reader() } /// Returns a writer to write data into it. - pub fn writer(&'a self) -> VmWriter<'a> { + pub fn writer(&'a self) -> VmWriter<'a, Infallible> { self.inner.vm_segment.writer() } } diff --git a/ostd/src/mm/dma/dma_stream.rs b/ostd/src/mm/dma/dma_stream.rs index d20b01e56..ed32062fe 100644 --- a/ostd/src/mm/dma/dma_stream.rs +++ b/ostd/src/mm/dma/dma_stream.rs @@ -11,7 +11,7 @@ use crate::{ error::Error, mm::{ dma::{dma_type, Daddr, DmaType}, - HasPaddr, Paddr, Segment, VmIo, VmReader, VmWriter, PAGE_SIZE, + HasPaddr, Infallible, Paddr, Segment, VmIo, VmReader, VmWriter, PAGE_SIZE, }, }; @@ -220,7 +220,7 @@ impl VmIo for DmaStream { impl<'a> DmaStream { /// Returns a reader to read data from it. - pub fn reader(&'a self) -> Result, Error> { + pub fn reader(&'a self) -> Result, Error> { if self.inner.direction == DmaDirection::ToDevice { return Err(Error::AccessDenied); } @@ -228,7 +228,7 @@ impl<'a> DmaStream { } /// Returns a writer to write data into it. - pub fn writer(&'a self) -> Result, Error> { + pub fn writer(&'a self) -> Result, Error> { if self.inner.direction == DmaDirection::FromDevice { return Err(Error::AccessDenied); } diff --git a/ostd/src/mm/frame/mod.rs b/ostd/src/mm/frame/mod.rs index 17ec2399c..01f1ad844 100644 --- a/ostd/src/mm/frame/mod.rs +++ b/ostd/src/mm/frame/mod.rs @@ -15,9 +15,12 @@ use core::mem::ManuallyDrop; pub use segment::Segment; -use super::page::{ - meta::{FrameMeta, MetaSlot, PageMeta, PageUsage}, - DynPage, Page, +use super::{ + page::{ + meta::{FrameMeta, MetaSlot, PageMeta, PageUsage}, + DynPage, Page, + }, + Infallible, }; use crate::{ mm::{ @@ -111,7 +114,7 @@ impl HasPaddr for Frame { impl<'a> Frame { /// Returns a reader to read data from it. - pub fn reader(&'a self) -> VmReader<'a> { + pub fn reader(&'a self) -> VmReader<'a, Infallible> { // SAFETY: // - The memory range points to untyped memory. // - The frame is alive during the lifetime `'a`. @@ -120,7 +123,7 @@ impl<'a> Frame { } /// Returns a writer to write data into it. - pub fn writer(&'a self) -> VmWriter<'a> { + pub fn writer(&'a self) -> VmWriter<'a, Infallible> { // SAFETY: // - The memory range points to untyped memory. // - The frame is alive during the lifetime `'a`. @@ -163,7 +166,7 @@ impl VmIo for alloc::vec::Vec { let num_skip_pages = offset / PAGE_SIZE; let mut start = offset % PAGE_SIZE; - let mut buf_writer: VmWriter = buf.into(); + let mut buf_writer: VmWriter = buf.into(); for frame in self.iter().skip(num_skip_pages) { let read_len = frame.reader().skip(start).read(&mut buf_writer); if read_len == 0 { @@ -183,7 +186,7 @@ impl VmIo for alloc::vec::Vec { let num_skip_pages = offset / PAGE_SIZE; let mut start = offset % PAGE_SIZE; - let mut buf_reader: VmReader = buf.into(); + let mut buf_reader: VmReader = buf.into(); for frame in self.iter().skip(num_skip_pages) { let write_len = frame.writer().skip(start).write(&mut buf_reader); if write_len == 0 { diff --git a/ostd/src/mm/frame/segment.rs b/ostd/src/mm/frame/segment.rs index 414daae7c..53b44a6eb 100644 --- a/ostd/src/mm/frame/segment.rs +++ b/ostd/src/mm/frame/segment.rs @@ -9,7 +9,7 @@ use super::Frame; use crate::{ mm::{ page::{cont_pages::ContPages, meta::FrameMeta, Page}, - HasPaddr, Paddr, VmIo, VmReader, VmWriter, PAGE_SIZE, + HasPaddr, Infallible, Paddr, VmIo, VmReader, VmWriter, PAGE_SIZE, }, Error, Result, }; @@ -95,7 +95,7 @@ impl Segment { impl<'a> Segment { /// Returns a reader to read data from it. - pub fn reader(&'a self) -> VmReader<'a> { + pub fn reader(&'a self) -> VmReader<'a, Infallible> { // SAFETY: // - The memory range points to untyped memory. // - The segment is alive during the lifetime `'a`. @@ -104,7 +104,7 @@ impl<'a> Segment { } /// Returns a writer to write data into it. - pub fn writer(&'a self) -> VmWriter<'a> { + pub fn writer(&'a self) -> VmWriter<'a, Infallible> { // SAFETY: // - The memory range points to untyped memory. // - The segment is alive during the lifetime `'a`. diff --git a/ostd/src/mm/io.rs b/ostd/src/mm/io.rs index e67b9db7d..31c94a427 100644 --- a/ostd/src/mm/io.rs +++ b/ostd/src/mm/io.rs @@ -250,12 +250,11 @@ impl_vm_io_once_pointer!(Box, "(**self)"); impl_vm_io_once_pointer!(Arc, "(**self)"); /// A marker structure used for [`VmReader`] and [`VmWriter`], -/// representing their operated memory scope is in user space. -pub struct UserSpace; - +/// representing whether reads or writes on the underlying memory region are fallible. +pub struct Fallible; /// A marker structure used for [`VmReader`] and [`VmWriter`], -/// representing their operated memory scope is in kernel space. -pub struct KernelSpace; +/// representing whether reads or writes on the underlying memory region are infallible. +pub struct Infallible; /// Copies `len` bytes from `src` to `dst`. /// @@ -300,13 +299,44 @@ unsafe fn memcpy_fallible(dst: *mut u8, src: *const u8, len: usize) -> usize { len - failed_bytes } +/// Fallible memory read from a `VmWriter`. +pub trait FallibleVmRead { + /// Reads all data into the writer until one of the three conditions is met: + /// 1. The reader has no remaining data. + /// 2. The writer has no available space. + /// 3. The reader/writer encounters some error. + /// + /// On success, the number of bytes read is returned; + /// On error, both the error and the number of bytes read so far are returned. + fn read_fallible( + &mut self, + writer: &mut VmWriter<'_, F>, + ) -> core::result::Result; +} + +/// Fallible memory write from a `VmReader`. +pub trait FallibleVmWrite { + /// Writes all data from the reader until one of the three conditions is met: + /// 1. The reader has no remaining data. + /// 2. The writer has no available space. + /// 3. The reader/writer encounters some error. + /// + /// On success, the number of bytes written is returned; + /// On error, both the error and the number of bytes written so far are returned. + fn write_fallible( + &mut self, + reader: &mut VmReader<'_, F>, + ) -> core::result::Result; +} + /// `VmReader` is a reader for reading data from a contiguous range of memory. /// /// The memory range read by `VmReader` can be in either kernel space or user space. /// When the operating range is in kernel space, the memory within that range -/// is guaranteed to be valid. +/// is guaranteed to be valid, and the corresponding memory reads are infallible. /// When the operating range is in user space, it is ensured that the page table of -/// the process creating the `VmReader` is active for the duration of `'a`. +/// the process creating the `VmReader` is active for the duration of `'a`, +/// and the corresponding memory reads are considered fallible. /// /// When perform reading with a `VmWriter`, if one of them represents typed memory, /// it can ensure that the reading range in this reader and writing range in the @@ -316,25 +346,18 @@ unsafe fn memcpy_fallible(dst: *mut u8, src: *const u8, len: usize) -> usize { /// and physical address level. There is not guarantee for the operation results /// of `VmReader` and `VmWriter` in overlapping untyped addresses, and it is /// the user's responsibility to handle this situation. -pub struct VmReader<'a, Space = KernelSpace> { +pub struct VmReader<'a, Fallibility = Fallible> { cursor: *const u8, end: *const u8, - phantom: PhantomData<(&'a [u8], Space)>, + phantom: PhantomData<(&'a [u8], Fallibility)>, } macro_rules! impl_read_fallible { - ($read_space:ty, $write_space:ty) => { - impl<'a> VmReader<'a, $read_space> { - /// Reads all data into the writer until one of the three conditions is met: - /// 1. The reader has no remaining data. - /// 2. The writer has no available space. - /// 3. The reader/writer encounters some error. - /// - /// On success, the number of bytes read is returned; - /// On error, both the error and the number of bytes read so far are returned. - pub fn read_fallible( + ($reader_fallibility:ty, $writer_fallibility:ty) => { + impl<'a> FallibleVmRead<$writer_fallibility> for VmReader<'a, $reader_fallibility> { + fn read_fallible( &mut self, - writer: &mut VmWriter<'_, $write_space>, + writer: &mut VmWriter<'_, $writer_fallibility>, ) -> core::result::Result { let copy_len = self.remain().min(writer.avail()); if copy_len == 0 { @@ -361,18 +384,11 @@ macro_rules! impl_read_fallible { } macro_rules! impl_write_fallible { - ($read_space:ty, $write_space:ty) => { - impl<'a> VmWriter<'a, $write_space> { - /// Writes all data from the reader until one of the three conditions is met: - /// 1. The reader has no remaining data. - /// 2. The writer has no available space. - /// 3. The reader/writer encounters some error. - /// - /// On success, the number of bytes written is returned; - /// On error, both the error and the number of bytes written so far are returned. - pub fn write_fallible( + ($writer_fallibility:ty, $reader_fallibility:ty) => { + impl<'a> FallibleVmWrite<$reader_fallibility> for VmWriter<'a, $writer_fallibility> { + fn write_fallible( &mut self, - reader: &mut VmReader<'_, $read_space>, + reader: &mut VmReader<'_, $reader_fallibility>, ) -> core::result::Result { reader.read_fallible(self) } @@ -380,14 +396,14 @@ macro_rules! impl_write_fallible { }; } -// TODO: implement an additional function `memcpy_nonoverlapping_fallible` -// to implement read/write instruction from user space to user space. -impl_read_fallible!(UserSpace, KernelSpace); -impl_read_fallible!(KernelSpace, UserSpace); -impl_write_fallible!(UserSpace, KernelSpace); -impl_write_fallible!(KernelSpace, UserSpace); +impl_read_fallible!(Fallible, Infallible); +impl_read_fallible!(Fallible, Fallible); +impl_read_fallible!(Infallible, Fallible); +impl_write_fallible!(Fallible, Infallible); +impl_write_fallible!(Fallible, Fallible); +impl_write_fallible!(Infallible, Fallible); -impl<'a> VmReader<'a, KernelSpace> { +impl<'a> VmReader<'a, Infallible> { /// Constructs a `VmReader` from a pointer and a length, which represents /// a memory range in kernel space. /// @@ -397,8 +413,9 @@ impl<'a> VmReader<'a, KernelSpace> { /// /// [valid]: crate::mm::io#safety pub unsafe fn from_kernel_space(ptr: *const u8, len: usize) -> Self { - // If casting a zero sized slice to a pointer, the pointer may be null - // and does not reside in our kernel space range. + // Rust is allowed to give the reference to a zero-sized object a very small address, + // falling out of the kernel virtual address space range. + // So when `len` is zero, we should not and need not to check `ptr`. debug_assert!(len == 0 || KERNEL_BASE_VADDR <= ptr as usize); debug_assert!(len == 0 || ptr.add(len) as usize <= KERNEL_END_VADDR); @@ -414,7 +431,7 @@ impl<'a> VmReader<'a, KernelSpace> { /// 2. The writer has no available space. /// /// Returns the number of bytes read. - pub fn read(&mut self, writer: &mut VmWriter<'_, KernelSpace>) -> usize { + pub fn read(&mut self, writer: &mut VmWriter<'_, Infallible>) -> usize { let copy_len = self.remain().min(writer.avail()); if copy_len == 0 { return 0; @@ -474,9 +491,17 @@ impl<'a> VmReader<'a, KernelSpace> { Ok(val) } + + /// Converts to a fallible reader. + pub fn to_fallible(self) -> VmReader<'a, Fallible> { + // SAFETY: It is safe to transmute to a fallible reader since + // 1. the fallibility is a zero-sized marker type, + // 2. an infallible reader covers the capabilities of a fallible reader. + unsafe { core::mem::transmute(self) } + } } -impl<'a> VmReader<'a, UserSpace> { +impl<'a> VmReader<'a, Fallible> { /// Constructs a `VmReader` from a pointer and a length, which represents /// a memory range in user space. /// @@ -498,6 +523,10 @@ impl<'a> VmReader<'a, UserSpace> { /// If the length of the `Pod` type exceeds `self.remain()`, /// or the value can not be read completely, /// this method will return `Err`. + /// + /// If the memory read failed, this method will return `Err` + /// and the current reader's cursor remains pointing to + /// the original starting position. pub fn read_val(&mut self) -> Result { if self.remain() < core::mem::size_of::() { return Err(Error::InvalidArgs); @@ -506,12 +535,19 @@ impl<'a> VmReader<'a, UserSpace> { let mut val = T::new_uninit(); let mut writer = VmWriter::from(val.as_bytes_mut()); self.read_fallible(&mut writer) - .map(|_| val) - .map_err(|err| err.0) + .map_err(|(err, copied_len)| { + // SAFETY: The `copied_len` is the number of bytes read so far. + // So the `cursor` can be moved back to the original position. + unsafe { + self.cursor = self.cursor.sub(copied_len); + } + err + })?; + Ok(val) } } -impl<'a, Space> VmReader<'a, Space> { +impl<'a, Fallibility> VmReader<'a, Fallibility> { /// Returns the number of bytes for the remaining data. pub const fn remain(&self) -> usize { // SAFETY: the end is equal to or greater than the cursor. @@ -554,7 +590,7 @@ impl<'a, Space> VmReader<'a, Space> { } } -impl<'a> From<&'a [u8]> for VmReader<'a> { +impl<'a> From<&'a [u8]> for VmReader<'a, Infallible> { fn from(slice: &'a [u8]) -> Self { // SAFETY: // - The memory range points to typed memory. @@ -569,9 +605,10 @@ impl<'a> From<&'a [u8]> for VmReader<'a> { /// /// The memory range write by `VmWriter` can be in either kernel space or user space. /// When the operating range is in kernel space, the memory within that range -/// is guaranteed to be valid. +/// is guaranteed to be valid, and the corresponding memory writes are infallible. /// When the operating range is in user space, it is ensured that the page table of -/// the process creating the `VmWriter` is active for the duration of `'a`. +/// the process creating the `VmWriter` is active for the duration of `'a`, +/// and the corresponding memory writes are considered fallible. /// /// When perform writing with a `VmReader`, if one of them represents typed memory, /// it can ensure that the writing range in this writer and reading range in the @@ -581,13 +618,13 @@ impl<'a> From<&'a [u8]> for VmReader<'a> { /// and physical address level. There is not guarantee for the operation results /// of `VmReader` and `VmWriter` in overlapping untyped addresses, and it is /// the user's responsibility to handle this situation. -pub struct VmWriter<'a, Space = KernelSpace> { +pub struct VmWriter<'a, Fallibility = Fallible> { cursor: *mut u8, end: *mut u8, - phantom: PhantomData<(&'a mut [u8], Space)>, + phantom: PhantomData<(&'a mut [u8], Fallibility)>, } -impl<'a> VmWriter<'a, KernelSpace> { +impl<'a> VmWriter<'a, Infallible> { /// Constructs a `VmWriter` from a pointer and a length, which represents /// a memory range in kernel space. /// @@ -614,7 +651,7 @@ impl<'a> VmWriter<'a, KernelSpace> { /// 2. The writer has no available space. /// /// Returns the number of bytes written. - pub fn write(&mut self, reader: &mut VmReader<'_, KernelSpace>) -> usize { + pub fn write(&mut self, reader: &mut VmReader<'_, Infallible>) -> usize { reader.read(self) } @@ -686,9 +723,17 @@ impl<'a> VmWriter<'a, KernelSpace> { self.cursor = self.end; written_num } + + /// Converts to a fallible writer. + pub fn to_fallible(self) -> VmWriter<'a, Fallible> { + // SAFETY: It is safe to transmute to a fallible writer since + // 1. the fallibility is a zero-sized marker type, + // 2. an infallible reader covers the capabilities of a fallible reader. + unsafe { core::mem::transmute(self) } + } } -impl<'a> VmWriter<'a, UserSpace> { +impl<'a> VmWriter<'a, Fallible> { /// Constructs a `VmWriter` from a pointer and a length, which represents /// a memory range in user space. /// @@ -713,18 +758,30 @@ impl<'a> VmWriter<'a, UserSpace> { /// If the length of the `Pod` type exceeds `self.avail()`, /// or the value can not be write completely, /// this method will return `Err`. + /// + /// If the memory write failed, this method will return `Err` + /// and the current writer's cursor remains pointing to + /// the original starting position. pub fn write_val(&mut self, new_val: &T) -> Result<()> { if self.avail() < core::mem::size_of::() { return Err(Error::InvalidArgs); } let mut reader = VmReader::from(new_val.as_bytes()); - self.write_fallible(&mut reader).map_err(|err| err.0)?; + self.write_fallible(&mut reader) + .map_err(|(err, copied_len)| { + // SAFETY: The `copied_len` is the number of bytes written so far. + // So the `cursor` can be moved back to the original position. + unsafe { + self.cursor = self.cursor.sub(copied_len); + } + err + })?; Ok(()) } } -impl<'a, Space> VmWriter<'a, Space> { +impl<'a, Fallibility> VmWriter<'a, Fallibility> { /// Returns the number of bytes for the available space. pub const fn avail(&self) -> usize { // SAFETY: the end is equal to or greater than the cursor. @@ -767,7 +824,7 @@ impl<'a, Space> VmWriter<'a, Space> { } } -impl<'a> From<&'a mut [u8]> for VmWriter<'a> { +impl<'a> From<&'a mut [u8]> for VmWriter<'a, Infallible> { fn from(slice: &'a mut [u8]) -> Self { // SAFETY: // - The memory range points to typed memory. diff --git a/ostd/src/mm/mod.rs b/ostd/src/mm/mod.rs index d3aa68326..2becf76a3 100644 --- a/ostd/src/mm/mod.rs +++ b/ostd/src/mm/mod.rs @@ -28,7 +28,10 @@ use spin::Once; pub use self::{ dma::{Daddr, DmaCoherent, DmaDirection, DmaStream, DmaStreamSlice, HasDaddr}, frame::{options::FrameAllocOptions, Frame, Segment}, - io::{KernelSpace, PodOnce, UserSpace, VmIo, VmIoOnce, VmReader, VmWriter}, + io::{ + Fallible, FallibleVmRead, FallibleVmWrite, Infallible, PodOnce, VmIo, VmIoOnce, VmReader, + VmWriter, + }, page_prop::{CachePolicy, PageFlags, PageProperty}, vm_space::VmSpace, }; diff --git a/ostd/src/mm/vm_space.rs b/ostd/src/mm/vm_space.rs index 574d8efb3..a84111347 100644 --- a/ostd/src/mm/vm_space.rs +++ b/ostd/src/mm/vm_space.rs @@ -14,7 +14,7 @@ use core::ops::Range; use spin::Once; use super::{ - io::UserSpace, + io::Fallible, kspace::KERNEL_PAGE_TABLE, page_table::{PageTable, UserMode}, PageFlags, PageProperty, VmReader, VmWriter, @@ -192,7 +192,7 @@ impl VmSpace { /// /// Returns `Err` if this `VmSpace` is not belonged to the user space of the current task /// or the `vaddr` and `len` do not represent a user space memory range. - pub fn reader(&self, vaddr: Vaddr, len: usize) -> Result> { + pub fn reader(&self, vaddr: Vaddr, len: usize) -> Result> { if current_page_table_paddr() != unsafe { self.pt.root_paddr() } { return Err(Error::AccessDenied); } @@ -206,14 +206,14 @@ impl VmSpace { // the `VmReader`. // // SAFETY: The memory range is in user space, as checked above. - Ok(unsafe { VmReader::::from_user_space(vaddr as *const u8, len) }) + Ok(unsafe { VmReader::::from_user_space(vaddr as *const u8, len) }) } /// Creates a writer to write data into the user space. /// /// Returns `Err` if this `VmSpace` is not belonged to the user space of the current task /// or the `vaddr` and `len` do not represent a user space memory range. - pub fn writer(&self, vaddr: Vaddr, len: usize) -> Result> { + pub fn writer(&self, vaddr: Vaddr, len: usize) -> Result> { if current_page_table_paddr() != unsafe { self.pt.root_paddr() } { return Err(Error::AccessDenied); } @@ -227,7 +227,7 @@ impl VmSpace { // the `VmWriter`. // // SAFETY: The memory range is in user space, as checked above. - Ok(unsafe { VmWriter::::from_user_space(vaddr as *mut u8, len) }) + Ok(unsafe { VmWriter::::from_user_space(vaddr as *mut u8, len) }) } }