From 1846c680fc4630b78d663c8b2475b9cde45079ea Mon Sep 17 00:00:00 2001 From: jiangjianfeng Date: Mon, 31 Mar 2025 03:01:22 +0000 Subject: [PATCH] Clone the reader to prevent cursor misplacement in `ReadCString` --- kernel/src/context.rs | 31 ++++++++++++++++++++++++++++++- ostd/src/mm/io.rs | 23 +++++++++++++++++++---- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/kernel/src/context.rs b/kernel/src/context.rs index 850b3e63..1ccb50fb 100644 --- a/kernel/src/context.rs +++ b/kernel/src/context.rs @@ -204,13 +204,15 @@ impl ReadCString for VmReader<'_, Fallible> { ); // Handle the rest of the bytes in bulk + let mut cloned_reader = self.clone(); while (buffer.len() + mem::size_of::()) <= max_len { - let Ok(word) = self.read_val::() else { + let Ok(word) = cloned_reader.read_val::() else { break; }; if has_zero(word) { for byte in word.to_ne_bytes() { + self.skip(1); buffer.push(byte); if byte == 0 { return Ok(CString::from_vec_with_nul(buffer) @@ -220,6 +222,7 @@ impl ReadCString for VmReader<'_, Fallible> { unreachable!("The branch should never be reached unless `has_zero` has bugs.") } + self.skip(size_of::()); buffer.extend_from_slice(&word.to_ne_bytes()); } @@ -267,3 +270,29 @@ fn check_vaddr(va: Vaddr) -> Result<()> { const fn is_addr_aligned(addr: usize) -> bool { (addr & (mem::size_of::() - 1)) == 0 } + +#[cfg(ktest)] +mod test { + use ostd::prelude::*; + + use super::*; + + #[ktest] + fn read_multiple_cstring() { + let mut buffer = vec![0u8; 100]; + + let str1 = CString::new("hello").unwrap(); + let str2 = CString::new("world!").unwrap(); + + let mut writer = VmWriter::from(buffer.as_mut_slice()); + writer.write(&mut VmReader::from(str1.as_bytes_with_nul())); + writer.write(&mut VmReader::from(str2.as_bytes_with_nul())); + drop(writer); + + let mut reader = VmReader::from(buffer.as_slice()).to_fallible(); + let read_str1 = reader.read_cstring().unwrap(); + assert_eq!(read_str1, str1); + let read_str2 = reader.read_cstring().unwrap(); + assert_eq!(read_str2, str2); + } +} diff --git a/ostd/src/mm/io.rs b/ostd/src/mm/io.rs index 9f180acf..284cf11e 100644 --- a/ostd/src/mm/io.rs +++ b/ostd/src/mm/io.rs @@ -275,12 +275,13 @@ impl_vm_io_once_pointer!(&mut T, "(**self)"); impl_vm_io_once_pointer!(Box, "(**self)"); impl_vm_io_once_pointer!(Arc, "(**self)"); -/// A marker structure used for [`VmReader`] and [`VmWriter`], +/// A marker type used for [`VmReader`] and [`VmWriter`], /// representing whether reads or writes on the underlying memory region are fallible. -pub struct Fallible; -/// A marker structure used for [`VmReader`] and [`VmWriter`], +pub enum Fallible {} + +/// A marker type used for [`VmReader`] and [`VmWriter`], /// representing whether reads or writes on the underlying memory region are infallible. -pub struct Infallible; +pub enum Infallible {} /// Copies `len` bytes from `src` to `dst`. /// @@ -393,6 +394,20 @@ pub struct VmReader<'a, Fallibility = Fallible> { phantom: PhantomData<(&'a [u8], Fallibility)>, } +// `Clone` can be implemented for `VmReader` +// because it either points to untyped memory or represents immutable references. +// Note that we cannot implement `Clone` for `VmWriter` +// because it can represent mutable references, which must remain exclusive. +impl Clone for VmReader<'_, Fallibility> { + fn clone(&self) -> Self { + Self { + cursor: self.cursor, + end: self.end, + phantom: PhantomData, + } + } +} + macro_rules! impl_read_fallible { ($reader_fallibility:ty, $writer_fallibility:ty) => { impl<'a> FallibleVmRead<$writer_fallibility> for VmReader<'a, $reader_fallibility> {