diff --git a/kernel/comps/virtio/src/device/block/device.rs b/kernel/comps/virtio/src/device/block/device.rs index 51c06143b..4a7ceecb1 100644 --- a/kernel/comps/virtio/src/device/block/device.rs +++ b/kernel/comps/virtio/src/device/block/device.rs @@ -81,10 +81,9 @@ impl aster_block::BlockDevice for BlockDevice { } fn metadata(&self) -> BlockDeviceMeta { - let device_config = self.device.config.read().unwrap(); BlockDeviceMeta { max_nr_segments_per_bio: self.queue.max_nr_segments_per_bio(), - nr_sectors: device_config.capacity_sectors(), + nr_sectors: VirtioBlockConfig::read_capacity_sectors(&self.device.config).unwrap(), } } } @@ -107,7 +106,7 @@ impl DeviceInner { pub fn init(mut transport: Box) -> Result, VirtioDeviceError> { let config = VirtioBlockConfig::new(transport.as_mut()); assert_eq!( - config.read().unwrap().block_size(), + VirtioBlockConfig::read_block_size(&config).unwrap(), VirtioBlockConfig::sector_size(), "currently not support customized device logical block size" ); diff --git a/kernel/comps/virtio/src/device/block/mod.rs b/kernel/comps/virtio/src/device/block/mod.rs index 9b9b8807f..d5de196fb 100644 --- a/kernel/comps/virtio/src/device/block/mod.rs +++ b/kernel/comps/virtio/src/device/block/mod.rs @@ -3,7 +3,7 @@ pub mod device; use aster_block::SECTOR_SIZE; -use aster_util::safe_ptr::SafePtr; +use aster_util::{field_ptr, safe_ptr::SafePtr}; use bitflags::bitflags; use int_to_c_enum::TryFromInt; use ostd::{io_mem::IoMem, Pod}; @@ -119,15 +119,17 @@ impl VirtioBlockConfig { SECTOR_SIZE } - pub(self) fn block_size(&self) -> usize { - self.blk_size as usize + pub(self) fn read_block_size(this: &SafePtr) -> ostd::prelude::Result { + field_ptr!(this, Self, blk_size) + .read_once() + .map(|val| val as usize) } - pub(self) fn capacity_sectors(&self) -> usize { - self.capacity as usize - } - - pub(self) fn capacity_bytes(&self) -> usize { - self.capacity_sectors() * Self::sector_size() + pub(self) fn read_capacity_sectors( + this: &SafePtr, + ) -> ostd::prelude::Result { + field_ptr!(this, Self, capacity) + .read_once() + .map(|val| val as usize) } } diff --git a/kernel/comps/virtio/src/device/input/device.rs b/kernel/comps/virtio/src/device/input/device.rs index 4f6121e97..a91cab04c 100644 --- a/kernel/comps/virtio/src/device/input/device.rs +++ b/kernel/comps/virtio/src/device/input/device.rs @@ -171,15 +171,17 @@ impl InputDevice { /// result to `out`, return the result size. pub fn query_config_select(&self, select: InputConfigSelect, subsel: u8, out: &mut [u8]) -> u8 { field_ptr!(&self.config, VirtioInputConfig, select) - .write(&(select as u8)) + .write_once(&(select as u8)) .unwrap(); field_ptr!(&self.config, VirtioInputConfig, subsel) - .write(&subsel) + .write_once(&subsel) .unwrap(); let size = field_ptr!(&self.config, VirtioInputConfig, size) - .read() + .read_once() .unwrap(); let data: [u8; 128] = field_ptr!(&self.config, VirtioInputConfig, data) + // FIXME: It is impossible to call `read_once` on `[u8; 128]`. What's the proper way to + // read this field out? .read() .unwrap(); out[..size as usize].copy_from_slice(&data[..size as usize]); diff --git a/kernel/comps/virtio/src/device/network/config.rs b/kernel/comps/virtio/src/device/network/config.rs index 278d89265..f29794963 100644 --- a/kernel/comps/virtio/src/device/network/config.rs +++ b/kernel/comps/virtio/src/device/network/config.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use aster_network::EthernetAddr; -use aster_util::safe_ptr::SafePtr; +use aster_util::{field_ptr, safe_ptr::SafePtr}; use bitflags::bitflags; use ostd::{io_mem::IoMem, Pod}; @@ -76,4 +76,25 @@ impl VirtioNetConfig { let memory = transport.device_config_memory(); SafePtr::new(memory, 0) } + + pub(super) fn read(this: &SafePtr) -> ostd::prelude::Result { + Ok(Self { + // FIXME: It is impossible to call `read_once` on `EthernetAddr`. What's the proper way + // to read this field out? + mac: field_ptr!(this, Self, mac).read()?, + status: field_ptr!(this, Self, status).read_once()?, + max_virtqueue_pairs: field_ptr!(this, Self, max_virtqueue_pairs).read_once()?, + mtu: field_ptr!(this, Self, mtu).read_once()?, + speed: field_ptr!(this, Self, speed).read_once()?, + duplex: field_ptr!(this, Self, duplex).read_once()?, + rss_max_key_size: field_ptr!(this, Self, rss_max_key_size).read_once()?, + rss_max_indirection_table_length: field_ptr!( + this, + Self, + rss_max_indirection_table_length + ) + .read_once()?, + supported_hash_types: field_ptr!(this, Self, supported_hash_types).read_once()?, + }) + } } diff --git a/kernel/comps/virtio/src/device/network/device.rs b/kernel/comps/virtio/src/device/network/device.rs index 22bc8330b..9d594a374 100644 --- a/kernel/comps/virtio/src/device/network/device.rs +++ b/kernel/comps/virtio/src/device/network/device.rs @@ -7,9 +7,9 @@ use aster_network::{ AnyNetworkDevice, EthernetAddr, RxBuffer, TxBuffer, VirtioNetError, RX_BUFFER_POOL, TX_BUFFER_POOL, }; -use aster_util::{field_ptr, slot_vec::SlotVec}; +use aster_util::slot_vec::SlotVec; use log::debug; -use ostd::{offset_of, sync::SpinLock, trap::TrapFrame}; +use ostd::{sync::SpinLock, trap::TrapFrame}; use smoltcp::phy::{DeviceCapabilities, Medium}; use super::{config::VirtioNetConfig, header::VirtioNetHdr}; @@ -44,13 +44,11 @@ impl NetworkDevice { )); debug!("virtio_net_config = {:?}", virtio_net_config); debug!("features = {:?}", features); - let mac_addr = field_ptr!(&virtio_net_config, VirtioNetConfig, mac) - .read() - .unwrap(); - let status = field_ptr!(&virtio_net_config, VirtioNetConfig, status) - .read() - .unwrap(); - debug!("mac addr = {:x?}, status = {:?}", mac_addr, status); + + let config = VirtioNetConfig::read(&virtio_net_config).unwrap(); + let mac_addr = config.mac; + debug!("mac addr = {:x?}, status = {:?}", mac_addr, config.status); + let mut recv_queue = VirtQueue::new(QUEUE_RECV, QUEUE_SIZE, transport.as_mut()) .expect("creating recv queue fails"); let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut()) @@ -71,7 +69,7 @@ impl NetworkDevice { recv_queue.notify(); } let mut device = Self { - config: virtio_net_config.read().unwrap(), + config, mac_addr, send_queue, recv_queue, diff --git a/kernel/comps/virtio/src/device/socket/device.rs b/kernel/comps/virtio/src/device/socket/device.rs index 4153d44f7..9f11412c3 100644 --- a/kernel/comps/virtio/src/device/socket/device.rs +++ b/kernel/comps/virtio/src/device/socket/device.rs @@ -53,10 +53,10 @@ impl SocketDevice { let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut()); debug!("virtio_vsock_config = {:?}", virtio_vsock_config); let guest_cid = field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low) - .read() + .read_once() .unwrap() as u64 | (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high) - .read() + .read_once() .unwrap() as u64) << 32; @@ -83,7 +83,7 @@ impl SocketDevice { } let mut device = Self { - config: virtio_vsock_config.read().unwrap(), + config: virtio_vsock_config.read_once().unwrap(), guest_cid, send_queue, recv_queue, diff --git a/kernel/comps/virtio/src/queue.rs b/kernel/comps/virtio/src/queue.rs index d3e295a7a..38e5e0652 100644 --- a/kernel/comps/virtio/src/queue.rs +++ b/kernel/comps/virtio/src/queue.rs @@ -342,7 +342,7 @@ impl VirtQueue { /// notify that there are available rings pub fn notify(&mut self) { - self.notify.write(&self.queue_idx).unwrap(); + self.notify.write_once(&self.queue_idx).unwrap(); } } diff --git a/kernel/comps/virtio/src/transport/mmio/device.rs b/kernel/comps/virtio/src/transport/mmio/device.rs index c0847ad76..a9db5b07a 100644 --- a/kernel/comps/virtio/src/transport/mmio/device.rs +++ b/kernel/comps/virtio/src/transport/mmio/device.rs @@ -79,7 +79,7 @@ impl VirtioMmioTransport { }; if device.common_device.version() == VirtioMmioVersion::Legacy { field_ptr!(&device.layout, VirtioMmioLayout, legacy_guest_page_size) - .write(&(PAGE_SIZE as u32)) + .write_once(&(PAGE_SIZE as u32)) .unwrap(); } device @@ -100,11 +100,11 @@ impl VirtioTransport for VirtioMmioTransport { device_ptr: &SafePtr, ) -> Result<(), VirtioTransportError> { field_ptr!(&self.layout, VirtioMmioLayout, queue_sel) - .write(&(idx as u32)) + .write_once(&(idx as u32)) .unwrap(); let queue_num_max: u32 = field_ptr!(&self.layout, VirtioMmioLayout, queue_num_max) - .read() + .read_once() .unwrap(); if queue_size as u32 > queue_num_max { @@ -117,7 +117,7 @@ impl VirtioTransport for VirtioMmioTransport { let device_paddr = device_ptr.paddr(); field_ptr!(&self.layout, VirtioMmioLayout, queue_num) - .write(&(queue_size as u32)) + .write_once(&(queue_size as u32)) .unwrap(); match self.common_device.version() { @@ -131,39 +131,39 @@ impl VirtioTransport for VirtioMmioTransport { assert_eq!(descriptor_paddr % PAGE_SIZE, 0); let pfn = (descriptor_paddr / PAGE_SIZE) as u32; field_ptr!(&self.layout, VirtioMmioLayout, legacy_queue_align) - .write(&(PAGE_SIZE as u32)) + .write_once(&(PAGE_SIZE as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, legacy_queue_pfn) - .write(&pfn) + .write_once(&pfn) .unwrap(); } VirtioMmioVersion::Modern => { field_ptr!(&self.layout, VirtioMmioLayout, queue_desc_low) - .write(&(descriptor_paddr as u32)) + .write_once(&(descriptor_paddr as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, queue_desc_high) - .write(&((descriptor_paddr >> 32) as u32)) + .write_once(&((descriptor_paddr >> 32) as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, queue_driver_low) - .write(&(driver_paddr as u32)) + .write_once(&(driver_paddr as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, queue_driver_high) - .write(&((driver_paddr >> 32) as u32)) + .write_once(&((driver_paddr >> 32) as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, queue_device_low) - .write(&(device_paddr as u32)) + .write_once(&(device_paddr as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, queue_device_high) - .write(&((device_paddr >> 32) as u32)) + .write_once(&((device_paddr >> 32) as u32)) .unwrap(); // enable queue field_ptr!(&self.layout, VirtioMmioLayout, queue_sel) - .write(&(idx as u32)) + .write_once(&(idx as u32)) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, queue_ready) - .write(&1u32) + .write_once(&1u32) .unwrap(); } }; @@ -182,10 +182,10 @@ impl VirtioTransport for VirtioMmioTransport { const MAX_QUEUES: u32 = 512; while num_queues < MAX_QUEUES { field_ptr!(&self.layout, VirtioMmioLayout, queue_sel) - .write(&num_queues) + .write_once(&num_queues) .unwrap(); if field_ptr!(&self.layout, VirtioMmioLayout, queue_num_max) - .read() + .read_once() .unwrap() == 0u32 { @@ -207,17 +207,17 @@ impl VirtioTransport for VirtioMmioTransport { fn device_features(&self) -> u64 { // select low field_ptr!(&self.layout, VirtioMmioLayout, device_features_select) - .write(&0u32) + .write_once(&0u32) .unwrap(); let device_feature_low = field_ptr!(&self.layout, VirtioMmioLayout, device_features) - .read() + .read_once() .unwrap(); // select high field_ptr!(&self.layout, VirtioMmioLayout, device_features_select) - .write(&1u32) + .write_once(&1u32) .unwrap(); let device_feature_high = field_ptr!(&self.layout, VirtioMmioLayout, device_features) - .read() + .read_once() .unwrap() as u64; device_feature_high << 32 | device_feature_low as u64 } @@ -226,16 +226,16 @@ impl VirtioTransport for VirtioMmioTransport { let low = features as u32; let high = (features >> 32) as u32; field_ptr!(&self.layout, VirtioMmioLayout, driver_features_select) - .write(&0u32) + .write_once(&0u32) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, driver_features) - .write(&low) + .write_once(&low) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, driver_features_select) - .write(&1u32) + .write_once(&1u32) .unwrap(); field_ptr!(&self.layout, VirtioMmioLayout, driver_features) - .write(&high) + .write_once(&high) .unwrap(); Ok(()) } @@ -243,7 +243,7 @@ impl VirtioTransport for VirtioMmioTransport { fn device_status(&self) -> DeviceStatus { DeviceStatus::from_bits( field_ptr!(&self.layout, VirtioMmioLayout, status) - .read() + .read_once() .unwrap() as u8, ) .unwrap() @@ -251,7 +251,7 @@ impl VirtioTransport for VirtioMmioTransport { fn set_device_status(&mut self, status: DeviceStatus) -> Result<(), VirtioTransportError> { field_ptr!(&self.layout, VirtioMmioLayout, status) - .write(&(status.bits() as u32)) + .write_once(&(status.bits() as u32)) .unwrap(); Ok(()) } @@ -262,10 +262,10 @@ impl VirtioTransport for VirtioMmioTransport { fn max_queue_size(&self, idx: u16) -> Result { field_ptr!(&self.layout, VirtioMmioLayout, queue_sel) - .write(&(idx as u32)) + .write_once(&(idx as u32)) .unwrap(); Ok(field_ptr!(&self.layout, VirtioMmioLayout, queue_num_max) - .read() + .read_once() .unwrap() as u16) } diff --git a/kernel/comps/virtio/src/transport/mmio/multiplex.rs b/kernel/comps/virtio/src/transport/mmio/multiplex.rs index 26ec2ed18..1d361d1fb 100644 --- a/kernel/comps/virtio/src/transport/mmio/multiplex.rs +++ b/kernel/comps/virtio/src/transport/mmio/multiplex.rs @@ -44,7 +44,7 @@ impl MultiplexIrq { return; }; let irq = multiplex_irq.read(); - let interrupt_status = irq.interrupt_status.read().unwrap(); + let interrupt_status = irq.interrupt_status.read_once().unwrap(); let callbacks = if interrupt_status & 0x01 == 1 { // Used buffer notification &irq.queue_callbacks @@ -55,7 +55,7 @@ impl MultiplexIrq { for callback in callbacks.iter() { callback.call((trap_frame,)); } - irq.interrupt_ack.write(&interrupt_status).unwrap(); + irq.interrupt_ack.write_once(&interrupt_status).unwrap(); }; lock.irq.on_active(callback); drop(lock); diff --git a/kernel/comps/virtio/src/transport/pci/device.rs b/kernel/comps/virtio/src/transport/pci/device.rs index 5567b7c34..6d90e09ad 100644 --- a/kernel/comps/virtio/src/transport/pci/device.rs +++ b/kernel/comps/virtio/src/transport/pci/device.rs @@ -80,30 +80,30 @@ impl VirtioTransport for VirtioPciTransport { return Err(VirtioTransportError::InvalidArgs); } field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_select) - .write(&idx) + .write_once(&idx) .unwrap(); debug_assert_eq!( field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_select) - .read() + .read_once() .unwrap(), idx ); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_size) - .write(&queue_size) + .write_once(&queue_size) .unwrap(); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_desc) - .write(&(descriptor_ptr.paddr() as u64)) + .write_once(&(descriptor_ptr.paddr() as u64)) .unwrap(); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_driver) - .write(&(avail_ring_ptr.paddr() as u64)) + .write_once(&(avail_ring_ptr.paddr() as u64)) .unwrap(); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_device) - .write(&(used_ring_ptr.paddr() as u64)) + .write_once(&(used_ring_ptr.paddr() as u64)) .unwrap(); // Enable queue field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_enable) - .write(&1u16) + .write_once(&1u16) .unwrap(); Ok(()) } @@ -120,7 +120,7 @@ impl VirtioTransport for VirtioPciTransport { fn num_queues(&self) -> u16 { field_ptr!(&self.common_cfg, VirtioPciCommonCfg, num_queues) - .read() + .read_once() .unwrap() } @@ -142,17 +142,17 @@ impl VirtioTransport for VirtioPciTransport { fn device_features(&self) -> u64 { // select low field_ptr!(&self.common_cfg, VirtioPciCommonCfg, device_feature_select) - .write(&0u32) + .write_once(&0u32) .unwrap(); let device_feature_low = field_ptr!(&self.common_cfg, VirtioPciCommonCfg, device_features) - .read() + .read_once() .unwrap(); // select high field_ptr!(&self.common_cfg, VirtioPciCommonCfg, device_feature_select) - .write(&1u32) + .write_once(&1u32) .unwrap(); let device_feature_high = field_ptr!(&self.common_cfg, VirtioPciCommonCfg, device_features) - .read() + .read_once() .unwrap() as u64; device_feature_high << 32 | device_feature_low as u64 } @@ -161,47 +161,47 @@ impl VirtioTransport for VirtioPciTransport { let low = features as u32; let high = (features >> 32) as u32; field_ptr!(&self.common_cfg, VirtioPciCommonCfg, driver_feature_select) - .write(&0u32) + .write_once(&0u32) .unwrap(); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, driver_features) - .write(&low) + .write_once(&low) .unwrap(); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, driver_feature_select) - .write(&1u32) + .write_once(&1u32) .unwrap(); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, driver_features) - .write(&high) + .write_once(&high) .unwrap(); Ok(()) } fn device_status(&self) -> DeviceStatus { let status = field_ptr!(&self.common_cfg, VirtioPciCommonCfg, device_status) - .read() + .read_once() .unwrap(); DeviceStatus::from_bits(status).unwrap() } fn set_device_status(&mut self, status: DeviceStatus) -> Result<(), VirtioTransportError> { field_ptr!(&self.common_cfg, VirtioPciCommonCfg, device_status) - .write(&(status.bits())) + .write_once(&(status.bits())) .unwrap(); Ok(()) } fn max_queue_size(&self, idx: u16) -> Result { field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_select) - .write(&idx) + .write_once(&idx) .unwrap(); debug_assert_eq!( field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_select) - .read() + .read_once() .unwrap(), idx ); Ok(field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_size) - .read() + .read_once() .unwrap()) } @@ -223,16 +223,16 @@ impl VirtioTransport for VirtioPciTransport { }; irq.on_active(func); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_select) - .write(&index) + .write_once(&index) .unwrap(); debug_assert_eq!( field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_select) - .read() + .read_once() .unwrap(), index ); field_ptr!(&self.common_cfg, VirtioPciCommonCfg, queue_msix_vector) - .write(&vector) + .write_once(&vector) .unwrap(); Ok(()) } diff --git a/ostd/src/bus/mmio/common_device.rs b/ostd/src/bus/mmio/common_device.rs index d83a2abb4..4e5504c15 100644 --- a/ostd/src/bus/mmio/common_device.rs +++ b/ostd/src/bus/mmio/common_device.rs @@ -8,7 +8,7 @@ use log::info; use super::VIRTIO_MMIO_MAGIC; use crate::{ io_mem::IoMem, - mm::{paddr_to_vaddr, Paddr, VmIo}, + mm::{paddr_to_vaddr, Paddr, VmIoOnce}, trap::IrqLine, }; @@ -55,12 +55,12 @@ impl MmioCommonDevice { /// Device ID pub fn device_id(&self) -> u32 { - self.io_mem.read_val::(8).unwrap() + self.io_mem.read_once::(8).unwrap() } /// Version of the MMIO device. pub fn version(&self) -> VirtioMmioVersion { - VirtioMmioVersion::try_from(self.io_mem.read_val::(4).unwrap()).unwrap() + VirtioMmioVersion::try_from(self.io_mem.read_once::(4).unwrap()).unwrap() } /// Interrupt line diff --git a/ostd/src/bus/pci/capability/msix.rs b/ostd/src/bus/pci/capability/msix.rs index 420aa9cac..00907c0b4 100644 --- a/ostd/src/bus/pci/capability/msix.rs +++ b/ostd/src/bus/pci/capability/msix.rs @@ -15,7 +15,7 @@ use crate::{ common_device::PciCommonDevice, device_info::PciDeviceLocation, }, - mm::VmIo, + mm::VmIoOnce, trap::IrqLine, }; @@ -121,15 +121,15 @@ impl CapabilityMsixData { // Set message address and disable this msix entry table_bar .io_mem() - .write_val((16 * i) as usize + table_offset, &message_address) + .write_once((16 * i) as usize + table_offset, &message_address) .unwrap(); table_bar .io_mem() - .write_val((16 * i + 4) as usize + table_offset, &message_upper_address) + .write_once((16 * i + 4) as usize + table_offset, &message_upper_address) .unwrap(); table_bar .io_mem() - .write_val((16 * i + 12) as usize + table_offset, &1_u32) + .write_once((16 * i + 12) as usize + table_offset, &1_u32) .unwrap(); } @@ -169,7 +169,7 @@ impl CapabilityMsixData { } self.table_bar .io_mem() - .write_val( + .write_once( (16 * index + 8) as usize + self.table_offset, &(handle.num() as u32), ) @@ -178,7 +178,7 @@ impl CapabilityMsixData { // Enable this msix vector self.table_bar .io_mem() - .write_val((16 * index + 12) as usize + self.table_offset, &0_u32) + .write_once((16 * index + 12) as usize + self.table_offset, &0_u32) .unwrap(); } diff --git a/ostd/src/io_mem.rs b/ostd/src/io_mem.rs index 54ecd8299..1e2f58c6e 100644 --- a/ostd/src/io_mem.rs +++ b/ostd/src/io_mem.rs @@ -2,14 +2,15 @@ //! I/O memory. -use core::{mem::size_of, ops::Range}; +use core::ops::Range; use crate::{ mm::{ - kspace::LINEAR_MAPPING_BASE_VADDR, paddr_to_vaddr, HasPaddr, Paddr, Vaddr, VmIo, VmReader, - VmWriter, + kspace::LINEAR_MAPPING_BASE_VADDR, paddr_to_vaddr, FallibleVmRead, FallibleVmWrite, + HasPaddr, Infallible, Paddr, PodOnce, Vaddr, VmIo, VmIoOnce, VmReader, VmWriter, }, - Error, Pod, Result, + prelude::*, + Error, }; /// I/O memory. @@ -19,45 +20,6 @@ pub struct IoMem { limit: usize, } -impl VmIo for IoMem { - fn read(&self, offset: usize, writer: &mut VmWriter) -> crate::Result<()> { - let read_len = writer.avail(); - self.check_range(offset, read_len)?; - unsafe { - core::ptr::copy( - (self.virtual_address + offset) as *const u8, - writer.cursor(), - read_len, - ); - } - Ok(()) - } - - fn write(&self, offset: usize, reader: &mut VmReader) -> crate::Result<()> { - let write_len = reader.remain(); - self.check_range(offset, write_len)?; - unsafe { - core::ptr::copy( - reader.cursor(), - (self.virtual_address + offset) as *mut u8, - write_len, - ); - } - Ok(()) - } - - fn read_val(&self, offset: usize) -> crate::Result { - self.check_range(offset, size_of::())?; - Ok(unsafe { core::ptr::read_volatile((self.virtual_address + offset) as *const T) }) - } - - fn write_val(&self, offset: usize, new_val: &T) -> crate::Result<()> { - self.check_range(offset, size_of::())?; - unsafe { core::ptr::write_volatile((self.virtual_address + offset) as *mut T, *new_val) }; - Ok(()) - } -} - impl HasPaddr for IoMem { fn paddr(&self) -> Paddr { self.virtual_address - LINEAR_MAPPING_BASE_VADDR @@ -69,7 +31,9 @@ impl IoMem { /// /// # Safety /// - /// User must ensure the given physical range is in the I/O memory region. + /// - The given physical address range must be in the I/O memory region. + /// - Reading from or writing to I/O memory regions may have side effects. Those side effects + /// must not cause soundness problems (e.g., they must not corrupt the kernel memory). pub(crate) unsafe fn new(range: Range) -> IoMem { IoMem { virtual_address: paddr_to_vaddr(range.start), @@ -111,18 +75,72 @@ impl IoMem { self.limit = range.len(); Ok(()) } +} - fn check_range(&self, offset: usize, len: usize) -> Result<()> { - let sum = offset.checked_add(len).ok_or(Error::InvalidArgs)?; - if sum > self.limit { - log::error!( - "attempt to access address out of bounds, limit:0x{:x}, access position:0x{:x}", - self.limit, - sum - ); - Err(Error::InvalidArgs) - } else { - Ok(()) - } +// For now, we reuse `VmReader` and `VmWriter` to access I/O memory. +// +// Note that I/O memory is not normal typed or untyped memory. Strictly speaking, it is not +// "memory", but rather I/O ports that communicate directly with the hardware. However, this code +// is in OSTD, so we can rely on the implementation details of `VmReader` and `VmWriter`, which we +// know are also suitable for accessing I/O memory. +impl IoMem { + fn reader(&self) -> VmReader<'_, Infallible> { + // SAFETY: The safety conditions of `IoMem::new` guarantee we can read from the I/O memory + // safely. + unsafe { VmReader::from_kernel_space(self.virtual_address as *mut u8, self.limit) } + } + + fn writer(&self) -> VmWriter<'_, Infallible> { + // SAFETY: The safety conditions of `IoMem::new` guarantee we can read from the I/O memory + // safely. + unsafe { VmWriter::from_kernel_space(self.virtual_address as *mut u8, self.limit) } + } +} + +impl VmIo for IoMem { + fn read(&self, offset: usize, writer: &mut VmWriter) -> Result<()> { + if self + .limit + .checked_sub(offset) + .is_none_or(|remain| remain < writer.avail()) + { + return Err(Error::InvalidArgs); + } + + self.reader() + .skip(offset) + .read_fallible(writer) + .map_err(|(e, _)| e)?; + debug_assert!(!writer.has_avail()); + + Ok(()) + } + + fn write(&self, offset: usize, reader: &mut VmReader) -> Result<()> { + if self + .limit + .checked_sub(offset) + .is_none_or(|avail| avail < reader.remain()) + { + return Err(Error::InvalidArgs); + } + + self.writer() + .skip(offset) + .write_fallible(reader) + .map_err(|(e, _)| e)?; + debug_assert!(!reader.has_remain()); + + Ok(()) + } +} + +impl VmIoOnce for IoMem { + fn read_once(&self, offset: usize) -> Result { + self.reader().skip(offset).read_once() + } + + fn write_once(&self, offset: usize, new_val: &T) -> Result<()> { + self.writer().skip(offset).write_once(new_val) } } diff --git a/ostd/src/lib.rs b/ostd/src/lib.rs index 745668ab7..d2a12c523 100644 --- a/ostd/src/lib.rs +++ b/ostd/src/lib.rs @@ -9,6 +9,7 @@ #![feature(coroutines)] #![feature(fn_traits)] #![feature(generic_const_exprs)] +#![feature(is_none_or)] #![feature(iter_from_coroutine)] #![feature(let_chains)] #![feature(min_specialization)]