Clone the reader to prevent cursor misplacement in ReadCString

This commit is contained in:
jiangjianfeng 2025-03-31 03:01:22 +00:00 committed by Ruihan Li
parent f4e79d99d0
commit 1846c680fc
2 changed files with 49 additions and 5 deletions

View File

@ -204,13 +204,15 @@ impl ReadCString for VmReader<'_, Fallible> {
);
// Handle the rest of the bytes in bulk
let mut cloned_reader = self.clone();
while (buffer.len() + mem::size_of::<usize>()) <= max_len {
let Ok(word) = self.read_val::<usize>() else {
let Ok(word) = cloned_reader.read_val::<usize>() else {
break;
};
if has_zero(word) {
for byte in word.to_ne_bytes() {
self.skip(1);
buffer.push(byte);
if byte == 0 {
return Ok(CString::from_vec_with_nul(buffer)
@ -220,6 +222,7 @@ impl ReadCString for VmReader<'_, Fallible> {
unreachable!("The branch should never be reached unless `has_zero` has bugs.")
}
self.skip(size_of::<usize>());
buffer.extend_from_slice(&word.to_ne_bytes());
}
@ -267,3 +270,29 @@ fn check_vaddr(va: Vaddr) -> Result<()> {
const fn is_addr_aligned(addr: usize) -> bool {
(addr & (mem::size_of::<usize>() - 1)) == 0
}
#[cfg(ktest)]
mod test {
use ostd::prelude::*;
use super::*;
#[ktest]
fn read_multiple_cstring() {
let mut buffer = vec![0u8; 100];
let str1 = CString::new("hello").unwrap();
let str2 = CString::new("world!").unwrap();
let mut writer = VmWriter::from(buffer.as_mut_slice());
writer.write(&mut VmReader::from(str1.as_bytes_with_nul()));
writer.write(&mut VmReader::from(str2.as_bytes_with_nul()));
drop(writer);
let mut reader = VmReader::from(buffer.as_slice()).to_fallible();
let read_str1 = reader.read_cstring().unwrap();
assert_eq!(read_str1, str1);
let read_str2 = reader.read_cstring().unwrap();
assert_eq!(read_str2, str2);
}
}

View File

@ -275,12 +275,13 @@ impl_vm_io_once_pointer!(&mut T, "(**self)");
impl_vm_io_once_pointer!(Box<T>, "(**self)");
impl_vm_io_once_pointer!(Arc<T>, "(**self)");
/// A marker structure used for [`VmReader`] and [`VmWriter`],
/// A marker type used for [`VmReader`] and [`VmWriter`],
/// representing whether reads or writes on the underlying memory region are fallible.
pub struct Fallible;
/// A marker structure used for [`VmReader`] and [`VmWriter`],
pub enum Fallible {}
/// A marker type used for [`VmReader`] and [`VmWriter`],
/// representing whether reads or writes on the underlying memory region are infallible.
pub struct Infallible;
pub enum Infallible {}
/// Copies `len` bytes from `src` to `dst`.
///
@ -393,6 +394,20 @@ pub struct VmReader<'a, Fallibility = Fallible> {
phantom: PhantomData<(&'a [u8], Fallibility)>,
}
// `Clone` can be implemented for `VmReader`
// because it either points to untyped memory or represents immutable references.
// Note that we cannot implement `Clone` for `VmWriter`
// because it can represent mutable references, which must remain exclusive.
impl<Fallibility> Clone for VmReader<'_, Fallibility> {
fn clone(&self) -> Self {
Self {
cursor: self.cursor,
end: self.end,
phantom: PhantomData,
}
}
}
macro_rules! impl_read_fallible {
($reader_fallibility:ty, $writer_fallibility:ty) => {
impl<'a> FallibleVmRead<$writer_fallibility> for VmReader<'a, $reader_fallibility> {