mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-22 17:03:23 +00:00
Implement ReadCString for MultiRead
This commit is contained in:
@ -17,6 +17,7 @@ use crate::{
|
||||
Process,
|
||||
},
|
||||
thread::Thread,
|
||||
util::{MultiRead, VmReaderArray},
|
||||
vm::vmar::Vmar,
|
||||
};
|
||||
|
||||
@ -185,52 +186,95 @@ impl ReadCString for VmReader<'_, Fallible> {
|
||||
let max_len = self.remain();
|
||||
let mut buffer: Vec<u8> = Vec::with_capacity(max_len);
|
||||
|
||||
macro_rules! read_one_byte_at_a_time_while {
|
||||
($cond:expr) => {
|
||||
while $cond {
|
||||
let byte = self.read_val::<u8>()?;
|
||||
buffer.push(byte);
|
||||
if byte == 0 {
|
||||
return Ok(CString::from_vec_with_nul(buffer)
|
||||
.expect("We provided 0 but no 0 is found"));
|
||||
}
|
||||
}
|
||||
};
|
||||
if read_until_nul_byte(self, &mut buffer, max_len)? {
|
||||
return Ok(CString::from_vec_with_nul(buffer).unwrap());
|
||||
}
|
||||
|
||||
// Handle the first few bytes to make `cur_addr` aligned with `size_of::<usize>`
|
||||
read_one_byte_at_a_time_while!(
|
||||
!is_addr_aligned(self.cursor() as usize) && buffer.len() < max_len
|
||||
return_errno_with_message!(
|
||||
Errno::EFAULT,
|
||||
"no nul terminator is present before reaching the buffer limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the rest of the bytes in bulk
|
||||
let mut cloned_reader = self.clone();
|
||||
while (buffer.len() + mem::size_of::<usize>()) <= max_len {
|
||||
let Ok(word) = cloned_reader.read_val::<usize>() else {
|
||||
break;
|
||||
};
|
||||
impl ReadCString for VmReaderArray<'_> {
|
||||
fn read_cstring(&mut self) -> Result<CString> {
|
||||
let max_len = self.sum_lens();
|
||||
let mut buffer: Vec<u8> = Vec::with_capacity(max_len);
|
||||
|
||||
if has_zero(word) {
|
||||
for byte in word.to_ne_bytes() {
|
||||
self.skip(1);
|
||||
buffer.push(byte);
|
||||
if byte == 0 {
|
||||
return Ok(CString::from_vec_with_nul(buffer)
|
||||
.expect("We provided 0 but no 0 is found"));
|
||||
}
|
||||
}
|
||||
unreachable!("The branch should never be reached unless `has_zero` has bugs.")
|
||||
for reader in self.readers_mut() {
|
||||
if read_until_nul_byte(reader, &mut buffer, max_len)? {
|
||||
return Ok(CString::from_vec_with_nul(buffer).unwrap());
|
||||
}
|
||||
|
||||
self.skip(size_of::<usize>());
|
||||
buffer.extend_from_slice(&word.to_ne_bytes());
|
||||
}
|
||||
|
||||
// Handle the last few bytes that are not enough for a word
|
||||
read_one_byte_at_a_time_while!(buffer.len() < max_len);
|
||||
return_errno_with_message!(
|
||||
Errno::EFAULT,
|
||||
"no nul terminator is present before reaching the buffer limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Maximum length exceeded before finding the null terminator
|
||||
return_errno_with_message!(Errno::EFAULT, "Fails to read CString from user");
|
||||
/// Reads bytes from `reader` into `buffer` until a nul byte is found.
|
||||
///
|
||||
/// This method returns the following values:
|
||||
/// 1. `Ok(true)`: If a nul byte is found in the reader;
|
||||
/// 2. `Ok(false)`: If no nul byte is found and the `reader` is exhausted;
|
||||
/// 3. `Err(_)`: If an error occurs while reading from the `reader`.
|
||||
fn read_until_nul_byte(
|
||||
reader: &mut VmReader,
|
||||
buffer: &mut Vec<u8>,
|
||||
max_len: usize,
|
||||
) -> Result<bool> {
|
||||
macro_rules! read_one_byte_at_a_time_while {
|
||||
($cond:expr) => {
|
||||
while $cond {
|
||||
let byte = reader.read_val::<u8>()?;
|
||||
buffer.push(byte);
|
||||
if byte == 0 {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Handle the first few bytes to make `cur_addr` aligned with `size_of::<usize>`
|
||||
read_one_byte_at_a_time_while!(
|
||||
!is_addr_aligned(reader.cursor() as usize) && buffer.len() < max_len && reader.has_remain()
|
||||
);
|
||||
|
||||
// Handle the rest of the bytes in bulk
|
||||
let mut cloned_reader = reader.clone();
|
||||
while (buffer.len() + mem::size_of::<usize>()) <= max_len {
|
||||
let Ok(word) = cloned_reader.read_val::<usize>() else {
|
||||
break;
|
||||
};
|
||||
|
||||
if has_zero(word) {
|
||||
for byte in word.to_ne_bytes() {
|
||||
reader.skip(1);
|
||||
buffer.push(byte);
|
||||
if byte == 0 {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
unreachable!("The branch should never be reached unless `has_zero` has bugs.")
|
||||
}
|
||||
|
||||
reader.skip(size_of::<usize>());
|
||||
buffer.extend_from_slice(&word.to_ne_bytes());
|
||||
}
|
||||
|
||||
// Handle the last few bytes that are not enough for a word
|
||||
read_one_byte_at_a_time_while!(buffer.len() < max_len && reader.has_remain());
|
||||
|
||||
if buffer.len() >= max_len {
|
||||
return_errno_with_message!(
|
||||
Errno::EFAULT,
|
||||
"no nul terminator is present before exceeding the maximum length"
|
||||
);
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,22 +321,67 @@ mod test {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[ktest]
|
||||
fn read_multiple_cstring() {
|
||||
let mut buffer = vec![0u8; 100];
|
||||
|
||||
let str1 = CString::new("hello").unwrap();
|
||||
let str2 = CString::new("world!").unwrap();
|
||||
fn init_buffer(cstrs: &[CString]) -> Vec<u8> {
|
||||
let mut buffer = vec![255u8; 100];
|
||||
|
||||
let mut writer = VmWriter::from(buffer.as_mut_slice());
|
||||
writer.write(&mut VmReader::from(str1.as_bytes_with_nul()));
|
||||
writer.write(&mut VmReader::from(str2.as_bytes_with_nul()));
|
||||
drop(writer);
|
||||
|
||||
for cstr in cstrs {
|
||||
writer.write(&mut VmReader::from(cstr.as_bytes_with_nul()));
|
||||
}
|
||||
|
||||
buffer
|
||||
}
|
||||
|
||||
#[ktest]
|
||||
fn read_multiple_cstring() {
|
||||
let strs = {
|
||||
let str1 = CString::new("hello").unwrap();
|
||||
let str2 = CString::new("world!").unwrap();
|
||||
vec![str1, str2]
|
||||
};
|
||||
|
||||
let buffer = init_buffer(&strs);
|
||||
|
||||
let mut reader = VmReader::from(buffer.as_slice()).to_fallible();
|
||||
let read_str1 = reader.read_cstring().unwrap();
|
||||
assert_eq!(read_str1, str1);
|
||||
assert_eq!(read_str1, strs[0]);
|
||||
let read_str2 = reader.read_cstring().unwrap();
|
||||
assert_eq!(read_str2, str2);
|
||||
assert_eq!(read_str2, strs[1]);
|
||||
|
||||
assert!(reader
|
||||
.read_cstring()
|
||||
.is_err_and(|err| err.error() == Errno::EFAULT));
|
||||
}
|
||||
|
||||
#[ktest]
|
||||
fn read_cstring_from_multiread() {
|
||||
let strs = {
|
||||
let str1 = CString::new("hello").unwrap();
|
||||
let str2 = CString::new("world!").unwrap();
|
||||
let str3 = CString::new("asterinas").unwrap();
|
||||
vec![str1, str2, str3]
|
||||
};
|
||||
|
||||
let buffer = init_buffer(&strs);
|
||||
|
||||
let mut readers = {
|
||||
let reader1 = VmReader::from(&buffer[0..20]).to_fallible();
|
||||
let reader2 = VmReader::from(&buffer[20..40]).to_fallible();
|
||||
let reader3 = VmReader::from(&buffer[40..60]).to_fallible();
|
||||
VmReaderArray::new(vec![reader1, reader2, reader3].into_boxed_slice())
|
||||
};
|
||||
|
||||
let multiread = &mut readers as &mut dyn MultiRead;
|
||||
let read_str1 = multiread.read_cstring().unwrap();
|
||||
assert_eq!(read_str1, strs[0]);
|
||||
let read_str2 = multiread.read_cstring().unwrap();
|
||||
assert_eq!(read_str2, strs[1]);
|
||||
let read_str3 = multiread.read_cstring().unwrap();
|
||||
assert_eq!(read_str3, strs[2]);
|
||||
|
||||
assert!(multiread
|
||||
.read_cstring()
|
||||
.is_err_and(|err| err.error() == Errno::EFAULT));
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ pub struct VmReaderArray<'a>(Box<[VmReader<'a>]>);
|
||||
pub struct VmWriterArray<'a>(Box<[VmWriter<'a>]>);
|
||||
|
||||
impl<'a> VmReaderArray<'a> {
|
||||
/// Creates a new `IoVecReader` from user-provided io vec buffer.
|
||||
/// Creates a new `VmReaderArray` from user-provided io vec buffer.
|
||||
pub fn from_user_io_vecs(
|
||||
user_space: &'a CurrentUserSpace<'a>,
|
||||
start_addr: Vaddr,
|
||||
@ -105,13 +105,19 @@ impl<'a> VmReaderArray<'a> {
|
||||
}
|
||||
|
||||
/// Returns mutable reference to [`VmReader`]s.
|
||||
pub fn readers_mut(&'a mut self) -> &'a mut [VmReader<'a>] {
|
||||
pub fn readers_mut(&mut self) -> &mut [VmReader<'a>] {
|
||||
&mut self.0
|
||||
}
|
||||
|
||||
/// Creates a new `VmReaderArray`.
|
||||
#[cfg(ktest)]
|
||||
pub const fn new(readers: Box<[VmReader<'a>]>) -> Self {
|
||||
Self(readers)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VmWriterArray<'a> {
|
||||
/// Creates a new `IoVecWriter` from user-provided io vec buffer.
|
||||
/// Creates a new `VmWriterArray` from user-provided io vec buffer.
|
||||
pub fn from_user_io_vecs(
|
||||
user_space: &'a CurrentUserSpace<'a>,
|
||||
start_addr: Vaddr,
|
||||
@ -122,13 +128,13 @@ impl<'a> VmWriterArray<'a> {
|
||||
}
|
||||
|
||||
/// Returns mutable reference to [`VmWriter`]s.
|
||||
pub fn writers_mut(&'a mut self) -> &'a mut [VmWriter<'a>] {
|
||||
pub fn writers_mut(&mut self) -> &mut [VmWriter<'a>] {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait defining the read behavior for a collection of [`VmReader`]s.
|
||||
pub trait MultiRead {
|
||||
pub trait MultiRead: ReadCString {
|
||||
/// Reads the exact number of bytes required to exhaust `self` or fill `writer`,
|
||||
/// accumulating total bytes read.
|
||||
///
|
||||
|
Reference in New Issue
Block a user