From b1ea422efaf6b0becd7d6cd99d270ae01fcd12de Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Tue, 6 Aug 2024 00:06:02 +1000 Subject: [PATCH] Fix accesses to VirtIO queue DMA --- kernel/comps/virtio/src/queue.rs | 60 +++++++++++++++----------- kernel/libs/aster-util/src/safe_ptr.rs | 24 ++++++++++- 2 files changed, 58 insertions(+), 26 deletions(-) diff --git a/kernel/comps/virtio/src/queue.rs b/kernel/comps/virtio/src/queue.rs index 4a45f4a2b..869e9c0c3 100644 --- a/kernel/comps/virtio/src/queue.rs +++ b/kernel/comps/virtio/src/queue.rs @@ -129,17 +129,21 @@ impl VirtQueue { let mut desc = descs.get(i as usize).unwrap().clone(); let next_i = i + 1; if next_i != size { - field_ptr!(&desc, Descriptor, next).write(&next_i).unwrap(); + field_ptr!(&desc, Descriptor, next) + .write_once(&next_i) + .unwrap(); desc.add(1); descs.push(desc); } else { - field_ptr!(&desc, Descriptor, next).write(&(0u16)).unwrap(); + field_ptr!(&desc, Descriptor, next) + .write_once(&(0u16)) + .unwrap(); } } let notify = transport.get_notify_ptr(idx).unwrap(); field_ptr!(&avail_ring_ptr, AvailRing, flags) - .write(&(0u16)) + .write_once(&(0u16)) .unwrap(); Ok(VirtQueue { descs, @@ -177,10 +181,10 @@ impl VirtQueue { let desc = &self.descs[self.free_head as usize]; set_dma_buf(&desc.borrow_vm().restrict::(), *input); field_ptr!(desc, Descriptor, flags) - .write(&DescFlags::NEXT) + .write_once(&DescFlags::NEXT) .unwrap(); last = self.free_head; - self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); + self.free_head = field_ptr!(desc, Descriptor, next).read_once().unwrap(); } for output in outputs.iter() { let desc = &mut self.descs[self.free_head as usize]; @@ -189,17 +193,19 @@ impl VirtQueue { *output, ); field_ptr!(desc, Descriptor, flags) - .write(&(DescFlags::NEXT | DescFlags::WRITE)) + .write_once(&(DescFlags::NEXT | DescFlags::WRITE)) .unwrap(); last = self.free_head; - self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); + self.free_head = field_ptr!(desc, Descriptor, next).read_once().unwrap(); } // set last_elem.next = NULL { let desc = &mut self.descs[last as usize]; - let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap(); + let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read_once().unwrap(); flags.remove(DescFlags::NEXT); - field_ptr!(desc, Descriptor, flags).write(&flags).unwrap(); + field_ptr!(desc, Descriptor, flags) + .write_once(&flags) + .unwrap(); } self.num_used += (inputs.len() + outputs.len()) as u16; @@ -210,7 +216,7 @@ impl VirtQueue { field_ptr!(&self.avail, AvailRing, ring); let mut ring_slot_ptr = ring_ptr.cast::(); ring_slot_ptr.add(avail_slot as usize); - ring_slot_ptr.write(&head).unwrap(); + ring_slot_ptr.write_once(&head).unwrap(); } // write barrier fence(Ordering::SeqCst); @@ -218,7 +224,7 @@ impl VirtQueue { // increase head of avail ring self.avail_idx = self.avail_idx.wrapping_add(1); field_ptr!(&self.avail, AvailRing, idx) - .write(&self.avail_idx) + .write_once(&self.avail_idx) .unwrap(); fence(Ordering::SeqCst); @@ -230,7 +236,7 @@ impl VirtQueue { // read barrier fence(Ordering::SeqCst); - self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read().unwrap() + self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read_once().unwrap() } /// The number of free descriptors. @@ -247,19 +253,23 @@ impl VirtQueue { loop { let desc = &mut self.descs[head as usize]; // Sets the buffer address and length to 0 - field_ptr!(desc, Descriptor, addr).write(&(0u64)).unwrap(); - field_ptr!(desc, Descriptor, len).write(&(0u32)).unwrap(); + field_ptr!(desc, Descriptor, addr) + .write_once(&(0u64)) + .unwrap(); + field_ptr!(desc, Descriptor, len) + .write_once(&(0u32)) + .unwrap(); self.num_used -= 1; - let flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap(); + let flags: DescFlags = field_ptr!(desc, Descriptor, flags).read_once().unwrap(); if flags.contains(DescFlags::NEXT) { field_ptr!(desc, Descriptor, flags) - .write(&DescFlags::empty()) + .write_once(&DescFlags::empty()) .unwrap(); - head = field_ptr!(desc, Descriptor, next).read().unwrap(); + head = field_ptr!(desc, Descriptor, next).read_once().unwrap(); } else { field_ptr!(desc, Descriptor, next) - .write(&origin_free_head) + .write_once(&origin_free_head) .unwrap(); break; } @@ -280,8 +290,8 @@ impl VirtQueue { ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8); ptr.cast::() }; - let index = field_ptr!(&element_ptr, UsedElem, id).read().unwrap(); - let len = field_ptr!(&element_ptr, UsedElem, len).read().unwrap(); + let index = field_ptr!(&element_ptr, UsedElem, id).read_once().unwrap(); + let len = field_ptr!(&element_ptr, UsedElem, len).read_once().unwrap(); self.recycle_descriptors(index as u16); self.last_used_idx = self.last_used_idx.wrapping_add(1); @@ -304,8 +314,8 @@ impl VirtQueue { ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8); ptr.cast::() }; - let index = field_ptr!(&element_ptr, UsedElem, id).read().unwrap(); - let len = field_ptr!(&element_ptr, UsedElem, len).read().unwrap(); + let index = field_ptr!(&element_ptr, UsedElem, id).read_once().unwrap(); + let len = field_ptr!(&element_ptr, UsedElem, len).read_once().unwrap(); if index as u16 != token { return Err(QueueError::WrongToken); @@ -326,7 +336,7 @@ impl VirtQueue { pub fn should_notify(&self) -> bool { // read barrier fence(Ordering::SeqCst); - let flags = field_ptr!(&self.used, UsedRing, flags).read().unwrap(); + let flags = field_ptr!(&self.used, UsedRing, flags).read_once().unwrap(); flags & 0x0001u16 == 0u16 } @@ -353,10 +363,10 @@ fn set_dma_buf(desc_ptr: &DescriptorPtr, buf: &T) { debug_assert_ne!(buf.len(), 0); let daddr = buf.daddr(); field_ptr!(desc_ptr, Descriptor, addr) - .write(&(daddr as u64)) + .write_once(&(daddr as u64)) .unwrap(); field_ptr!(desc_ptr, Descriptor, len) - .write(&(buf.len() as u32)) + .write_once(&(buf.len() as u32)) .unwrap(); } diff --git a/kernel/libs/aster-util/src/safe_ptr.rs b/kernel/libs/aster-util/src/safe_ptr.rs index 0444a9f5d..6167b01d2 100644 --- a/kernel/libs/aster-util/src/safe_ptr.rs +++ b/kernel/libs/aster-util/src/safe_ptr.rs @@ -7,7 +7,7 @@ use aster_rights_proc::require; use inherit_methods_macro::inherit_methods; pub use ostd::Pod; use ostd::{ - mm::{Daddr, DmaStream, HasDaddr, HasPaddr, Paddr, VmIo}, + mm::{Daddr, DmaStream, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce}, Result, }; pub use typeflags_util::SetContain; @@ -324,6 +324,28 @@ impl SafePtr> { } } +impl SafePtr> { + /// Reads the value from the pointer using one non-tearing instruction. + /// + /// # Access rights + /// + /// This method requires the `Read` right. + #[require(R > Read)] + pub fn read_once(&self) -> Result { + self.vm_obj.read_once(self.offset) + } + + /// Overwrites the value at the pointer using one non-tearing instruction. + /// + /// # Access rights + /// + /// This method requires the `Write` right. + #[require(R > Write)] + pub fn write_once(&self, val: &T) -> Result<()> { + self.vm_obj.write_once(self.offset, val) + } +} + impl HasDaddr for SafePtr { fn daddr(&self) -> Daddr { self.offset + self.vm_obj.daddr()