Add unsafe with trivial cleanups

This commit is contained in:
Ruihan Li 2025-05-07 23:26:38 +08:00 committed by Junyang Zhang
parent 3bc4424a5b
commit 3f8dbe6990
4 changed files with 66 additions and 54 deletions

View File

@ -78,8 +78,9 @@ impl RootTable {
return Err(ContextTableError::InvalidDeviceId); return Err(ContextTableError::InvalidDeviceId);
} }
self.get_or_create_context_table(device) let context_table = self.get_or_create_context_table(device);
.map(device, daddr, paddr)?; // SAFETY: The safety is upheld by the caller.
unsafe { context_table.map(device, daddr, paddr)? };
Ok(()) Ok(())
} }
@ -93,8 +94,8 @@ impl RootTable {
return Err(ContextTableError::InvalidDeviceId); return Err(ContextTableError::InvalidDeviceId);
} }
self.get_or_create_context_table(device) let context_table = self.get_or_create_context_table(device);
.unmap(device, daddr)?; context_table.unmap(device, daddr)?;
Ok(()) Ok(())
} }
@ -298,23 +299,26 @@ impl ContextTable {
if device.device >= 32 || device.function >= 8 { if device.device >= 32 || device.function >= 8 {
return Err(ContextTableError::InvalidDeviceId); return Err(ContextTableError::InvalidDeviceId);
} }
trace!( trace!(
"Mapping Daddr: {:x?} to Paddr: {:x?} for device: {:x?}", "Mapping Daddr: {:x?} to Paddr: {:x?} for device: {:x?}",
daddr, daddr,
paddr, paddr,
device device
); );
self.get_or_create_page_table(device)
.map( let from = daddr..daddr + PAGE_SIZE;
&(daddr..daddr + PAGE_SIZE), let to = paddr..paddr + PAGE_SIZE;
&(paddr..paddr + PAGE_SIZE), let prop = PageProperty {
PageProperty {
flags: PageFlags::RW, flags: PageFlags::RW,
cache: CachePolicy::Uncacheable, cache: CachePolicy::Uncacheable,
priv_flags: PrivFlags::empty(), priv_flags: PrivFlags::empty(),
}, };
)
.unwrap(); let pt = self.get_or_create_page_table(device);
// SAFETY: The safety is upheld by the caller.
unsafe { pt.map(&from, &to, prop).unwrap() };
Ok(()) Ok(())
} }
@ -322,16 +326,19 @@ impl ContextTable {
if device.device >= 32 || device.function >= 8 { if device.device >= 32 || device.function >= 8 {
return Err(ContextTableError::InvalidDeviceId); return Err(ContextTableError::InvalidDeviceId);
} }
trace!("Unmapping Daddr: {:x?} for device: {:x?}", daddr, device); trace!("Unmapping Daddr: {:x?} for device: {:x?}", daddr, device);
let pt = self.get_or_create_page_table(device); let pt = self.get_or_create_page_table(device);
let preempt_guard = disable_preempt(); let preempt_guard = disable_preempt();
let mut cursor = pt let mut cursor = pt
.cursor_mut(&preempt_guard, &(daddr..daddr + PAGE_SIZE)) .cursor_mut(&preempt_guard, &(daddr..daddr + PAGE_SIZE))
.unwrap(); .unwrap();
unsafe {
let result = cursor.take_next(PAGE_SIZE); // SAFETY: This unmaps a page from the context table, which is always safe.
debug_assert!(matches!(result, PageTableItem::MappedUntracked { .. })); let item = unsafe { cursor.take_next(PAGE_SIZE) };
} debug_assert!(matches!(item, PageTableItem::MappedUntracked { .. }));
Ok(()) Ok(())
} }
} }

View File

