diff --git a/kernel/src/process/process_vm/init_stack/mod.rs b/kernel/src/process/process_vm/init_stack/mod.rs index bebb31c3..55b9459d 100644 --- a/kernel/src/process/process_vm/init_stack/mod.rs +++ b/kernel/src/process/process_vm/init_stack/mod.rs @@ -20,7 +20,10 @@ use core::{ use align_ext::AlignExt; use aster_rights::Full; -use ostd::mm::{vm_space::VmItem, UntypedMem, VmIo, MAX_USERSPACE_VADDR}; +use ostd::{ + mm::{vm_space::VmItem, UntypedMem, VmIo, MAX_USERSPACE_VADDR}, + task::disable_preempt, +}; use self::aux_vec::{AuxKey, AuxVec}; use super::ProcessVmarGuard; @@ -386,7 +389,11 @@ impl InitStackReader<'_> { let page_base_addr = stack_base.align_down(PAGE_SIZE); let vm_space = self.vmar.unwrap().vm_space(); - let mut cursor = vm_space.cursor(&(page_base_addr..page_base_addr + PAGE_SIZE))?; + let preempt_guard = disable_preempt(); + let mut cursor = vm_space.cursor( + &preempt_guard, + &(page_base_addr..page_base_addr + PAGE_SIZE), + )?; let VmItem::Mapped { frame, .. } = cursor.query()? else { return_errno_with_message!(Errno::EACCES, "Page not accessible"); }; @@ -410,7 +417,11 @@ impl InitStackReader<'_> { let page_base_addr = read_offset.align_down(PAGE_SIZE); let vm_space = self.vmar.unwrap().vm_space(); - let mut cursor = vm_space.cursor(&(page_base_addr..page_base_addr + PAGE_SIZE))?; + let preempt_guard = disable_preempt(); + let mut cursor = vm_space.cursor( + &preempt_guard, + &(page_base_addr..page_base_addr + PAGE_SIZE), + )?; let VmItem::Mapped { frame, .. } = cursor.query()? else { return_errno_with_message!(Errno::EACCES, "Page not accessible"); }; @@ -450,7 +461,11 @@ impl InitStackReader<'_> { let page_base_addr = read_offset.align_down(PAGE_SIZE); let vm_space = self.vmar.unwrap().vm_space(); - let mut cursor = vm_space.cursor(&(page_base_addr..page_base_addr + PAGE_SIZE))?; + let preempt_guard = disable_preempt(); + let mut cursor = vm_space.cursor( + &preempt_guard, + &(page_base_addr..page_base_addr + PAGE_SIZE), + )?; let VmItem::Mapped { frame, .. } = cursor.query()? else { return_errno_with_message!(Errno::EACCES, "Page not accessible"); }; diff --git a/kernel/src/process/program_loader/elf/load_elf.rs b/kernel/src/process/program_loader/elf/load_elf.rs index 21113781..5b0d939c 100644 --- a/kernel/src/process/program_loader/elf/load_elf.rs +++ b/kernel/src/process/program_loader/elf/load_elf.rs @@ -7,7 +7,10 @@ use align_ext::AlignExt; use aster_rights::Full; -use ostd::mm::{CachePolicy, PageFlags, PageProperty, VmIo}; +use ostd::{ + mm::{CachePolicy, PageFlags, PageProperty, VmIo}, + task::disable_preempt, +}; use xmas_elf::program::{self, ProgramHeader64}; use super::elf_file::Elf; @@ -311,9 +314,10 @@ fn map_segment_vmo( // Tail padding: If the segment's mem_size is larger than file size, // then the bytes that are not backed up by file content should be zeros.(usually .data/.bss sections). + let preempt_guard = disable_preempt(); let mut cursor = root_vmar .vm_space() - .cursor_mut(&(map_addr..map_addr + segment_size))?; + .cursor_mut(&preempt_guard, &(map_addr..map_addr + segment_size))?; let page_flags = PageFlags::from(perms) | PageFlags::ACCESSED; // Head padding. diff --git a/kernel/src/vm/vmar/mod.rs b/kernel/src/vm/vmar/mod.rs index 0b59cf5a..da086ca0 100644 --- a/kernel/src/vm/vmar/mod.rs +++ b/kernel/src/vm/vmar/mod.rs @@ -11,7 +11,10 @@ use core::{num::NonZeroUsize, ops::Range}; use align_ext::AlignExt; use aster_rights::Rights; -use ostd::mm::{tlb::TlbFlushOp, PageFlags, PageProperty, VmSpace, MAX_USERSPACE_VADDR}; +use ostd::{ + mm::{tlb::TlbFlushOp, PageFlags, PageProperty, VmSpace, MAX_USERSPACE_VADDR}, + task::disable_preempt, +}; use self::{ interval_set::{Interval, IntervalSet}, @@ -355,16 +358,19 @@ impl Vmar_ { /// Clears all content of the root VMAR. fn clear_root_vmar(&self) -> Result<()> { - { - let full_range = 0..MAX_USERSPACE_VADDR; - let mut cursor = self.vm_space.cursor_mut(&full_range).unwrap(); - cursor.unmap(full_range.len()); - cursor.flusher().sync_tlb_flush(); - } - { - let mut inner = self.inner.write(); - inner.vm_mappings.clear(); - } + let mut inner = self.inner.write(); + inner.vm_mappings.clear(); + + // Keep `inner` locked to avoid race conditions. + let preempt_guard = disable_preempt(); + let full_range = 0..MAX_USERSPACE_VADDR; + let mut cursor = self + .vm_space + .cursor_mut(&preempt_guard, &full_range) + .unwrap(); + cursor.unmap(full_range.len()); + cursor.flusher().sync_tlb_flush(); + Ok(()) } @@ -428,11 +434,12 @@ impl Vmar_ { let mut new_inner = new_vmar_.inner.write(); // Clone mappings. + let preempt_guard = disable_preempt(); let new_vmspace = new_vmar_.vm_space(); let range = self.base..(self.base + self.size); - let mut new_cursor = new_vmspace.cursor_mut(&range).unwrap(); + let mut new_cursor = new_vmspace.cursor_mut(&preempt_guard, &range).unwrap(); let cur_vmspace = self.vm_space(); - let mut cur_cursor = cur_vmspace.cursor_mut(&range).unwrap(); + let mut cur_cursor = cur_vmspace.cursor_mut(&preempt_guard, &range).unwrap(); for vm_mapping in inner.vm_mappings.iter() { let base = vm_mapping.map_to_addr(); diff --git a/kernel/src/vm/vmar/vm_mapping.rs b/kernel/src/vm/vmar/vm_mapping.rs index 3c8ceb8b..3d9a5d7b 100644 --- a/kernel/src/vm/vmar/vm_mapping.rs +++ b/kernel/src/vm/vmar/vm_mapping.rs @@ -7,9 +7,12 @@ use core::{ }; use align_ext::AlignExt; -use ostd::mm::{ - tlb::TlbFlushOp, vm_space::VmItem, CachePolicy, FrameAllocOptions, PageFlags, PageProperty, - UFrame, VmSpace, +use ostd::{ + mm::{ + tlb::TlbFlushOp, vm_space::VmItem, CachePolicy, FrameAllocOptions, PageFlags, PageProperty, + UFrame, VmSpace, + }, + task::disable_preempt, }; use super::interval_set::Interval; @@ -152,8 +155,11 @@ impl VmMapping { // Errors caused by the "around" pages should be ignored, so here we // only return the error if the faulting page is still not mapped. if res.is_err() { - let mut cursor = - vm_space.cursor(&(page_aligned_addr..page_aligned_addr + PAGE_SIZE))?; + let preempt_guard = disable_preempt(); + let mut cursor = vm_space.cursor( + &preempt_guard, + &(page_aligned_addr..page_aligned_addr + PAGE_SIZE), + )?; if let VmItem::Mapped { .. } = cursor.query().unwrap() { return Ok(()); } @@ -163,8 +169,11 @@ impl VmMapping { } 'retry: loop { - let mut cursor = - vm_space.cursor_mut(&(page_aligned_addr..page_aligned_addr + PAGE_SIZE))?; + let preempt_guard = disable_preempt(); + let mut cursor = vm_space.cursor_mut( + &preempt_guard, + &(page_aligned_addr..page_aligned_addr + PAGE_SIZE), + )?; match cursor.query().unwrap() { VmItem::Mapped { @@ -213,6 +222,7 @@ impl VmMapping { Err(VmoCommitError::Err(e)) => return Err(e), Err(VmoCommitError::NeedIo(index)) => { drop(cursor); + drop(preempt_guard); self.vmo .as_ref() .unwrap() @@ -291,7 +301,8 @@ impl VmMapping { let vm_perms = self.perms - VmPerms::WRITE; 'retry: loop { - let mut cursor = vm_space.cursor_mut(&(start_addr..end_addr))?; + let preempt_guard = disable_preempt(); + let mut cursor = vm_space.cursor_mut(&preempt_guard, &(start_addr..end_addr))?; let operate = move |commit_fn: &mut dyn FnMut() -> core::result::Result| { @@ -317,6 +328,7 @@ impl VmMapping { match vmo.try_operate_on_range(&(start_offset..end_offset), operate) { Ok(_) => return Ok(()), Err(VmoCommitError::NeedIo(index)) => { + drop(preempt_guard); vmo.commit_on(index, CommitFlags::empty())?; start_addr = index * PAGE_SIZE + self.map_to_addr; continue 'retry; @@ -419,8 +431,10 @@ impl VmMapping { impl VmMapping { /// Unmaps the mapping from the VM space. pub(super) fn unmap(self, vm_space: &VmSpace) -> Result<()> { + let preempt_guard = disable_preempt(); let range = self.range(); - let mut cursor = vm_space.cursor_mut(&range)?; + let mut cursor = vm_space.cursor_mut(&preempt_guard, &range)?; + cursor.unmap(range.len()); cursor.flusher().dispatch_tlb_flush(); cursor.flusher().sync_tlb_flush(); @@ -430,9 +444,9 @@ impl VmMapping { /// Change the perms of the mapping. pub(super) fn protect(self, vm_space: &VmSpace, perms: VmPerms) -> Self { + let preempt_guard = disable_preempt(); let range = self.range(); - - let mut cursor = vm_space.cursor_mut(&range).unwrap(); + let mut cursor = vm_space.cursor_mut(&preempt_guard, &range).unwrap(); let op = |p: &mut PageProperty| p.flags = perms.into(); while cursor.virt_addr() < range.end { diff --git a/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs b/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs index f0233574..bbadd5b1 100644 --- a/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs +++ b/osdk/tests/examples_in_book/write_a_kernel_in_100_lines_templates/lib.rs @@ -18,7 +18,7 @@ use ostd::mm::{ VmWriter, PAGE_SIZE, }; use ostd::prelude::*; -use ostd::task::{Task, TaskOptions}; +use ostd::task::{disable_preempt, Task, TaskOptions}; use ostd::user::{ReturnReason, UserMode}; /// The kernel's boot and initialization process is managed by OSTD. @@ -51,7 +51,10 @@ fn create_vm_space(program: &[u8]) -> VmSpace { // the `VmSpace` abstraction. let vm_space = VmSpace::new(); const MAP_ADDR: Vaddr = 0x0040_0000; // The map addr for statically-linked executable - let mut cursor = vm_space.cursor_mut(&(MAP_ADDR..MAP_ADDR + nbytes)).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = vm_space + .cursor_mut(&preempt_guard, &(MAP_ADDR..MAP_ADDR + nbytes)) + .unwrap(); let map_prop = PageProperty::new(PageFlags::RWX, CachePolicy::Writeback); for frame in user_pages { cursor.map(frame.into(), map_prop); diff --git a/ostd/src/arch/x86/iommu/dma_remapping/context_table.rs b/ostd/src/arch/x86/iommu/dma_remapping/context_table.rs index f5553218..6cc50862 100644 --- a/ostd/src/arch/x86/iommu/dma_remapping/context_table.rs +++ b/ostd/src/arch/x86/iommu/dma_remapping/context_table.rs @@ -17,6 +17,7 @@ use crate::{ page_table::{PageTableError, PageTableItem}, Frame, FrameAllocOptions, Paddr, PageFlags, PageTable, VmIo, PAGE_SIZE, }, + task::disable_preempt, }; /// Bit 0 is `Present` bit, indicating whether this entry is present. @@ -323,7 +324,10 @@ impl ContextTable { } trace!("Unmapping Daddr: {:x?} for device: {:x?}", daddr, device); let pt = self.get_or_create_page_table(device); - let mut cursor = pt.cursor_mut(&(daddr..daddr + PAGE_SIZE)).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = pt + .cursor_mut(&preempt_guard, &(daddr..daddr + PAGE_SIZE)) + .unwrap(); unsafe { let result = cursor.take_next(PAGE_SIZE); debug_assert!(matches!(result, PageTableItem::MappedUntracked { .. })); diff --git a/ostd/src/mm/kspace/kvirt_area.rs b/ostd/src/mm/kspace/kvirt_area.rs index d70f7320..c85847be 100644 --- a/ostd/src/mm/kspace/kvirt_area.rs +++ b/ostd/src/mm/kspace/kvirt_area.rs @@ -12,6 +12,7 @@ use crate::{ page_table::PageTableItem, Paddr, Vaddr, PAGE_SIZE, }, + task::disable_preempt, util::range_alloc::RangeAllocator, }; @@ -89,7 +90,8 @@ impl KVirtArea { let start = addr.align_down(PAGE_SIZE); let vaddr = start..start + PAGE_SIZE; let page_table = KERNEL_PAGE_TABLE.get().unwrap(); - let mut cursor = page_table.cursor(&vaddr).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = page_table.cursor(&preempt_guard, &vaddr).unwrap(); cursor.query().unwrap() } } @@ -117,7 +119,10 @@ impl KVirtArea { let range = Tracked::select_allocator().alloc(area_size).unwrap(); let cursor_range = range.start + map_offset..range.end; let page_table = KERNEL_PAGE_TABLE.get().unwrap(); - let mut cursor = page_table.cursor_mut(&cursor_range).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = page_table + .cursor_mut(&preempt_guard, &cursor_range) + .unwrap(); for page in pages.into_iter() { // SAFETY: The constructor of the `KVirtArea` structure // has already ensured that this mapping does not affect kernel's @@ -187,7 +192,8 @@ impl KVirtArea { let va_range = range.start + map_offset..range.start + map_offset + pa_range.len(); let page_table = KERNEL_PAGE_TABLE.get().unwrap(); - let mut cursor = page_table.cursor_mut(&va_range).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &va_range).unwrap(); // SAFETY: The caller of `map_untracked_pages` has ensured the safety of this mapping. unsafe { cursor.map_pa(&pa_range, prop); @@ -229,7 +235,8 @@ impl Drop for KVirtArea { // 1. unmap all mapped pages. let page_table = KERNEL_PAGE_TABLE.get().unwrap(); let range = self.start()..self.end(); - let mut cursor = page_table.cursor_mut(&range).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &range).unwrap(); loop { let result = unsafe { cursor.take_next(self.end() - cursor.virt_addr()) }; if matches!(&result, PageTableItem::NotMapped { .. }) { diff --git a/ostd/src/mm/kspace/mod.rs b/ostd/src/mm/kspace/mod.rs index b291f9b8..25a90b77 100644 --- a/ostd/src/mm/kspace/mod.rs +++ b/ostd/src/mm/kspace/mod.rs @@ -57,6 +57,7 @@ use super::{ use crate::{ arch::mm::{PageTableEntry, PagingConsts}, boot::memory_region::MemoryRegionType, + task::disable_preempt, }; /// The shortest supported address width is 39 bits. And the literal @@ -134,6 +135,7 @@ pub fn init_kernel_page_table(meta_pages: Segment) { // Start to initialize the kernel page table. let kpt = PageTable::::new_kernel_page_table(); + let preempt_guard = disable_preempt(); // Do linear mappings for the kernel. { @@ -160,7 +162,7 @@ pub fn init_kernel_page_table(meta_pages: Segment) { cache: CachePolicy::Writeback, priv_flags: PrivilegedPageFlags::GLOBAL, }; - let mut cursor = kpt.cursor_mut(&from).unwrap(); + let mut cursor = kpt.cursor_mut(&preempt_guard, &from).unwrap(); for meta_page in meta_pages { // SAFETY: we are doing the metadata mappings for the kernel. unsafe { @@ -202,7 +204,7 @@ pub fn init_kernel_page_table(meta_pages: Segment) { cache: CachePolicy::Writeback, priv_flags: PrivilegedPageFlags::GLOBAL, }; - let mut cursor = kpt.cursor_mut(&from).unwrap(); + let mut cursor = kpt.cursor_mut(&preempt_guard, &from).unwrap(); for frame_paddr in to.step_by(PAGE_SIZE) { // SAFETY: They were initialized at `super::frame::meta::init`. let page = unsafe { Frame::::from_raw(frame_paddr) }; diff --git a/ostd/src/mm/page_table/cursor/locking.rs b/ostd/src/mm/page_table/cursor/locking.rs index fc734596..dbc8e218 100644 --- a/ostd/src/mm/page_table/cursor/locking.rs +++ b/ostd/src/mm/page_table/cursor/locking.rs @@ -2,7 +2,7 @@ //! Implementation of the locking protocol. -use core::{marker::PhantomData, ops::Range, sync::atomic::Ordering}; +use core::{marker::PhantomData, mem::ManuallyDrop, ops::Range, sync::atomic::Ordering}; use align_ext::AlignExt; @@ -15,19 +15,17 @@ use crate::{ PageTableEntryTrait, PageTableGuard, PageTableMode, PageTableNodeRef, PagingConstsTrait, PagingLevel, }, - Paddr, Vaddr, + Vaddr, }, - task::disable_preempt, + task::atomic_mode::InAtomicMode, }; -pub(super) fn lock_range<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait>( - pt: &'a PageTable, +pub(super) fn lock_range<'rcu, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait>( + pt: &'rcu PageTable, + guard: &'rcu dyn InAtomicMode, va: &Range, new_pt_is_tracked: MapTrackingStatus, -) -> Cursor<'a, M, E, C> { - // Start RCU read-side critical section. - let preempt_guard = disable_preempt(); - +) -> Cursor<'rcu, M, E, C> { // The re-try loop of finding the sub-tree root. // // If we locked a stray node, we need to re-try. Otherwise, although @@ -35,7 +33,9 @@ pub(super) fn lock_range<'a, M: PageTableMode, E: PageTableEntryTrait, C: Paging // sub-tree will not see the current state and will not change the current // state, breaking serializability. let mut subtree_root = loop { - if let Some(subtree_root) = try_traverse_and_lock_subtree_root(pt, va, new_pt_is_tracked) { + if let Some(subtree_root) = + try_traverse_and_lock_subtree_root(pt, guard, va, new_pt_is_tracked) + { break subtree_root; } }; @@ -44,18 +44,18 @@ pub(super) fn lock_range<'a, M: PageTableMode, E: PageTableEntryTrait, C: Paging // stray nodes in the following traversal since we must lock before reading. let guard_level = subtree_root.level(); let cur_node_va = va.start.align_down(page_size::(guard_level + 1)); - dfs_acquire_lock(&mut subtree_root, cur_node_va, va.clone()); + dfs_acquire_lock(guard, &mut subtree_root, cur_node_va, va.clone()); let mut path = core::array::from_fn(|_| None); path[guard_level as usize - 1] = Some(subtree_root); - Cursor::<'a, M, E, C> { + Cursor::<'rcu, M, E, C> { path, + rcu_guard: guard, level: guard_level, guard_level, va: va.start, barrier_va: va.clone(), - preempt_guard, _phantom: PhantomData, } } @@ -65,7 +65,7 @@ pub(super) fn unlock_range(cursor.guard_level + 1); // SAFETY: A cursor maintains that its corresponding sub-tree is locked. - unsafe { dfs_release_lock(guard_node, cur_node_va, cursor.barrier_va.clone()) }; + unsafe { + dfs_release_lock( + cursor.rcu_guard, + guard_node, + cur_node_va, + cursor.barrier_va.clone(), + ) + }; } /// Finds and locks an intermediate page table node that covers the range. @@ -86,31 +93,16 @@ pub(super) fn unlock_range( - pt: &'a PageTable, + pt: &PageTable, + guard: &'rcu dyn InAtomicMode, va: &Range, new_pt_is_tracked: MapTrackingStatus, -) -> Option> { - // # Safety - // Must be called with `cur_pt_addr` and `'a`, `E`, `E` of the residing function. - unsafe fn lock_cur_pt<'a, E: PageTableEntryTrait, C: PagingConstsTrait>( - cur_pt_addr: Paddr, - ) -> PageTableGuard<'a, E, C> { - // SAFETY: The reference is valid for `'a` because the page table - // is alive within `'a` and `'a` is under the RCU read guard. - let ptn_ref = unsafe { PageTableNodeRef::<'a, E, C>::borrow_paddr(cur_pt_addr) }; - // Forfeit a guard protecting a node that lives for `'a` rather - // than the lifetime of `ptn_ref`. - let pt_addr = ptn_ref.lock().into_raw_paddr(); - // SAFETY: The lock guard was forgotten at the above line. We manually - // ensure that the protected node lives for `'a`. - unsafe { PageTableGuard::<'a, E, C>::from_raw_paddr(pt_addr) } - } - +) -> Option> { let mut cur_node_guard: Option> = None; let mut cur_pt_addr = pt.root.start_paddr(); for cur_level in (1..=C::NR_LEVELS).rev() { @@ -141,17 +133,19 @@ fn try_traverse_and_lock_subtree_root< } // In case the child is absent, we should lock and allocate a new page table node. - // SAFETY: It is called with the required parameters. - let mut guard = cur_node_guard - .take() - .unwrap_or_else(|| unsafe { lock_cur_pt::<'a, E, C>(cur_pt_addr) }); - if *guard.stray_mut() { + let mut pt_guard = cur_node_guard.take().unwrap_or_else(|| { + // SAFETY: The node must be alive for at least `'rcu` since the + // address is read from the page table node. + let node_ref = unsafe { PageTableNodeRef::<'rcu, E, C>::borrow_paddr(cur_pt_addr) }; + node_ref.lock(guard) + }); + if *pt_guard.stray_mut() { return None; } - let mut cur_entry = guard.entry(start_idx); + let mut cur_entry = pt_guard.entry(start_idx); if cur_entry.is_none() { - let allocated_guard = cur_entry.alloc_if_none(new_pt_is_tracked).unwrap(); + let allocated_guard = cur_entry.alloc_if_none(guard, new_pt_is_tracked).unwrap(); cur_pt_addr = allocated_guard.start_paddr(); cur_node_guard = Some(allocated_guard); } else if cur_entry.is_node() { @@ -165,14 +159,17 @@ fn try_traverse_and_lock_subtree_root< } } - // SAFETY: It is called with the required parameters. - let mut guard = - cur_node_guard.unwrap_or_else(|| unsafe { lock_cur_pt::<'a, E, C>(cur_pt_addr) }); - if *guard.stray_mut() { + let mut pt_guard = cur_node_guard.unwrap_or_else(|| { + // SAFETY: The node must be alive for at least `'rcu` since the + // address is read from the page table node. + let node_ref = unsafe { PageTableNodeRef::<'rcu, E, C>::borrow_paddr(cur_pt_addr) }; + node_ref.lock(guard) + }); + if *pt_guard.stray_mut() { return None; } - Some(guard) + Some(pt_guard) } /// Acquires the locks for the given range in the sub-tree rooted at the node. @@ -180,14 +177,15 @@ fn try_traverse_and_lock_subtree_root< /// `cur_node_va` must be the virtual address of the `cur_node`. The `va_range` /// must be within the range of the `cur_node`. The range must not be empty. /// -/// The function will forget all the [`PageTableGuard`] objects in the sub-tree -/// with [`PageTableGuard::into_raw_paddr`]. +/// The function will forget all the [`PageTableGuard`] objects in the sub-tree. fn dfs_acquire_lock( + guard: &dyn InAtomicMode, cur_node: &mut PageTableGuard<'_, E, C>, cur_node_va: Vaddr, va_range: Range, ) { debug_assert!(!*cur_node.stray_mut()); + let cur_level = cur_node.level(); if cur_level <= 1 { return; @@ -198,13 +196,13 @@ fn dfs_acquire_lock( let child = cur_node.entry(i); match child.to_ref() { Child::PageTableRef(pt) => { - let mut pt_guard = pt.lock(); + let mut pt_guard = pt.lock(guard); let child_node_va = cur_node_va + i * page_size::(cur_level); let child_node_va_end = child_node_va + page_size::(cur_level); let va_start = va_range.start.max(child_node_va); let va_end = va_range.end.min(child_node_va_end); - dfs_acquire_lock(&mut pt_guard, child_node_va, va_start..va_end); - let _ = pt_guard.into_raw_paddr(); + dfs_acquire_lock(guard, &mut pt_guard, child_node_va, va_start..va_end); + let _ = ManuallyDrop::new(pt_guard); } Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {} } @@ -215,9 +213,11 @@ fn dfs_acquire_lock( /// /// # Safety /// -/// The caller must ensure that the nodes in the specified sub-tree are locked. -unsafe fn dfs_release_lock( - mut cur_node: PageTableGuard, +/// The caller must ensure that the nodes in the specified sub-tree are locked +/// and all guards are forgotten. +unsafe fn dfs_release_lock<'rcu, E: PageTableEntryTrait, C: PagingConstsTrait>( + guard: &'rcu dyn InAtomicMode, + mut cur_node: PageTableGuard<'rcu, E, C>, cur_node_va: Vaddr, va_range: Range, ) { @@ -231,16 +231,14 @@ unsafe fn dfs_release_lock( let child = cur_node.entry(i); match child.to_ref() { Child::PageTableRef(pt) => { - // SAFETY: The node was locked before and we have a - // reference to the parent node that is still alive. - let child_node = - unsafe { PageTableGuard::::from_raw_paddr(pt.start_paddr()) }; + // SAFETY: The caller ensures that the node is locked. + let child_node = unsafe { pt.make_guard_unchecked(guard) }; let child_node_va = cur_node_va + i * page_size::(cur_level); let child_node_va_end = child_node_va + page_size::(cur_level); let va_start = va_range.start.max(child_node_va); let va_end = va_range.end.min(child_node_va_end); // SAFETY: The caller ensures that this sub-tree is locked. - unsafe { dfs_release_lock(child_node, child_node_va, va_start..va_end) }; + unsafe { dfs_release_lock(guard, child_node, child_node_va, va_start..va_end) }; } Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {} } @@ -256,16 +254,18 @@ unsafe fn dfs_release_lock( /// /// # Safety /// -/// The caller must ensure that all the nodes in the sub-tree are locked. +/// The caller must ensure that all the nodes in the sub-tree are locked +/// and all guards are forgotten. /// -/// This function must not be called upon a shared node. E.g., the second- +/// This function must not be called upon a shared node, e.g., the second- /// top level nodes that the kernel space and user space share. pub(super) unsafe fn dfs_mark_stray_and_unlock( + rcu_guard: &dyn InAtomicMode, mut sub_tree: PageTableGuard, ) { *sub_tree.stray_mut() = true; - if sub_tree.level() <= 1 { + if sub_tree.level() > 1 { return; } @@ -274,8 +274,8 @@ pub(super) unsafe fn dfs_mark_stray_and_unlock { // SAFETY: The caller ensures that the node is locked. - let locked_pt = unsafe { PageTableGuard::::from_raw_paddr(pt.start_paddr()) }; - dfs_mark_stray_and_unlock(locked_pt); + let locked_pt = unsafe { pt.make_guard_unchecked(rcu_guard) }; + dfs_mark_stray_and_unlock(rcu_guard, locked_pt); } Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {} } diff --git a/ostd/src/mm/page_table/cursor/mod.rs b/ostd/src/mm/page_table/cursor/mod.rs index 928c2bb2..0045920f 100644 --- a/ostd/src/mm/page_table/cursor/mod.rs +++ b/ostd/src/mm/page_table/cursor/mod.rs @@ -29,11 +29,7 @@ mod locking; -use core::{ - any::TypeId, - marker::PhantomData, - ops::{Deref, Range}, -}; +use core::{any::TypeId, fmt::Debug, marker::PhantomData, mem::ManuallyDrop, ops::Range}; use align_ext::AlignExt; @@ -47,7 +43,7 @@ use crate::{ frame::{meta::AnyFrameMeta, Frame}, Paddr, PageProperty, Vaddr, }, - task::DisabledPreemptGuard, + task::atomic_mode::InAtomicMode, }; /// The cursor for traversal over the page table. @@ -58,12 +54,14 @@ use crate::{ /// A cursor is able to move to the next slot, to read page properties, /// and even to jump to a virtual address directly. #[derive(Debug)] -pub struct Cursor<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> { +pub struct Cursor<'rcu, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> { /// The current path of the cursor. /// /// The level 1 page table lock guard is at index 0, and the level N page /// table lock guard is at index N - 1. - path: [Option>; MAX_NR_LEVELS], + path: [Option>; MAX_NR_LEVELS], + /// The cursor should be used in a RCU read side critical section. + rcu_guard: &'rcu dyn InAtomicMode, /// The level of the page table that the cursor currently points to. level: PagingLevel, /// The top-most level that the cursor is allowed to access. @@ -74,11 +72,7 @@ pub struct Cursor<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsT va: Vaddr, /// The virtual address range that is locked. barrier_va: Range, - /// This also make all the operation in the `Cursor::new` performed in a - /// RCU read-side critical section. - #[expect(dead_code)] - preempt_guard: DisabledPreemptGuard, - _phantom: PhantomData<&'a PageTable>, + _phantom: PhantomData<&'rcu PageTable>, } /// The maximum value of `PagingConstsTrait::NR_LEVELS`. @@ -111,13 +105,17 @@ pub enum PageTableItem { }, } -impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor<'a, M, E, C> { +impl<'rcu, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor<'rcu, M, E, C> { /// Creates a cursor claiming exclusive access over the given range. /// /// The cursor created will only be able to query or jump within the given /// range. Out-of-bound accesses will result in panics or errors as return values, /// depending on the access method. - pub fn new(pt: &'a PageTable, va: &Range) -> Result { + pub fn new( + pt: &'rcu PageTable, + guard: &'rcu dyn InAtomicMode, + va: &Range, + ) -> Result { if !M::covers(va) || va.is_empty() { return Err(PageTableError::InvalidVaddrRange(va.start, va.end)); } @@ -133,7 +131,7 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor< MapTrackingStatus::Untracked }; - Ok(locking::lock_range(pt, va, new_pt_is_tracked)) + Ok(locking::lock_range(pt, guard, va, new_pt_is_tracked)) } /// Gets the information of the current slot. @@ -142,6 +140,8 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor< return Err(PageTableError::InvalidVaddr(self.va)); } + let rcu_guard = self.rcu_guard; + loop { let level = self.level; let va = self.va; @@ -150,10 +150,9 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor< match entry.to_ref() { Child::PageTableRef(pt) => { - let paddr = pt.start_paddr(); - // SAFETY: `pt` points to a PT that is attached to a node - // in the locked sub-tree, so that it is locked and alive. - self.push_level(unsafe { PageTableGuard::::from_raw_paddr(paddr) }); + // SAFETY: The `pt` must be locked and no other guards exist. + let guard = unsafe { pt.make_guard_unchecked(rcu_guard) }; + self.push_level(guard); continue; } Child::PageTable(_) => { @@ -236,13 +235,13 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor< let Some(taken) = self.path[self.level as usize - 1].take() else { panic!("Popping a level without a lock"); }; - let _taken = taken.into_raw_paddr(); + let _ = ManuallyDrop::new(taken); self.level += 1; } /// Goes down a level to a child page table. - fn push_level(&mut self, child_guard: PageTableGuard<'a, E, C>) { + fn push_level(&mut self, child_guard: PageTableGuard<'rcu, E, C>) { self.level -= 1; debug_assert_eq!(self.level, child_guard.level()); @@ -250,7 +249,7 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor< debug_assert!(old.is_none()); } - fn cur_entry<'s>(&'s mut self) -> Entry<'s, 'a, E, C> { + fn cur_entry(&mut self) -> Entry<'_, 'rcu, E, C> { let node = self.path[self.level as usize - 1].as_mut().unwrap(); node.entry(pte_index::(self.va, self.level)) } @@ -283,21 +282,24 @@ impl Iterator /// in a page table can only be accessed by one cursor, regardless of the /// mutability of the cursor. #[derive(Debug)] -pub struct CursorMut<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait>( - Cursor<'a, M, E, C>, +pub struct CursorMut<'rcu, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait>( + Cursor<'rcu, M, E, C>, ); -impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorMut<'a, M, E, C> { +impl<'rcu, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> + CursorMut<'rcu, M, E, C> +{ /// Creates a cursor claiming exclusive access over the given range. /// /// The cursor created will only be able to map, query or jump within the given /// range. Out-of-bound accesses will result in panics or errors as return values, /// depending on the access method. pub(super) fn new( - pt: &'a PageTable, + pt: &'rcu PageTable, + guard: &'rcu dyn InAtomicMode, va: &Range, ) -> Result { - Cursor::new(pt, va).map(|inner| Self(inner)) + Cursor::new(pt, guard, va).map(|inner| Self(inner)) } /// Jumps to the given virtual address. @@ -345,6 +347,8 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let end = self.0.va + frame.size(); assert!(end <= self.0.barrier_va.end); + let rcu_guard = self.0.rcu_guard; + // Go down if not applicable. while self.0.level > frame.map_level() || self.0.va % page_size::(self.0.level) != 0 @@ -354,17 +358,17 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let mut cur_entry = self.0.cur_entry(); match cur_entry.to_ref() { Child::PageTableRef(pt) => { - let paddr = pt.start_paddr(); - // SAFETY: `pt` points to a PT that is attached to a node - // in the locked sub-tree, so that it is locked and alive. - self.0 - .push_level(unsafe { PageTableGuard::::from_raw_paddr(paddr) }); + // SAFETY: The `pt` must be locked and no other guards exist. + let guard = unsafe { pt.make_guard_unchecked(rcu_guard) }; + self.0.push_level(guard); } Child::PageTable(_) => { unreachable!(); } Child::None => { - let child_guard = cur_entry.alloc_if_none(MapTrackingStatus::Tracked).unwrap(); + let child_guard = cur_entry + .alloc_if_none(rcu_guard, MapTrackingStatus::Tracked) + .unwrap(); self.0.push_level(child_guard); } Child::Frame(_, _) => { @@ -429,6 +433,8 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let mut pa = pa.start; assert!(end <= self.0.barrier_va.end); + let rcu_guard = self.0.rcu_guard; + while self.0.va < end { // We ensure not mapping in reserved kernel shared tables or releasing it. // Although it may be an invariant for all architectures and will be optimized @@ -444,18 +450,16 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let mut cur_entry = self.0.cur_entry(); match cur_entry.to_ref() { Child::PageTableRef(pt) => { - let paddr = pt.start_paddr(); - // SAFETY: `pt` points to a PT that is attached to a node - // in the locked sub-tree, so that it is locked and alive. - self.0 - .push_level(unsafe { PageTableGuard::::from_raw_paddr(paddr) }); + // SAFETY: The `pt` must be locked and no other guards exist. + let guard = unsafe { pt.make_guard_unchecked(rcu_guard) }; + self.0.push_level(guard); } Child::PageTable(_) => { unreachable!(); } Child::None => { let child_guard = cur_entry - .alloc_if_none(MapTrackingStatus::Untracked) + .alloc_if_none(rcu_guard, MapTrackingStatus::Untracked) .unwrap(); self.0.push_level(child_guard); } @@ -463,7 +467,7 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM panic!("Mapping a smaller page in an already mapped huge page"); } Child::Untracked(_, _, _) => { - let split_child = cur_entry.split_if_untracked_huge().unwrap(); + let split_child = cur_entry.split_if_untracked_huge(rcu_guard).unwrap(); self.0.push_level(split_child); } } @@ -513,6 +517,8 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let end = start + len; assert!(end <= self.0.barrier_va.end); + let rcu_guard = self.0.rcu_guard; + while self.0.va < end { let cur_va = self.0.va; let cur_level = self.0.level; @@ -533,16 +539,14 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let child = cur_entry.to_ref(); match child { Child::PageTableRef(pt) => { - let paddr = pt.start_paddr(); - // SAFETY: `pt` points to a PT that is attached to a node - // in the locked sub-tree, so that it is locked and alive. - let pt = unsafe { PageTableGuard::::from_raw_paddr(paddr) }; + // SAFETY: The `pt` must be locked and no other guards exist. + let pt = unsafe { pt.make_guard_unchecked(rcu_guard) }; // If there's no mapped PTEs in the next level, we can // skip to save time. if pt.nr_children() != 0 { self.0.push_level(pt); } else { - let _ = pt.into_raw_paddr(); + let _ = ManuallyDrop::new(pt); if self.0.va + page_size::(self.0.level) > end { self.0.va = end; break; @@ -560,7 +564,7 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM panic!("Removing part of a huge page"); } Child::Untracked(_, _, _) => { - let split_child = cur_entry.split_if_untracked_huge().unwrap(); + let split_child = cur_entry.split_if_untracked_huge(rcu_guard).unwrap(); self.0.push_level(split_child); } } @@ -591,16 +595,15 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM "Unmapping shared kernel page table nodes" ); - // SAFETY: We must have locked this node. - let locked_pt = - unsafe { PageTableGuard::::from_raw_paddr(pt.start_paddr()) }; + // SAFETY: The `pt` must be locked and no other guards exist. + let locked_pt = unsafe { pt.borrow().make_guard_unchecked(rcu_guard) }; // SAFETY: // - We checked that we are not unmapping shared kernel page table nodes. // - We must have locked the entire sub-tree since the range is locked. - unsafe { locking::dfs_mark_stray_and_unlock(locked_pt) }; + unsafe { locking::dfs_mark_stray_and_unlock(rcu_guard, locked_pt) }; PageTableItem::StrayPageTable { - pt: pt.deref().clone().into(), + pt: (*pt).clone().into(), va: self.0.va, len: page_size::(self.0.level), } @@ -649,6 +652,8 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let end = self.0.va + len; assert!(end <= self.0.barrier_va.end); + let rcu_guard = self.0.rcu_guard; + while self.0.va < end { let cur_va = self.0.va; let cur_level = self.0.level; @@ -665,16 +670,14 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let Child::PageTableRef(pt) = cur_entry.to_ref() else { unreachable!("Already checked"); }; - let paddr = pt.start_paddr(); - // SAFETY: `pt` points to a PT that is attached to a node - // in the locked sub-tree, so that it is locked and alive. - let pt = unsafe { PageTableGuard::::from_raw_paddr(paddr) }; + // SAFETY: The `pt` must be locked and no other guards exist. + let pt = unsafe { pt.make_guard_unchecked(rcu_guard) }; // If there's no mapped PTEs in the next level, we can // skip to save time. if pt.nr_children() != 0 { self.0.push_level(pt); } else { - pt.into_raw_paddr(); + let _ = ManuallyDrop::new(pt); self.0.move_forward(); } continue; @@ -684,7 +687,7 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM // of untracked huge pages. if cur_va % page_size::(cur_level) != 0 || cur_va + page_size::(cur_level) > end { let split_child = cur_entry - .split_if_untracked_huge() + .split_if_untracked_huge(rcu_guard) .expect("Protecting part of a huge page"); self.0.push_level(split_child); continue; @@ -741,22 +744,22 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorM let src_end = src.0.va + len; assert!(src_end <= src.0.barrier_va.end); + let rcu_guard = self.0.rcu_guard; + while self.0.va < this_end && src.0.va < src_end { let src_va = src.0.va; let mut src_entry = src.0.cur_entry(); match src_entry.to_ref() { Child::PageTableRef(pt) => { - let paddr = pt.start_paddr(); - // SAFETY: `pt` points to a PT that is attached to a node - // in the locked sub-tree, so that it is locked and alive. - let pt = unsafe { PageTableGuard::::from_raw_paddr(paddr) }; + // SAFETY: The `pt` must be locked and no other guards exist. + let pt = unsafe { pt.make_guard_unchecked(rcu_guard) }; // If there's no mapped PTEs in the next level, we can // skip to save time. if pt.nr_children() != 0 { src.0.push_level(pt); } else { - pt.into_raw_paddr(); + let _ = ManuallyDrop::new(pt); src.0.move_forward(); } } diff --git a/ostd/src/mm/page_table/mod.rs b/ostd/src/mm/page_table/mod.rs index 372ec5f4..be540f00 100644 --- a/ostd/src/mm/page_table/mod.rs +++ b/ostd/src/mm/page_table/mod.rs @@ -14,6 +14,7 @@ use super::{ }; use crate::{ arch::mm::{PageTableEntry, PagingConsts}, + task::{atomic_mode::AsAtomicModeGuard, disable_preempt}, util::marker::SameSizeAs, Pod, }; @@ -107,12 +108,12 @@ impl PageTable { // Make shared the page tables mapped by the root table in the kernel space. { + let preempt_guard = disable_preempt(); + let mut root_node = kpt.root.borrow().lock(&preempt_guard); + const NR_PTES_PER_NODE: usize = nr_subpage_per_huge::(); let kernel_space_range = NR_PTES_PER_NODE / 2..NR_PTES_PER_NODE; - let _guard = crate::task::disable_preempt(); - - let mut root_node = kpt.root.lock(); for i in kernel_space_range { let mut root_entry = root_node.entry(i); let is_tracked = if super::kspace::should_map_as_tracked( @@ -122,7 +123,9 @@ impl PageTable { } else { MapTrackingStatus::Untracked }; - let _ = root_entry.alloc_if_none(is_tracked).unwrap(); + let _ = root_entry + .alloc_if_none(&preempt_guard, is_tracked) + .unwrap(); } } @@ -134,15 +137,17 @@ impl PageTable { /// This should be the only way to create the user page table, that is to /// duplicate the kernel page table with all the kernel mappings shared. pub fn create_user_page_table(&self) -> PageTable { - let _preempt_guard = crate::task::disable_preempt(); - let mut root_node = self.root.lock(); let new_root = PageTableNode::alloc(PagingConsts::NR_LEVELS, MapTrackingStatus::NotApplicable); - let mut new_node = new_root.lock(); + + let preempt_guard = disable_preempt(); + let mut root_node = self.root.borrow().lock(&preempt_guard); + let mut new_node = new_root.borrow().lock(&preempt_guard); // Make a shallow copy of the root node in the kernel space range. // The user space range is not copied. const NR_PTES_PER_NODE: usize = nr_subpage_per_huge::(); + for i in NR_PTES_PER_NODE / 2..NR_PTES_PER_NODE { let root_entry = root_node.entry(i); let child = root_entry.to_ref(); @@ -176,7 +181,8 @@ impl PageTable { vaddr: &Range, mut op: impl FnMut(&mut PageProperty), ) -> Result<(), PageTableError> { - let mut cursor = CursorMut::new(self, vaddr)?; + let preempt_guard = disable_preempt(); + let mut cursor = CursorMut::new(self, &preempt_guard, vaddr)?; while let Some(range) = cursor.protect_next(vaddr.end - cursor.virt_addr(), &mut op) { crate::arch::mm::tlb_flush_addr(range.start); } @@ -184,7 +190,7 @@ impl PageTable { } } -impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTable { +impl PageTable { /// Create a new empty page table. /// /// Useful for the IOMMU page tables only. @@ -213,7 +219,8 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTab paddr: &Range, prop: PageProperty, ) -> Result<(), PageTableError> { - self.cursor_mut(vaddr)?.map_pa(paddr, prop); + let preempt_guard = disable_preempt(); + self.cursor_mut(&preempt_guard, vaddr)?.map_pa(paddr, prop); Ok(()) } @@ -232,11 +239,12 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTab /// /// If another cursor is already accessing the range, the new cursor may wait until the /// previous cursor is dropped. - pub fn cursor_mut( - &'a self, + pub fn cursor_mut<'rcu, G: AsAtomicModeGuard>( + &'rcu self, + guard: &'rcu G, va: &Range, - ) -> Result, PageTableError> { - CursorMut::new(self, va) + ) -> Result, PageTableError> { + CursorMut::new(self, guard.as_atomic_mode_guard(), va) } /// Create a new cursor exclusively accessing the virtual address range for querying. @@ -244,8 +252,12 @@ impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTab /// If another cursor is already accessing the range, the new cursor may wait until the /// previous cursor is dropped. The modification to the mapping by the cursor may also /// block or be overridden by the mapping of another cursor. - pub fn cursor(&'a self, va: &Range) -> Result, PageTableError> { - Cursor::new(self, va) + pub fn cursor<'rcu, G: AsAtomicModeGuard>( + &'rcu self, + guard: &'rcu G, + va: &Range, + ) -> Result, PageTableError> { + Cursor::new(self, guard.as_atomic_mode_guard(), va) } /// Create a new reference to the same page table. diff --git a/ostd/src/mm/page_table/node/entry.rs b/ostd/src/mm/page_table/node/entry.rs index e65bd0ea..48b1b834 100644 --- a/ostd/src/mm/page_table/node/entry.rs +++ b/ostd/src/mm/page_table/node/entry.rs @@ -2,10 +2,15 @@ //! This module provides accessors to the page table entries in a node. -use super::{Child, MapTrackingStatus, PageTableEntryTrait, PageTableGuard, PageTableNode}; +use core::mem::ManuallyDrop; + +use super::{ + Child, MapTrackingStatus, PageTableEntryTrait, PageTableGuard, PageTableNode, PageTableNodeRef, +}; use crate::{ mm::{nr_subpage_per_huge, page_prop::PageProperty, page_size, PagingConstsTrait}, sync::RcuDrop, + task::atomic_mode::InAtomicMode, }; /// A view of an entry in a page table node. @@ -15,7 +20,7 @@ use crate::{ /// This is a static reference to an entry in a node that does not account for /// a dynamic reference count to the child. It can be used to create a owned /// handle, which is a [`Child`]. -pub(in crate::mm) struct Entry<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> { +pub(in crate::mm) struct Entry<'a, 'rcu, E: PageTableEntryTrait, C: PagingConstsTrait> { /// The page table entry. /// /// We store the page table entry here to optimize the number of reads from @@ -27,10 +32,10 @@ pub(in crate::mm) struct Entry<'guard, 'pt, E: PageTableEntryTrait, C: PagingCon /// The index of the entry in the node. idx: usize, /// The node that contains the entry. - node: &'guard mut PageTableGuard<'pt, E, C>, + node: &'a mut PageTableGuard<'rcu, E, C>, } -impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'pt, E, C> { +impl<'a, 'rcu, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'a, 'rcu, E, C> { /// Returns if the entry does not map to anything. pub(in crate::mm) fn is_none(&self) -> bool { !self.pte.is_present() @@ -42,7 +47,7 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p } /// Gets a reference to the child. - pub(in crate::mm) fn to_ref(&self) -> Child<'_, E, C> { + pub(in crate::mm) fn to_ref(&self) -> Child<'rcu, E, C> { // SAFETY: The entry structure represents an existent entry with the // right node information. unsafe { Child::ref_from_pte(&self.pte, self.node.level(), self.node.is_tracked()) } @@ -115,8 +120,9 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p /// Otherwise, the lock guard of the new child page table node is returned. pub(in crate::mm::page_table) fn alloc_if_none( &mut self, + guard: &'rcu dyn InAtomicMode, new_pt_is_tracked: MapTrackingStatus, - ) -> Option> { + ) -> Option> { if !(self.is_none() && self.node.level() > 1) { return None; } @@ -124,7 +130,8 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p let level = self.node.level(); let new_page = PageTableNode::::alloc(level - 1, new_pt_is_tracked); - let guard_addr = new_page.lock().into_raw_paddr(); + let paddr = new_page.start_paddr(); + let _ = ManuallyDrop::new(new_page.borrow().lock(guard)); // SAFETY: // 1. The index is within the bounds. @@ -138,10 +145,11 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p *self.node.nr_children_mut() += 1; - // SAFETY: The resulting guard lifetime (`'a`) is no shorter than the - // lifetime of the current entry (`'a`), because we store the allocated - // page table in the current node. - Some(unsafe { PageTableGuard::from_raw_paddr(guard_addr) }) + // SAFETY: The page table won't be dropped before the RCU grace period + // ends, so it outlives `'rcu`. + let pt_ref = unsafe { PageTableNodeRef::borrow_paddr(paddr) }; + // SAFETY: The node is locked and there are no other guards. + Some(unsafe { pt_ref.make_guard_unchecked(guard) }) } /// Splits the entry to smaller pages if it maps to a untracked huge page. @@ -154,7 +162,8 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p /// `None`. pub(in crate::mm::page_table) fn split_if_untracked_huge( &mut self, - ) -> Option> { + guard: &'rcu dyn InAtomicMode, + ) -> Option> { let level = self.node.level(); if !(self.pte.is_last(level) @@ -168,16 +177,17 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p let prop = self.pte.prop(); let new_page = PageTableNode::::alloc(level - 1, MapTrackingStatus::Untracked); - let mut guard = new_page.lock(); + let mut pt_lock_guard = new_page.borrow().lock(guard); for i in 0..nr_subpage_per_huge::() { let small_pa = pa + i * page_size::(level - 1); - let mut entry = guard.entry(i); + let mut entry = pt_lock_guard.entry(i); let old = entry.replace(Child::Untracked(small_pa, level - 1, prop)); debug_assert!(old.is_none()); } - let guard_addr = guard.into_raw_paddr(); + let paddr = new_page.start_paddr(); + let _ = ManuallyDrop::new(pt_lock_guard); // SAFETY: // 1. The index is within the bounds. @@ -189,10 +199,11 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p ) }; - // SAFETY: The resulting guard lifetime (`'a`) is no shorter than the - // lifetime of the current entry (`'a`), because we store the allocated - // page table in the current node. - Some(unsafe { PageTableGuard::from_raw_paddr(guard_addr) }) + // SAFETY: The page table won't be dropped before the RCU grace period + // ends, so it outlives `'rcu`. + let pt_ref = unsafe { PageTableNodeRef::borrow_paddr(paddr) }; + // SAFETY: The node is locked and there are no other guards. + Some(unsafe { pt_ref.make_guard_unchecked(guard) }) } /// Create a new entry at the node with guard. @@ -200,7 +211,7 @@ impl<'guard, 'pt, E: PageTableEntryTrait, C: PagingConstsTrait> Entry<'guard, 'p /// # Safety /// /// The caller must ensure that the index is within the bounds of the node. - pub(super) unsafe fn new_at(guard: &'guard mut PageTableGuard<'pt, E, C>, idx: usize) -> Self { + pub(super) unsafe fn new_at(guard: &'a mut PageTableGuard<'rcu, E, C>, idx: usize) -> Self { // SAFETY: The index is within the bound. let pte = unsafe { guard.read_pte(idx) }; Self { diff --git a/ostd/src/mm/page_table/node/mod.rs b/ostd/src/mm/page_table/node/mod.rs index 1997cb30..ee8d2fc8 100644 --- a/ostd/src/mm/page_table/node/mod.rs +++ b/ostd/src/mm/page_table/node/mod.rs @@ -37,11 +37,14 @@ use core::{ pub(in crate::mm) use self::{child::Child, entry::Entry}; use super::{nr_subpage_per_huge, PageTableEntryTrait}; -use crate::mm::{ - frame::{meta::AnyFrameMeta, Frame, FrameRef}, - paddr_to_vaddr, - page_table::{load_pte, store_pte}, - FrameAllocOptions, Infallible, Paddr, PagingConstsTrait, PagingLevel, VmReader, +use crate::{ + mm::{ + frame::{meta::AnyFrameMeta, Frame, FrameRef}, + paddr_to_vaddr, + page_table::{load_pte, store_pte}, + FrameAllocOptions, Infallible, PagingConstsTrait, PagingLevel, VmReader, + }, + task::atomic_mode::InAtomicMode, }; /// A smart pointer to a page table node. @@ -55,9 +58,6 @@ use crate::mm::{ /// [`PageTableGuard`]. pub(super) type PageTableNode = Frame>; -/// A reference to a page table node. -pub(super) type PageTableNodeRef<'a, E, C> = FrameRef<'a, PageTablePageMeta>; - impl PageTableNode { pub(super) fn level(&self) -> PagingLevel { self.meta().level @@ -68,8 +68,6 @@ impl PageTableNode { } /// Allocates a new empty page table node. - /// - /// This function returns a locked owning guard. pub(super) fn alloc(level: PagingLevel, is_tracked: MapTrackingStatus) -> Self { let meta = PageTablePageMeta::new(level, is_tracked); let frame = FrameAllocOptions::new() @@ -82,22 +80,6 @@ impl PageTableNode { frame } - /// Locks the page table node. - pub(super) fn lock(&self) -> PageTableGuard<'_, E, C> { - while self - .meta() - .lock - .compare_exchange(0, 1, Ordering::Acquire, Ordering::Relaxed) - .is_err() - { - core::hint::spin_loop(); - } - - PageTableGuard::<'_, E, C> { - inner: self.borrow(), - } - } - /// Activates the page table assuming it is a root page table. /// /// Here we ensure not dropping an active page table by making a @@ -145,46 +127,69 @@ impl PageTableNode { } } -/// A guard that holds the lock of a page table node. -#[derive(Debug)] -pub(super) struct PageTableGuard<'a, E: PageTableEntryTrait, C: PagingConstsTrait> { - inner: PageTableNodeRef<'a, E, C>, +/// A reference to a page table node. +pub(super) type PageTableNodeRef<'a, E, C> = FrameRef<'a, PageTablePageMeta>; + +impl<'a, E: PageTableEntryTrait, C: PagingConstsTrait> PageTableNodeRef<'a, E, C> { + /// Locks the page table node. + /// + /// An atomic mode guard is required to + /// 1. prevent deadlocks; + /// 2. provide a lifetime (`'rcu`) that the nodes are guaranteed to outlive. + pub(super) fn lock<'rcu>(self, _guard: &'rcu dyn InAtomicMode) -> PageTableGuard<'rcu, E, C> + where + 'a: 'rcu, + { + while self + .meta() + .lock + .compare_exchange(0, 1, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + core::hint::spin_loop(); + } + + PageTableGuard::<'rcu, E, C> { inner: self } + } + + /// Creates a new [`PageTableGuard`] without checking if the page table lock is held. + /// + /// # Safety + /// + /// This function must be called if this task logically holds the lock. + /// + /// Calling this function when a guard is already created is undefined behavior + /// unless that guard was already forgotten. + pub(super) unsafe fn make_guard_unchecked<'rcu>( + self, + _guard: &'rcu dyn InAtomicMode, + ) -> PageTableGuard<'rcu, E, C> + where + 'a: 'rcu, + { + PageTableGuard { inner: self } + } } -impl<'a, E: PageTableEntryTrait, C: PagingConstsTrait> PageTableGuard<'a, E, C> { +/// A guard that holds the lock of a page table node. +#[derive(Debug)] +pub(super) struct PageTableGuard<'rcu, E: PageTableEntryTrait, C: PagingConstsTrait> { + inner: PageTableNodeRef<'rcu, E, C>, +} + +impl<'rcu, E: PageTableEntryTrait, C: PagingConstsTrait> PageTableGuard<'rcu, E, C> { /// Borrows an entry in the node at a given index. /// /// # Panics /// /// Panics if the index is not within the bound of /// [`nr_subpage_per_huge`]. - pub(super) fn entry<'s>(&'s mut self, idx: usize) -> Entry<'s, 'a, E, C> { + pub(super) fn entry(&mut self, idx: usize) -> Entry<'_, 'rcu, E, C> { assert!(idx < nr_subpage_per_huge::()); // SAFETY: The index is within the bound. unsafe { Entry::new_at(self, idx) } } - /// Converts the guard into a raw physical address. - /// - /// It will not release the lock. It may be paired with [`Self::from_raw_paddr`] - /// to manually manage pointers. - pub(super) fn into_raw_paddr(self) -> Paddr { - self.start_paddr() - } - - /// Converts a raw physical address to a guard. - /// - /// # Safety - /// - /// The caller must ensure that the physical address is valid and points to - /// a forgotten page table node that is locked (see [`Self::into_raw_paddr`]). - pub(super) unsafe fn from_raw_paddr(paddr: Paddr) -> Self { - Self { - // SAFETY: The caller ensures safety. - inner: unsafe { PageTableNodeRef::borrow_paddr(paddr) }, - } - } - /// Gets the number of valid PTEs in the node. pub(super) fn nr_children(&self) -> u16 { // SAFETY: The lock is held so we have an exclusive access. @@ -244,8 +249,8 @@ impl<'a, E: PageTableEntryTrait, C: PagingConstsTrait> PageTableGuard<'a, E, C> } } -impl<'a, E: PageTableEntryTrait, C: PagingConstsTrait> Deref for PageTableGuard<'a, E, C> { - type Target = PageTableNodeRef<'a, E, C>; +impl<'rcu, E: PageTableEntryTrait, C: PagingConstsTrait> Deref for PageTableGuard<'rcu, E, C> { + type Target = PageTableNodeRef<'rcu, E, C>; fn deref(&self) -> &Self::Target { &self.inner diff --git a/ostd/src/mm/page_table/test.rs b/ostd/src/mm/page_table/test.rs index 4a720bc8..08095637 100644 --- a/ostd/src/mm/page_table/test.rs +++ b/ostd/src/mm/page_table/test.rs @@ -8,6 +8,7 @@ use crate::{ FrameAllocOptions, MAX_USERSPACE_VADDR, PAGE_SIZE, }, prelude::*, + task::disable_preempt, }; mod test_utils { @@ -38,9 +39,10 @@ mod test_utils { /// Unmaps a range of virtual addresses. #[track_caller] pub fn unmap_range(page_table: &PageTable, range: Range) { + let preempt_guard = disable_preempt(); unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .take_next(range.len()); } @@ -113,7 +115,8 @@ mod test_utils { range: &Range, mut protect_op: impl FnMut(&mut PageProperty), ) { - let mut cursor = page_table.cursor_mut(range).unwrap(); + let preempt_guard = disable_preempt(); + let mut cursor = page_table.cursor_mut(&preempt_guard, range).unwrap(); loop { unsafe { if cursor @@ -133,14 +136,19 @@ mod create_page_table { #[ktest] fn init_user_page_table() { let user_pt = setup_page_table::(); - assert!(user_pt.cursor(&(0..MAX_USERSPACE_VADDR)).is_ok()); + assert!(user_pt + .cursor(&disable_preempt(), &(0..MAX_USERSPACE_VADDR)) + .is_ok()); } #[ktest] fn init_kernel_page_table() { let kernel_pt = setup_page_table::(); assert!(kernel_pt - .cursor(&(LINEAR_MAPPING_BASE_VADDR..LINEAR_MAPPING_BASE_VADDR + PAGE_SIZE)) + .cursor( + &disable_preempt(), + &(LINEAR_MAPPING_BASE_VADDR..LINEAR_MAPPING_BASE_VADDR + PAGE_SIZE) + ) .is_ok()); } @@ -148,9 +156,10 @@ mod create_page_table { fn create_user_page_table() { let kernel_pt = PageTable::::new_kernel_page_table(); let user_pt = kernel_pt.create_user_page_table(); + let guard = disable_preempt(); - let mut kernel_root = kernel_pt.root.lock(); - let mut user_root = user_pt.root.lock(); + let mut kernel_root = kernel_pt.root.borrow().lock(&guard); + let mut user_root = user_pt.root.borrow().lock(&guard); const NR_PTES_PER_NODE: usize = nr_subpage_per_huge::(); for i in NR_PTES_PER_NODE / 2..NR_PTES_PER_NODE { @@ -176,30 +185,36 @@ mod range_checks { let valid_va = 0..PAGE_SIZE; let invalid_va = 0..(PAGE_SIZE + 1); let kernel_va = LINEAR_MAPPING_BASE_VADDR..(LINEAR_MAPPING_BASE_VADDR + PAGE_SIZE); + let preempt_guard = disable_preempt(); // Valid range succeeds. - assert!(page_table.cursor_mut(&valid_va).is_ok()); + assert!(page_table.cursor_mut(&preempt_guard, &valid_va).is_ok()); // Invalid ranges fail. - assert!(page_table.cursor_mut(&invalid_va).is_err()); - assert!(page_table.cursor_mut(&kernel_va).is_err()); + assert!(page_table.cursor_mut(&preempt_guard, &invalid_va).is_err()); + assert!(page_table.cursor_mut(&preempt_guard, &kernel_va).is_err()); } #[ktest] fn boundary_conditions() { let page_table = setup_page_table::(); + let preempt_guard = disable_preempt(); // Tests an empty range. let empty_range = 0..0; - assert!(page_table.cursor_mut(&empty_range).is_err()); + assert!(page_table.cursor_mut(&preempt_guard, &empty_range).is_err()); // Tests an out-of-range virtual address. let out_of_range = MAX_USERSPACE_VADDR..(MAX_USERSPACE_VADDR + PAGE_SIZE); - assert!(page_table.cursor_mut(&out_of_range).is_err()); + assert!(page_table + .cursor_mut(&preempt_guard, &out_of_range) + .is_err()); // Tests misaligned addresses. let unaligned_range = 1..(PAGE_SIZE + 1); - assert!(page_table.cursor_mut(&unaligned_range).is_err()); + assert!(page_table + .cursor_mut(&preempt_guard, &unaligned_range) + .is_err()); } #[ktest] @@ -208,13 +223,14 @@ mod range_checks { let max_address = 0x100000; let range = 0..max_address; let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); // Allocates required frames. let frames = FrameAllocOptions::default() .alloc_segment_with(max_address / PAGE_SIZE, |_| ()) .unwrap(); - let mut cursor = page_table.cursor_mut(&range).unwrap(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &range).unwrap(); for frame in frames { unsafe { @@ -233,11 +249,12 @@ mod range_checks { let range = 0..PAGE_SIZE; let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); let frame = FrameAllocOptions::default().alloc_frame().unwrap(); + let preempt_guard = disable_preempt(); // Maps the virtual range to the physical frame. unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), page_property); } @@ -253,11 +270,12 @@ mod range_checks { let range = (MAX_USERSPACE_VADDR - PAGE_SIZE)..MAX_USERSPACE_VADDR; let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); let frame = FrameAllocOptions::default().alloc_frame().unwrap(); + let preempt_guard = disable_preempt(); // Maps the virtual range to the physical frame. unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), page_property); } @@ -275,10 +293,11 @@ mod range_checks { (MAX_USERSPACE_VADDR - (PAGE_SIZE / 2))..(MAX_USERSPACE_VADDR + (PAGE_SIZE / 2)); let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); let frame = FrameAllocOptions::default().alloc_frame().unwrap(); + let preempt_guard = disable_preempt(); unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), page_property); } @@ -294,9 +313,10 @@ mod page_properties { let page_table = setup_page_table::(); let range = PAGE_SIZE..(PAGE_SIZE * 2); let frame = FrameAllocOptions::default().alloc_frame().unwrap(); + let preempt_guard = disable_preempt(); unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), prop); } @@ -311,11 +331,12 @@ mod page_properties { let page_table = setup_page_table::(); let virtual_range = PAGE_SIZE..(PAGE_SIZE * 2); let frame = FrameAllocOptions::default().alloc_frame().unwrap(); + let preempt_guard = disable_preempt(); let invalid_prop = PageProperty::new(PageFlags::RW, CachePolicy::Uncacheable); unsafe { page_table - .cursor_mut(&virtual_range) + .cursor_mut(&preempt_guard, &virtual_range) .unwrap() .map(frame.into(), invalid_prop); let (_, prop) = page_table.query(virtual_range.start + 10).unwrap(); @@ -367,6 +388,7 @@ mod different_page_sizes { #[ktest] fn different_page_sizes() { let page_table = setup_page_table::(); + let preempt_guard = disable_preempt(); // 2MiB pages let virtual_range_2m = (PAGE_SIZE * 512)..(PAGE_SIZE * 512 * 2); @@ -374,7 +396,7 @@ mod different_page_sizes { let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); unsafe { page_table - .cursor_mut(&virtual_range_2m) + .cursor_mut(&preempt_guard, &virtual_range_2m) .unwrap() .map(frame_2m.into(), page_property); } @@ -385,7 +407,7 @@ mod different_page_sizes { let frame_1g = FrameAllocOptions::default().alloc_frame().unwrap(); unsafe { page_table - .cursor_mut(&virtual_range_1g) + .cursor_mut(&preempt_guard, &virtual_range_1g) .unwrap() .map(frame_1g.into(), page_property); } @@ -402,6 +424,7 @@ mod overlapping_mappings { let range1 = PAGE_SIZE..(PAGE_SIZE * 2); let range2 = PAGE_SIZE..(PAGE_SIZE * 3); let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); let frame1 = FrameAllocOptions::default().alloc_frame().unwrap(); let frame2 = FrameAllocOptions::default().alloc_frame().unwrap(); @@ -409,13 +432,13 @@ mod overlapping_mappings { unsafe { // Maps the first range. page_table - .cursor_mut(&range1) + .cursor_mut(&preempt_guard, &range1) .unwrap() .map(frame1.into(), page_property); // Maps the second range, overlapping with the first. page_table - .cursor_mut(&range2) + .cursor_mut(&preempt_guard, &range2) .unwrap() .map(frame2.clone().into(), page_property); } @@ -433,11 +456,12 @@ mod overlapping_mappings { let range = (PAGE_SIZE + 512)..(PAGE_SIZE * 2 + 512); let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); let frame = FrameAllocOptions::default().alloc_frame().unwrap(); + let preempt_guard = disable_preempt(); // Attempts to map an unaligned virtual address range (expected to panic). unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), page_property); } @@ -452,6 +476,7 @@ mod tracked_mapping { let page_table = setup_page_table::(); let range = PAGE_SIZE..(PAGE_SIZE * 2); let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); // Allocates and maps a frame. let frame = FrameAllocOptions::default().alloc_frame().unwrap(); @@ -460,7 +485,7 @@ mod tracked_mapping { unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), page_property); // frame is moved here } @@ -474,7 +499,7 @@ mod tracked_mapping { // Unmaps the range and verifies the returned item. let unmapped_item = unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .take_next(range.len()) }; @@ -496,12 +521,13 @@ mod tracked_mapping { let range = PAGE_SIZE..(PAGE_SIZE * 2); let initial_prop = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); let new_prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); // Initial mapping. let initial_frame = FrameAllocOptions::default().alloc_frame().unwrap(); unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(initial_frame.into(), initial_prop); } @@ -513,7 +539,7 @@ mod tracked_mapping { let new_frame = FrameAllocOptions::default().alloc_frame().unwrap(); unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(new_frame.into(), new_prop); } @@ -524,6 +550,8 @@ mod tracked_mapping { #[ktest] fn user_copy_on_write() { + let preempt_guard = disable_preempt(); + // Modifies page properties by removing the write flag. fn remove_write_flag(prop: &mut PageProperty) { prop.flags -= PageFlags::W; @@ -542,7 +570,7 @@ mod tracked_mapping { unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame.into(), page_property); // Original frame moved here } @@ -557,8 +585,10 @@ mod tracked_mapping { let child_pt = setup_page_table::(); { let parent_range = 0..MAX_USERSPACE_VADDR; - let mut child_cursor = child_pt.cursor_mut(&parent_range).unwrap(); - let mut parent_cursor = page_table.cursor_mut(&parent_range).unwrap(); + let mut child_cursor = child_pt.cursor_mut(&preempt_guard, &parent_range).unwrap(); + let mut parent_cursor = page_table + .cursor_mut(&preempt_guard, &parent_range) + .unwrap(); unsafe { child_cursor.copy_from( &mut parent_cursor, @@ -581,7 +611,7 @@ mod tracked_mapping { // Unmaps the range from the parent and verifies. let unmapped_parent = unsafe { page_table - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .take_next(range.len()) }; @@ -603,8 +633,12 @@ mod tracked_mapping { let sibling_pt = setup_page_table::(); { let parent_range = 0..MAX_USERSPACE_VADDR; - let mut sibling_cursor = sibling_pt.cursor_mut(&parent_range).unwrap(); - let mut parent_cursor = page_table.cursor_mut(&parent_range).unwrap(); + let mut sibling_cursor = sibling_pt + .cursor_mut(&preempt_guard, &parent_range) + .unwrap(); + let mut parent_cursor = page_table + .cursor_mut(&preempt_guard, &parent_range) + .unwrap(); unsafe { sibling_cursor.copy_from( &mut parent_cursor, @@ -627,7 +661,12 @@ mod tracked_mapping { ); // Unmaps the range from the child and verifies. - let unmapped_child = unsafe { child_pt.cursor_mut(&range).unwrap().take_next(range.len()) }; + let unmapped_child = unsafe { + child_pt + .cursor_mut(&preempt_guard, &range) + .unwrap() + .take_next(range.len()) + }; assert_item_is_tracked_frame( unmapped_child, range.start, @@ -640,7 +679,7 @@ mod tracked_mapping { let sibling_prop_final = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); unsafe { sibling_pt - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .unwrap() .map(frame_clone_for_assert3.into(), sibling_prop_final); } @@ -662,6 +701,8 @@ mod untracked_mapping { #[ktest] fn untracked_map_unmap() { let kernel_pt = setup_page_table::(); + let preempt_guard = disable_preempt(); + const UNTRACKED_OFFSET: usize = LINEAR_MAPPING_BASE_VADDR; let from_ppn = 13245..(512 * 512 + 23456); @@ -692,7 +733,9 @@ mod untracked_mapping { let unmap_va_range = unmap_va_start..(unmap_va_start + PAGE_SIZE); let unmap_len = PAGE_SIZE; - let mut cursor = kernel_pt.cursor_mut(&unmap_va_range).unwrap(); + let mut cursor = kernel_pt + .cursor_mut(&preempt_guard, &unmap_va_range) + .unwrap(); assert_eq!(cursor.virt_addr(), unmap_va_range.start); // Unmaps the single page. @@ -731,6 +774,8 @@ mod untracked_mapping { #[ktest] fn untracked_large_protect_query() { let kernel_pt = PageTable::::empty(); + let preempt_guard = disable_preempt(); + const UNTRACKED_OFFSET: usize = crate::mm::kspace::LINEAR_MAPPING_BASE_VADDR; let gmult = 512 * 512; let from_ppn = gmult - 512..gmult + gmult + 514; @@ -741,7 +786,11 @@ mod untracked_mapping { let mapped_pa_of_va = |va: Vaddr| va - (from.start - to.start); let prop = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); map_range(&kernel_pt, from.clone(), to.clone(), prop); - for (item, i) in kernel_pt.cursor(&from).unwrap().zip(0..512 + 2 + 2) { + for (item, i) in kernel_pt + .cursor(&preempt_guard, &from) + .unwrap() + .zip(0..512 + 2 + 2) + { let PageTableItem::MappedUntracked { va, pa, len, prop } = item else { panic!("Expected MappedUntracked, got {:#x?}", item); }; @@ -771,7 +820,7 @@ mod untracked_mapping { // Checks the page before the protection range. let va_before = protect_va_range.start - PAGE_SIZE; let item_before = kernel_pt - .cursor(&(va_before..va_before + PAGE_SIZE)) + .cursor(&preempt_guard, &(va_before..va_before + PAGE_SIZE)) .unwrap() .next() .unwrap(); @@ -785,7 +834,7 @@ mod untracked_mapping { // Checks pages within the protection range. for (item, i) in kernel_pt - .cursor(&protect_va_range) + .cursor(&preempt_guard, &protect_va_range) .unwrap() .zip(protect_ppn_range.clone()) { @@ -801,7 +850,7 @@ mod untracked_mapping { // Checks the page after the protection range. let va_after = protect_va_range.end; let item_after = kernel_pt - .cursor(&(va_after..va_after + PAGE_SIZE)) + .cursor(&preempt_guard, &(va_after..va_after + PAGE_SIZE)) .unwrap() .next() .unwrap(); @@ -826,6 +875,7 @@ mod full_unmap_verification { let page_table = setup_page_table::(); let range = 0..(PAGE_SIZE * 100); let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); // Allocates and maps multiple frames. let frames = FrameAllocOptions::default() @@ -833,7 +883,7 @@ mod full_unmap_verification { .unwrap(); unsafe { - let mut cursor = page_table.cursor_mut(&range).unwrap(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &range).unwrap(); for frame in frames { cursor.map(frame.into(), page_property); // Original frames moved here } @@ -846,7 +896,7 @@ mod full_unmap_verification { // Unmaps the entire range. unsafe { - let mut cursor = page_table.cursor_mut(&range).unwrap(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &range).unwrap(); for _ in (range.start..range.end).step_by(PAGE_SIZE) { cursor.take_next(PAGE_SIZE); } @@ -867,6 +917,7 @@ mod protection_and_query { let page_table = setup_page_table::(); let from_ppn = 1..1000; let virtual_range = PAGE_SIZE * from_ppn.start..PAGE_SIZE * from_ppn.end; + let preempt_guard = disable_preempt(); // Allocates and maps multiple frames. let frames = FrameAllocOptions::default() @@ -875,7 +926,9 @@ mod protection_and_query { let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); unsafe { - let mut cursor = page_table.cursor_mut(&virtual_range).unwrap(); + let mut cursor = page_table + .cursor_mut(&preempt_guard, &virtual_range) + .unwrap(); for frame in frames { cursor.map(frame.into(), page_property); // frames are moved here } @@ -914,9 +967,10 @@ mod protection_and_query { fn test_protect_next_empty_entry() { let page_table = PageTable::::empty(); let range = 0x1000..0x2000; + let preempt_guard = disable_preempt(); // Attempts to protect an empty range. - let mut cursor = page_table.cursor_mut(&range).unwrap(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &range).unwrap(); let result = unsafe { cursor.protect_next(range.len(), &mut |prop| prop.flags = PageFlags::R) }; @@ -928,6 +982,7 @@ mod protection_and_query { fn test_protect_next_child_table_with_children() { let page_table = setup_page_table::(); let range = 0x1000..0x3000; // Range potentially spanning intermediate tables + let preempt_guard = disable_preempt(); // Maps a page within the range to create necessary intermediate tables. let map_range_inner = 0x1000..0x2000; @@ -935,13 +990,13 @@ mod protection_and_query { let page_property = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); unsafe { page_table - .cursor_mut(&map_range_inner) + .cursor_mut(&preempt_guard, &map_range_inner) .unwrap() .map(frame_inner.into(), page_property); } // Attempts to protect the larger range. protect_next should traverse. - let mut cursor = page_table.cursor_mut(&range).unwrap(); + let mut cursor = page_table.cursor_mut(&preempt_guard, &range).unwrap(); let result = unsafe { cursor.protect_next(range.len(), &mut |prop| prop.flags = PageFlags::R) }; diff --git a/ostd/src/mm/test.rs b/ostd/src/mm/test.rs index efc80a26..d6599102 100644 --- a/ostd/src/mm/test.rs +++ b/ostd/src/mm/test.rs @@ -14,6 +14,7 @@ use crate::{ UFrame, VmSpace, }, prelude::*, + task::disable_preempt, Error, }; @@ -508,7 +509,10 @@ mod vmspace { fn vmspace_creation() { let vmspace = VmSpace::new(); let range = 0x0..0x1000; - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let preempt_guard = disable_preempt(); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.next(), Some(VmItem::NotMapped { va: 0, len: 0x1000 }) @@ -522,10 +526,11 @@ mod vmspace { let range = 0x1000..0x2000; let frame = create_dummy_frame(); let prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); // Initially, the page should not be mapped. assert_eq!( @@ -541,7 +546,9 @@ mod vmspace { // Queries the mapping. { - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!(cursor.virt_addr(), range.start); assert_eq!( cursor.query().unwrap(), @@ -555,14 +562,16 @@ mod vmspace { { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); // Unmaps the frame. cursor_mut.unmap(range.start); } // Queries to ensure it's unmapped. - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.query().unwrap(), VmItem::NotMapped { @@ -579,16 +588,19 @@ mod vmspace { let range = 0x1000..0x2000; let frame = create_dummy_frame(); let prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.map(frame.clone(), prop); } { - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.query().unwrap(), VmItem::Mapped { @@ -601,13 +613,15 @@ mod vmspace { { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.map(frame.clone(), prop); } { - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.query().unwrap(), VmItem::Mapped { @@ -620,12 +634,14 @@ mod vmspace { { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.unmap(range.start); } - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.query().unwrap(), VmItem::NotMapped { @@ -642,29 +658,32 @@ mod vmspace { let range = 0x1000..0x2000; let frame = create_dummy_frame(); let prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.map(frame.clone(), prop); } { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.unmap(range.start); } { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.unmap(range.start); } - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.query().unwrap(), VmItem::NotMapped { @@ -696,17 +715,20 @@ mod vmspace { let range = 0x4000..0x5000; let frame = create_dummy_frame(); let prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.map(frame.clone(), prop); } { // Verifies that the mapping exists. - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.next(), Some(VmItem::Mapped { @@ -720,7 +742,7 @@ mod vmspace { { // Flushes the TLB using a mutable cursor. let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.flusher().issue_tlb_flush(TlbFlushOp::All); cursor_mut.flusher().dispatch_tlb_flush(); @@ -728,7 +750,9 @@ mod vmspace { { // Verifies that the mapping still exists. - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.next(), Some(VmItem::Mapped { @@ -745,9 +769,10 @@ mod vmspace { fn vmspace_reader_writer() { let vmspace = Arc::new(VmSpace::new()); let range = 0x4000..0x5000; + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); let frame = create_dummy_frame(); let prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); @@ -790,14 +815,15 @@ mod vmspace { let vmspace = VmSpace::new(); let range1 = 0x5000..0x6000; let range2 = 0x5800..0x6800; // Overlaps with range1. + let preempt_guard = disable_preempt(); // Creates the first cursor. let _cursor1 = vmspace - .cursor(&range1) + .cursor(&preempt_guard, &range1) .expect("Failed to create first cursor"); // Attempts to create the second overlapping cursor. - let cursor2_result = vmspace.cursor(&range2); + let cursor2_result = vmspace.cursor(&preempt_guard, &range2); assert!(cursor2_result.is_err()); } @@ -807,15 +833,18 @@ mod vmspace { let vmspace = VmSpace::new(); let range = 0x6000..0x7000; let frame = create_dummy_frame(); + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); let prop = PageProperty::new(PageFlags::R, CachePolicy::Writeback); cursor_mut.map(frame.clone(), prop); } - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert!(cursor.jump(range.start).is_ok()); let item = cursor.next(); assert_eq!( @@ -837,9 +866,10 @@ mod vmspace { let vmspace = VmSpace::new(); let range = 0x7000..0x8000; let frame = create_dummy_frame(); + let preempt_guard = disable_preempt(); { let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); let prop = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); cursor_mut.map(frame.clone(), prop); @@ -851,7 +881,9 @@ mod vmspace { assert_eq!(protected_range, Some(0x7000..0x8000)); } // Confirms that the property was updated. - let mut cursor = vmspace.cursor(&range).expect("Failed to create cursor"); + let mut cursor = vmspace + .cursor(&preempt_guard, &range) + .expect("Failed to create cursor"); assert_eq!( cursor.next(), Some(VmItem::Mapped { @@ -868,8 +900,9 @@ mod vmspace { fn unaligned_unmap_panics() { let vmspace = VmSpace::new(); let range = 0xA000..0xB000; + let preempt_guard = disable_preempt(); let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.unmap(0x800); // Not page-aligned. } @@ -880,8 +913,9 @@ mod vmspace { fn protect_out_range_page() { let vmspace = VmSpace::new(); let range = 0xB000..0xC000; + let preempt_guard = disable_preempt(); let mut cursor_mut = vmspace - .cursor_mut(&range) + .cursor_mut(&preempt_guard, &range) .expect("Failed to create mutable cursor"); cursor_mut.protect_next(0x2000, |_| {}); // Not page-aligned. } diff --git a/ostd/src/mm/vm_space.rs b/ostd/src/mm/vm_space.rs index 3ff87e59..a8681c57 100644 --- a/ostd/src/mm/vm_space.rs +++ b/ostd/src/mm/vm_space.rs @@ -23,7 +23,7 @@ use crate::{ PageProperty, UFrame, VmReader, VmWriter, MAX_USERSPACE_VADDR, }, prelude::*, - task::{disable_preempt, DisabledPreemptGuard}, + task::{atomic_mode::AsAtomicModeGuard, disable_preempt, DisabledPreemptGuard}, Error, }; @@ -85,8 +85,12 @@ impl VmSpace { /// /// The creation of the cursor may block if another cursor having an /// overlapping range is alive. - pub fn cursor(&self, va: &Range) -> Result> { - Ok(self.pt.cursor(va).map(Cursor)?) + pub fn cursor<'a, G: AsAtomicModeGuard>( + &'a self, + guard: &'a G, + va: &Range, + ) -> Result> { + Ok(self.pt.cursor(guard, va).map(Cursor)?) } /// Gets an mutable cursor in the virtual address range. @@ -99,8 +103,12 @@ impl VmSpace { /// The creation of the cursor may block if another cursor having an /// overlapping range is alive. The modification to the mapping by the /// cursor may also block or be overridden the mapping of another cursor. - pub fn cursor_mut(&self, va: &Range) -> Result> { - Ok(self.pt.cursor_mut(va).map(|pt_cursor| CursorMut { + pub fn cursor_mut<'a, G: AsAtomicModeGuard>( + &'a self, + guard: &'a G, + va: &Range, + ) -> Result> { + Ok(self.pt.cursor_mut(guard, va).map(|pt_cursor| CursorMut { pt_cursor, flusher: TlbFlusher::new(&self.cpus, disable_preempt()), })?) @@ -228,14 +236,14 @@ impl Cursor<'_> { /// /// It exclusively owns a sub-tree of the page table, preventing others from /// reading or modifying the same sub-tree. -pub struct CursorMut<'a, 'b> { +pub struct CursorMut<'a> { pt_cursor: page_table::CursorMut<'a, UserMode, PageTableEntry, PagingConsts>, // We have a read lock so the CPU set in the flusher is always a superset // of actual activated CPUs. - flusher: TlbFlusher<'b, DisabledPreemptGuard>, + flusher: TlbFlusher<'a, DisabledPreemptGuard>, } -impl<'b> CursorMut<'_, 'b> { +impl<'a> CursorMut<'a> { /// Query about the current slot. /// /// This is the same as [`Cursor::query`]. @@ -262,7 +270,7 @@ impl<'b> CursorMut<'_, 'b> { } /// Get the dedicated TLB flusher for this cursor. - pub fn flusher(&mut self) -> &mut TlbFlusher<'b, DisabledPreemptGuard> { + pub fn flusher(&mut self) -> &mut TlbFlusher<'a, DisabledPreemptGuard> { &mut self.flusher }