Fix accesses to VirtIO queue DMA

This commit is contained in:
Ruihan Li
2024-08-06 00:06:02 +10:00
committed by Tate, Hongliang Tian
parent 3deff2e842
commit b1ea422efa
2 changed files with 58 additions and 26 deletions

View File

@ -129,17 +129,21 @@ impl VirtQueue {
let mut desc = descs.get(i as usize).unwrap().clone();
let next_i = i + 1;
if next_i != size {
field_ptr!(&desc, Descriptor, next).write(&next_i).unwrap();
field_ptr!(&desc, Descriptor, next)
.write_once(&next_i)
.unwrap();
desc.add(1);
descs.push(desc);
} else {
field_ptr!(&desc, Descriptor, next).write(&(0u16)).unwrap();
field_ptr!(&desc, Descriptor, next)
.write_once(&(0u16))
.unwrap();
}
}
let notify = transport.get_notify_ptr(idx).unwrap();
field_ptr!(&avail_ring_ptr, AvailRing, flags)
.write(&(0u16))
.write_once(&(0u16))
.unwrap();
Ok(VirtQueue {
descs,
@ -177,10 +181,10 @@ impl VirtQueue {
let desc = &self.descs[self.free_head as usize];
set_dma_buf(&desc.borrow_vm().restrict::<TRights![Write, Dup]>(), *input);
field_ptr!(desc, Descriptor, flags)
.write(&DescFlags::NEXT)
.write_once(&DescFlags::NEXT)
.unwrap();
last = self.free_head;
self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap();
self.free_head = field_ptr!(desc, Descriptor, next).read_once().unwrap();
}
for output in outputs.iter() {
let desc = &mut self.descs[self.free_head as usize];
@ -189,17 +193,19 @@ impl VirtQueue {
*output,
);
field_ptr!(desc, Descriptor, flags)
.write(&(DescFlags::NEXT | DescFlags::WRITE))
.write_once(&(DescFlags::NEXT | DescFlags::WRITE))
.unwrap();
last = self.free_head;
self.free_head = field_ptr!(desc, Descriptor, next).read().unwrap();
self.free_head = field_ptr!(desc, Descriptor, next).read_once().unwrap();
}
// set last_elem.next = NULL
{
let desc = &mut self.descs[last as usize];
let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap();
let mut flags: DescFlags = field_ptr!(desc, Descriptor, flags).read_once().unwrap();
flags.remove(DescFlags::NEXT);
field_ptr!(desc, Descriptor, flags).write(&flags).unwrap();
field_ptr!(desc, Descriptor, flags)
.write_once(&flags)
.unwrap();
}
self.num_used += (inputs.len() + outputs.len()) as u16;
@ -210,7 +216,7 @@ impl VirtQueue {
field_ptr!(&self.avail, AvailRing, ring);
let mut ring_slot_ptr = ring_ptr.cast::<u16>();
ring_slot_ptr.add(avail_slot as usize);
ring_slot_ptr.write(&head).unwrap();
ring_slot_ptr.write_once(&head).unwrap();
}
// write barrier
fence(Ordering::SeqCst);
@ -218,7 +224,7 @@ impl VirtQueue {
// increase head of avail ring
self.avail_idx = self.avail_idx.wrapping_add(1);
field_ptr!(&self.avail, AvailRing, idx)
.write(&self.avail_idx)
.write_once(&self.avail_idx)
.unwrap();
fence(Ordering::SeqCst);
@ -230,7 +236,7 @@ impl VirtQueue {
// read barrier
fence(Ordering::SeqCst);
self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read().unwrap()
self.last_used_idx != field_ptr!(&self.used, UsedRing, idx).read_once().unwrap()
}
/// The number of free descriptors.
@ -247,19 +253,23 @@ impl VirtQueue {
loop {
let desc = &mut self.descs[head as usize];
// Sets the buffer address and length to 0
field_ptr!(desc, Descriptor, addr).write(&(0u64)).unwrap();
field_ptr!(desc, Descriptor, len).write(&(0u32)).unwrap();
field_ptr!(desc, Descriptor, addr)
.write_once(&(0u64))
.unwrap();
field_ptr!(desc, Descriptor, len)
.write_once(&(0u32))
.unwrap();
self.num_used -= 1;
let flags: DescFlags = field_ptr!(desc, Descriptor, flags).read().unwrap();
let flags: DescFlags = field_ptr!(desc, Descriptor, flags).read_once().unwrap();
if flags.contains(DescFlags::NEXT) {
field_ptr!(desc, Descriptor, flags)
.write(&DescFlags::empty())
.write_once(&DescFlags::empty())
.unwrap();
head = field_ptr!(desc, Descriptor, next).read().unwrap();
head = field_ptr!(desc, Descriptor, next).read_once().unwrap();
} else {
field_ptr!(desc, Descriptor, next)
.write(&origin_free_head)
.write_once(&origin_free_head)
.unwrap();
break;
}
@ -280,8 +290,8 @@ impl VirtQueue {
ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8);
ptr.cast::<UsedElem>()
};
let index = field_ptr!(&element_ptr, UsedElem, id).read().unwrap();
let len = field_ptr!(&element_ptr, UsedElem, len).read().unwrap();
let index = field_ptr!(&element_ptr, UsedElem, id).read_once().unwrap();
let len = field_ptr!(&element_ptr, UsedElem, len).read_once().unwrap();
self.recycle_descriptors(index as u16);
self.last_used_idx = self.last_used_idx.wrapping_add(1);
@ -304,8 +314,8 @@ impl VirtQueue {
ptr.byte_add(offset_of!(UsedRing, ring) as usize + last_used_slot as usize * 8);
ptr.cast::<UsedElem>()
};
let index = field_ptr!(&element_ptr, UsedElem, id).read().unwrap();
let len = field_ptr!(&element_ptr, UsedElem, len).read().unwrap();
let index = field_ptr!(&element_ptr, UsedElem, id).read_once().unwrap();
let len = field_ptr!(&element_ptr, UsedElem, len).read_once().unwrap();
if index as u16 != token {
return Err(QueueError::WrongToken);
@ -326,7 +336,7 @@ impl VirtQueue {
pub fn should_notify(&self) -> bool {
// read barrier
fence(Ordering::SeqCst);
let flags = field_ptr!(&self.used, UsedRing, flags).read().unwrap();
let flags = field_ptr!(&self.used, UsedRing, flags).read_once().unwrap();
flags & 0x0001u16 == 0u16
}
@ -353,10 +363,10 @@ fn set_dma_buf<T: DmaBuf>(desc_ptr: &DescriptorPtr, buf: &T) {
debug_assert_ne!(buf.len(), 0);
let daddr = buf.daddr();
field_ptr!(desc_ptr, Descriptor, addr)
.write(&(daddr as u64))
.write_once(&(daddr as u64))
.unwrap();
field_ptr!(desc_ptr, Descriptor, len)
.write(&(buf.len() as u32))
.write_once(&(buf.len() as u32))
.unwrap();
}

View File

@ -7,7 +7,7 @@ use aster_rights_proc::require;
use inherit_methods_macro::inherit_methods;
pub use ostd::Pod;
use ostd::{
mm::{Daddr, DmaStream, HasDaddr, HasPaddr, Paddr, VmIo},
mm::{Daddr, DmaStream, HasDaddr, HasPaddr, Paddr, PodOnce, VmIo, VmIoOnce},
Result,
};
pub use typeflags_util::SetContain;
@ -324,6 +324,28 @@ impl<T: Pod, M: VmIo, R: TRights> SafePtr<T, M, TRightSet<R>> {
}
}
impl<T: PodOnce, M: VmIoOnce, R: TRights> SafePtr<T, M, TRightSet<R>> {
/// Reads the value from the pointer using one non-tearing instruction.
///
/// # Access rights
///
/// This method requires the `Read` right.
#[require(R > Read)]
pub fn read_once(&self) -> Result<T> {
self.vm_obj.read_once(self.offset)
}
/// Overwrites the value at the pointer using one non-tearing instruction.
///
/// # Access rights
///
/// This method requires the `Write` right.
#[require(R > Write)]
pub fn write_once(&self, val: &T) -> Result<()> {
self.vm_obj.write_once(self.offset, val)
}
}
impl<T, M: HasDaddr, R> HasDaddr for SafePtr<T, M, R> {
fn daddr(&self) -> Daddr {
self.offset + self.vm_obj.daddr()