@ -30,32 +30,37 @@ pub unsafe fn map(daddr: Daddr, paddr: Paddr) -> Result<(), IommuError> {
let Some(table) = PAGE_TABLE.get() else { let Some(table) = PAGE_TABLE.get() else {
return Err(IommuError::NoIommu); return Err(IommuError::NoIommu);
}; };
// The page table of all devices is the same. So we can use any device ID. // The page table of all devices is the same. So we can use any device ID.
table let mut locked_table = table.lock();
.lock() // SAFETY: The safety is upheld by the caller.
.map(PciDeviceLocation::zero(), daddr, paddr) let res = unsafe { locked_table.map(PciDeviceLocation::zero(), daddr, paddr) };
.map_err(|err| match err {
context_table::ContextTableError::InvalidDeviceId => unreachable!(), match res {
context_table::ContextTableError::ModificationError(err) => { Ok(()) => Ok(()),
IommuError::ModificationError(err) Err(context_table::ContextTableError::InvalidDeviceId) => unreachable!(),
Err(context_table::ContextTableError::ModificationError(err)) => {
Err(IommuError::ModificationError(err))
}
} }
})
} }
pub fn unmap(daddr: Daddr) -> Result<(), IommuError> { pub fn unmap(daddr: Daddr) -> Result<(), IommuError> {
let Some(table) = PAGE_TABLE.get() else { let Some(table) = PAGE_TABLE.get() else {
return Err(IommuError::NoIommu); return Err(IommuError::NoIommu);
}; };
// The page table of all devices is the same. So we can use any device ID. // The page table of all devices is the same. So we can use any device ID.
table let mut locked_table = table.lock();
.lock() let res = locked_table.unmap(PciDeviceLocation::zero(), daddr);
.unmap(PciDeviceLocation::zero(), daddr)
.map_err(|err| match err { match res {
context_table::ContextTableError::InvalidDeviceId => unreachable!(), Ok(()) => Ok(()),
context_table::ContextTableError::ModificationError(err) => { Err(context_table::ContextTableError::InvalidDeviceId) => unreachable!(),
IommuError::ModificationError(err) Err(context_table::ContextTableError::ModificationError(err)) => {
Err(IommuError::ModificationError(err))
}
} }
})
} }
pub fn init() { pub fn init() {

View File

@ -124,19 +124,16 @@ pub struct PageTableEntry(usize);
/// Changing the level 4 page table is unsafe, because it's possible to violate memory safety by /// Changing the level 4 page table is unsafe, because it's possible to violate memory safety by
/// changing the page mapping. /// changing the page mapping.
pub unsafe fn activate_page_table(root_paddr: Paddr, root_pt_cache: CachePolicy) { pub unsafe fn activate_page_table(root_paddr: Paddr, root_pt_cache: CachePolicy) {
x86_64::registers::control::Cr3::write( let addr = PhysFrame::from_start_address(x86_64::PhysAddr::new(root_paddr as u64)).unwrap();
PhysFrame::from_start_address(x86_64::PhysAddr::new(root_paddr as u64)).unwrap(), let flags = match root_pt_cache {
match root_pt_cache {
CachePolicy::Writeback => x86_64::registers::control::Cr3Flags::empty(), CachePolicy::Writeback => x86_64::registers::control::Cr3Flags::empty(),
CachePolicy::Writethrough => { CachePolicy::Writethrough => x86_64::registers::control::Cr3Flags::PAGE_LEVEL_WRITETHROUGH,
x86_64::registers::control::Cr3Flags::PAGE_LEVEL_WRITETHROUGH CachePolicy::Uncacheable => x86_64::registers::control::Cr3Flags::PAGE_LEVEL_CACHE_DISABLE,
}
CachePolicy::Uncacheable => {
x86_64::registers::control::Cr3Flags::PAGE_LEVEL_CACHE_DISABLE
}
_ => panic!("unsupported cache policy for the root page table"), _ => panic!("unsupported cache policy for the root page table"),
}, };
);
// SAFETY: The safety is upheld by the caller.
unsafe { x86_64::registers::control::Cr3::write(addr, flags) };
} }
pub fn current_page_table_paddr() -> Paddr { pub fn current_page_table_paddr() -> Paddr {

View File

@ -231,13 +231,14 @@ unsafe fn dfs_release_lock<'rcu, E: PageTableEntryTrait, C: PagingConstsTrait>(
let child = cur_node.entry(i); let child = cur_node.entry(i);
match child.to_ref() { match child.to_ref() {
Child::PageTableRef(pt) => { Child::PageTableRef(pt) => {
// SAFETY: The caller ensures that the node is locked. // SAFETY: The caller ensures that the node is locked and the new guard is unique.
let child_node = unsafe { pt.make_guard_unchecked(guard) }; let child_node = unsafe { pt.make_guard_unchecked(guard) };
let child_node_va = cur_node_va + i * page_size::<C>(cur_level); let child_node_va = cur_node_va + i * page_size::<C>(cur_level);
let child_node_va_end = child_node_va + page_size::<C>(cur_level); let child_node_va_end = child_node_va + page_size::<C>(cur_level);
let va_start = va_range.start.max(child_node_va); let va_start = va_range.start.max(child_node_va);
let va_end = va_range.end.min(child_node_va_end); let va_end = va_range.end.min(child_node_va_end);
// SAFETY: The caller ensures that this sub-tree is locked. // SAFETY: The caller ensures that all the nodes in the sub-tree are locked and all
// guards are forgotten.
unsafe { dfs_release_lock(guard, child_node, child_node_va, va_start..va_end) }; unsafe { dfs_release_lock(guard, child_node, child_node_va, va_start..va_end) };
} }
Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {} Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {}
@ -273,9 +274,11 @@ pub(super) unsafe fn dfs_mark_stray_and_unlock<E: PageTableEntryTrait, C: Paging
let child = sub_tree.entry(i); let child = sub_tree.entry(i);
match child.to_ref() { match child.to_ref() {
Child::PageTableRef(pt) => { Child::PageTableRef(pt) => {
// SAFETY: The caller ensures that the node is locked. // SAFETY: The caller ensures that the node is locked and the new guard is unique.
let locked_pt = unsafe { pt.make_guard_unchecked(rcu_guard) }; let locked_pt = unsafe { pt.make_guard_unchecked(rcu_guard) };
dfs_mark_stray_and_unlock(rcu_guard, locked_pt); // SAFETY: The caller ensures that all the nodes in the sub-tree are locked and all
// guards are forgotten.
unsafe { dfs_mark_stray_and_unlock(rcu_guard, locked_pt) };
} }
Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {} Child::None | Child::Frame(_, _) | Child::Untracked(_, _, _) | Child::PageTable(_) => {}
} }