diff --git a/kernel/src/context.rs b/kernel/src/context.rs index 1ccb50fbb..42e1f1f10 100644 --- a/kernel/src/context.rs +++ b/kernel/src/context.rs @@ -17,6 +17,7 @@ use crate::{ Process, }, thread::Thread, + util::{MultiRead, VmReaderArray}, vm::vmar::Vmar, }; @@ -185,52 +186,95 @@ impl ReadCString for VmReader<'_, Fallible> { let max_len = self.remain(); let mut buffer: Vec = Vec::with_capacity(max_len); - macro_rules! read_one_byte_at_a_time_while { - ($cond:expr) => { - while $cond { - let byte = self.read_val::()?; - buffer.push(byte); - if byte == 0 { - return Ok(CString::from_vec_with_nul(buffer) - .expect("We provided 0 but no 0 is found")); - } - } - }; + if read_until_nul_byte(self, &mut buffer, max_len)? { + return Ok(CString::from_vec_with_nul(buffer).unwrap()); } - // Handle the first few bytes to make `cur_addr` aligned with `size_of::` - read_one_byte_at_a_time_while!( - !is_addr_aligned(self.cursor() as usize) && buffer.len() < max_len + return_errno_with_message!( + Errno::EFAULT, + "no nul terminator is present before reaching the buffer limit" ); + } +} - // Handle the rest of the bytes in bulk - let mut cloned_reader = self.clone(); - while (buffer.len() + mem::size_of::()) <= max_len { - let Ok(word) = cloned_reader.read_val::() else { - break; - }; +impl ReadCString for VmReaderArray<'_> { + fn read_cstring(&mut self) -> Result { + let max_len = self.sum_lens(); + let mut buffer: Vec = Vec::with_capacity(max_len); - if has_zero(word) { - for byte in word.to_ne_bytes() { - self.skip(1); - buffer.push(byte); - if byte == 0 { - return Ok(CString::from_vec_with_nul(buffer) - .expect("We provided 0 but no 0 is found")); - } - } - unreachable!("The branch should never be reached unless `has_zero` has bugs.") + for reader in self.readers_mut() { + if read_until_nul_byte(reader, &mut buffer, max_len)? { + return Ok(CString::from_vec_with_nul(buffer).unwrap()); } - - self.skip(size_of::()); - buffer.extend_from_slice(&word.to_ne_bytes()); } - // Handle the last few bytes that are not enough for a word - read_one_byte_at_a_time_while!(buffer.len() < max_len); + return_errno_with_message!( + Errno::EFAULT, + "no nul terminator is present before reaching the buffer limit" + ); + } +} - // Maximum length exceeded before finding the null terminator - return_errno_with_message!(Errno::EFAULT, "Fails to read CString from user"); +/// Reads bytes from `reader` into `buffer` until a nul byte is found. +/// +/// This method returns the following values: +/// 1. `Ok(true)`: If a nul byte is found in the reader; +/// 2. `Ok(false)`: If no nul byte is found and the `reader` is exhausted; +/// 3. `Err(_)`: If an error occurs while reading from the `reader`. +fn read_until_nul_byte( + reader: &mut VmReader, + buffer: &mut Vec, + max_len: usize, +) -> Result { + macro_rules! read_one_byte_at_a_time_while { + ($cond:expr) => { + while $cond { + let byte = reader.read_val::()?; + buffer.push(byte); + if byte == 0 { + return Ok(true); + } + } + }; + } + + // Handle the first few bytes to make `cur_addr` aligned with `size_of::` + read_one_byte_at_a_time_while!( + !is_addr_aligned(reader.cursor() as usize) && buffer.len() < max_len && reader.has_remain() + ); + + // Handle the rest of the bytes in bulk + let mut cloned_reader = reader.clone(); + while (buffer.len() + mem::size_of::()) <= max_len { + let Ok(word) = cloned_reader.read_val::() else { + break; + }; + + if has_zero(word) { + for byte in word.to_ne_bytes() { + reader.skip(1); + buffer.push(byte); + if byte == 0 { + return Ok(true); + } + } + unreachable!("The branch should never be reached unless `has_zero` has bugs.") + } + + reader.skip(size_of::()); + buffer.extend_from_slice(&word.to_ne_bytes()); + } + + // Handle the last few bytes that are not enough for a word + read_one_byte_at_a_time_while!(buffer.len() < max_len && reader.has_remain()); + + if buffer.len() >= max_len { + return_errno_with_message!( + Errno::EFAULT, + "no nul terminator is present before exceeding the maximum length" + ); + } else { + Ok(false) } } @@ -277,22 +321,67 @@ mod test { use super::*; - #[ktest] - fn read_multiple_cstring() { - let mut buffer = vec![0u8; 100]; - - let str1 = CString::new("hello").unwrap(); - let str2 = CString::new("world!").unwrap(); + fn init_buffer(cstrs: &[CString]) -> Vec { + let mut buffer = vec![255u8; 100]; let mut writer = VmWriter::from(buffer.as_mut_slice()); - writer.write(&mut VmReader::from(str1.as_bytes_with_nul())); - writer.write(&mut VmReader::from(str2.as_bytes_with_nul())); - drop(writer); + + for cstr in cstrs { + writer.write(&mut VmReader::from(cstr.as_bytes_with_nul())); + } + + buffer + } + + #[ktest] + fn read_multiple_cstring() { + let strs = { + let str1 = CString::new("hello").unwrap(); + let str2 = CString::new("world!").unwrap(); + vec![str1, str2] + }; + + let buffer = init_buffer(&strs); let mut reader = VmReader::from(buffer.as_slice()).to_fallible(); let read_str1 = reader.read_cstring().unwrap(); - assert_eq!(read_str1, str1); + assert_eq!(read_str1, strs[0]); let read_str2 = reader.read_cstring().unwrap(); - assert_eq!(read_str2, str2); + assert_eq!(read_str2, strs[1]); + + assert!(reader + .read_cstring() + .is_err_and(|err| err.error() == Errno::EFAULT)); + } + + #[ktest] + fn read_cstring_from_multiread() { + let strs = { + let str1 = CString::new("hello").unwrap(); + let str2 = CString::new("world!").unwrap(); + let str3 = CString::new("asterinas").unwrap(); + vec![str1, str2, str3] + }; + + let buffer = init_buffer(&strs); + + let mut readers = { + let reader1 = VmReader::from(&buffer[0..20]).to_fallible(); + let reader2 = VmReader::from(&buffer[20..40]).to_fallible(); + let reader3 = VmReader::from(&buffer[40..60]).to_fallible(); + VmReaderArray::new(vec![reader1, reader2, reader3].into_boxed_slice()) + }; + + let multiread = &mut readers as &mut dyn MultiRead; + let read_str1 = multiread.read_cstring().unwrap(); + assert_eq!(read_str1, strs[0]); + let read_str2 = multiread.read_cstring().unwrap(); + assert_eq!(read_str2, strs[1]); + let read_str3 = multiread.read_cstring().unwrap(); + assert_eq!(read_str3, strs[2]); + + assert!(multiread + .read_cstring() + .is_err_and(|err| err.error() == Errno::EFAULT)); } } diff --git a/kernel/src/util/iovec.rs b/kernel/src/util/iovec.rs index 983d5116c..9192534bd 100644 --- a/kernel/src/util/iovec.rs +++ b/kernel/src/util/iovec.rs @@ -94,7 +94,7 @@ pub struct VmReaderArray<'a>(Box<[VmReader<'a>]>); pub struct VmWriterArray<'a>(Box<[VmWriter<'a>]>); impl<'a> VmReaderArray<'a> { - /// Creates a new `IoVecReader` from user-provided io vec buffer. + /// Creates a new `VmReaderArray` from user-provided io vec buffer. pub fn from_user_io_vecs( user_space: &'a CurrentUserSpace<'a>, start_addr: Vaddr, @@ -105,13 +105,19 @@ impl<'a> VmReaderArray<'a> { } /// Returns mutable reference to [`VmReader`]s. - pub fn readers_mut(&'a mut self) -> &'a mut [VmReader<'a>] { + pub fn readers_mut(&mut self) -> &mut [VmReader<'a>] { &mut self.0 } + + /// Creates a new `VmReaderArray`. + #[cfg(ktest)] + pub const fn new(readers: Box<[VmReader<'a>]>) -> Self { + Self(readers) + } } impl<'a> VmWriterArray<'a> { - /// Creates a new `IoVecWriter` from user-provided io vec buffer. + /// Creates a new `VmWriterArray` from user-provided io vec buffer. pub fn from_user_io_vecs( user_space: &'a CurrentUserSpace<'a>, start_addr: Vaddr, @@ -122,13 +128,13 @@ impl<'a> VmWriterArray<'a> { } /// Returns mutable reference to [`VmWriter`]s. - pub fn writers_mut(&'a mut self) -> &'a mut [VmWriter<'a>] { + pub fn writers_mut(&mut self) -> &mut [VmWriter<'a>] { &mut self.0 } } /// Trait defining the read behavior for a collection of [`VmReader`]s. -pub trait MultiRead { +pub trait MultiRead: ReadCString { /// Reads the exact number of bytes required to exhaust `self` or fill `writer`, /// accumulating total bytes read. ///