diff --git a/kernel/aster-nix/src/vm/vmar/vm_mapping.rs b/kernel/aster-nix/src/vm/vmar/vm_mapping.rs index 04799dde..7c86ec69 100644 --- a/kernel/aster-nix/src/vm/vmar/vm_mapping.rs +++ b/kernel/aster-nix/src/vm/vmar/vm_mapping.rs @@ -538,7 +538,7 @@ impl VmMappingInner { debug_assert!(range.start % PAGE_SIZE == 0); debug_assert!(range.end % PAGE_SIZE == 0); let mut cursor = vm_space.cursor_mut(&range).unwrap(); - cursor.protect(range.len(), |p| p.flags = perms.into(), true)?; + cursor.protect(range.len(), |p| p.flags = perms.into()); Ok(()) } diff --git a/ostd/src/arch/x86/tdx_guest.rs b/ostd/src/arch/x86/tdx_guest.rs index 82ac2c27..92694285 100644 --- a/ostd/src/arch/x86/tdx_guest.rs +++ b/ostd/src/arch/x86/tdx_guest.rs @@ -39,8 +39,8 @@ pub unsafe fn unprotect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), Pag if gpa & PAGE_MASK != 0 { warn!("Misaligned address: {:x}", gpa); } - // Protect the page in the kernel page table. - let pt = KERNEL_PAGE_TABLE.get().unwrap(); + + // Protect the page in the boot page table if in the boot phase. let protect_op = |prop: &mut PageProperty| { *prop = PageProperty { flags: prop.flags, @@ -48,10 +48,6 @@ pub unsafe fn unprotect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), Pag priv_flags: prop.priv_flags | PrivFlags::SHARED, } }; - let vaddr = paddr_to_vaddr(gpa); - pt.protect(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op) - .map_err(|_| PageConvertError::PageTable)?; - // Protect the page in the boot page table if in the boot phase. { let mut boot_pt_lock = BOOT_PAGE_TABLE.lock(); if let Some(boot_pt) = boot_pt_lock.as_mut() { @@ -61,6 +57,12 @@ pub unsafe fn unprotect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), Pag } } } + // Protect the page in the kernel page table. + let pt = KERNEL_PAGE_TABLE.get().unwrap(); + let vaddr = paddr_to_vaddr(gpa); + pt.protect_flush_tlb(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op) + .map_err(|_| PageConvertError::PageTable)?; + map_gpa( (gpa & (!PAGE_MASK)) as u64 | SHARED_MASK, (page_num * PAGE_SIZE) as u64, @@ -82,8 +84,8 @@ pub unsafe fn protect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), PageC if gpa & !PAGE_MASK == 0 { warn!("Misaligned address: {:x}", gpa); } - // Protect the page in the kernel page table. - let pt = KERNEL_PAGE_TABLE.get().unwrap(); + + // Protect the page in the boot page table if in the boot phase. let protect_op = |prop: &mut PageProperty| { *prop = PageProperty { flags: prop.flags, @@ -91,10 +93,6 @@ pub unsafe fn protect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), PageC priv_flags: prop.priv_flags - PrivFlags::SHARED, } }; - let vaddr = paddr_to_vaddr(gpa); - pt.protect(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op) - .map_err(|_| PageConvertError::PageTable)?; - // Protect the page in the boot page table if in the boot phase. { let mut boot_pt_lock = BOOT_PAGE_TABLE.lock(); if let Some(boot_pt) = boot_pt_lock.as_mut() { @@ -104,6 +102,12 @@ pub unsafe fn protect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), PageC } } } + // Protect the page in the kernel page table. + let pt = KERNEL_PAGE_TABLE.get().unwrap(); + let vaddr = paddr_to_vaddr(gpa); + pt.protect_flush_tlb(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op) + .map_err(|_| PageConvertError::PageTable)?; + map_gpa((gpa & PAGE_MASK) as u64, (page_num * PAGE_SIZE) as u64) .map_err(|_| PageConvertError::TdVmcall)?; for i in 0..page_num { diff --git a/ostd/src/mm/dma/dma_coherent.rs b/ostd/src/mm/dma/dma_coherent.rs index 26148f4f..dec990c9 100644 --- a/ostd/src/mm/dma/dma_coherent.rs +++ b/ostd/src/mm/dma/dma_coherent.rs @@ -7,7 +7,7 @@ use cfg_if::cfg_if; use super::{check_and_insert_dma_mapping, remove_dma_mapping, DmaError, HasDaddr}; use crate::{ - arch::{iommu, mm::tlb_flush_addr_range}, + arch::iommu, mm::{ dma::{dma_type, Daddr, DmaType}, io::VmIoOnce, @@ -71,10 +71,9 @@ impl DmaCoherent { // SAFETY: the physical mappings is only used by DMA so protecting it is safe. unsafe { page_table - .protect(&va_range, |p| p.cache = CachePolicy::Uncacheable) + .protect_flush_tlb(&va_range, |p| p.cache = CachePolicy::Uncacheable) .unwrap(); } - tlb_flush_addr_range(&va_range); } let start_daddr = match dma_type() { DmaType::Direct => { @@ -159,10 +158,9 @@ impl Drop for DmaCoherentInner { // SAFETY: the physical mappings is only used by DMA so protecting it is safe. unsafe { page_table - .protect(&va_range, |p| p.cache = CachePolicy::Writeback) + .protect_flush_tlb(&va_range, |p| p.cache = CachePolicy::Writeback) .unwrap(); } - tlb_flush_addr_range(&va_range); } remove_dma_mapping(start_paddr, frame_count); } diff --git a/ostd/src/mm/page_table/cursor.rs b/ostd/src/mm/page_table/cursor.rs index 1e8a1a04..2fda8f58 100644 --- a/ostd/src/mm/page_table/cursor.rs +++ b/ostd/src/mm/page_table/cursor.rs @@ -629,34 +629,41 @@ where PageTableItem::NotMapped { va: start, len } } - /// Applies the given operation to all the mappings within the range. + /// Applies the operation to the next slot of mapping within the range. /// - /// The funtction will return an error if it is not allowed to protect an invalid range and - /// it does so, or if the range to be protected only covers a part of a page. + /// The range to be found in is the current virtual address with the + /// provided length. + /// + /// The function stops and yields the actually protected range if it has + /// actually protected a page, no matter if the following pages are also + /// required to be protected. + /// + /// It also makes the cursor moves forward to the next page after the + /// protected one. If no mapped pages exist in the following range, the + /// cursor will stop at the end of the range and return [`None`]. /// /// # Safety /// - /// The caller should ensure that the range being protected does not affect kernel's memory safety. + /// The caller should ensure that the range being protected with the + /// operation does not affect kernel's memory safety. /// /// # Panics /// /// This function will panic if: - /// - the range to be protected is out of the range where the cursor is required to operate. - pub unsafe fn protect( + /// - the range to be protected is out of the range where the cursor + /// is required to operate; + /// - the specified virtual address range only covers a part of a page. + pub unsafe fn protect_next( &mut self, len: usize, - mut op: impl FnMut(&mut PageProperty), - allow_protect_absent: bool, - ) -> Result<(), PageTableError> { + op: &mut impl FnMut(&mut PageProperty), + ) -> Option> { let end = self.0.va + len; assert!(end <= self.0.barrier_va.end); while self.0.va < end { let cur_pte = self.0.read_cur_pte(); if !cur_pte.is_present() { - if !allow_protect_absent { - return Err(PageTableError::ProtectingAbsent); - } self.0.move_forward(); continue; } @@ -664,18 +671,33 @@ where // Go down if it's not a last node. if !cur_pte.is_last(self.0.level) { self.0.level_down(); + + // We have got down a level. If there's no mapped PTEs in + // the current node, we can go back and skip to save time. + if self.0.guards[(self.0.level - 1) as usize] + .as_ref() + .unwrap() + .nr_children() + == 0 + { + self.0.level_up(); + self.0.move_forward(); + } + continue; } // Go down if the page size is too big and we are protecting part // of untracked huge pages. - let vaddr_not_fit = self.0.va % page_size::(self.0.level) != 0 - || self.0.va + page_size::(self.0.level) > end; - if !self.0.in_tracked_range() && vaddr_not_fit { - self.level_down_split(); - continue; - } else if vaddr_not_fit { - return Err(PageTableError::ProtectingPartial); + if self.0.va % page_size::(self.0.level) != 0 + || self.0.va + page_size::(self.0.level) > end + { + if self.0.in_tracked_range() { + panic!("protecting part of a huge page"); + } else { + self.level_down_split(); + continue; + } } let mut pte_prop = cur_pte.prop(); @@ -683,10 +705,14 @@ where let idx = self.0.cur_idx(); self.cur_node_mut().protect(idx, pte_prop); + let protected_va = self.0.va..self.0.va + page_size::(self.0.level); self.0.move_forward(); + + return Some(protected_va); } - Ok(()) + + None } /// Consumes itself and leak the root guard for the caller if it locked the root level. diff --git a/ostd/src/mm/page_table/mod.rs b/ostd/src/mm/page_table/mod.rs index 47a3273c..22a9614d 100644 --- a/ostd/src/mm/page_table/mod.rs +++ b/ostd/src/mm/page_table/mod.rs @@ -3,9 +3,8 @@ use core::{fmt::Debug, marker::PhantomData, ops::Range}; use super::{ - nr_subpage_per_huge, - page_prop::{PageFlags, PageProperty}, - page_size, Paddr, PagingConstsTrait, PagingLevel, Vaddr, + nr_subpage_per_huge, page_prop::PageProperty, page_size, Paddr, PagingConstsTrait, PagingLevel, + Vaddr, }; use crate::{ arch::mm::{PageTableEntry, PagingConsts}, @@ -29,10 +28,6 @@ pub enum PageTableError { InvalidVaddr(Vaddr), /// Using virtual address not aligned. UnalignedVaddr, - /// Protecting a mapping that does not exist. - ProtectingAbsent, - /// Protecting a part of an already mapped page. - ProtectingPartial, } /// This is a compile-time technique to force the frame developers to distinguish @@ -98,24 +93,18 @@ impl PageTable { } } - /// Remove all write permissions from the user page table and create a cloned - /// new page table. + /// Create a cloned new page table. + /// + /// This method takes a mutable cursor to the old page table that locks the + /// entire virtual address range. The caller may implement the copy-on-write + /// mechanism by first protecting the old page table and then clone it using + /// this method. /// /// TODO: We may consider making the page table itself copy-on-write. - pub fn fork_copy_on_write(&self) -> Self { - let mut cursor = self.cursor_mut(&UserMode::VADDR_RANGE).unwrap(); - - // SAFETY: Protecting the user page table is safe. - unsafe { - cursor - .protect( - UserMode::VADDR_RANGE.len(), - |p: &mut PageProperty| p.flags -= PageFlags::W, - true, - ) - .unwrap(); - }; - + pub fn clone_with( + &self, + cursor: CursorMut<'_, UserMode, PageTableEntry, PagingConsts>, + ) -> Self { let root_node = cursor.leak_root_guard().unwrap(); const NR_PTES_PER_NODE: usize = nr_subpage_per_huge::(); @@ -177,6 +166,26 @@ impl PageTable { } } } + + /// Protect the given virtual address range in the kernel page table. + /// + /// This method flushes the TLB entries when doing protection. + /// + /// # Safety + /// + /// The caller must ensure that the protection operation does not affect + /// the memory safety of the kernel. + pub unsafe fn protect_flush_tlb( + &self, + vaddr: &Range, + mut op: impl FnMut(&mut PageProperty), + ) -> Result<(), PageTableError> { + let mut cursor = CursorMut::new(self, vaddr)?; + while let Some(range) = cursor.protect_next(vaddr.end - cursor.virt_addr(), &mut op) { + crate::arch::mm::tlb_flush_addr(range.start); + } + Ok(()) + } } impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTable @@ -213,17 +222,6 @@ where Ok(()) } - pub unsafe fn protect( - &self, - vaddr: &Range, - op: impl FnMut(&mut PageProperty), - ) -> Result<(), PageTableError> { - self.cursor_mut(vaddr)? - .protect(vaddr.len(), op, true) - .unwrap(); - Ok(()) - } - /// Query about the mapping of a single byte at the given virtual address. /// /// Note that this function may fail reflect an accurate result if there are diff --git a/ostd/src/mm/page_table/test.rs b/ostd/src/mm/page_table/test.rs index 1ad38aed..834289a9 100644 --- a/ostd/src/mm/page_table/test.rs +++ b/ostd/src/mm/page_table/test.rs @@ -8,6 +8,7 @@ use crate::{ kspace::LINEAR_MAPPING_BASE_VADDR, page::{allocator, meta::FrameMeta}, page_prop::{CachePolicy, PageFlags}, + MAX_USERSPACE_VADDR, }, prelude::*, }; @@ -95,7 +96,7 @@ fn test_user_copy_on_write() { unsafe { pt.cursor_mut(&from).unwrap().map(page.clone().into(), prop) }; assert_eq!(pt.query(from.start + 10).unwrap().0, start_paddr + 10); - let child_pt = pt.fork_copy_on_write(); + let child_pt = pt.clone_with(pt.cursor_mut(&(0..MAX_USERSPACE_VADDR)).unwrap()); assert_eq!(pt.query(from.start + 10).unwrap().0, start_paddr + 10); assert_eq!(child_pt.query(from.start + 10).unwrap().0, start_paddr + 10); assert!(matches!( @@ -105,7 +106,7 @@ fn test_user_copy_on_write() { assert!(pt.query(from.start + 10).is_none()); assert_eq!(child_pt.query(from.start + 10).unwrap().0, start_paddr + 10); - let sibling_pt = pt.fork_copy_on_write(); + let sibling_pt = pt.clone_with(pt.cursor_mut(&(0..MAX_USERSPACE_VADDR)).unwrap()); assert!(sibling_pt.query(from.start + 10).is_none()); assert_eq!(child_pt.query(from.start + 10).unwrap().0, start_paddr + 10); drop(pt); @@ -139,6 +140,25 @@ impl PagingConstsTrait for BasePagingConsts { const PTE_SIZE: usize = core::mem::size_of::(); } +impl PageTable +where + [(); C::NR_LEVELS as usize]:, +{ + fn protect(&self, range: &Range, mut op: impl FnMut(&mut PageProperty)) { + let mut cursor = self.cursor_mut(range).unwrap(); + loop { + unsafe { + if cursor + .protect_next(range.end - cursor.virt_addr(), &mut op) + .is_none() + { + break; + } + }; + } + } +} + #[ktest] fn test_base_protect_query() { let pt = PageTable::::empty(); @@ -162,7 +182,7 @@ fn test_base_protect_query() { assert_eq!(va..va + page.size(), i * PAGE_SIZE..(i + 1) * PAGE_SIZE); } let prot = PAGE_SIZE * 18..PAGE_SIZE * 20; - unsafe { pt.protect(&prot, |p| p.flags -= PageFlags::W).unwrap() }; + pt.protect(&prot, |p| p.flags -= PageFlags::W); for (item, i) in pt.cursor(&prot).unwrap().zip(18..20) { let PageTableItem::Mapped { va, page, prop } = item else { panic!("Expected Mapped, got {:#x?}", item); @@ -225,7 +245,7 @@ fn test_untracked_large_protect_query() { } let ppn = from_ppn.start + 18..from_ppn.start + 20; let va = UNTRACKED_OFFSET + PAGE_SIZE * ppn.start..UNTRACKED_OFFSET + PAGE_SIZE * ppn.end; - unsafe { pt.protect(&va, |p| p.flags -= PageFlags::W).unwrap() }; + pt.protect(&va, |p| p.flags -= PageFlags::W); for (item, i) in pt .cursor(&(va.start - PAGE_SIZE..va.start)) .unwrap() diff --git a/ostd/src/mm/vm_space.rs b/ostd/src/mm/vm_space.rs index 1cdb1c5e..6cc66094 100644 --- a/ostd/src/mm/vm_space.rs +++ b/ostd/src/mm/vm_space.rs @@ -17,12 +17,12 @@ use super::{ io::UserSpace, kspace::KERNEL_PAGE_TABLE, page_table::{PageTable, UserMode}, - PageProperty, VmReader, VmWriter, + PageFlags, PageProperty, VmReader, VmWriter, }; use crate::{ arch::mm::{ - current_page_table_paddr, tlb_flush_addr, tlb_flush_addr_range, - tlb_flush_all_excluding_global, PageTableEntry, PagingConsts, + current_page_table_paddr, tlb_flush_addr, tlb_flush_addr_range, PageTableEntry, + PagingConsts, }, cpu::CpuExceptionInfo, mm::{ @@ -129,6 +129,18 @@ impl VmSpace { /// read-only. And both the VM space will take handles to the same /// physical memory pages. pub fn fork_copy_on_write(&self) -> Self { + // Protect the parent VM space as read-only. + let end = MAX_USERSPACE_VADDR; + let mut cursor = self.pt.cursor_mut(&(0..end)).unwrap(); + let mut op = |prop: &mut PageProperty| { + prop.flags -= PageFlags::W; + }; + + // SAFETY: It is safe to protect memory in the userspace. + while let Some(range) = unsafe { cursor.protect_next(end - cursor.virt_addr(), &mut op) } { + tlb_flush_addr(range.start); + } + let page_fault_handler = { let new_handler = Once::new(); if let Some(handler) = self.page_fault_handler.get() { @@ -136,12 +148,11 @@ impl VmSpace { } new_handler }; - let new_space = Self { - pt: self.pt.fork_copy_on_write(), + + Self { + pt: self.pt.clone_with(cursor), page_fault_handler, - }; - tlb_flush_all_excluding_global(); - new_space + } } /// Creates a reader to read data from the user space of the current task. @@ -319,22 +330,14 @@ impl CursorMut<'_> { /// # Panics /// /// This method will panic if `len` is not page-aligned. - pub fn protect( - &mut self, - len: usize, - op: impl FnMut(&mut PageProperty), - allow_protect_absent: bool, - ) -> Result<()> { + pub fn protect(&mut self, len: usize, mut op: impl FnMut(&mut PageProperty)) { assert!(len % super::PAGE_SIZE == 0); - let start_va = self.virt_addr(); - let end_va = start_va + len; + let end = self.0.virt_addr() + len; // SAFETY: It is safe to protect memory in the userspace. - let result = unsafe { self.0.protect(len, op, allow_protect_absent) }; - - tlb_flush_addr_range(&(start_va..end_va)); - - Ok(result?) + while let Some(range) = unsafe { self.0.protect_next(end - self.0.virt_addr(), &mut op) } { + tlb_flush_addr(range.start); + } } } diff --git a/ostd/src/task/task.rs b/ostd/src/task/task.rs index 638591b3..e875678f 100644 --- a/ostd/src/task/task.rs +++ b/ostd/src/task/task.rs @@ -16,7 +16,6 @@ use super::{ }; pub(crate) use crate::arch::task::{context_switch, TaskContext}; use crate::{ - arch::mm::tlb_flush_addr_range, cpu::CpuSet, mm::{kspace::KERNEL_PAGE_TABLE, FrameAllocOptions, Paddr, PageFlags, Segment, PAGE_SIZE}, prelude::*, @@ -70,9 +69,8 @@ impl KernelStack { unsafe { let vaddr_range = guard_page_vaddr..guard_page_vaddr + PAGE_SIZE; page_table - .protect(&vaddr_range, |p| p.flags -= PageFlags::RW) + .protect_flush_tlb(&vaddr_range, |p| p.flags -= PageFlags::RW) .unwrap(); - tlb_flush_addr_range(&vaddr_range); } Ok(Self { segment: stack_segment, @@ -98,9 +96,8 @@ impl Drop for KernelStack { unsafe { let vaddr_range = guard_page_vaddr..guard_page_vaddr + PAGE_SIZE; page_table - .protect(&vaddr_range, |p| p.flags |= PageFlags::RW) + .protect_flush_tlb(&vaddr_range, |p| p.flags |= PageFlags::RW) .unwrap(); - tlb_flush_addr_range(&vaddr_range); } } }