Allow page table protectors to flush TLB entries precisely

This commit is contained in:
Zhang Junyang 2024-08-12 08:11:45 +00:00 committed by Tate, Hongliang Tian
parent 9a6e1b03e3
commit 4844e7ca7c
8 changed files with 148 additions and 102 deletions

View File

@ -538,7 +538,7 @@ impl VmMappingInner {
debug_assert!(range.start % PAGE_SIZE == 0);
debug_assert!(range.end % PAGE_SIZE == 0);
let mut cursor = vm_space.cursor_mut(&range).unwrap();
cursor.protect(range.len(), |p| p.flags = perms.into(), true)?;
cursor.protect(range.len(), |p| p.flags = perms.into());
Ok(())
}

View File

@ -39,8 +39,8 @@ pub unsafe fn unprotect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), Pag
if gpa & PAGE_MASK != 0 {
warn!("Misaligned address: {:x}", gpa);
}
// Protect the page in the kernel page table.
let pt = KERNEL_PAGE_TABLE.get().unwrap();
// Protect the page in the boot page table if in the boot phase.
let protect_op = |prop: &mut PageProperty| {
*prop = PageProperty {
flags: prop.flags,
@ -48,10 +48,6 @@ pub unsafe fn unprotect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), Pag
priv_flags: prop.priv_flags | PrivFlags::SHARED,
}
};
let vaddr = paddr_to_vaddr(gpa);
pt.protect(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op)
.map_err(|_| PageConvertError::PageTable)?;
// Protect the page in the boot page table if in the boot phase.
{
let mut boot_pt_lock = BOOT_PAGE_TABLE.lock();
if let Some(boot_pt) = boot_pt_lock.as_mut() {
@ -61,6 +57,12 @@ pub unsafe fn unprotect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), Pag
}
}
}
// Protect the page in the kernel page table.
let pt = KERNEL_PAGE_TABLE.get().unwrap();
let vaddr = paddr_to_vaddr(gpa);
pt.protect_flush_tlb(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op)
.map_err(|_| PageConvertError::PageTable)?;
map_gpa(
(gpa & (!PAGE_MASK)) as u64 | SHARED_MASK,
(page_num * PAGE_SIZE) as u64,
@ -82,8 +84,8 @@ pub unsafe fn protect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), PageC
if gpa & !PAGE_MASK == 0 {
warn!("Misaligned address: {:x}", gpa);
}
// Protect the page in the kernel page table.
let pt = KERNEL_PAGE_TABLE.get().unwrap();
// Protect the page in the boot page table if in the boot phase.
let protect_op = |prop: &mut PageProperty| {
*prop = PageProperty {
flags: prop.flags,
@ -91,10 +93,6 @@ pub unsafe fn protect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), PageC
priv_flags: prop.priv_flags - PrivFlags::SHARED,
}
};
let vaddr = paddr_to_vaddr(gpa);
pt.protect(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op)
.map_err(|_| PageConvertError::PageTable)?;
// Protect the page in the boot page table if in the boot phase.
{
let mut boot_pt_lock = BOOT_PAGE_TABLE.lock();
if let Some(boot_pt) = boot_pt_lock.as_mut() {
@ -104,6 +102,12 @@ pub unsafe fn protect_gpa_range(gpa: Paddr, page_num: usize) -> Result<(), PageC
}
}
}
// Protect the page in the kernel page table.
let pt = KERNEL_PAGE_TABLE.get().unwrap();
let vaddr = paddr_to_vaddr(gpa);
pt.protect_flush_tlb(&(vaddr..vaddr + page_num * PAGE_SIZE), protect_op)
.map_err(|_| PageConvertError::PageTable)?;
map_gpa((gpa & PAGE_MASK) as u64, (page_num * PAGE_SIZE) as u64)
.map_err(|_| PageConvertError::TdVmcall)?;
for i in 0..page_num {

View File

@ -7,7 +7,7 @@ use cfg_if::cfg_if;
use super::{check_and_insert_dma_mapping, remove_dma_mapping, DmaError, HasDaddr};
use crate::{
arch::{iommu, mm::tlb_flush_addr_range},
arch::iommu,
mm::{
dma::{dma_type, Daddr, DmaType},
io::VmIoOnce,
@ -71,10 +71,9 @@ impl DmaCoherent {
// SAFETY: the physical mappings is only used by DMA so protecting it is safe.
unsafe {
page_table
.protect(&va_range, |p| p.cache = CachePolicy::Uncacheable)
.protect_flush_tlb(&va_range, |p| p.cache = CachePolicy::Uncacheable)
.unwrap();
}
tlb_flush_addr_range(&va_range);
}
let start_daddr = match dma_type() {
DmaType::Direct => {
@ -159,10 +158,9 @@ impl Drop for DmaCoherentInner {
// SAFETY: the physical mappings is only used by DMA so protecting it is safe.
unsafe {
page_table
.protect(&va_range, |p| p.cache = CachePolicy::Writeback)
.protect_flush_tlb(&va_range, |p| p.cache = CachePolicy::Writeback)
.unwrap();
}
tlb_flush_addr_range(&va_range);
}
remove_dma_mapping(start_paddr, frame_count);
}

View File

@ -629,34 +629,41 @@ where
PageTableItem::NotMapped { va: start, len }
}
/// Applies the given operation to all the mappings within the range.
/// Applies the operation to the next slot of mapping within the range.
///
/// The funtction will return an error if it is not allowed to protect an invalid range and
/// it does so, or if the range to be protected only covers a part of a page.
/// The range to be found in is the current virtual address with the
/// provided length.
///
/// The function stops and yields the actually protected range if it has
/// actually protected a page, no matter if the following pages are also
/// required to be protected.
///
/// It also makes the cursor moves forward to the next page after the
/// protected one. If no mapped pages exist in the following range, the
/// cursor will stop at the end of the range and return [`None`].
///
/// # Safety
///
/// The caller should ensure that the range being protected does not affect kernel's memory safety.
/// The caller should ensure that the range being protected with the
/// operation does not affect kernel's memory safety.
///
/// # Panics
///
/// This function will panic if:
/// - the range to be protected is out of the range where the cursor is required to operate.
pub unsafe fn protect(
/// - the range to be protected is out of the range where the cursor
/// is required to operate;
/// - the specified virtual address range only covers a part of a page.
pub unsafe fn protect_next(
&mut self,
len: usize,
mut op: impl FnMut(&mut PageProperty),
allow_protect_absent: bool,
) -> Result<(), PageTableError> {
op: &mut impl FnMut(&mut PageProperty),
) -> Option<Range<Vaddr>> {
let end = self.0.va + len;
assert!(end <= self.0.barrier_va.end);
while self.0.va < end {
let cur_pte = self.0.read_cur_pte();
if !cur_pte.is_present() {
if !allow_protect_absent {
return Err(PageTableError::ProtectingAbsent);
}
self.0.move_forward();
continue;
}
@ -664,18 +671,33 @@ where
// Go down if it's not a last node.
if !cur_pte.is_last(self.0.level) {
self.0.level_down();
// We have got down a level. If there's no mapped PTEs in
// the current node, we can go back and skip to save time.
if self.0.guards[(self.0.level - 1) as usize]
.as_ref()
.unwrap()
.nr_children()
== 0
{
self.0.level_up();
self.0.move_forward();
}
continue;
}
// Go down if the page size is too big and we are protecting part
// of untracked huge pages.
let vaddr_not_fit = self.0.va % page_size::<C>(self.0.level) != 0
|| self.0.va + page_size::<C>(self.0.level) > end;
if !self.0.in_tracked_range() && vaddr_not_fit {
self.level_down_split();
continue;
} else if vaddr_not_fit {
return Err(PageTableError::ProtectingPartial);
if self.0.va % page_size::<C>(self.0.level) != 0
|| self.0.va + page_size::<C>(self.0.level) > end
{
if self.0.in_tracked_range() {
panic!("protecting part of a huge page");
} else {
self.level_down_split();
continue;
}
}
let mut pte_prop = cur_pte.prop();
@ -683,10 +705,14 @@ where
let idx = self.0.cur_idx();
self.cur_node_mut().protect(idx, pte_prop);
let protected_va = self.0.va..self.0.va + page_size::<C>(self.0.level);
self.0.move_forward();
return Some(protected_va);
}
Ok(())
None
}
/// Consumes itself and leak the root guard for the caller if it locked the root level.

View File

@ -3,9 +3,8 @@
use core::{fmt::Debug, marker::PhantomData, ops::Range};
use super::{
nr_subpage_per_huge,
page_prop::{PageFlags, PageProperty},
page_size, Paddr, PagingConstsTrait, PagingLevel, Vaddr,
nr_subpage_per_huge, page_prop::PageProperty, page_size, Paddr, PagingConstsTrait, PagingLevel,
Vaddr,
};
use crate::{
arch::mm::{PageTableEntry, PagingConsts},
@ -29,10 +28,6 @@ pub enum PageTableError {
InvalidVaddr(Vaddr),
/// Using virtual address not aligned.
UnalignedVaddr,
/// Protecting a mapping that does not exist.
ProtectingAbsent,
/// Protecting a part of an already mapped page.
ProtectingPartial,
}
/// This is a compile-time technique to force the frame developers to distinguish
@ -98,24 +93,18 @@ impl PageTable<UserMode> {
}
}
/// Remove all write permissions from the user page table and create a cloned
/// new page table.
/// Create a cloned new page table.
///
/// This method takes a mutable cursor to the old page table that locks the
/// entire virtual address range. The caller may implement the copy-on-write
/// mechanism by first protecting the old page table and then clone it using
/// this method.
///
/// TODO: We may consider making the page table itself copy-on-write.
pub fn fork_copy_on_write(&self) -> Self {
let mut cursor = self.cursor_mut(&UserMode::VADDR_RANGE).unwrap();
// SAFETY: Protecting the user page table is safe.
unsafe {
cursor
.protect(
UserMode::VADDR_RANGE.len(),
|p: &mut PageProperty| p.flags -= PageFlags::W,
true,
)
.unwrap();
};
pub fn clone_with(
&self,
cursor: CursorMut<'_, UserMode, PageTableEntry, PagingConsts>,
) -> Self {
let root_node = cursor.leak_root_guard().unwrap();
const NR_PTES_PER_NODE: usize = nr_subpage_per_huge::<PagingConsts>();
@ -177,6 +166,26 @@ impl PageTable<KernelMode> {
}
}
}
/// Protect the given virtual address range in the kernel page table.
///
/// This method flushes the TLB entries when doing protection.
///
/// # Safety
///
/// The caller must ensure that the protection operation does not affect
/// the memory safety of the kernel.
pub unsafe fn protect_flush_tlb(
&self,
vaddr: &Range<Vaddr>,
mut op: impl FnMut(&mut PageProperty),
) -> Result<(), PageTableError> {
let mut cursor = CursorMut::new(self, vaddr)?;
while let Some(range) = cursor.protect_next(vaddr.end - cursor.virt_addr(), &mut op) {
crate::arch::mm::tlb_flush_addr(range.start);
}
Ok(())
}
}
impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTable<M, E, C>
@ -213,17 +222,6 @@ where
Ok(())
}
pub unsafe fn protect(
&self,
vaddr: &Range<Vaddr>,
op: impl FnMut(&mut PageProperty),
) -> Result<(), PageTableError> {
self.cursor_mut(vaddr)?
.protect(vaddr.len(), op, true)
.unwrap();
Ok(())
}
/// Query about the mapping of a single byte at the given virtual address.
///
/// Note that this function may fail reflect an accurate result if there are

View File

@ -8,6 +8,7 @@ use crate::{
kspace::LINEAR_MAPPING_BASE_VADDR,
page::{allocator, meta::FrameMeta},
page_prop::{CachePolicy, PageFlags},
MAX_USERSPACE_VADDR,
},
prelude::*,
};
@ -95,7 +96,7 @@ fn test_user_copy_on_write() {
unsafe { pt.cursor_mut(&from).unwrap().map(page.clone().into(), prop) };
assert_eq!(pt.query(from.start + 10).unwrap().0, start_paddr + 10);
let child_pt = pt.fork_copy_on_write();
let child_pt = pt.clone_with(pt.cursor_mut(&(0..MAX_USERSPACE_VADDR)).unwrap());
assert_eq!(pt.query(from.start + 10).unwrap().0, start_paddr + 10);
assert_eq!(child_pt.query(from.start + 10).unwrap().0, start_paddr + 10);
assert!(matches!(
@ -105,7 +106,7 @@ fn test_user_copy_on_write() {
assert!(pt.query(from.start + 10).is_none());
assert_eq!(child_pt.query(from.start + 10).unwrap().0, start_paddr + 10);
let sibling_pt = pt.fork_copy_on_write();
let sibling_pt = pt.clone_with(pt.cursor_mut(&(0..MAX_USERSPACE_VADDR)).unwrap());
assert!(sibling_pt.query(from.start + 10).is_none());
assert_eq!(child_pt.query(from.start + 10).unwrap().0, start_paddr + 10);
drop(pt);
@ -139,6 +140,25 @@ impl PagingConstsTrait for BasePagingConsts {
const PTE_SIZE: usize = core::mem::size_of::<PageTableEntry>();
}
impl<M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> PageTable<M, E, C>
where
[(); C::NR_LEVELS as usize]:,
{
fn protect(&self, range: &Range<Vaddr>, mut op: impl FnMut(&mut PageProperty)) {
let mut cursor = self.cursor_mut(range).unwrap();
loop {
unsafe {
if cursor
.protect_next(range.end - cursor.virt_addr(), &mut op)
.is_none()
{
break;
}
};
}
}
}
#[ktest]
fn test_base_protect_query() {
let pt = PageTable::<UserMode>::empty();
@ -162,7 +182,7 @@ fn test_base_protect_query() {
assert_eq!(va..va + page.size(), i * PAGE_SIZE..(i + 1) * PAGE_SIZE);
}
let prot = PAGE_SIZE * 18..PAGE_SIZE * 20;
unsafe { pt.protect(&prot, |p| p.flags -= PageFlags::W).unwrap() };
pt.protect(&prot, |p| p.flags -= PageFlags::W);
for (item, i) in pt.cursor(&prot).unwrap().zip(18..20) {
let PageTableItem::Mapped { va, page, prop } = item else {
panic!("Expected Mapped, got {:#x?}", item);
@ -225,7 +245,7 @@ fn test_untracked_large_protect_query() {
}
let ppn = from_ppn.start + 18..from_ppn.start + 20;
let va = UNTRACKED_OFFSET + PAGE_SIZE * ppn.start..UNTRACKED_OFFSET + PAGE_SIZE * ppn.end;
unsafe { pt.protect(&va, |p| p.flags -= PageFlags::W).unwrap() };
pt.protect(&va, |p| p.flags -= PageFlags::W);
for (item, i) in pt
.cursor(&(va.start - PAGE_SIZE..va.start))
.unwrap()

View File

@ -17,12 +17,12 @@ use super::{
io::UserSpace,
kspace::KERNEL_PAGE_TABLE,
page_table::{PageTable, UserMode},
PageProperty, VmReader, VmWriter,
PageFlags, PageProperty, VmReader, VmWriter,
};
use crate::{
arch::mm::{
current_page_table_paddr, tlb_flush_addr, tlb_flush_addr_range,
tlb_flush_all_excluding_global, PageTableEntry, PagingConsts,
current_page_table_paddr, tlb_flush_addr, tlb_flush_addr_range, PageTableEntry,
PagingConsts,
},
cpu::CpuExceptionInfo,
mm::{
@ -129,6 +129,18 @@ impl VmSpace {
/// read-only. And both the VM space will take handles to the same
/// physical memory pages.
pub fn fork_copy_on_write(&self) -> Self {
// Protect the parent VM space as read-only.
let end = MAX_USERSPACE_VADDR;
let mut cursor = self.pt.cursor_mut(&(0..end)).unwrap();
let mut op = |prop: &mut PageProperty| {
prop.flags -= PageFlags::W;
};
// SAFETY: It is safe to protect memory in the userspace.
while let Some(range) = unsafe { cursor.protect_next(end - cursor.virt_addr(), &mut op) } {
tlb_flush_addr(range.start);
}
let page_fault_handler = {
let new_handler = Once::new();
if let Some(handler) = self.page_fault_handler.get() {
@ -136,12 +148,11 @@ impl VmSpace {
}
new_handler
};
let new_space = Self {
pt: self.pt.fork_copy_on_write(),
Self {
pt: self.pt.clone_with(cursor),
page_fault_handler,
};
tlb_flush_all_excluding_global();
new_space
}
}
/// Creates a reader to read data from the user space of the current task.
@ -319,22 +330,14 @@ impl CursorMut<'_> {
/// # Panics
///
/// This method will panic if `len` is not page-aligned.
pub fn protect(
&mut self,
len: usize,
op: impl FnMut(&mut PageProperty),
allow_protect_absent: bool,
) -> Result<()> {
pub fn protect(&mut self, len: usize, mut op: impl FnMut(&mut PageProperty)) {
assert!(len % super::PAGE_SIZE == 0);
let start_va = self.virt_addr();
let end_va = start_va + len;
let end = self.0.virt_addr() + len;
// SAFETY: It is safe to protect memory in the userspace.
let result = unsafe { self.0.protect(len, op, allow_protect_absent) };
tlb_flush_addr_range(&(start_va..end_va));
Ok(result?)
while let Some(range) = unsafe { self.0.protect_next(end - self.0.virt_addr(), &mut op) } {
tlb_flush_addr(range.start);
}
}
}

View File

@ -16,7 +16,6 @@ use super::{
};
pub(crate) use crate::arch::task::{context_switch, TaskContext};
use crate::{
arch::mm::tlb_flush_addr_range,
cpu::CpuSet,
mm::{kspace::KERNEL_PAGE_TABLE, FrameAllocOptions, Paddr, PageFlags, Segment, PAGE_SIZE},
prelude::*,
@ -70,9 +69,8 @@ impl KernelStack {
unsafe {
let vaddr_range = guard_page_vaddr..guard_page_vaddr + PAGE_SIZE;
page_table
.protect(&vaddr_range, |p| p.flags -= PageFlags::RW)
.protect_flush_tlb(&vaddr_range, |p| p.flags -= PageFlags::RW)
.unwrap();
tlb_flush_addr_range(&vaddr_range);
}
Ok(Self {
segment: stack_segment,
@ -98,9 +96,8 @@ impl Drop for KernelStack {
unsafe {
let vaddr_range = guard_page_vaddr..guard_page_vaddr + PAGE_SIZE;
page_table
.protect(&vaddr_range, |p| p.flags |= PageFlags::RW)
.protect_flush_tlb(&vaddr_range, |p| p.flags |= PageFlags::RW)
.unwrap();
tlb_flush_addr_range(&vaddr_range);
}
}
}