From 3c9ab308e14f1e4cee9839ba3a6a6400c22b79ea Mon Sep 17 00:00:00 2001 From: Zhang Junyang Date: Wed, 12 Jun 2024 09:31:33 +0000 Subject: [PATCH] Add the protect functionality in the boot page table for TDX --- framework/aster-frame/src/arch/x86/mm/mod.rs | 7 -- .../aster-frame/src/arch/x86/tdx_guest.rs | 59 +++++++---- framework/aster-frame/src/arch/x86/trap.rs | 6 +- .../aster-frame/src/mm/page_table/boot_pt.rs | 97 +++++++++++++++++-- 4 files changed, 131 insertions(+), 38 deletions(-) diff --git a/framework/aster-frame/src/arch/x86/mm/mod.rs b/framework/aster-frame/src/arch/x86/mm/mod.rs index f46fe6d93..20ac40780 100644 --- a/framework/aster-frame/src/arch/x86/mm/mod.rs +++ b/framework/aster-frame/src/arch/x86/mm/mod.rs @@ -161,13 +161,6 @@ impl PageTableEntryTrait for PageTableEntry { let flags = PageTableFlags::PRESENT.bits() | PageTableFlags::WRITABLE.bits() | PageTableFlags::USER.bits(); - #[cfg(feature = "intel_tdx")] - let flags = flags - | parse_flags!( - prop.priv_flags.bits(), - PrivFlags::SHARED, - PageTableFlags::SHARED - ); Self(paddr & Self::PHYS_ADDR_MASK | flags) } diff --git a/framework/aster-frame/src/arch/x86/tdx_guest.rs b/framework/aster-frame/src/arch/x86/tdx_guest.rs index 3d9fa70bc..879d9df12 100644 --- a/framework/aster-frame/src/arch/x86/tdx_guest.rs +++ b/framework/aster-frame/src/arch/x86/tdx_guest.rs @@ -11,15 +11,12 @@ use tdx_guest::{ }; use trapframe::TrapFrame; -use crate::{ - arch::mm::PageTableFlags, - mm::{ - kspace::KERNEL_PAGE_TABLE, - paddr_to_vaddr, - page_prop::{CachePolicy, PageProperty, PrivilegedPageFlags as PrivFlags}, - page_table::PageTableError, - KERNEL_BASE_VADDR, KERNEL_END_VADDR, PAGE_SIZE, - }, +use crate::mm::{ + kspace::{BOOT_PAGE_TABLE, KERNEL_BASE_VADDR, KERNEL_END_VADDR, KERNEL_PAGE_TABLE}, + paddr_to_vaddr, + page_prop::{PageProperty, PrivilegedPageFlags as PrivFlags}, + page_table::PageTableError, + PAGE_SIZE, }; const SHARED_BIT: u8 = 51; @@ -416,16 +413,28 @@ pub unsafe fn unprotect_gpa_range(gpa: TdxGpa, page_num: usize) -> Result<(), Pa if gpa & PAGE_MASK != 0 { warn!("Misaligned address: {:x}", gpa); } - let vaddr = paddr_to_vaddr(gpa); + // Protect the page in the kernel page table. let pt = KERNEL_PAGE_TABLE.get().unwrap(); - pt.protect(&(vaddr..page_num * PAGE_SIZE), |prop| { - prop = PageProperty { + let protect_op = |prop: &mut PageProperty| { + *prop = PageProperty { flags: prop.flags, cache: prop.cache, priv_flags: prop.priv_flags | PrivFlags::SHARED, } - }) - .map_err(PageConvertError::PageTableError)?; + }; + let vaddr = paddr_to_vaddr(gpa); + pt.protect(&(vaddr..page_num * PAGE_SIZE), protect_op) + .map_err(PageConvertError::PageTableError)?; + // Protect the page in the boot page table if in the boot phase. + { + let mut boot_pt_lock = BOOT_PAGE_TABLE.lock(); + if let Some(boot_pt) = boot_pt_lock.as_mut() { + for i in 0..page_num { + let vaddr = paddr_to_vaddr(gpa + i * PAGE_SIZE); + boot_pt.protect_base_page(vaddr, protect_op); + } + } + } map_gpa( (gpa & (!PAGE_MASK)) as u64 | SHARED_MASK, (page_num * PAGE_SIZE) as u64, @@ -452,16 +461,28 @@ pub unsafe fn protect_gpa_range(gpa: TdxGpa, page_num: usize) -> Result<(), Page if gpa & !PAGE_MASK == 0 { warn!("Misaligned address: {:x}", gpa); } - let vaddr = paddr_to_vaddr(gpa); + // Protect the page in the kernel page table. let pt = KERNEL_PAGE_TABLE.get().unwrap(); - pt.protect(&(vaddr..page_num * PAGE_SIZE), |prop| { - prop = PageProperty { + let protect_op = |prop: &mut PageProperty| { + *prop = PageProperty { flags: prop.flags, cache: prop.cache, priv_flags: prop.priv_flags - PrivFlags::SHARED, } - }) - .map_err(PageConvertError::PageTableError)?; + }; + let vaddr = paddr_to_vaddr(gpa); + pt.protect(&(vaddr..page_num * PAGE_SIZE), protect_op) + .map_err(PageConvertError::PageTableError)?; + // Protect the page in the boot page table if in the boot phase. + { + let mut boot_pt_lock = BOOT_PAGE_TABLE.lock(); + if let Some(boot_pt) = boot_pt_lock.as_mut() { + for i in 0..page_num { + let vaddr = paddr_to_vaddr(gpa + i * PAGE_SIZE); + boot_pt.protect_base_page(vaddr, protect_op); + } + } + } map_gpa((gpa & PAGE_MASK) as u64, (page_num * PAGE_SIZE) as u64) .map_err(PageConvertError::TdVmcallError)?; for i in 0..page_num { diff --git a/framework/aster-frame/src/arch/x86/trap.rs b/framework/aster-frame/src/arch/x86/trap.rs index c34c7e41a..0e0aad5f4 100644 --- a/framework/aster-frame/src/arch/x86/trap.rs +++ b/framework/aster-frame/src/arch/x86/trap.rs @@ -11,11 +11,7 @@ use tdx_guest::tdcall; use trapframe::TrapFrame; #[cfg(feature = "intel_tdx")] -use crate::arch::{ - cpu::VIRTUALIZATION_EXCEPTION, - mm::PageTableFlags, - tdx_guest::{handle_virtual_exception, TdxTrapFrame}, -}; +use crate::arch::{cpu::VIRTUALIZATION_EXCEPTION, tdx_guest::handle_virtual_exception}; use crate::{ cpu::{CpuException, PageFaultErrorCode, PAGE_FAULT}, cpu_local, diff --git a/framework/aster-frame/src/mm/page_table/boot_pt.rs b/framework/aster-frame/src/mm/page_table/boot_pt.rs index f59ac5764..96c82c9ee 100644 --- a/framework/aster-frame/src/mm/page_table/boot_pt.rs +++ b/framework/aster-frame/src/mm/page_table/boot_pt.rs @@ -10,8 +10,8 @@ use super::{pte_index, PageTableEntryTrait}; use crate::{ arch::mm::{PageTableEntry, PagingConsts}, mm::{ - paddr_to_vaddr, page::allocator::FRAME_ALLOCATOR, PageProperty, PagingConstsTrait, Vaddr, - PAGE_SIZE, + nr_subpage_per_huge, paddr_to_vaddr, page::allocator::FRAME_ALLOCATOR, PageProperty, + PagingConstsTrait, Vaddr, PAGE_SIZE, }, }; @@ -44,7 +44,15 @@ impl BootPageTable { } /// Maps a base page to a frame. + /// + /// # Panics + /// /// This function will panic if the page is already mapped. + /// + /// # Safety + /// + /// This function is unsafe because it can cause undefined behavior if the caller + /// maps a page in the kernel address space. pub unsafe fn map_base_page(&mut self, from: Vaddr, to: FrameNumber, prop: PageProperty) { let mut pt = self.root_pt; let mut level = C::NR_LEVELS; @@ -74,6 +82,67 @@ impl BootPageTable { unsafe { pte_ptr.write(E::new_frame(to * C::BASE_PAGE_SIZE, 1, prop)) }; } + /// Maps a base page to a frame. + /// + /// This function may split a huge page into base pages, causing page allocations + /// if the original mapping is a huge page. + /// + /// # Panics + /// + /// This function will panic if the page is already mapped. + /// + /// # Safety + /// + /// This function is unsafe because it can cause undefined behavior if the caller + /// maps a page in the kernel address space. + pub unsafe fn protect_base_page( + &mut self, + virt_addr: Vaddr, + mut op: impl FnMut(&mut PageProperty), + ) { + let mut pt = self.root_pt; + let mut level = C::NR_LEVELS; + // Walk to the last level of the page table. + while level > 1 { + let index = pte_index::(virt_addr, level); + let pte_ptr = unsafe { (paddr_to_vaddr(pt * C::BASE_PAGE_SIZE) as *mut E).add(index) }; + let pte = unsafe { pte_ptr.read() }; + pt = if !pte.is_present() { + panic!("protecting an unmapped page in the boot page table"); + } else if pte.is_last(level) { + // Split the huge page. + let frame = self.alloc_frame(); + let huge_pa = pte.paddr(); + for i in 0..nr_subpage_per_huge::() { + let nxt_ptr = + unsafe { (paddr_to_vaddr(frame * C::BASE_PAGE_SIZE) as *mut E).add(i) }; + unsafe { + nxt_ptr.write(E::new_frame( + huge_pa + i * C::BASE_PAGE_SIZE, + level - 1, + pte.prop(), + )) + }; + } + unsafe { pte_ptr.write(E::new_pt(frame * C::BASE_PAGE_SIZE)) }; + frame + } else { + pte.paddr() / C::BASE_PAGE_SIZE + }; + level -= 1; + } + // Do protection in the last level page table. + let index = pte_index::(virt_addr, 1); + let pte_ptr = unsafe { (paddr_to_vaddr(pt * C::BASE_PAGE_SIZE) as *mut E).add(index) }; + let pte = unsafe { pte_ptr.read() }; + if !pte.is_present() { + panic!("protecting an unmapped page in the boot page table"); + } + let mut prop = pte.prop(); + op(&mut prop); + unsafe { pte_ptr.write(E::new_frame(pte.paddr(), 1, prop)) }; + } + fn alloc_frame(&mut self) -> FrameNumber { let frame = FRAME_ALLOCATOR.get().unwrap().lock().alloc(1).unwrap(); self.frames.push(frame); @@ -94,7 +163,7 @@ impl Drop for BootPageTable #[cfg(ktest)] #[ktest] -fn test_boot_pt() { +fn test_boot_pt_map_protect() { use super::page_walk; use crate::{ arch::mm::{PageTableEntry, PagingConsts}, @@ -113,20 +182,34 @@ fn test_boot_pt() { let from1 = 0x1000; let to1 = 0x2; let prop1 = PageProperty::new(PageFlags::RW, CachePolicy::Writeback); - boot_pt.map_base_page(from1, to1, prop1); + unsafe { boot_pt.map_base_page(from1, to1, prop1) }; assert_eq!( unsafe { page_walk::(root_paddr, from1 + 1) }, Some((to1 * PAGE_SIZE + 1, prop1)) ); + unsafe { boot_pt.protect_base_page(from1, |prop| prop.flags = PageFlags::RX) }; + assert_eq!( + unsafe { page_walk::(root_paddr, from1 + 1) }, + Some(( + to1 * PAGE_SIZE + 1, + PageProperty::new(PageFlags::RX, CachePolicy::Writeback) + )) + ); let from2 = 0x2000; let to2 = 0x3; let prop2 = PageProperty::new(PageFlags::RX, CachePolicy::Uncacheable); - boot_pt.map_base_page(from2, to2, prop2); + unsafe { boot_pt.map_base_page(from2, to2, prop2) }; assert_eq!( unsafe { page_walk::(root_paddr, from2 + 2) }, Some((to2 * PAGE_SIZE + 2, prop2)) ); - - unsafe { boot_pt.retire() } + unsafe { boot_pt.protect_base_page(from2, |prop| prop.flags = PageFlags::RW) }; + assert_eq!( + unsafe { page_walk::(root_paddr, from2 + 2) }, + Some(( + to2 * PAGE_SIZE + 2, + PageProperty::new(PageFlags::RW, CachePolicy::Uncacheable) + )) + ); }