Fix accesses to VirtIO queue DMA

This commit is contained in:
Ruihan Li
2024-08-06 00:06:02 +10:00
committed by Tate, Hongliang Tian
parent 3deff2e842
commit b1ea422efa
2 changed files with 58 additions and 26 deletions

View File

@ -129,17 +129,21 @@ impl VirtQueue {
let mut desc = descs.get(i as usize).unwrap().clone(); let mut desc = descs.get(i as usize).unwrap().clone();
let next_i = i + 1; let next_i = i + 1;
if next_i != size { if next_i != size {
field_ptr!(&desc, Descriptor, next).write(&next_i).unwrap(); field_ptr!(&desc, Descriptor, next)
.write_once(&next_i)
.unwrap();
desc.add(1); desc.add(1);
descs.push(desc); descs.push(desc);
} else { } else {
field_ptr!(&desc, Descriptor, next).write(&(0u16)).unwrap(); field_ptr!(&desc, Descriptor, next)
.write_once(&(0u16))
.unwrap();
} }
} }
let notify = transport.get_notify_ptr(idx).unwrap(); let notify = transport.get_notify_ptr(idx).unwrap();
field_ptr!(&avail_ring_ptr, AvailRing, flags) field_ptr!(&avail_ring_ptr, AvailRing, flags)
.write(&(0u16)) .write_once(&(0u16))
.unwrap(); .unwrap();
Ok(VirtQueue { Ok(VirtQueue {
descs, descs,
@ -177,10 +181,10 @@ impl VirtQueue {
let desc = &self.descs[self.free_head as usize]; let desc = &self.descs[self.free_head as usize];
set_dma_buf(&desc.borrow_vm().restrict::<TRights![Write, Dup]>(), *input); set_dma_buf(&desc.borrow_vm().restrict::<TRights![Write, Dup]>(), *input);
field_ptr!(desc, Descriptor, flags) field_ptr!(desc, Descriptor, flags)
.write(&DescFlags::NEXT) .write_once(&DescFlags::NEXT)
.unwrap(); .unwrap();
last = self.free_head; last = self.free_head;
self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); self.free_head = field_ptr!(desc, Descriptor, next).read_once().unwrap();
} }
for output in outputs.iter() { for output in outputs.iter() {
let desc = &mut self.descs[self.free_head as usize]; let desc = &mut self.descs[self.free_head as usize];
@ -189,17 +193,19 @@ impl VirtQueue {
*output, *output,
); );
field_ptr!(desc, Descriptor, flags) field_ptr!(desc, Descriptor, flags)
.write(&(DescFlags::NEXT | DescFlags::WRITE)) .write_once(&(DescFlags::NEXT | DescFlags::WRITE))
.unwrap(); .unwrap();
last = self.free_head; last = self.free_head;
self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap(); self.free_head = field_ptr!(desc, Descriptor, next).read_once().unwrap();
} }
// set last_elem.next = NULL // set last_elem.next = NULL
{ {
let desc = &mut self.descs[last as usize]; let desc = &mut self.descs[last as usize];
let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap(); let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read_once().unwrap();
flags.remove(DescFlags::NEXT); flags.remove(DescFlags::NEXT);
field_ptr!(desc, Descriptor, flags).write(&flags).unwrap(); field_ptr!(desc, Descriptor, flags)
.write_once(&flags)
.unwrap();
} }
self.num_used += (inputs.len() + outputs.len()) as u16; self.num_used += (inputs.len() + outputs.len()) as u16;
@ -210,7 +216,7 @@ impl VirtQueue {
field_ptr!(&self.avail, AvailRing, ring); field_ptr!(&self.avail, AvailRing, ring);
let mut ring_slot_ptr = ring_ptr.cast::<u16>(); let mut ring_slot_ptr = ring_ptr.cast::<u16>();
ring_slot_ptr.add(avail_slot as usize); ring_slot_ptr.add(avail_slot as usize);
ring_slot_ptr.write(&head).unwrap(); ring_slot_ptr.write_once(&head).unwrap();
} }
// write barrier // write barrier
fence(Ordering::SeqCst); fence(Ordering::SeqCst);
@ -218,7 +224,7 @@ impl VirtQueue {
// increase head of avail ring // increase head of avail ring
self.avail_idx = self.avail_idx.wrapping_add(1); self.avail_idx = self.avail_idx.wrapping_add(1);
field_ptr!(&self.avail, AvailRing, idx) field_ptr!(&self.avail, AvailRing, idx)
.write(&self.avail_idx) .write_once(&self.avail_idx)
.unwrap(); .unwrap();
fence(Ordering::SeqCst); fence(Ordering::SeqCst);
@ -230,7 +236,7 @@ impl VirtQueue {
// read barrier // read barrier
fence(Ordering::SeqCst); fence(Ordering::SeqCst);
self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read().unwrap() self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read_once().unwrap()
} }
/// The number of free descriptors. /// The number of free descriptors.
@ -247,19 +253,23 @@ impl VirtQueue {
loop { loop {
let desc = &mut self.descs[head as usize]; let desc = &mut self.descs[head as usize];
// Sets the buffer address and length to 0 // Sets the buffer address and length to 0
field_ptr!(desc, Descriptor, addr).write(&(0u64)).unwrap(); field_ptr!(desc, Descriptor, addr)
field_ptr!(desc, Descriptor, len).write(&(0u32)).unwrap(); .write_once(&(0u64))
.unwrap();
field_ptr!(desc, Descriptor, len)
.write_once(&(0u32))
.unwrap();
self.num_used -= 1; self.num_used -= 1;
let flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap(); let flags: DescFlags = field_ptr!(desc, Descriptor, flags).read_once().unwrap();
if flags.contains(DescFlags::NEXT) { if flags.contains(DescFlags::NEXT) {
field_ptr!(desc, Descriptor, flags) field_ptr!(desc, Descriptor, flags)
.write(&DescFlags::empty()) .write_once(&DescFlags::empty())
.unwrap(); .unwrap();
head = field_ptr!(desc, Descriptor, next).read().unwrap(); head = field_ptr!(desc, Descriptor, next).read_once().unwrap();
} else { } else {
field_ptr!(desc, Descriptor, next) field_ptr!(desc, Descriptor, next)
.write(&origin_free_head) .write_once(&origin_free_head)
.unwrap(); .unwrap();
break; break;
} }
@ -280,8 +290,8 @@ impl VirtQueue {
ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8); ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8);
ptr.cast::<UsedElem>() ptr.cast::<UsedElem>()
}; };
let index = field_ptr!(&element_ptr, UsedElem, id).read().unwrap(); let index = field_ptr!(&element_ptr, UsedElem, id).read_once().unwrap();
let len = field_ptr!(&element_ptr, UsedElem, len).read().unwrap(); let len = field_ptr!(&element_ptr, UsedElem, len).read_once().unwrap();
self.recycle_descriptors(index as u16); self.recycle_descriptors(index as u16);
self.last_used_idx = self.last_used_idx.wrapping_add(1); self.last_used_idx = self.last_used_idx.wrapping_add(1);
@ -304,8 +314,8 @@ impl VirtQueue {
ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8); ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8);
ptr.cast::<UsedElem>() ptr.cast::<UsedElem>()
}; };
let index = field_ptr!(&element_ptr, UsedElem, id).read().unwrap(); let index = field_ptr!(&element_ptr, UsedElem, id).read_once().unwrap();
let len = field_ptr!(&element_ptr, UsedElem, len).read().unwrap(); let len = field_ptr!(&element_ptr, UsedElem, len).read_once().unwrap();
if index as u16 != token { if index as u16 != token {
return Err(QueueError::WrongToken); return Err(QueueError::WrongToken);
@ -326,7 +336,7 @@ impl VirtQueue {
pub fn should_notify(&self) -> bool { pub fn should_notify(&self) -> bool {
// read barrier // read barrier
fence(Ordering::SeqCst); fence(Ordering::SeqCst);
let flags = field_ptr!(&self.used, UsedRing, flags).read().unwrap(); let flags = field_ptr!(&self.used, UsedRing, flags).read_once().unwrap();
flags & 0x0001u16 == 0u16 flags & 0x0001u16 == 0u16
} }
@ -353,10 +363,10 @@ fn set_dma_buf<T: DmaBuf>(desc_ptr: &DescriptorPtr, buf: &T) {
debug_assert_ne!(buf.len(), 0); debug_assert_ne!(buf.len(), 0);
let daddr = buf.daddr(); let daddr = buf.daddr();
field_ptr!(desc_ptr, Descriptor, addr) field_ptr!(desc_ptr, Descriptor, addr)
.write(&(daddr as u64)) .write_once(&(daddr as u64))
.unwrap(); .unwrap();
field_ptr!(desc_ptr, Descriptor, len) field_ptr!(desc_ptr, Descriptor, len)
.write(&(buf.len() as u32)) .write_once(&(buf.len() as u32))
.unwrap(); .unwrap();
} }

View File

@ -7,7 +7,7 @@ use aster_rights_proc::require;
use inherit_methods_macro::inherit_methods; use inherit_methods_macro::inherit_methods;
pub use ostd::Pod; pub use ostd::Pod;
use ostd::{ use ostd::{
mm::{Daddr, DmaStream, HasDaddr, HasPaddr, Paddr, VmIo}, mm::{Daddr, DmaStream, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce},
Result, Result,
}; };
pub use typeflags_util::SetContain; pub use typeflags_util::SetContain;
@ -324,6 +324,28 @@ impl<T: Pod, M: VmIo, R: TRights> SafePtr<T, M, TRightSet<R>> {
} }
} }
impl<T: PodOnce, M: VmIoOnce, R: TRights> SafePtr<T, M, TRightSet<R>> {
/// Reads the value from the pointer using one non-tearing instruction.
///
/// # Access rights
///
/// This method requires the `Read` right.
#[require(R > Read)]
pub fn read_once(&self) -> Result<T> {
self.vm_obj.read_once(self.offset)
}
/// Overwrites the value at the pointer using one non-tearing instruction.
///
/// # Access rights
///
/// This method requires the `Write` right.
#[require(R > Write)]
pub fn write_once(&self, val: &T) -> Result<()> {
self.vm_obj.write_once(self.offset, val)
}
}
impl<T, M: HasDaddr, R> HasDaddr for SafePtr<T, M, R> { impl<T, M: HasDaddr, R> HasDaddr for SafePtr<T, M, R> {
fn daddr(&self) -> Daddr { fn daddr(&self) -> Daddr {
self.offset + self.vm_obj.daddr() self.offset + self.vm_obj.daddr()