Implement ReadCString for MultiRead

This commit is contained in:
jiangjianfeng
2025-04-09 06:36:21 +00:00
committed by Ruihan Li
parent 3c2c31ceb0
commit b833ec6ede
2 changed files with 148 additions and 53 deletions

View File

@ -17,6 +17,7 @@ use crate::{
Process, Process,
}, },
thread::Thread, thread::Thread,
util::{MultiRead, VmReaderArray},
vm::vmar::Vmar, vm::vmar::Vmar,
}; };
@ -185,52 +186,95 @@ impl ReadCString for VmReader<'_, Fallible> {
let max_len = self.remain(); let max_len = self.remain();
let mut buffer: Vec<u8> = Vec::with_capacity(max_len); let mut buffer: Vec<u8> = Vec::with_capacity(max_len);
macro_rules! read_one_byte_at_a_time_while { if read_until_nul_byte(self, &mut buffer, max_len)? {
($cond:expr) => { return Ok(CString::from_vec_with_nul(buffer).unwrap());
while $cond {
let byte = self.read_val::<u8>()?;
buffer.push(byte);
if byte == 0 {
return Ok(CString::from_vec_with_nul(buffer)
.expect("We provided 0 but no 0 is found"));
}
}
};
} }
// Handle the first few bytes to make `cur_addr` aligned with `size_of::<usize>` return_errno_with_message!(
read_one_byte_at_a_time_while!( Errno::EFAULT,
!is_addr_aligned(self.cursor() as usize) && buffer.len() < max_len "no nul terminator is present before reaching the buffer limit"
); );
}
}
// Handle the rest of the bytes in bulk impl ReadCString for VmReaderArray<'_> {
let mut cloned_reader = self.clone(); fn read_cstring(&mut self) -> Result<CString> {
while (buffer.len() + mem::size_of::<usize>()) <= max_len { let max_len = self.sum_lens();
let Ok(word) = cloned_reader.read_val::<usize>() else { let mut buffer: Vec<u8> = Vec::with_capacity(max_len);
break;
};
if has_zero(word) { for reader in self.readers_mut() {
for byte in word.to_ne_bytes() { if read_until_nul_byte(reader, &mut buffer, max_len)? {
self.skip(1); return Ok(CString::from_vec_with_nul(buffer).unwrap());
buffer.push(byte);
if byte == 0 {
return Ok(CString::from_vec_with_nul(buffer)
.expect("We provided 0 but no 0 is found"));
}
}
unreachable!("The branch should never be reached unless `has_zero` has bugs.")
} }
self.skip(size_of::<usize>());
buffer.extend_from_slice(&word.to_ne_bytes());
} }
// Handle the last few bytes that are not enough for a word return_errno_with_message!(
read_one_byte_at_a_time_while!(buffer.len() < max_len); Errno::EFAULT,
"no nul terminator is present before reaching the buffer limit"
);
}
}
// Maximum length exceeded before finding the null terminator /// Reads bytes from `reader` into `buffer` until a nul byte is found.
return_errno_with_message!(Errno::EFAULT, "Fails to read CString from user"); ///
/// This method returns the following values:
/// 1. `Ok(true)`: If a nul byte is found in the reader;
/// 2. `Ok(false)`: If no nul byte is found and the `reader` is exhausted;
/// 3. `Err(_)`: If an error occurs while reading from the `reader`.
fn read_until_nul_byte(
reader: &mut VmReader,
buffer: &mut Vec<u8>,
max_len: usize,
) -> Result<bool> {
macro_rules! read_one_byte_at_a_time_while {
($cond:expr) => {
while $cond {
let byte = reader.read_val::<u8>()?;
buffer.push(byte);
if byte == 0 {
return Ok(true);
}
}
};
}
// Handle the first few bytes to make `cur_addr` aligned with `size_of::<usize>`
read_one_byte_at_a_time_while!(
!is_addr_aligned(reader.cursor() as usize) && buffer.len() < max_len && reader.has_remain()
);
// Handle the rest of the bytes in bulk
let mut cloned_reader = reader.clone();
while (buffer.len() + mem::size_of::<usize>()) <= max_len {
let Ok(word) = cloned_reader.read_val::<usize>() else {
break;
};
if has_zero(word) {
for byte in word.to_ne_bytes() {
reader.skip(1);
buffer.push(byte);
if byte == 0 {
return Ok(true);
}
}
unreachable!("The branch should never be reached unless `has_zero` has bugs.")
}
reader.skip(size_of::<usize>());
buffer.extend_from_slice(&word.to_ne_bytes());
}
// Handle the last few bytes that are not enough for a word
read_one_byte_at_a_time_while!(buffer.len() < max_len && reader.has_remain());
if buffer.len() >= max_len {
return_errno_with_message!(
Errno::EFAULT,
"no nul terminator is present before exceeding the maximum length"
);
} else {
Ok(false)
} }
} }
@ -277,22 +321,67 @@ mod test {
use super::*; use super::*;
#[ktest] fn init_buffer(cstrs: &[CString]) -> Vec<u8> {
fn read_multiple_cstring() { let mut buffer = vec![255u8; 100];
let mut buffer = vec![0u8; 100];
let str1 = CString::new("hello").unwrap();
let str2 = CString::new("world!").unwrap();
let mut writer = VmWriter::from(buffer.as_mut_slice()); let mut writer = VmWriter::from(buffer.as_mut_slice());
writer.write(&mut VmReader::from(str1.as_bytes_with_nul()));
writer.write(&mut VmReader::from(str2.as_bytes_with_nul())); for cstr in cstrs {
drop(writer); writer.write(&mut VmReader::from(cstr.as_bytes_with_nul()));
}
buffer
}
#[ktest]
fn read_multiple_cstring() {
let strs = {
let str1 = CString::new("hello").unwrap();
let str2 = CString::new("world!").unwrap();
vec![str1, str2]
};
let buffer = init_buffer(&strs);
let mut reader = VmReader::from(buffer.as_slice()).to_fallible(); let mut reader = VmReader::from(buffer.as_slice()).to_fallible();
let read_str1 = reader.read_cstring().unwrap(); let read_str1 = reader.read_cstring().unwrap();
assert_eq!(read_str1, str1); assert_eq!(read_str1, strs[0]);
let read_str2 = reader.read_cstring().unwrap(); let read_str2 = reader.read_cstring().unwrap();
assert_eq!(read_str2, str2); assert_eq!(read_str2, strs[1]);
assert!(reader
.read_cstring()
.is_err_and(|err| err.error() == Errno::EFAULT));
}
#[ktest]
fn read_cstring_from_multiread() {
let strs = {
let str1 = CString::new("hello").unwrap();
let str2 = CString::new("world!").unwrap();
let str3 = CString::new("asterinas").unwrap();
vec![str1, str2, str3]
};
let buffer = init_buffer(&strs);
let mut readers = {
let reader1 = VmReader::from(&buffer[0..20]).to_fallible();
let reader2 = VmReader::from(&buffer[20..40]).to_fallible();
let reader3 = VmReader::from(&buffer[40..60]).to_fallible();
VmReaderArray::new(vec![reader1, reader2, reader3].into_boxed_slice())
};
let multiread = &mut readers as &mut dyn MultiRead;
let read_str1 = multiread.read_cstring().unwrap();
assert_eq!(read_str1, strs[0]);
let read_str2 = multiread.read_cstring().unwrap();
assert_eq!(read_str2, strs[1]);
let read_str3 = multiread.read_cstring().unwrap();
assert_eq!(read_str3, strs[2]);
assert!(multiread
.read_cstring()
.is_err_and(|err| err.error() == Errno::EFAULT));
} }
} }

