From b2b00bdfd2746f6a3f9f35cb12a45b149425198a Mon Sep 17 00:00:00 2001 From: Zhang Junyang Date: Mon, 4 Nov 2024 10:16:06 +0800 Subject: [PATCH] Lock-free cursor creation --- ostd/src/mm/page_table/cursor.rs | 59 ++++++++++++++++-------------- ostd/src/mm/page_table/node/mod.rs | 2 +- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/ostd/src/mm/page_table/cursor.rs b/ostd/src/mm/page_table/cursor.rs index a92cd4519..561dc00e5 100644 --- a/ostd/src/mm/page_table/cursor.rs +++ b/ostd/src/mm/page_table/cursor.rs @@ -65,17 +65,18 @@ //! table cursor should add additional entry point checks to prevent these defined //! behaviors if they are not wanted. -use core::{any::TypeId, marker::PhantomData, ops::Range}; +use core::{any::TypeId, marker::PhantomData, mem::ManuallyDrop, ops::Range}; use align_ext::AlignExt; use super::{ page_size, pte_index, Child, Entry, KernelMode, PageTable, PageTableEntryTrait, PageTableError, - PageTableMode, PageTableNode, PagingConstsTrait, PagingLevel, UserMode, + PageTableMode, PageTableNode, PagingConstsTrait, PagingLevel, RawPageTableNode, UserMode, }; use crate::{ mm::{ kspace::should_map_as_tracked, + paddr_to_vaddr, page::{meta::MapTrackingStatus, DynPage}, Paddr, PageProperty, Vaddr, }, @@ -160,16 +161,8 @@ where return Err(PageTableError::UnalignedVaddr); } - // Create a guard array that only hold the root node lock. - let guards = core::array::from_fn(|i| { - if i == (C::NR_LEVELS - 1) as usize { - Some(pt.root.clone_shallow().lock()) - } else { - None - } - }); let mut cursor = Self { - guards, + guards: core::array::from_fn(|_| None), level: C::NR_LEVELS, guard_level: C::NR_LEVELS, va: va.start, @@ -178,35 +171,47 @@ where _phantom: PhantomData, }; + let mut cur_pt_addr = pt.root.paddr(); + // Go down and get proper locks. The cursor should hold a lock of a // page table node containing the virtual address range. // // While going down, previous guards of too-high levels will be released. loop { + let start_idx = pte_index::(va.start, cursor.level); let level_too_high = { - let start_idx = pte_index::(va.start, cursor.level); let end_idx = pte_index::(va.end - 1, cursor.level); - start_idx == end_idx + cursor.level > 1 && start_idx == end_idx }; if !level_too_high { break; } - let entry = cursor.cur_entry(); - if !entry.is_node() { + let cur_pt_ptr = paddr_to_vaddr(cur_pt_addr) as *mut E; + // SAFETY: The pointer and index is valid since the root page table + // does not short-live it. The child page table node won't be + // recycled by another thread while we are using it. + let cur_pte = unsafe { cur_pt_ptr.add(start_idx).read() }; + if cur_pte.is_present() { + if cur_pte.is_last(cursor.level) { + break; + } else { + cur_pt_addr = cur_pte.paddr(); + } + } else { break; } - let Child::PageTable(child_pt) = entry.to_owned() else { - unreachable!("Already checked"); - }; - - cursor.push_level(child_pt.lock()); - - // Release the guard of the previous (upper) level. - cursor.guards[cursor.level as usize] = None; - cursor.guard_level -= 1; + cursor.level -= 1; } + // SAFETY: The address and level corresponds to a child converted into + // a PTE and we clone it to get a new handle to the node. + let raw = unsafe { RawPageTableNode::::from_raw_parts(cur_pt_addr, cursor.level) }; + let _inc_ref = ManuallyDrop::new(raw.clone_shallow()); + let lock = raw.lock(); + cursor.guards[cursor.level as usize - 1] = Some(lock); + cursor.guard_level = cursor.level; + Ok(cursor) } @@ -306,7 +311,7 @@ where /// This method requires locks acquired before calling it. The discarded /// level will be unlocked. fn pop_level(&mut self) { - self.guards[(self.level - 1) as usize] = None; + self.guards[self.level as usize - 1] = None; self.level += 1; // TODO: Drop page tables if page tables become empty. @@ -316,7 +321,7 @@ where fn push_level(&mut self, child_pt: PageTableNode) { self.level -= 1; debug_assert_eq!(self.level, child_pt.level()); - self.guards[(self.level - 1) as usize] = Some(child_pt); + self.guards[self.level as usize - 1] = Some(child_pt); } fn should_map_as_tracked(&self) -> bool { @@ -326,7 +331,7 @@ where } fn cur_entry(&mut self) -> Entry<'_, E, C> { - let node = self.guards[(self.level - 1) as usize].as_mut().unwrap(); + let node = self.guards[self.level as usize - 1].as_mut().unwrap(); node.entry(pte_index::(self.va, self.level)) } } diff --git a/ostd/src/mm/page_table/node/mod.rs b/ostd/src/mm/page_table/node/mod.rs index 80d8df824..f5ebf46ba 100644 --- a/ostd/src/mm/page_table/node/mod.rs +++ b/ostd/src/mm/page_table/node/mod.rs @@ -176,7 +176,7 @@ where /// The caller must ensure that the physical address is valid and points to /// a forgotten page table node. A forgotten page table node can only be /// restored once. The level must match the level of the page table node. - unsafe fn from_raw_parts(paddr: Paddr, level: PagingLevel) -> Self { + pub(super) unsafe fn from_raw_parts(paddr: Paddr, level: PagingLevel) -> Self { Self { raw: paddr, level,