diff --git a/Cargo.lock b/Cargo.lock index 2fc1f5f07..30a55d29d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -649,7 +649,6 @@ dependencies = [ "jinux-framebuffer", "jinux-std", "jinux-time", - "tdx-guest", "x86_64", ] diff --git a/Cargo.toml b/Cargo.toml index 5afcdc0fd..4208ebe50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ path = "kernel/main.rs" jinux-frame = { path = "framework/jinux-frame" } jinux-std = { path = "services/libs/jinux-std" } component = { path = "services/libs/comp-sys/component" } -tdx-guest = { path = "framework/libs/tdx-guest", optional = true } [dev-dependencies] x86_64 = "0.14.2" @@ -41,4 +40,4 @@ members = [ exclude = ["services/libs/comp-sys/controlled", "services/libs/comp-sys/cargo-component"] [features] -intel_tdx = ["dep:tdx-guest", "jinux-frame/intel_tdx", "jinux-std/intel_tdx"] +intel_tdx = ["jinux-frame/intel_tdx", "jinux-std/intel_tdx"] diff --git a/Makefile b/Makefile index cf2ae2e70..b176588ac 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ AUTO_SYSCALL_TEST ?= 0 BUILD_SYSCALL_TEST ?= 0 EMULATE_IOMMU ?= 0 ENABLE_KVM ?= 1 +INTEL_TDX ?= 0 # End of Make arguments KERNEL_CMDLINE := SHELL="/bin/sh" LOGNAME="root" HOME="/" USER="root" PATH="/bin" init=/usr/bin/busybox -- sh -l @@ -10,8 +11,15 @@ ifeq ($(AUTO_SYSCALL_TEST), 1) KERNEL_CMDLINE += /opt/syscall_test/run_syscall_test.sh endif +CARGO_KBUILD_ARGS := + CARGO_KRUN_ARGS := -- '$(KERNEL_CMDLINE)' +ifeq ($(INTEL_TDX), 1) +CARGO_KBUILD_ARGS += --features intel_tdx +CARGO_KRUN_ARGS += --features intel_tdx +endif + ifeq ($(ENABLE_KVM), 1) CARGO_KRUN_ARGS += --enable-kvm endif @@ -39,11 +47,7 @@ setup: build: @make --no-print-directory -C regression - @cargo kbuild - -build_td: - @make --no-print-directory -C regression - @cargo kbuild --features intel_tdx + @cargo kbuild $(CARGO_KBUILD_ARGS) tools: @cd services/libs/comp-sys && cargo install --path cargo-component diff --git a/framework/jinux-frame/src/arch/x86/cpu.rs b/framework/jinux-frame/src/arch/x86/cpu.rs index 61389ea2e..69afce1dd 100644 --- a/framework/jinux-frame/src/arch/x86/cpu.rs +++ b/framework/jinux-frame/src/arch/x86/cpu.rs @@ -6,7 +6,7 @@ use core::fmt::Debug; use trapframe::{GeneralRegs, UserContext as RawUserContext}; #[cfg(feature = "intel_tdx")] -use crate::arch::tdx_guest::{virtual_exception_handler, TdxTrapFrame}; +use crate::arch::tdx_guest::handle_virtual_exception; use log::debug; #[cfg(feature = "intel_tdx")] use tdx_guest::tdcall; @@ -43,55 +43,6 @@ pub struct TrapInformation { pub err: usize, } -#[cfg(feature = "intel_tdx")] -struct VeGeneralRegs<'a>(&'a mut GeneralRegs); - -#[cfg(feature = "intel_tdx")] -impl TdxTrapFrame for VeGeneralRegs<'_> { - fn rax(&self) -> usize { - self.0.rax - } - fn set_rax(&mut self, rax: usize) { - self.0.rax = rax; - } - fn rbx(&self) -> usize { - self.0.rbx - } - fn set_rbx(&mut self, rbx: usize) { - self.0.rbx = rbx; - } - fn rcx(&self) -> usize { - self.0.rcx - } - fn set_rcx(&mut self, rcx: usize) { - self.0.rcx = rcx; - } - fn rdx(&self) -> usize { - self.0.rdx - } - fn set_rdx(&mut self, rdx: usize) { - self.0.rdx = rdx; - } - fn rsi(&self) -> usize { - self.0.rsi - } - fn set_rsi(&mut self, rsi: usize) { - self.0.rsi = rsi; - } - fn rdi(&self) -> usize { - self.0.rdi - } - fn set_rdi(&mut self, rdi: usize) { - self.0.rdi = rdi; - } - fn rip(&self) -> usize { - self.0.rip - } - fn set_rip(&mut self, rip: usize) { - self.0.rip = rip; - } -} - impl UserContext { pub fn general_regs(&self) -> &GeneralRegs { &self.user_context.general @@ -130,8 +81,7 @@ impl UserContextApiInternal for UserContext { if *exception == VIRTUALIZATION_EXCEPTION { let ve_info = tdcall::get_veinfo().expect("#VE handler: fail to get VE info\n"); - let mut ve_f = VeGeneralRegs(self.general_regs_mut()); - virtual_exception_handler(&mut ve_f, &ve_info); + handle_virtual_exception(&mut (*self.general_regs_mut()).into(), &ve_info); continue; } if exception.typ == CpuExceptionType::FaultOrTrap diff --git a/framework/jinux-frame/src/arch/x86/tdx_guest.rs b/framework/jinux-frame/src/arch/x86/tdx_guest.rs index 9828a52e8..ad62dab19 100644 --- a/framework/jinux-frame/src/arch/x86/tdx_guest.rs +++ b/framework/jinux-frame/src/arch/x86/tdx_guest.rs @@ -1,31 +1,88 @@ use tdx_guest::{ tdcall::TdgVeInfo, - tdvmcall::{cpuid, hlt, rdmsr, wrmsr}, + tdvmcall::{cpuid, hlt, rdmsr, wrmsr, IoSize}, {serial_println, tdcall, tdvmcall, TdxVirtualExceptionType}, }; +use trapframe::{GeneralRegs, TrapFrame}; -pub trait TdxTrapFrame { - fn rax(&self) -> usize; - fn set_rax(&mut self, rax: usize); - fn rbx(&self) -> usize; - fn set_rbx(&mut self, rbx: usize); - fn rcx(&self) -> usize; - fn set_rcx(&mut self, rcx: usize); - fn rdx(&self) -> usize; - fn set_rdx(&mut self, rdx: usize); - fn rsi(&self) -> usize; - fn set_rsi(&mut self, rsi: usize); - fn rdi(&self) -> usize; - fn set_rdi(&mut self, rdi: usize); - fn rip(&self) -> usize; - fn set_rip(&mut self, rip: usize); +pub struct TdxTrapFrame { + rax: usize, + rbx: usize, + rcx: usize, + rdx: usize, + rsi: usize, + rdi: usize, + rip: usize, } -fn io_handler(trapframe: &mut dyn TdxTrapFrame, ve_info: &tdcall::TdgVeInfo) -> bool { +impl From for TdxTrapFrame { + fn from(tf: TrapFrame) -> Self { + Self { + rax: tf.rax, + rbx: tf.rbx, + rcx: tf.rcx, + rdx: tf.rdx, + rsi: tf.rsi, + rdi: tf.rdi, + rip: tf.rip, + } + } +} + +impl From for TdxTrapFrame { + fn from(gr: GeneralRegs) -> Self { + Self { + rax: gr.rax, + rbx: gr.rbx, + rcx: gr.rcx, + rdx: gr.rdx, + rsi: gr.rsi, + rdi: gr.rdi, + rip: gr.rip, + } + } +} + +pub fn handle_virtual_exception(trapframe: &mut TdxTrapFrame, ve_info: &TdgVeInfo) { + match ve_info.exit_reason.into() { + TdxVirtualExceptionType::Hlt => { + serial_println!("Ready to halt"); + hlt(); + } + TdxVirtualExceptionType::Io => { + if !handle_io(trapframe, ve_info) { + serial_println!("Handle tdx ioexit errors, ready to halt"); + hlt(); + } + } + TdxVirtualExceptionType::MsrRead => { + let msr = unsafe { rdmsr(trapframe.rcx as u32).unwrap() }; + trapframe.rax = (msr as u32 & u32::MAX) as usize; + trapframe.rdx = ((msr >> 32) as u32 & u32::MAX) as usize; + } + TdxVirtualExceptionType::MsrWrite => { + let data = trapframe.rax as u64 | ((trapframe.rdx as u64) << 32); + unsafe { wrmsr(trapframe.rcx as u32, data).unwrap() }; + } + TdxVirtualExceptionType::CpuId => { + let cpuid_info = cpuid(trapframe.rax as u32, trapframe.rcx as u32).unwrap(); + let mask = 0xFFFF_FFFF_0000_0000_usize; + trapframe.rax = (trapframe.rax & mask) | cpuid_info.eax; + trapframe.rbx = (trapframe.rbx & mask) | cpuid_info.ebx; + trapframe.rcx = (trapframe.rcx & mask) | cpuid_info.ecx; + trapframe.rdx = (trapframe.rdx & mask) | cpuid_info.edx; + } + TdxVirtualExceptionType::Other => panic!("Unknown TDX vitrual exception type"), + _ => return, + } + trapframe.rip += ve_info.exit_instruction_length as usize; +} + +fn handle_io(trapframe: &mut TdxTrapFrame, ve_info: &tdcall::TdgVeInfo) -> bool { let size = match ve_info.exit_qualification & 0x3 { - 0 => 1, - 1 => 2, - 3 => 4, + 0 => IoSize::Size1, + 1 => IoSize::Size2, + 3 => IoSize::Size4, _ => panic!("Invalid size value"), }; let direction = if (ve_info.exit_qualification >> 3) & 0x1 == 0 { @@ -33,8 +90,6 @@ fn io_handler(trapframe: &mut dyn TdxTrapFrame, ve_info: &tdcall::TdgVeInfo) -> } else { tdvmcall::Direction::In }; - let string = (ve_info.exit_qualification >> 4) & 0x1 == 1; - let repeat = (ve_info.exit_qualification >> 5) & 0x1 == 1; let operand = if (ve_info.exit_qualification >> 6) & 0x1 == 0 { tdvmcall::Operand::Dx } else { @@ -44,46 +99,11 @@ fn io_handler(trapframe: &mut dyn TdxTrapFrame, ve_info: &tdcall::TdgVeInfo) -> match direction { tdvmcall::Direction::In => { - trapframe.set_rax(tdvmcall::io_read(size, port).unwrap() as usize); + trapframe.rax = tdvmcall::io_read(size, port).unwrap() as usize; } tdvmcall::Direction::Out => { - tdvmcall::io_write(size, port, trapframe.rax() as u32).unwrap(); + tdvmcall::io_write(size, port, trapframe.rax as u32).unwrap(); } }; true } - -pub fn virtual_exception_handler(trapframe: &mut impl TdxTrapFrame, ve_info: &TdgVeInfo) { - match ve_info.exit_reason.into() { - TdxVirtualExceptionType::Hlt => { - serial_println!("Ready to halt"); - hlt(); - } - TdxVirtualExceptionType::Io => { - if !io_handler(trapframe, ve_info) { - serial_println!("Handle tdx ioexit errors, ready to halt"); - hlt(); - } - } - TdxVirtualExceptionType::MsrRead => { - let msr = rdmsr(trapframe.rcx() as u32).unwrap(); - trapframe.set_rax((msr as u32 & u32::MAX) as usize); - trapframe.set_rdx(((msr >> 32) as u32 & u32::MAX) as usize); - } - TdxVirtualExceptionType::MsrWrite => { - let data = trapframe.rax() as u64 | ((trapframe.rdx() as u64) << 32); - wrmsr(trapframe.rcx() as u32, data).unwrap(); - } - TdxVirtualExceptionType::CpuId => { - let cpuid_info = cpuid(trapframe.rax() as u32, trapframe.rcx() as u32).unwrap(); - let mask = 0xFFFF_FFFF_0000_0000_usize; - trapframe.set_rax((trapframe.rax() & mask) | cpuid_info.eax); - trapframe.set_rbx((trapframe.rbx() & mask) | cpuid_info.ebx); - trapframe.set_rcx((trapframe.rcx() & mask) | cpuid_info.ecx); - trapframe.set_rdx((trapframe.rdx() & mask) | cpuid_info.edx); - } - TdxVirtualExceptionType::Other => panic!("Unknown TDX vitrual exception type"), - _ => return, - } - trapframe.set_rip(trapframe.rip() + ve_info.exit_instruction_length as usize); -} diff --git a/framework/jinux-frame/src/lib.rs b/framework/jinux-frame/src/lib.rs index 6cea09da6..b8e930d74 100644 --- a/framework/jinux-frame/src/lib.rs +++ b/framework/jinux-frame/src/lib.rs @@ -41,7 +41,7 @@ use alloc::vec::Vec; use arch::irq::{IrqCallbackHandle, IrqLine}; use core::{mem, panic::PanicInfo}; #[cfg(feature = "intel_tdx")] -use tdx_guest::tdx_early_init; +use tdx_guest::init_tdx; use trapframe::TrapFrame; static mut IRQ_CALLBACK_LIST: Vec = Vec::new(); @@ -50,7 +50,7 @@ pub fn init() { arch::before_all_init(); logger::init(); #[cfg(feature = "intel_tdx")] - let td_info = tdx_early_init().unwrap(); + let td_info = init_tdx().unwrap(); #[cfg(feature = "intel_tdx")] println!( "td gpaw: {}, td attributes: {:?}\nTDX guest is initialized", diff --git a/framework/jinux-frame/src/trap/handler.rs b/framework/jinux-frame/src/trap/handler.rs index 4c81eac6b..e26430152 100644 --- a/framework/jinux-frame/src/trap/handler.rs +++ b/framework/jinux-frame/src/trap/handler.rs @@ -1,60 +1,11 @@ use crate::{arch::irq::IRQ_LIST, cpu::CpuException}; #[cfg(feature = "intel_tdx")] -use crate::arch::tdx_guest::{virtual_exception_handler, TdxTrapFrame}; +use crate::arch::tdx_guest::handle_virtual_exception; #[cfg(feature = "intel_tdx")] use tdx_guest::tdcall; use trapframe::TrapFrame; -#[cfg(feature = "intel_tdx")] -struct VeTrapFrame<'a>(&'a mut TrapFrame); - -#[cfg(feature = "intel_tdx")] -impl TdxTrapFrame for VeTrapFrame<'_> { - fn rax(&self) -> usize { - self.0.rax - } - fn set_rax(&mut self, rax: usize) { - self.0.rax = rax; - } - fn rbx(&self) -> usize { - self.0.rbx - } - fn set_rbx(&mut self, rbx: usize) { - self.0.rbx = rbx; - } - fn rcx(&self) -> usize { - self.0.rcx - } - fn set_rcx(&mut self, rcx: usize) { - self.0.rcx = rcx; - } - fn rdx(&self) -> usize { - self.0.rdx - } - fn set_rdx(&mut self, rdx: usize) { - self.0.rdx = rdx; - } - fn rsi(&self) -> usize { - self.0.rsi - } - fn set_rsi(&mut self, rsi: usize) { - self.0.rsi = rsi; - } - fn rdi(&self) -> usize { - self.0.rdi - } - fn set_rdi(&mut self, rdi: usize) { - self.0.rdi = rdi; - } - fn rip(&self) -> usize { - self.0.rip - } - fn set_rip(&mut self, rip: usize) { - self.0.rip = rip; - } -} - /// Only from kernel #[no_mangle] extern "sysv64" fn trap_handler(f: &mut TrapFrame) { @@ -62,8 +13,7 @@ extern "sysv64" fn trap_handler(f: &mut TrapFrame) { #[cfg(feature = "intel_tdx")] if f.trap_num as u16 == 20 { let ve_info = tdcall::get_veinfo().expect("#VE handler: fail to get VE info\n"); - let mut ve_f = VeTrapFrame(f); - virtual_exception_handler(&mut ve_f, &ve_info); + handle_virtual_exception(&mut (*f).into(), &ve_info); } #[cfg(not(feature = "intel_tdx"))] panic!("cannot handle kernel cpu fault now, information:{:#x?}", f); diff --git a/framework/libs/tdx-guest/src/lib.rs b/framework/libs/tdx-guest/src/lib.rs index 8ed1f0a1f..68cd6e29b 100644 --- a/framework/libs/tdx-guest/src/lib.rs +++ b/framework/libs/tdx-guest/src/lib.rs @@ -12,10 +12,9 @@ pub use self::tdcall::{get_veinfo, TdxVirtualExceptionType}; pub use self::tdvmcall::print; use raw_cpuid::{native_cpuid::cpuid_count, CpuIdResult}; -use tdcall::{InitError, TdgVeInfo, TdgVpInfo}; -use tdvmcall::*; +use tdcall::{InitError, TdgVpInfo}; -pub fn tdx_early_init() -> Result { +pub fn init_tdx() -> Result { check_tdx_guest()?; Ok(tdcall::get_tdinfo()?) } diff --git a/framework/libs/tdx-guest/src/tdcall.rs b/framework/libs/tdx-guest/src/tdcall.rs index 609afedd6..409e34f35 100644 --- a/framework/libs/tdx-guest/src/tdcall.rs +++ b/framework/libs/tdx-guest/src/tdcall.rs @@ -364,20 +364,16 @@ pub fn get_tdinfo() -> Result { rax: TdcallNum::VpInfo as u64, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => { - let td_info = TdgVpInfo { - gpaw: Gpaw::from(args.rcx), - attributes: GuestTdAttributes::from_bits_truncate(args.rdx), - num_vcpus: args.r8 as u32, - max_vcpus: (args.r8 >> 32) as u32, - vcpu_index: args.r9 as u32, - sys_rd: args.r10 as u32, - }; - Ok(td_info) - } - Err(res) => Err(res), - } + td_call(&mut args)?; + let td_info = TdgVpInfo { + gpaw: Gpaw::from(args.rcx), + attributes: GuestTdAttributes::from_bits_truncate(args.rdx), + num_vcpus: args.r8 as u32, + max_vcpus: (args.r8 >> 32) as u32, + vcpu_index: args.r9 as u32, + sys_rd: args.r10 as u32, + }; + Ok(td_info) } /// Get Virtualization Exception Information for the recent #VE exception. @@ -386,20 +382,16 @@ pub fn get_veinfo() -> Result { rax: TdcallNum::VpVeinfoGet as u64, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => { - let ve_info = TdgVeInfo { - exit_reason: args.rcx as u32, - exit_qualification: args.rdx, - guest_linear_address: args.r8, - guest_physical_address: args.r9, - exit_instruction_length: args.r10 as u32, - exit_instruction_info: (args.r10 >> 32) as u32, - }; - Ok(ve_info) - } - Err(res) => Err(res), - } + td_call(&mut args)?; + let ve_info = TdgVeInfo { + exit_reason: args.rcx as u32, + exit_qualification: args.rdx, + guest_linear_address: args.r8, + guest_physical_address: args.r9, + exit_instruction_length: args.r10 as u32, + exit_instruction_info: (args.r10 >> 32) as u32, + }; + Ok(ve_info) } /// Extend a TDCS.RTMR measurement register. @@ -408,72 +400,58 @@ pub fn extend_rtmr() -> Result<(), TdCallError> { rax: TdcallNum::MrRtmrExtend as u64, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_call(&mut args) } /// TDG.MR.REPORT creates a TDREPORT_STRUCT structure that contains the measurements/configuration /// information of the guest TD that called the function, measurements/configuration information /// of the Intel TDX module and a REPORTMACSTRUCT. -pub fn get_report(report_addr: u64, data_addr: u64) -> Result<(), TdCallError> { +pub fn get_report(report_gpa: &[u8], data_gpa: &[u8]) -> Result<(), TdCallError> { let mut args = TdcallArgs { rax: TdcallNum::MrReport as u64, - rcx: report_addr, - rdx: data_addr, + rcx: report_gpa.as_ptr() as u64, + rdx: data_gpa.as_ptr() as u64, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_call(&mut args) } /// Verify a cryptographic REPORTMACSTRUCT that describes the contents of a TD, /// to determine that it was created on the current TEE on the current platform. -pub fn verify_report(report_mac_addr: u64) -> Result<(), TdCallError> { +pub fn verify_report(report_mac_gpa: &[u8]) -> Result<(), TdCallError> { let mut args = TdcallArgs { rax: TdcallNum::MrVerifyreport as u64, - rcx: report_mac_addr, + rcx: report_mac_gpa.as_ptr() as u64, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_call(&mut args) } /// Accept a pending private page and initialize it to all-0 using the TD ephemeral private key. -pub fn accept_page(sept_level: u64, addr: u64) -> Result<(), TdCallError> { +/// # Safety +/// The 'gpa' parameter must be a valid address. +pub unsafe fn accept_page(sept_level: u64, gpa: &[u8]) -> Result<(), TdCallError> { let mut args = TdcallArgs { rax: TdcallNum::MemPageAccept as u64, - rcx: sept_level | (addr << 12), + rcx: sept_level | ((gpa.as_ptr() as u64) << 12), ..Default::default() }; - match td_call(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_call(&mut args) } /// Read the GPA mapping and attributes of a TD private page. -pub fn read_page_attr(phy_addr: u64) -> Result { +pub fn read_page_attr(gpa: &[u8]) -> Result { let mut args = TdcallArgs { rax: TdcallNum::MemPageAttrRd as u64, - rcx: phy_addr, + rcx: gpa.as_ptr() as u64, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => { - let page_attr = PageAttr { - gpa_mapping: args.rcx, - gpa_attr: GpaAttrAll::from(args.rdx), - }; - Ok(page_attr) - } - Err(res) => Err(res), - } + td_call(&mut args)?; + let page_attr = PageAttr { + gpa_mapping: args.rcx, + gpa_attr: GpaAttrAll::from(args.rdx), + }; + Ok(page_attr) } /// Write the attributes of a private page. Create or remove L2 page aliases as required. @@ -485,16 +463,12 @@ pub fn write_page_attr(page_attr: PageAttr, attr_flags: u64) -> Result { - let page_attr = PageAttr { - gpa_mapping: args.rcx, - gpa_attr: GpaAttrAll::from(args.rdx), - }; - Ok(page_attr) - } - Err(res) => Err(res), - } + td_call(&mut args)?; + let page_attr = PageAttr { + gpa_mapping: args.rcx, + gpa_attr: GpaAttrAll::from(args.rdx), + }; + Ok(page_attr) } /// Read a TD-scope metadata field (control structure field) of a TD. @@ -504,10 +478,8 @@ pub fn read_td_metadata(field_identifier: u64) -> Result { rdx: field_identifier, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => Ok(args.r8), - Err(res) => Err(res), - } + td_call(&mut args)?; + Ok(args.r8) } /// Write a TD-scope metadata field (control structure field) of a TD. @@ -545,17 +517,13 @@ pub fn set_cpuidve(cpuidve_flag: u64) -> Result<(), TdCallError> { rcx: cpuidve_flag, ..Default::default() }; - match td_call(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_call(&mut args) } fn td_call(args: &mut TdcallArgs) -> Result<(), TdCallError> { - let td_call_result = unsafe { asm_td_call(args) }; - if td_call_result == 0 { - Ok(()) - } else { - Err(td_call_result.into()) + let result = unsafe { asm_td_call(args) }; + match result { + 0 => Ok(()), + _ => Err(result.into()), } } diff --git a/framework/libs/tdx-guest/src/tdvmcall.rs b/framework/libs/tdx-guest/src/tdvmcall.rs index d85f810c9..7ccc418a9 100644 --- a/framework/libs/tdx-guest/src/tdvmcall.rs +++ b/framework/libs/tdx-guest/src/tdvmcall.rs @@ -5,7 +5,7 @@ //! resumes the TD via a SEAMCALL [TDH.VP.ENTER] invocation. extern crate alloc; -use crate::{asm::asm_td_vmcall, tdcall::TdgVeInfo}; +use crate::asm::asm_td_vmcall; use alloc::fmt; use bitflags::bitflags; use core::fmt::Write; @@ -16,7 +16,7 @@ use x86_64::{ /// TDVMCALL Instruction Leaf Numbers Definition. #[repr(u64)] -pub enum TdvmcallNum { +pub enum TdVmcallNum { Cpuid = 0x0000a, Hlt = 0x0000c, Io = 0x0001e, @@ -88,132 +88,121 @@ pub enum Operand { Immediate, } +pub enum IoSize { + Size1 = 1, + Size2 = 2, + Size4 = 4, + Size8 = 8, +} + pub fn cpuid(eax: u32, ecx: u32) -> Result { let mut args = TdVmcallArgs { - r11: TdvmcallNum::Cpuid as u64, + r11: TdVmcallNum::Cpuid as u64, r12: eax as u64, r13: ecx as u64, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(CpuIdInfo { - eax: args.r12 as usize, - ebx: args.r13 as usize, - ecx: args.r14 as usize, - edx: args.r15 as usize, - }), - Err(res) => Err(res), - } + td_vmcall(&mut args)?; + Ok(CpuIdInfo { + eax: args.r12 as usize, + ebx: args.r13 as usize, + ecx: args.r14 as usize, + edx: args.r15 as usize, + }) } pub fn hlt() { let interrupt_blocked = !rflags::read().contains(RFlags::INTERRUPT_FLAG); let mut args = TdVmcallArgs { - r11: TdvmcallNum::Hlt as u64, + r11: TdVmcallNum::Hlt as u64, r12: interrupt_blocked as u64, ..Default::default() }; let _ = td_vmcall(&mut args); } -pub fn rdmsr(index: u32) -> Result { +/// # Safety +/// Make sure the index is valid. +pub unsafe fn rdmsr(index: u32) -> Result { let mut args = TdVmcallArgs { - r11: TdvmcallNum::Rdmsr as u64, + r11: TdVmcallNum::Rdmsr as u64, r12: index as u64, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(args.r11), - Err(res) => Err(res), - } + td_vmcall(&mut args)?; + Ok(args.r11) } -pub fn wrmsr(index: u32, value: u64) -> Result<(), TdVmcallError> { +/// # Safety +/// Make sure the index and the corresponding value are valid. +pub unsafe fn wrmsr(index: u32, value: u64) -> Result<(), TdVmcallError> { let mut args = TdVmcallArgs { - r11: TdvmcallNum::Wrmsr as u64, + r11: TdVmcallNum::Wrmsr as u64, r12: index as u64, r13: value, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_vmcall(&mut args) } -/// Used to help perform WBINVD operation. -pub fn wbinvd(wbinvd: u64) -> Result<(), TdVmcallError> { +/// Used to help perform WBINVD or WBNOINVD operation. +/// - cache_operation: 0: WBINVD, 1: WBNOINVD +pub fn perform_cache_operation(cache_operation: u64) -> Result<(), TdVmcallError> { let mut args = TdVmcallArgs { - r11: TdvmcallNum::Wbinvd as u64, - r12: wbinvd, + r11: TdVmcallNum::Wbinvd as u64, + r12: cache_operation, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_vmcall(&mut args) } -pub fn read_mmio(size: u64, mmio_addr: u64) -> Result { - match size { - 1 | 2 | 4 | 8 => {} - _ => return Err(TdVmcallError::TdxInvalidOperand), - } +/// # Safety +/// Make sure the mmio address is valid. +pub unsafe fn read_mmio(size: IoSize, mmio_gpa: &[u8]) -> Result { let mut args = TdVmcallArgs { - r11: TdvmcallNum::RequestMmio as u64, - r12: size, + r11: TdVmcallNum::RequestMmio as u64, + r12: size as u64, r13: 0, - r14: mmio_addr, + r14: mmio_gpa.as_ptr() as u64, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(args.r11), - Err(res) => Err(res), - } + td_vmcall(&mut args)?; + Ok(args.r11) } -pub fn write_mmio(size: u64, mmio_addr: u64, data: u64) -> Result<(), TdVmcallError> { - match size { - 1 | 2 | 4 | 8 => {} - _ => { - return Err(TdVmcallError::TdxInvalidOperand); - } - } +/// # Safety +/// Make sure the mmio address is valid. +pub unsafe fn write_mmio(size: IoSize, mmio_gpa: &[u8], data: u64) -> Result<(), TdVmcallError> { let mut args = TdVmcallArgs { - r11: TdvmcallNum::RequestMmio as u64, - r12: size, + r11: TdVmcallNum::RequestMmio as u64, + r12: size as u64, r13: 1, - r14: mmio_addr, + r14: mmio_gpa.as_ptr() as u64, r15: data, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_vmcall(&mut args) } macro_rules! io_read { ($port:expr, $ty:ty) => {{ let mut args = TdVmcallArgs { - r11: TdvmcallNum::Io as u64, + r11: TdVmcallNum::Io as u64, r12: core::mem::size_of::<$ty>() as u64, r13: IO_READ, r14: $port as u64, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(args.r11 as u32), - Err(res) => Err(res), - } + td_vmcall(&mut args)?; + Ok(args.r11 as u32) }}; } -pub fn io_read(size: usize, port: u16) -> Result { +pub fn io_read(size: IoSize, port: u16) -> Result { match size { - 1 => io_read!(port, u8), - 2 => io_read!(port, u16), - 4 => io_read!(port, u32), + IoSize::Size1 => io_read!(port, u8), + IoSize::Size2 => io_read!(port, u16), + IoSize::Size4 => io_read!(port, u32), _ => unreachable!(), } } @@ -221,39 +210,36 @@ pub fn io_read(size: usize, port: u16) -> Result { macro_rules! io_write { ($port:expr, $byte:expr, $size:expr) => {{ let mut args = TdVmcallArgs { - r11: TdvmcallNum::Io as u64, + r11: TdVmcallNum::Io as u64, r12: core::mem::size_of_val(&$byte) as u64, r13: IO_WRITE, r14: $port as u64, r15: $byte as u64, ..Default::default() }; - match td_vmcall(&mut args) { - Ok(()) => Ok(()), - Err(res) => Err(res), - } + td_vmcall(&mut args) }}; } -pub fn io_write(size: usize, port: u16, byte: u32) -> Result<(), TdVmcallError> { +pub fn io_write(size: IoSize, port: u16, byte: u32) -> Result<(), TdVmcallError> { match size { - 1 => io_write!(port, byte, u8), - 2 => io_write!(port, byte, u16), - 4 => io_write!(port, byte, u32), + IoSize::Size1 => io_write!(port, byte, u8), + IoSize::Size2 => io_write!(port, byte, u16), + IoSize::Size4 => io_write!(port, byte, u32), _ => unreachable!(), } } fn td_vmcall(args: &mut TdVmcallArgs) -> Result<(), TdVmcallError> { - let td_vmcall_result = unsafe { asm_td_vmcall(args) }; - if td_vmcall_result == 0 { - Ok(()) - } else { - Err(td_vmcall_result.into()) + let result = unsafe { asm_td_vmcall(args) }; + match result { + 0 => Ok(()), + _ => Err(result.into()), } } bitflags! { + /// LineSts: Line Status struct LineSts: u8 { const INPUT_FULL = 1; const OUTPUT_EMPTY = 1 << 5;