diff --git a/Cargo.lock b/Cargo.lock index 32b84e77e..b1c21d916 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,6 +109,7 @@ dependencies = [ "buddy_system_allocator", "cfg-if", "gimli", + "iced-x86", "inherit-methods-macro", "int-to-c-enum", "intrusive-collections", @@ -725,6 +726,15 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "iced-x86" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c447cff8c7f384a7d4f741cfcff32f75f3ad02b406432e8d6c878d56b1edf6b" +dependencies = [ + "lazy_static", +] + [[package]] name = "ident_case" version = "1.0.1" diff --git a/framework/aster-frame/Cargo.toml b/framework/aster-frame/Cargo.toml index 5c33d6fb5..4262cb2bb 100644 --- a/framework/aster-frame/Cargo.toml +++ b/framework/aster-frame/Cargo.toml @@ -36,6 +36,7 @@ acpi = "4.1.1" aml = "0.16.3" multiboot2 = "0.16.0" rsdp = "2.0.0" +iced-x86 = { version = "1.21.0", default-features = false, features = [ "no_std", "decoder", "gas" ], optional = true } [features] -intel_tdx = ["dep:tdx-guest"] +intel_tdx = ["dep:tdx-guest", "dep:iced-x86"] diff --git a/framework/aster-frame/src/arch/x86/cpu.rs b/framework/aster-frame/src/arch/x86/cpu.rs index 6dffb3ba4..16997064a 100644 --- a/framework/aster-frame/src/arch/x86/cpu.rs +++ b/framework/aster-frame/src/arch/x86/cpu.rs @@ -155,6 +155,60 @@ impl TdxTrapFrame for GeneralRegs { fn set_rip(&mut self, rip: usize) { self.rip = rip; } + fn r8(&self) -> usize { + self.r8 + } + fn set_r8(&mut self, r8: usize) { + self.r8 = r8; + } + fn r9(&self) -> usize { + self.r9 + } + fn set_r9(&mut self, r9: usize) { + self.r9 = r9; + } + fn r10(&self) -> usize { + self.r10 + } + fn set_r10(&mut self, r10: usize) { + self.r10 = r10; + } + fn r11(&self) -> usize { + self.r11 + } + fn set_r11(&mut self, r11: usize) { + self.r11 = r11; + } + fn r12(&self) -> usize { + self.r12 + } + fn set_r12(&mut self, r12: usize) { + self.r12 = r12; + } + fn r13(&self) -> usize { + self.r13 + } + fn set_r13(&mut self, r13: usize) { + self.r13 = r13; + } + fn r14(&self) -> usize { + self.r14 + } + fn set_r14(&mut self, r14: usize) { + self.r14 = r14; + } + fn r15(&self) -> usize { + self.r15 + } + fn set_r15(&mut self, r15: usize) { + self.r15 = r15; + } + fn rbp(&self) -> usize { + self.rbp + } + fn set_rbp(&mut self, rbp: usize) { + self.rbp = rbp; + } } impl UserContext { diff --git a/framework/aster-frame/src/arch/x86/kernel/apic/ioapic.rs b/framework/aster-frame/src/arch/x86/kernel/apic/ioapic.rs index 0834a16dc..ae0ddbb5a 100644 --- a/framework/aster-frame/src/arch/x86/kernel/apic/ioapic.rs +++ b/framework/aster-frame/src/arch/x86/kernel/apic/ioapic.rs @@ -2,11 +2,15 @@ use alloc::{vec, vec::Vec}; +#[cfg(feature = "intel_tdx")] +use ::tdx_guest::tdx_is_enabled; use acpi::PlatformInfo; use bit_field::BitField; use log::info; use spin::Once; +#[cfg(feature = "intel_tdx")] +use crate::arch::tdx_guest; use crate::{ arch::x86::kernel::acpi::ACPI_TABLES, sync::SpinLock, trap::IrqLine, vm::paddr_to_vaddr, Error, Result, @@ -151,6 +155,17 @@ pub fn init() { // FIXME: Is it possible to have an address that is not the default 0xFEC0_0000? // Need to find a way to determine if it is a valid address or not. const IO_APIC_DEFAULT_ADDRESS: usize = 0xFEC0_0000; + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the `IO_APIC_DEFAULT_ADDRESS` is a valid MMIO address before this operation. + // The `IO_APIC_DEFAULT_ADDRESS` is a well-known address used for IO APICs in x86 systems, and it is page-aligned, which is a requirement for the `unprotect_gpa_range` function. + // We are also ensuring that we are only unprotecting a single page. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the `unprotect_gpa_range` function. + if tdx_is_enabled() { + unsafe { + tdx_guest::unprotect_gpa_range(IO_APIC_DEFAULT_ADDRESS, 1).unwrap(); + } + } let mut io_apic = unsafe { IoApicAccess::new(IO_APIC_DEFAULT_ADDRESS) }; io_apic.set_id(0); let id = io_apic.id(); diff --git a/framework/aster-frame/src/arch/x86/kernel/apic/x2apic.rs b/framework/aster-frame/src/arch/x86/kernel/apic/x2apic.rs index da31bf41f..eaf4f2275 100644 --- a/framework/aster-frame/src/arch/x86/kernel/apic/x2apic.rs +++ b/framework/aster-frame/src/arch/x86/kernel/apic/x2apic.rs @@ -25,12 +25,25 @@ impl X2Apic { } pub fn enable(&mut self) { - // Enable + const X2APIC_ENABLE_BITS: u64 = { + // IA32_APIC_BASE MSR's EN bit: xAPIC global enable/disable + const EN_BIT_IDX: u8 = 11; + // IA32_APIC_BASE MSR's EXTD bit: Enable x2APIC mode + const EXTD_BIT_IDX: u8 = 10; + (1 << EN_BIT_IDX) | (1 << EXTD_BIT_IDX) + }; + // Safety: + // This is safe because we are ensuring that the operations are performed on valid MSRs. + // We are using them to read and write to the `IA32_APIC_BASE` and `IA32_X2APIC_SIVR` MSRs, which are well-defined and valid MSRs in x86 systems. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the `rdmsr` and `wrmsr` functions. unsafe { // Enable x2APIC mode globally let mut base = rdmsr(IA32_APIC_BASE); - base |= 0b1100_0000_0000; // Enable x2APIC and xAPIC - wrmsr(IA32_APIC_BASE, base); + // Enable x2APIC and xAPIC if they are not enabled by default + if base & X2APIC_ENABLE_BITS != X2APIC_ENABLE_BITS { + base |= X2APIC_ENABLE_BITS; + wrmsr(IA32_APIC_BASE, base); + } // Set SVR, Enable APIC and set Spurious Vector to 15 (Reserved irq number) let svr: u64 = 1 << 8 | 15; diff --git a/framework/aster-frame/src/arch/x86/mm/mod.rs b/framework/aster-frame/src/arch/x86/mm/mod.rs index d32d0ccc3..bb99e15f3 100644 --- a/framework/aster-frame/src/arch/x86/mm/mod.rs +++ b/framework/aster-frame/src/arch/x86/mm/mod.rs @@ -39,6 +39,9 @@ bitflags::bitflags! { /// Indicates that the mapping is present in all address spaces, so it isn't flushed from /// the TLB on an address space switch. const GLOBAL = 1 << 8; + /// TDX shared bit. + #[cfg(feature = "intel_tdx")] + const SHARED = 1 << 51; /// Forbid execute codes on the page. The NXE bits in EFER msr must be set. const NO_EXECUTE = 1 << 63; } @@ -176,7 +179,10 @@ impl PageTableFlagsTrait for PageTableFlags { impl PageTableEntry { /// 51:12 + #[cfg(not(feature = "intel_tdx"))] const PHYS_ADDR_MASK: usize = 0xF_FFFF_FFFF_F000; + #[cfg(feature = "intel_tdx")] + const PHYS_ADDR_MASK: usize = 0x7_FFFF_FFFF_F000; } impl PageTableEntryTrait for PageTableEntry { diff --git a/framework/aster-frame/src/arch/x86/mod.rs b/framework/aster-frame/src/arch/x86/mod.rs index ed2766238..b9f189cd7 100644 --- a/framework/aster-frame/src/arch/x86/mod.rs +++ b/framework/aster-frame/src/arch/x86/mod.rs @@ -16,6 +16,8 @@ pub(crate) mod timer; use core::{arch::x86_64::_rdtsc, sync::atomic::Ordering}; +#[cfg(feature = "intel_tdx")] +use ::tdx_guest::tdx_is_enabled; use kernel::apic::ioapic; use log::{info, warn}; @@ -39,6 +41,14 @@ pub(crate) fn after_all_init() { } console::callback_init(); timer::init(); + #[cfg(feature = "intel_tdx")] + if !tdx_is_enabled() { + match iommu::init() { + Ok(_) => {} + Err(err) => warn!("IOMMU initialization error:{:?}", err), + } + } + #[cfg(not(feature = "intel_tdx"))] match iommu::init() { Ok(_) => {} Err(err) => warn!("IOMMU initialization error:{:?}", err), diff --git a/framework/aster-frame/src/arch/x86/tdx_guest.rs b/framework/aster-frame/src/arch/x86/tdx_guest.rs index 33250b824..beaaa0b89 100644 --- a/framework/aster-frame/src/arch/x86/tdx_guest.rs +++ b/framework/aster-frame/src/arch/x86/tdx_guest.rs @@ -1,13 +1,30 @@ // SPDX-License-Identifier: MPL-2.0 +use iced_x86::{Code, Decoder, DecoderOptions, Instruction, Register}; +use log::warn; use tdx_guest::{ serial_println, tdcall, - tdcall::TdgVeInfo, + tdcall::{accept_page, TdgVeInfo}, tdvmcall, - tdvmcall::{cpuid, hlt, rdmsr, wrmsr, IoSize}, + tdvmcall::{cpuid, hlt, map_gpa, rdmsr, read_mmio, write_mmio, wrmsr, IoSize}, TdxVirtualExceptionType, }; +use crate::{ + arch::mm::{is_kernel_vaddr, PageTableFlags}, + config::PAGE_SIZE, + vm::{ + paddr_to_vaddr, + page_table::{PageTableError, KERNEL_PAGE_TABLE}, + }, +}; + +const SHARED_BIT: u8 = 51; +const SHARED_MASK: u64 = 1u64 << SHARED_BIT; + +// Intel TDX guest physical address. Maybe protected(private) gpa or unprotected(shared) gpa. +pub type TdxGpa = usize; + pub trait TdxTrapFrame { fn rax(&self) -> usize; fn set_rax(&mut self, rax: usize); @@ -23,9 +40,54 @@ pub trait TdxTrapFrame { fn set_rdi(&mut self, rdi: usize); fn rip(&self) -> usize; fn set_rip(&mut self, rip: usize); + fn r8(&self) -> usize; + fn set_r8(&mut self, r8: usize); + fn r9(&self) -> usize; + fn set_r9(&mut self, r9: usize); + fn r10(&self) -> usize; + fn set_r10(&mut self, r10: usize); + fn r11(&self) -> usize; + fn set_r11(&mut self, r11: usize); + fn r12(&self) -> usize; + fn set_r12(&mut self, r12: usize); + fn r13(&self) -> usize; + fn set_r13(&mut self, r13: usize); + fn r14(&self) -> usize; + fn set_r14(&mut self, r14: usize); + fn r15(&self) -> usize; + fn set_r15(&mut self, r15: usize); + fn rbp(&self) -> usize; + fn set_rbp(&mut self, rbp: usize); +} + +enum InstrMmioType { + Write, + WriteImm, + Read, + ReadZeroExtend, + ReadSignExtend, + Movs, +} + +#[derive(Debug)] +enum MmioError { + Unimplemented, + InvalidInstruction, + InvalidAddress, + DecodeFailed, + TdVmcallError(tdvmcall::TdVmcallError), +} + +#[derive(Debug)] +pub enum PageConvertError { + TdxPageStatusMismatch, + PageTableError(PageTableError), + TdCallError(tdcall::TdCallError), + TdVmcallError((u64, tdvmcall::TdVmcallError)), } pub fn handle_virtual_exception(trapframe: &mut dyn TdxTrapFrame, ve_info: &TdgVeInfo) { + let mut instr_len = ve_info.exit_instruction_length; match ve_info.exit_reason.into() { TdxVirtualExceptionType::Hlt => { serial_println!("Ready to halt"); @@ -54,10 +116,20 @@ pub fn handle_virtual_exception(trapframe: &mut dyn TdxTrapFrame, ve_info: &TdgV trapframe.set_rcx((trapframe.rcx() & mask) | cpuid_info.ecx); trapframe.set_rdx((trapframe.rdx() & mask) | cpuid_info.edx); } - TdxVirtualExceptionType::Other => panic!("Unknown TDX vitrual exception type"), + TdxVirtualExceptionType::EptViolation => { + if is_protected_gpa(ve_info.guest_physical_address as TdxGpa) { + serial_println!("Unexpected EPT-violation on private memory"); + hlt(); + } + instr_len = handle_mmio(trapframe, &ve_info).unwrap() as u32; + } + TdxVirtualExceptionType::Other => { + serial_println!("Unknown TDX vitrual exception type"); + hlt(); + } _ => return, } - trapframe.set_rip(trapframe.rip() + ve_info.exit_instruction_length as usize); + trapframe.set_rip(trapframe.rip() + instr_len as usize); } fn handle_io(trapframe: &mut dyn TdxTrapFrame, ve_info: &tdcall::TdgVeInfo) -> bool { @@ -89,3 +161,313 @@ fn handle_io(trapframe: &mut dyn TdxTrapFrame, ve_info: &tdcall::TdgVeInfo) -> b }; true } + +fn is_protected_gpa(gpa: TdxGpa) -> bool { + (gpa as u64 & SHARED_MASK) == 0 +} + +fn handle_mmio(trapframe: &mut dyn TdxTrapFrame, ve_info: &TdgVeInfo) -> Result { + // Get instruction + let instr = decode_instr(trapframe.rip())?; + + // Decode MMIO instruction + match decode_mmio(&instr) { + Some((mmio, size)) => { + match mmio { + InstrMmioType::Write => { + let value = match instr.op1_register() { + Register::RCX => trapframe.rcx() as u64, + Register::ECX => (trapframe.rcx() & 0xFFFF_FFFF) as u64, + Register::CX => (trapframe.rcx() & 0xFFFF) as u64, + Register::CL => (trapframe.rcx() & 0xFF) as u64, + _ => todo!(), + }; + // Safety: The mmio_gpa obtained from `ve_info` is valid, and the value and size parsed from the instruction are valid. + unsafe { + write_mmio(size, ve_info.guest_physical_address, value) + .map_err(|e| MmioError::TdVmcallError(e))? + } + } + InstrMmioType::WriteImm => { + let value = instr.immediate(0); + // Safety: The mmio_gpa obtained from `ve_info` is valid, and the value and size parsed from the instruction are valid. + unsafe { + write_mmio(size, ve_info.guest_physical_address, value) + .map_err(|e| MmioError::TdVmcallError(e))? + } + } + InstrMmioType::Read => + // Safety: The mmio_gpa obtained from `ve_info` is valid, and the size parsed from the instruction is valid. + unsafe { + let read_res = read_mmio(size, ve_info.guest_physical_address) + .map_err(|e| MmioError::TdVmcallError(e))? + as usize; + match instr.op0_register() { + Register::RAX => trapframe.set_rax(read_res), + Register::EAX => { + trapframe.set_rax((trapframe.rax() & 0xFFFF_FFFF_0000_0000) | read_res) + } + Register::AX => { + trapframe.set_rax((trapframe.rax() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::AL => { + trapframe.set_rax((trapframe.rax() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::RBX => trapframe.set_rbx(read_res), + Register::EBX => { + trapframe.set_rbx((trapframe.rbx() & 0xFFFF_FFFF_0000_0000) | read_res) + } + Register::BX => { + trapframe.set_rbx((trapframe.rbx() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::BL => { + trapframe.set_rbx((trapframe.rbx() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::RCX => trapframe.set_rcx(read_res), + Register::ECX => { + trapframe.set_rcx((trapframe.rcx() & 0xFFFF_FFFF_0000_0000) | read_res) + } + Register::CX => { + trapframe.set_rcx((trapframe.rcx() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::CL => { + trapframe.set_rcx((trapframe.rcx() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::RDX => trapframe.set_rdx(read_res), + Register::EDX => { + trapframe.set_rdx((trapframe.rdx() & 0xFFFF_FFFF_0000_0000) | read_res) + } + Register::DX => { + trapframe.set_rdx((trapframe.rdx() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::DL => { + trapframe.set_rdx((trapframe.rdx() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::SIL => { + trapframe.set_rsi((trapframe.rsi() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::DIL => { + trapframe.set_rdi((trapframe.rdi() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R8L => { + trapframe.set_r8((trapframe.r8() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R9L => { + trapframe.set_r9((trapframe.r9() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R10L => { + trapframe.set_r10((trapframe.r10() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R11L => { + trapframe.set_r11((trapframe.r11() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R11W => { + trapframe.set_r11((trapframe.r11() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::R12L => { + trapframe.set_r12((trapframe.r12() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R13L => { + trapframe.set_r13((trapframe.r13() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R13W => { + trapframe.set_r13((trapframe.r13() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::R14L => { + trapframe.set_r14((trapframe.r14() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::R14D => { + trapframe.set_r14((trapframe.r14() & 0xFFFF_FFFF_0000_0000) | read_res) + } + Register::R15L => { + trapframe.set_r15((trapframe.r15() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + Register::BP => { + trapframe.set_rbp((trapframe.rbp() & 0xFFFF_FFFF_FFFF_0000) | read_res) + } + Register::BPL => { + trapframe.set_rbp((trapframe.rbp() & 0xFFFF_FFFF_FFFF_FF00) | read_res) + } + _ => return Err(MmioError::Unimplemented), + } + }, + InstrMmioType::ReadZeroExtend => + // Safety: The mmio_gpa obtained from `ve_info` is valid, and the size parsed from the instruction is valid. + unsafe { + let read_res = read_mmio(size, ve_info.guest_physical_address) + .map_err(|e| MmioError::TdVmcallError(e))? + as usize; + match instr.op0_register() { + Register::RAX | Register::EAX | Register::AX | Register::AL => { + trapframe.set_rax(read_res) + } + Register::RBX | Register::EBX | Register::BX | Register::BL => { + trapframe.set_rbx(read_res) + } + Register::RCX | Register::ECX | Register::CX | Register::CL => { + trapframe.set_rcx(read_res) + } + _ => return Err(MmioError::Unimplemented), + } + }, + InstrMmioType::ReadSignExtend => return Err(MmioError::Unimplemented), + // MMIO was accessed with an instruction that could not be decoded or handled properly. + InstrMmioType::Movs => return Err(MmioError::InvalidInstruction), + } + } + None => { + return Err(MmioError::DecodeFailed); + } + } + Ok(instr.len()) +} + +fn decode_instr(rip: usize) -> Result { + if !is_kernel_vaddr(rip) { + return Err(MmioError::InvalidAddress); + } + let code_data = { + const MAX_X86_INSTR_LEN: usize = 15; + let mut data = [0u8; MAX_X86_INSTR_LEN]; + // Safety: + // This is safe because we are ensuring that 'rip' is a valid kernel virtual address before this operation. + // We are also ensuring that the size of the data we are copying does not exceed 'MAX_X86_INSTR_LEN'. + // Therefore, we are not reading any memory that we shouldn't be, and we are not causing any undefined behavior. + unsafe { + core::ptr::copy_nonoverlapping(rip as *const u8, data.as_mut_ptr(), data.len()); + } + data + }; + let mut decoder = Decoder::with_ip(64, &code_data, rip as u64, DecoderOptions::NONE); + let mut instr = Instruction::default(); + decoder.decode_out(&mut instr); + if instr.is_invalid() { + return Err(MmioError::InvalidInstruction); + } + Ok(instr) +} + +fn decode_mmio(instr: &Instruction) -> Option<(InstrMmioType, IoSize)> { + match instr.code() { + // 0x88 + Code::Mov_rm8_r8 => Some((InstrMmioType::Write, IoSize::Size1)), + // 0x89 + Code::Mov_rm16_r16 => Some((InstrMmioType::Write, IoSize::Size2)), + Code::Mov_rm32_r32 => Some((InstrMmioType::Write, IoSize::Size4)), + Code::Mov_rm64_r64 => Some((InstrMmioType::Write, IoSize::Size8)), + // 0xc6 + Code::Mov_rm8_imm8 => Some((InstrMmioType::WriteImm, IoSize::Size1)), + // 0xc7 + Code::Mov_rm16_imm16 => Some((InstrMmioType::WriteImm, IoSize::Size2)), + Code::Mov_rm32_imm32 => Some((InstrMmioType::WriteImm, IoSize::Size4)), + Code::Mov_rm64_imm32 => Some((InstrMmioType::WriteImm, IoSize::Size8)), + // 0x8a + Code::Mov_r8_rm8 => Some((InstrMmioType::Read, IoSize::Size1)), + // 0x8b + Code::Mov_r16_rm16 => Some((InstrMmioType::Read, IoSize::Size2)), + Code::Mov_r32_rm32 => Some((InstrMmioType::Read, IoSize::Size4)), + Code::Mov_r64_rm64 => Some((InstrMmioType::Read, IoSize::Size8)), + // 0xa4 + Code::Movsb_m8_m8 => Some((InstrMmioType::Movs, IoSize::Size1)), + // 0xa5 + Code::Movsw_m16_m16 => Some((InstrMmioType::Movs, IoSize::Size2)), + Code::Movsd_m32_m32 => Some((InstrMmioType::Movs, IoSize::Size4)), + Code::Movsq_m64_m64 => Some((InstrMmioType::Movs, IoSize::Size8)), + // 0x0f 0xb6 + Code::Movzx_r16_rm8 | Code::Movzx_r32_rm8 | Code::Movzx_r64_rm8 => { + Some((InstrMmioType::ReadZeroExtend, IoSize::Size1)) + } + // 0x0f 0xb7 + Code::Movzx_r16_rm16 | Code::Movzx_r32_rm16 | Code::Movzx_r64_rm16 => { + Some((InstrMmioType::ReadZeroExtend, IoSize::Size2)) + } + // 0x0f 0xbe + Code::Movsx_r16_rm8 | Code::Movsx_r32_rm8 | Code::Movsx_r64_rm8 => { + Some((InstrMmioType::ReadSignExtend, IoSize::Size1)) + } + // 0x0f 0xbf + Code::Movsx_r16_rm16 | Code::Movsx_r32_rm16 | Code::Movsx_r64_rm16 => { + Some((InstrMmioType::ReadSignExtend, IoSize::Size2)) + } + _ => None, + } +} + +/// Sets the given physical address range to Intel TDX shared pages. +/// Clears the data within the given address range. +/// Make sure the provided physical address is page size aligned. +/// +/// # Safety +/// +/// To safely use this function, the caller must ensure that: +/// - The given guest physical address range is currently mapped in the page table. +/// - The `page_num` argument represents a valid number of pages. +/// - This function will erase any valid data in the range and should not assume that the data will still be there after the operation. +pub unsafe fn unprotect_gpa_range(gpa: TdxGpa, page_num: usize) -> Result<(), PageConvertError> { + const PAGE_MASK: usize = PAGE_SIZE - 1; + for i in 0..page_num { + if !is_protected_gpa(gpa + (i * PAGE_SIZE)) { + return Err(PageConvertError::TdxPageStatusMismatch); + } + } + if gpa & PAGE_MASK != 0 { + warn!("Misaligned address: {:x}", gpa); + } + let vaddr = paddr_to_vaddr(gpa); + let mut pt = KERNEL_PAGE_TABLE.get().unwrap().lock(); + unsafe { + for i in 0..page_num { + pt.protect( + vaddr + (i * PAGE_SIZE), + PageTableFlags::SHARED | PageTableFlags::WRITABLE | PageTableFlags::PRESENT, + ) + .map_err(|e| PageConvertError::PageTableError(e))?; + } + }; + map_gpa( + (gpa & (!PAGE_MASK)) as u64 | SHARED_MASK, + (page_num * PAGE_SIZE) as u64, + ) + .map_err(|e| PageConvertError::TdVmcallError(e)) +} + +/// Sets the given physical address range to Intel TDX private pages. +/// Make sure the provided physical address is page size aligned. +/// +/// # Safety +/// +/// To safely use this function, the caller must ensure that: +/// - The given guest physical address range is currently mapped in the page table. +/// - The `page_num` argument represents a valid number of pages. +/// +pub unsafe fn protect_gpa_range(gpa: TdxGpa, page_num: usize) -> Result<(), PageConvertError> { + const PAGE_MASK: usize = PAGE_SIZE - 1; + for i in 0..page_num { + if is_protected_gpa(gpa + (i * PAGE_SIZE)) { + return Err(PageConvertError::TdxPageStatusMismatch); + } + } + if gpa & !PAGE_MASK == 0 { + warn!("Misaligned address: {:x}", gpa); + } + let vaddr = paddr_to_vaddr(gpa); + let mut pt = KERNEL_PAGE_TABLE.get().unwrap().lock(); + unsafe { + for i in 0..page_num { + pt.protect( + vaddr + (i * PAGE_SIZE), + PageTableFlags::WRITABLE | PageTableFlags::PRESENT, + ) + .map_err(|e| PageConvertError::PageTableError(e))?; + } + }; + map_gpa((gpa & PAGE_MASK) as u64, (page_num * PAGE_SIZE) as u64) + .map_err(|e| PageConvertError::TdVmcallError(e))?; + for i in 0..page_num { + unsafe { + accept_page(0, (gpa + i * PAGE_SIZE) as u64) + .map_err(|e| PageConvertError::TdCallError(e))?; + } + } + Ok(()) +} diff --git a/framework/aster-frame/src/bus/mmio/mod.rs b/framework/aster-frame/src/bus/mmio/mod.rs index a33c274b9..c4941c5ef 100644 --- a/framework/aster-frame/src/bus/mmio/mod.rs +++ b/framework/aster-frame/src/bus/mmio/mod.rs @@ -8,9 +8,13 @@ pub mod device; use alloc::vec::Vec; use core::ops::Range; +#[cfg(feature = "intel_tdx")] +use ::tdx_guest::tdx_is_enabled; use log::debug; use self::bus::MmioBus; +#[cfg(feature = "intel_tdx")] +use crate::arch::tdx_guest; use crate::{ arch::kernel::IO_APIC, bus::mmio::device::MmioCommonDevice, sync::SpinLock, trap::IrqLine, vm::paddr_to_vaddr, @@ -22,6 +26,17 @@ pub static MMIO_BUS: SpinLock = SpinLock::new(MmioBus::new()); static IRQS: SpinLock> = SpinLock::new(Vec::new()); pub fn init() { + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the address range 0xFEB0_0000 to 0xFEB0_4000 is valid before this operation. + // The address range is page-aligned and falls within the MMIO range, which is a requirement for the `unprotect_gpa_range` function. + // We are also ensuring that we are only unprotecting four pages. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the `unprotect_gpa_range` function. + if tdx_is_enabled() { + unsafe { + tdx_guest::unprotect_gpa_range(0xFEB0_0000, 4).unwrap(); + } + } // FIXME: The address 0xFEB0_0000 is obtained from an instance of microvm, and it may not work in other architecture. iter_range(0xFEB0_0000..0xFEB0_4000); } diff --git a/framework/aster-frame/src/bus/pci/capability/msix.rs b/framework/aster-frame/src/bus/pci/capability/msix.rs index 5305afc16..7bf4d621d 100644 --- a/framework/aster-frame/src/bus/pci/capability/msix.rs +++ b/framework/aster-frame/src/bus/pci/capability/msix.rs @@ -2,6 +2,11 @@ use alloc::{sync::Arc, vec::Vec}; +#[cfg(feature = "intel_tdx")] +use ::tdx_guest::tdx_is_enabled; + +#[cfg(feature = "intel_tdx")] +use crate::arch::tdx_guest; use crate::{ bus::pci::{ cfg_space::{Bar, Command, MemoryBar}, @@ -90,6 +95,20 @@ impl CapabilityMsixData { // Set message address 0xFEE0_0000 for i in 0..table_size { + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the physical address of the MSI-X table is valid before this operation. + // We are also ensuring that we are only unprotecting a single page. + // The MSI-X table will not exceed one page size, because the size of an MSI-X entry is 16 bytes, and 256 entries are required to fill a page, + // which is just equal to the number of all the interrupt numbers on the x86 platform. + // It is better to add a judgment here in case the device deliberately uses so many interrupt numbers. + // In addition, due to granularity, the minimum value that can be set here is only one page. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the `unprotect_gpa_range` function. + if tdx_is_enabled() { + unsafe { + tdx_guest::unprotect_gpa_range(table_bar.io_mem().paddr(), 1).unwrap(); + } + } // Set message address and disable this msix entry table_bar .io_mem() diff --git a/framework/aster-frame/src/trap/handler.rs b/framework/aster-frame/src/trap/handler.rs index 7d58981b4..dbea33f89 100644 --- a/framework/aster-frame/src/trap/handler.rs +++ b/framework/aster-frame/src/trap/handler.rs @@ -8,6 +8,13 @@ use trapframe::TrapFrame; #[cfg(feature = "intel_tdx")] use crate::arch::tdx_guest::{handle_virtual_exception, TdxTrapFrame}; +#[cfg(feature = "intel_tdx")] +use crate::arch::{ + mm::PageTableFlags, + tdx_guest::{handle_virtual_exception, TdxTrapFrame}, +}; +#[cfg(feature = "intel_tdx")] +use crate::vm::{page_table::KERNEL_PAGE_TABLE, vaddr_to_paddr}; use crate::{arch::irq::IRQ_LIST, cpu::CpuException, cpu_local}; #[cfg(feature = "intel_tdx")] @@ -54,19 +61,92 @@ impl TdxTrapFrame for TrapFrame { fn set_rip(&mut self, rip: usize) { self.rip = rip; } + fn r8(&self) -> usize { + self.r8 + } + fn set_r8(&mut self, r8: usize) { + self.r8 = r8; + } + fn r9(&self) -> usize { + self.r9 + } + fn set_r9(&mut self, r9: usize) { + self.r9 = r9; + } + fn r10(&self) -> usize { + self.r10 + } + fn set_r10(&mut self, r10: usize) { + self.r10 = r10; + } + fn r11(&self) -> usize { + self.r11 + } + fn set_r11(&mut self, r11: usize) { + self.r11 = r11; + } + fn r12(&self) -> usize { + self.r12 + } + fn set_r12(&mut self, r12: usize) { + self.r12 = r12; + } + fn r13(&self) -> usize { + self.r13 + } + fn set_r13(&mut self, r13: usize) { + self.r13 = r13; + } + fn r14(&self) -> usize { + self.r14 + } + fn set_r14(&mut self, r14: usize) { + self.r14 = r14; + } + fn r15(&self) -> usize { + self.r15 + } + fn set_r15(&mut self, r15: usize) { + self.r15 = r15; + } + fn rbp(&self) -> usize { + self.rbp + } + fn set_rbp(&mut self, rbp: usize) { + self.rbp = rbp; + } } /// Only from kernel #[no_mangle] extern "sysv64" fn trap_handler(f: &mut TrapFrame) { if CpuException::is_cpu_exception(f.trap_num as u16) { + const VIRTUALIZATION_EXCEPTION: u16 = 20; + const PAGE_FAULT: u16 = 14; #[cfg(feature = "intel_tdx")] - if f.trap_num as u16 == 20 { + if f.trap_num as u16 == VIRTUALIZATION_EXCEPTION { let ve_info = tdcall::get_veinfo().expect("#VE handler: fail to get VE info\n"); handle_virtual_exception(f, &ve_info); return; } - panic!("cannot handle kernel cpu fault now, information:{:#x?}", f); + #[cfg(feature = "intel_tdx")] + if f.trap_num as u16 == PAGE_FAULT { + let mut pt = KERNEL_PAGE_TABLE.get().unwrap().lock(); + // Safety: Map virtio addr when set shared bit in a TD. Only add the `PageTableFlags::SHARED` flag. + unsafe { + let page_fault_vaddr = x86::controlregs::cr2(); + let _ = pt.map( + page_fault_vaddr, + vaddr_to_paddr(page_fault_vaddr).unwrap(), + PageTableFlags::SHARED | PageTableFlags::PRESENT | PageTableFlags::WRITABLE, + ); + }; + return; + } + panic!( + "cannot handle this kernel cpu fault now, information:{:#x?}", + f + ); } else { call_irq_callback_functions(f); } diff --git a/framework/aster-frame/src/vm/dma/dma_coherent.rs b/framework/aster-frame/src/vm/dma/dma_coherent.rs index 577f33dc9..53fab945d 100644 --- a/framework/aster-frame/src/vm/dma/dma_coherent.rs +++ b/framework/aster-frame/src/vm/dma/dma_coherent.rs @@ -3,7 +3,12 @@ use alloc::sync::Arc; use core::ops::Deref; +#[cfg(feature = "intel_tdx")] +use ::tdx_guest::tdx_is_enabled; + use super::{check_and_insert_dma_mapping, remove_dma_mapping, DmaError, HasDaddr}; +#[cfg(feature = "intel_tdx")] +use crate::arch::tdx_guest; use crate::{ arch::{iommu, mm::PageTableFlags}, vm::{ @@ -63,7 +68,20 @@ impl DmaCoherent { } } let start_daddr = match dma_type() { - DmaType::Direct => start_paddr as Daddr, + DmaType::Direct => { + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the physical address range specified by `start_paddr` and `frame_count` is valid before these operations. + // The `check_and_insert_dma_mapping` function checks if the physical address range is already mapped. + // We are also ensuring that we are only modifying the page table entries corresponding to the physical address range specified by `start_paddr` and `frame_count`. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the 'unprotect_gpa_range' function. + if tdx_is_enabled() { + unsafe { + tdx_guest::unprotect_gpa_range(start_paddr, frame_count).unwrap(); + } + } + start_paddr as Daddr + } DmaType::Iommu => { for i in 0..frame_count { let paddr = start_paddr + (i * PAGE_SIZE); @@ -74,9 +92,6 @@ impl DmaCoherent { } start_paddr as Daddr } - DmaType::Tdx => { - todo!() - } }; Ok(Self { inner: Arc::new(DmaCoherentInner { @@ -106,16 +121,25 @@ impl Drop for DmaCoherentInner { let frame_count = self.vm_segment.nframes(); let start_paddr = self.vm_segment.start_paddr(); match dma_type() { - DmaType::Direct => {} + DmaType::Direct => { + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the physical address range specified by `start_paddr` and `frame_count` is valid before these operations. + // The `start_paddr()` ensures the `start_paddr` is page-aligned. + // We are also ensuring that we are only modifying the page table entries corresponding to the physical address range specified by `start_paddr` and `frame_count`. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the `protect_gpa_range` function. + if tdx_is_enabled() { + unsafe { + tdx_guest::protect_gpa_range(start_paddr, frame_count).unwrap(); + } + } + } DmaType::Iommu => { for i in 0..frame_count { let paddr = start_paddr + (i * PAGE_SIZE); iommu::unmap(paddr).unwrap(); } } - DmaType::Tdx => { - todo!(); - } } if !self.is_cache_coherent { let mut page_table = KERNEL_PAGE_TABLE.get().unwrap().lock(); diff --git a/framework/aster-frame/src/vm/dma/dma_stream.rs b/framework/aster-frame/src/vm/dma/dma_stream.rs index 774653562..e3153045b 100644 --- a/framework/aster-frame/src/vm/dma/dma_stream.rs +++ b/framework/aster-frame/src/vm/dma/dma_stream.rs @@ -3,7 +3,12 @@ use alloc::sync::Arc; use core::{arch::x86_64::_mm_clflush, ops::Range}; +#[cfg(feature = "intel_tdx")] +use ::tdx_guest::tdx_is_enabled; + use super::{check_and_insert_dma_mapping, remove_dma_mapping, DmaError, HasDaddr}; +#[cfg(feature = "intel_tdx")] +use crate::arch::tdx_guest; use crate::{ arch::iommu, error::Error, @@ -55,7 +60,20 @@ impl DmaStream { return Err(DmaError::AlreadyMapped); } let start_daddr = match dma_type() { - DmaType::Direct => start_paddr as Daddr, + DmaType::Direct => { + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the physical address range specified by `start_paddr` and `frame_count` is valid before these operations. + // The `check_and_insert_dma_mapping` function checks if the physical address range is already mapped. + // We are also ensuring that we are only modifying the page table entries corresponding to the physical address range specified by `start_paddr` and `frame_count`. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the 'unprotect_gpa_range' function. + if tdx_is_enabled() { + unsafe { + tdx_guest::unprotect_gpa_range(start_paddr, frame_count).unwrap(); + } + } + start_paddr as Daddr + } DmaType::Iommu => { for i in 0..frame_count { let paddr = start_paddr + (i * PAGE_SIZE); @@ -66,9 +84,6 @@ impl DmaStream { } start_paddr as Daddr } - DmaType::Tdx => { - todo!() - } }; Ok(Self { @@ -110,20 +125,15 @@ impl DmaStream { if self.inner.is_cache_coherent { return Ok(()); } - if dma_type() == DmaType::Tdx { - // copy pages. - todo!("support dma for tdx") - } else { - let start_va = self.inner.vm_segment.as_ptr(); - // TODO: Query the CPU for the cache line size via CPUID, we use 64 bytes as the cache line size here. - for i in byte_range.step_by(64) { - // Safety: the addresses is limited by a valid `byte_range`. - unsafe { - _mm_clflush(start_va.wrapping_add(i)); - } + let start_va = self.inner.vm_segment.as_ptr(); + // TODO: Query the CPU for the cache line size via CPUID, we use 64 bytes as the cache line size here. + for i in byte_range.step_by(64) { + // Safety: the addresses is limited by a valid `byte_range`. + unsafe { + _mm_clflush(start_va.wrapping_add(i)); } - Ok(()) } + Ok(()) } } @@ -138,16 +148,25 @@ impl Drop for DmaStreamInner { let frame_count = self.vm_segment.nframes(); let start_paddr = self.vm_segment.start_paddr(); match dma_type() { - DmaType::Direct => {} + DmaType::Direct => { + #[cfg(feature = "intel_tdx")] + // Safety: + // This is safe because we are ensuring that the physical address range specified by `start_paddr` and `frame_count` is valid before these operations. + // The `start_paddr()` ensures the `start_paddr` is page-aligned. + // We are also ensuring that we are only modifying the page table entries corresponding to the physical address range specified by `start_paddr` and `frame_count`. + // Therefore, we are not causing any undefined behavior or violating any of the requirements of the `protect_gpa_range` function. + if tdx_is_enabled() { + unsafe { + tdx_guest::protect_gpa_range(start_paddr, frame_count).unwrap(); + } + } + } DmaType::Iommu => { for i in 0..frame_count { let paddr = start_paddr + (i * PAGE_SIZE); iommu::unmap(paddr).unwrap(); } } - DmaType::Tdx => { - todo!(); - } } remove_dma_mapping(start_paddr, frame_count); } diff --git a/framework/aster-frame/src/vm/dma/mod.rs b/framework/aster-frame/src/vm/dma/mod.rs index 74475fa8e..283aa016d 100644 --- a/framework/aster-frame/src/vm/dma/mod.rs +++ b/framework/aster-frame/src/vm/dma/mod.rs @@ -18,16 +18,10 @@ use crate::{arch::iommu::has_iommu, config::PAGE_SIZE, sync::SpinLock}; /// the address space used by device side. pub type Daddr = usize; -fn has_tdx() -> bool { - // FIXME: Support TDX - false -} - #[derive(PartialEq)] pub enum DmaType { Direct, Iommu, - Tdx, } #[derive(Debug)] @@ -48,10 +42,8 @@ static DMA_MAPPING_SET: Once>> = Once::new(); pub fn dma_type() -> DmaType { if has_iommu() { DmaType::Iommu - } else if has_tdx() { - return DmaType::Tdx; } else { - return DmaType::Direct; + DmaType::Direct } } diff --git a/framework/libs/tdx-guest/Cargo.toml b/framework/libs/tdx-guest/Cargo.toml index 02383f272..4ef0ed681 100644 --- a/framework/libs/tdx-guest/Cargo.toml +++ b/framework/libs/tdx-guest/Cargo.toml @@ -10,4 +10,3 @@ x86_64 = "0.14.10" bitflags = "1.3" raw-cpuid = "10" lazy_static = { version = "1.4.0", features = ["spin_no_std"] } - diff --git a/framework/libs/tdx-guest/src/lib.rs b/framework/libs/tdx-guest/src/lib.rs index ae1347d5d..460d0d793 100644 --- a/framework/libs/tdx-guest/src/lib.rs +++ b/framework/libs/tdx-guest/src/lib.rs @@ -11,6 +11,8 @@ mod asm; pub mod tdcall; pub mod tdvmcall; +use core::sync::atomic::{AtomicBool, Ordering::Relaxed}; + use raw_cpuid::{native_cpuid::cpuid_count, CpuIdResult}; use tdcall::{InitError, TdgVpInfo}; @@ -19,8 +21,16 @@ pub use self::{ tdvmcall::print, }; +static TDX_ENABLED: AtomicBool = AtomicBool::new(false); + +#[inline(always)] +pub fn tdx_is_enabled() -> bool { + TDX_ENABLED.load(Relaxed) +} + pub fn init_tdx() -> Result { check_tdx_guest()?; + TDX_ENABLED.store(true, Relaxed); Ok(tdcall::get_tdinfo()?) } @@ -32,7 +42,7 @@ fn check_tdx_guest() -> Result<(), InitError> { } let cpuid_result: CpuIdResult = cpuid_count(TDX_CPUID_LEAF_ID as u32, 0); if &cpuid_result.ebx.to_ne_bytes() != b"Inte" - || &cpuid_result.ebx.to_ne_bytes() != b"lTDX" + || &cpuid_result.edx.to_ne_bytes() != b"lTDX" || &cpuid_result.ecx.to_ne_bytes() != b" " { return Err(InitError::TdxVendorIdError); diff --git a/framework/libs/tdx-guest/src/tdcall.rs b/framework/libs/tdx-guest/src/tdcall.rs index cd78e24c7..7df5f885b 100644 --- a/framework/libs/tdx-guest/src/tdcall.rs +++ b/framework/libs/tdx-guest/src/tdcall.rs @@ -327,6 +327,7 @@ pub enum TdxVirtualExceptionType { VmCall, Mwait, Monitor, + EptViolation, Wbinvd, Rdpmc, Other, @@ -344,6 +345,7 @@ impl From for TdxVirtualExceptionType { 32 => Self::MsrWrite, 36 => Self::Mwait, 39 => Self::Monitor, + 48 => Self::EptViolation, 54 => Self::Wbinvd, _ => Self::Other, } @@ -435,10 +437,10 @@ pub fn verify_report(report_mac_gpa: &[u8]) -> Result<(), TdCallError> { /// Accept a pending private page and initialize it to all-0 using the TD ephemeral private key. /// # Safety /// The 'gpa' parameter must be a valid address. -pub unsafe fn accept_page(sept_level: u64, gpa: &[u8]) -> Result<(), TdCallError> { +pub unsafe fn accept_page(sept_level: u64, gpa: u64) -> Result<(), TdCallError> { let mut args = TdcallArgs { rax: TdcallNum::MemPageAccept as u64, - rcx: sept_level | ((gpa.as_ptr() as u64) << 12), + rcx: sept_level | gpa, ..Default::default() }; td_call(&mut args) diff --git a/framework/libs/tdx-guest/src/tdvmcall.rs b/framework/libs/tdx-guest/src/tdvmcall.rs index 339b6991c..54e763625 100644 --- a/framework/libs/tdx-guest/src/tdvmcall.rs +++ b/framework/libs/tdx-guest/src/tdvmcall.rs @@ -163,12 +163,12 @@ pub fn perform_cache_operation(cache_operation: u64) -> Result<(), TdVmcallError /// # Safety /// Make sure the mmio address is valid. -pub unsafe fn read_mmio(size: IoSize, mmio_gpa: &[u8]) -> Result { +pub unsafe fn read_mmio(size: IoSize, mmio_gpa: u64) -> Result { let mut args = TdVmcallArgs { r11: TdVmcallNum::RequestMmio as u64, r12: size as u64, r13: 0, - r14: mmio_gpa.as_ptr() as u64, + r14: mmio_gpa, ..Default::default() }; td_vmcall(&mut args)?; @@ -177,18 +177,32 @@ pub unsafe fn read_mmio(size: IoSize, mmio_gpa: &[u8]) -> Result Result<(), TdVmcallError> { +pub unsafe fn write_mmio(size: IoSize, mmio_gpa: u64, data: u64) -> Result<(), TdVmcallError> { let mut args = TdVmcallArgs { r11: TdVmcallNum::RequestMmio as u64, r12: size as u64, r13: 1, - r14: mmio_gpa.as_ptr() as u64, + r14: mmio_gpa, r15: data, ..Default::default() }; td_vmcall(&mut args) } +/// MapGPA TDG.VP.VMCALL is used to help request the host VMM to map a GPA range as private +/// or shared-memory mappings. This API may also be used to convert page mappings from +/// private to shared. The GPA range passed in this operation can indicate if the mapping is +/// requested for a shared or private memory via the GPA.Shared bit in the start address. +pub fn map_gpa(gpa: u64, size: u64) -> Result<(), (u64, TdVmcallError)> { + let mut args = TdVmcallArgs { + r11: TdVmcallNum::Mapgpa as u64, + r12: gpa, + r13: size, + ..Default::default() + }; + td_vmcall(&mut args).map_err(|e| (args.r11, e)) +} + macro_rules! io_read { ($port:expr, $ty:ty) => {{ let mut args = TdVmcallArgs {