Refactor VmReader&VmWriter as given fallibility marker

This commit is contained in:
Shaowei Song
2024-08-20 02:05:25 +00:00
committed by Tate, Hongliang Tian
parent 562e644375
commit 2102107be1
16 changed files with 172 additions and 104 deletions

View File

@ -250,12 +250,11 @@ impl_vm_io_once_pointer!(Box<T>, "(**self)");
impl_vm_io_once_pointer!(Arc<T>, "(**self)");
/// A marker structure used for [`VmReader`] and [`VmWriter`],
/// representing their operated memory scope is in user space.
pub struct UserSpace;
/// representing whether reads or writes on the underlying memory region are fallible.
pub struct Fallible;
/// A marker structure used for [`VmReader`] and [`VmWriter`],
/// representing their operated memory scope is in kernel space.
pub struct KernelSpace;
/// representing whether reads or writes on the underlying memory region are infallible.
pub struct Infallible;
/// Copies `len` bytes from `src` to `dst`.
///
@ -300,13 +299,44 @@ unsafe fn memcpy_fallible(dst: *mut u8, src: *const u8, len: usize) -> usize {
len - failed_bytes
}
/// Fallible memory read from a `VmWriter`.
pub trait FallibleVmRead<F> {
/// Reads all data into the writer until one of the three conditions is met:
/// 1. The reader has no remaining data.
/// 2. The writer has no available space.
/// 3. The reader/writer encounters some error.
///
/// On success, the number of bytes read is returned;
/// On error, both the error and the number of bytes read so far are returned.
fn read_fallible(
&mut self,
writer: &mut VmWriter<'_, F>,
) -> core::result::Result<usize, (Error, usize)>;
}
/// Fallible memory write from a `VmReader`.
pub trait FallibleVmWrite<F> {
/// Writes all data from the reader until one of the three conditions is met:
/// 1. The reader has no remaining data.
/// 2. The writer has no available space.
/// 3. The reader/writer encounters some error.
///
/// On success, the number of bytes written is returned;
/// On error, both the error and the number of bytes written so far are returned.
fn write_fallible(
&mut self,
reader: &mut VmReader<'_, F>,
) -> core::result::Result<usize, (Error, usize)>;
}
/// `VmReader` is a reader for reading data from a contiguous range of memory.
///
/// The memory range read by `VmReader` can be in either kernel space or user space.
/// When the operating range is in kernel space, the memory within that range
/// is guaranteed to be valid.
/// is guaranteed to be valid, and the corresponding memory reads are infallible.
/// When the operating range is in user space, it is ensured that the page table of
/// the process creating the `VmReader` is active for the duration of `'a`.
/// the process creating the `VmReader` is active for the duration of `'a`,
/// and the corresponding memory reads are considered fallible.
///
/// When perform reading with a `VmWriter`, if one of them represents typed memory,
/// it can ensure that the reading range in this reader and writing range in the
@ -316,25 +346,18 @@ unsafe fn memcpy_fallible(dst: *mut u8, src: *const u8, len: usize) -> usize {
/// and physical address level. There is not guarantee for the operation results
/// of `VmReader` and `VmWriter` in overlapping untyped addresses, and it is
/// the user's responsibility to handle this situation.
pub struct VmReader<'a, Space = KernelSpace> {
pub struct VmReader<'a, Fallibility = Fallible> {
cursor: *const u8,
end: *const u8,
phantom: PhantomData<(&'a [u8], Space)>,
phantom: PhantomData<(&'a [u8], Fallibility)>,
}
macro_rules! impl_read_fallible {
($read_space:ty, $write_space:ty) => {
impl<'a> VmReader<'a, $read_space> {
/// Reads all data into the writer until one of the three conditions is met:
/// 1. The reader has no remaining data.
/// 2. The writer has no available space.
/// 3. The reader/writer encounters some error.
///
/// On success, the number of bytes read is returned;
/// On error, both the error and the number of bytes read so far are returned.
pub fn read_fallible(
($reader_fallibility:ty, $writer_fallibility:ty) => {
impl<'a> FallibleVmRead<$writer_fallibility> for VmReader<'a, $reader_fallibility> {
fn read_fallible(
&mut self,
writer: &mut VmWriter<'_, $write_space>,
writer: &mut VmWriter<'_, $writer_fallibility>,
) -> core::result::Result<usize, (Error, usize)> {
let copy_len = self.remain().min(writer.avail());
if copy_len == 0 {
@ -361,18 +384,11 @@ macro_rules! impl_read_fallible {
}
macro_rules! impl_write_fallible {
($read_space:ty, $write_space:ty) => {
impl<'a> VmWriter<'a, $write_space> {
/// Writes all data from the reader until one of the three conditions is met:
/// 1. The reader has no remaining data.
/// 2. The writer has no available space.
/// 3. The reader/writer encounters some error.
///
/// On success, the number of bytes written is returned;
/// On error, both the error and the number of bytes written so far are returned.
pub fn write_fallible(
($writer_fallibility:ty, $reader_fallibility:ty) => {
impl<'a> FallibleVmWrite<$reader_fallibility> for VmWriter<'a, $writer_fallibility> {
fn write_fallible(
&mut self,
reader: &mut VmReader<'_, $read_space>,
reader: &mut VmReader<'_, $reader_fallibility>,
) -> core::result::Result<usize, (Error, usize)> {
reader.read_fallible(self)
}
@ -380,14 +396,14 @@ macro_rules! impl_write_fallible {
};
}
// TODO: implement an additional function `memcpy_nonoverlapping_fallible`
// to implement read/write instruction from user space to user space.
impl_read_fallible!(UserSpace, KernelSpace);
impl_read_fallible!(KernelSpace, UserSpace);
impl_write_fallible!(UserSpace, KernelSpace);
impl_write_fallible!(KernelSpace, UserSpace);
impl_read_fallible!(Fallible, Infallible);
impl_read_fallible!(Fallible, Fallible);
impl_read_fallible!(Infallible, Fallible);
impl_write_fallible!(Fallible, Infallible);
impl_write_fallible!(Fallible, Fallible);
impl_write_fallible!(Infallible, Fallible);
impl<'a> VmReader<'a, KernelSpace> {
impl<'a> VmReader<'a, Infallible> {
/// Constructs a `VmReader` from a pointer and a length, which represents
/// a memory range in kernel space.
///
@ -397,8 +413,9 @@ impl<'a> VmReader<'a, KernelSpace> {
///
/// [valid]: crate::mm::io#safety
pub unsafe fn from_kernel_space(ptr: *const u8, len: usize) -> Self {
// If casting a zero sized slice to a pointer, the pointer may be null
// and does not reside in our kernel space range.
// Rust is allowed to give the reference to a zero-sized object a very small address,
// falling out of the kernel virtual address space range.
// So when `len` is zero, we should not and need not to check `ptr`.
debug_assert!(len == 0 || KERNEL_BASE_VADDR <= ptr as usize);
debug_assert!(len == 0 || ptr.add(len) as usize <= KERNEL_END_VADDR);
@ -414,7 +431,7 @@ impl<'a> VmReader<'a, KernelSpace> {
/// 2. The writer has no available space.
///
/// Returns the number of bytes read.
pub fn read(&mut self, writer: &mut VmWriter<'_, KernelSpace>) -> usize {
pub fn read(&mut self, writer: &mut VmWriter<'_, Infallible>) -> usize {
let copy_len = self.remain().min(writer.avail());
if copy_len == 0 {
return 0;
@ -474,9 +491,17 @@ impl<'a> VmReader<'a, KernelSpace> {
Ok(val)
}
/// Converts to a fallible reader.
pub fn to_fallible(self) -> VmReader<'a, Fallible> {
// SAFETY: It is safe to transmute to a fallible reader since
// 1. the fallibility is a zero-sized marker type,
// 2. an infallible reader covers the capabilities of a fallible reader.
unsafe { core::mem::transmute(self) }
}
}
impl<'a> VmReader<'a, UserSpace> {
impl<'a> VmReader<'a, Fallible> {
/// Constructs a `VmReader` from a pointer and a length, which represents
/// a memory range in user space.
///
@ -498,6 +523,10 @@ impl<'a> VmReader<'a, UserSpace> {
/// If the length of the `Pod` type exceeds `self.remain()`,
/// or the value can not be read completely,
/// this method will return `Err`.
///
/// If the memory read failed, this method will return `Err`
/// and the current reader's cursor remains pointing to
/// the original starting position.
pub fn read_val<T: Pod>(&mut self) -> Result<T> {
if self.remain() < core::mem::size_of::<T>() {
return Err(Error::InvalidArgs);
@ -506,12 +535,19 @@ impl<'a> VmReader<'a, UserSpace> {
let mut val = T::new_uninit();
let mut writer = VmWriter::from(val.as_bytes_mut());
self.read_fallible(&mut writer)
.map(|_| val)
.map_err(|err| err.0)
.map_err(|(err, copied_len)| {
// SAFETY: The `copied_len` is the number of bytes read so far.
// So the `cursor` can be moved back to the original position.
unsafe {
self.cursor = self.cursor.sub(copied_len);
}
err
})?;
Ok(val)
}
}
impl<'a, Space> VmReader<'a, Space> {
impl<'a, Fallibility> VmReader<'a, Fallibility> {
/// Returns the number of bytes for the remaining data.
pub const fn remain(&self) -> usize {
// SAFETY: the end is equal to or greater than the cursor.
@ -554,7 +590,7 @@ impl<'a, Space> VmReader<'a, Space> {
}
}
impl<'a> From<&'a [u8]> for VmReader<'a> {
impl<'a> From<&'a [u8]> for VmReader<'a, Infallible> {
fn from(slice: &'a [u8]) -> Self {
// SAFETY:
// - The memory range points to typed memory.
@ -569,9 +605,10 @@ impl<'a> From<&'a [u8]> for VmReader<'a> {
///
/// The memory range write by `VmWriter` can be in either kernel space or user space.
/// When the operating range is in kernel space, the memory within that range
/// is guaranteed to be valid.
/// is guaranteed to be valid, and the corresponding memory writes are infallible.
/// When the operating range is in user space, it is ensured that the page table of
/// the process creating the `VmWriter` is active for the duration of `'a`.
/// the process creating the `VmWriter` is active for the duration of `'a`,
/// and the corresponding memory writes are considered fallible.
///
/// When perform writing with a `VmReader`, if one of them represents typed memory,
/// it can ensure that the writing range in this writer and reading range in the
@ -581,13 +618,13 @@ impl<'a> From<&'a [u8]> for VmReader<'a> {
/// and physical address level. There is not guarantee for the operation results
/// of `VmReader` and `VmWriter` in overlapping untyped addresses, and it is
/// the user's responsibility to handle this situation.
pub struct VmWriter<'a, Space = KernelSpace> {
pub struct VmWriter<'a, Fallibility = Fallible> {
cursor: *mut u8,
end: *mut u8,
phantom: PhantomData<(&'a mut [u8], Space)>,
phantom: PhantomData<(&'a mut [u8], Fallibility)>,
}
impl<'a> VmWriter<'a, KernelSpace> {
impl<'a> VmWriter<'a, Infallible> {
/// Constructs a `VmWriter` from a pointer and a length, which represents
/// a memory range in kernel space.
///
@ -614,7 +651,7 @@ impl<'a> VmWriter<'a, KernelSpace> {
/// 2. The writer has no available space.
///
/// Returns the number of bytes written.
pub fn write(&mut self, reader: &mut VmReader<'_, KernelSpace>) -> usize {
pub fn write(&mut self, reader: &mut VmReader<'_, Infallible>) -> usize {
reader.read(self)
}
@ -686,9 +723,17 @@ impl<'a> VmWriter<'a, KernelSpace> {
self.cursor = self.end;
written_num
}
/// Converts to a fallible writer.
pub fn to_fallible(self) -> VmWriter<'a, Fallible> {
// SAFETY: It is safe to transmute to a fallible writer since
// 1. the fallibility is a zero-sized marker type,
// 2. an infallible reader covers the capabilities of a fallible reader.
unsafe { core::mem::transmute(self) }
}
}
impl<'a> VmWriter<'a, UserSpace> {
impl<'a> VmWriter<'a, Fallible> {
/// Constructs a `VmWriter` from a pointer and a length, which represents
/// a memory range in user space.
///
@ -713,18 +758,30 @@ impl<'a> VmWriter<'a, UserSpace> {
/// If the length of the `Pod` type exceeds `self.avail()`,
/// or the value can not be write completely,
/// this method will return `Err`.
///
/// If the memory write failed, this method will return `Err`
/// and the current writer's cursor remains pointing to
/// the original starting position.
pub fn write_val<T: Pod>(&mut self, new_val: &T) -> Result<()> {
if self.avail() < core::mem::size_of::<T>() {
return Err(Error::InvalidArgs);
}
let mut reader = VmReader::from(new_val.as_bytes());
self.write_fallible(&mut reader).map_err(|err| err.0)?;
self.write_fallible(&mut reader)
.map_err(|(err, copied_len)| {
// SAFETY: The `copied_len` is the number of bytes written so far.
// So the `cursor` can be moved back to the original position.
unsafe {
self.cursor = self.cursor.sub(copied_len);
}
err
})?;
Ok(())
}
}
impl<'a, Space> VmWriter<'a, Space> {
impl<'a, Fallibility> VmWriter<'a, Fallibility> {
/// Returns the number of bytes for the available space.
pub const fn avail(&self) -> usize {
// SAFETY: the end is equal to or greater than the cursor.
@ -767,7 +824,7 @@ impl<'a, Space> VmWriter<'a, Space> {
}
}
impl<'a> From<&'a mut [u8]> for VmWriter<'a> {
impl<'a> From<&'a mut [u8]> for VmWriter<'a, Infallible> {
fn from(slice: &'a mut [u8]) -> Self {
// SAFETY:
// - The memory range points to typed memory.