View File

@ -94,7 +94,7 @@ pub struct VmReaderArray<'a>(Box<[VmReader<'a>]>);
pub struct VmWriterArray<'a>(Box<[VmWriter<'a>]>); pub struct VmWriterArray<'a>(Box<[VmWriter<'a>]>);
impl<'a> VmReaderArray<'a> { impl<'a> VmReaderArray<'a> {
/// Creates a new `IoVecReader` from user-provided io vec buffer. /// Creates a new `VmReaderArray` from user-provided io vec buffer.
pub fn from_user_io_vecs( pub fn from_user_io_vecs(
user_space: &'a CurrentUserSpace<'a>, user_space: &'a CurrentUserSpace<'a>,
start_addr: Vaddr, start_addr: Vaddr,
@ -105,13 +105,19 @@ impl<'a> VmReaderArray<'a> {
} }
/// Returns mutable reference to [`VmReader`]s. /// Returns mutable reference to [`VmReader`]s.
pub fn readers_mut(&'a mut self) -> &'a mut [VmReader<'a>] { pub fn readers_mut(&mut self) -> &mut [VmReader<'a>] {
&mut self.0 &mut self.0
} }
/// Creates a new `VmReaderArray`.
#[cfg(ktest)]
pub const fn new(readers: Box<[VmReader<'a>]>) -> Self {
Self(readers)
}
} }
impl<'a> VmWriterArray<'a> { impl<'a> VmWriterArray<'a> {
/// Creates a new `IoVecWriter` from user-provided io vec buffer. /// Creates a new `VmWriterArray` from user-provided io vec buffer.
pub fn from_user_io_vecs( pub fn from_user_io_vecs(
user_space: &'a CurrentUserSpace<'a>, user_space: &'a CurrentUserSpace<'a>,
start_addr: Vaddr, start_addr: Vaddr,
@ -122,13 +128,13 @@ impl<'a> VmWriterArray<'a> {
} }
/// Returns mutable reference to [`VmWriter`]s. /// Returns mutable reference to [`VmWriter`]s.
pub fn writers_mut(&'a mut self) -> &'a mut [VmWriter<'a>] { pub fn writers_mut(&mut self) -> &mut [VmWriter<'a>] {
&mut self.0 &mut self.0
} }
} }
/// Trait defining the read behavior for a collection of [`VmReader`]s. /// Trait defining the read behavior for a collection of [`VmReader`]s.
pub trait MultiRead { pub trait MultiRead: ReadCString {
/// Reads the exact number of bytes required to exhaust `self` or fill `writer`, /// Reads the exact number of bytes required to exhaust `self` or fill `writer`,
/// accumulating total bytes read. /// accumulating total bytes read.
/// ///