Support calling from inside via vsock

This commit is contained in:
Anmin Liu
2024-05-06 15:00:48 +00:00
committed by Tate, Hongliang Tian
parent 48f69c25a9
commit 60dd17fdd3
24 changed files with 582 additions and 558 deletions

View File

@ -1,22 +1,89 @@
// SPDX-License-Identifier: MPL-2.0
use align_ext::AlignExt;
use bytes::BytesMut;
use alloc::{collections::LinkedList, sync::Arc};
use align_ext::AlignExt;
use aster_frame::{
sync::SpinLock,
vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmReader, VmWriter, PAGE_SIZE},
};
use aster_network::dma_pool::{DmaPool, DmaSegment};
use pod::Pod;
use spin::Once;
pub struct TxBuffer {
dma_stream: DmaStream,
nbytes: usize,
}
impl TxBuffer {
pub fn new<H: Pod>(header: &H, packet: &[u8]) -> Self {
let header = header.as_bytes();
let nbytes = header.len() + packet.len();
let dma_stream = if let Some(stream) = get_tx_stream_from_pool(nbytes) {
stream
} else {
let segment = {
let nframes = (nbytes.align_up(PAGE_SIZE)) / PAGE_SIZE;
VmAllocOptions::new(nframes).alloc_contiguous().unwrap()
};
DmaStream::map(segment, DmaDirection::ToDevice, false).unwrap()
};
let mut writer = dma_stream.writer().unwrap();
writer.write(&mut VmReader::from(header));
writer.write(&mut VmReader::from(packet));
let tx_buffer = Self { dma_stream, nbytes };
tx_buffer.sync();
tx_buffer
}
pub fn writer(&self) -> VmWriter<'_> {
self.dma_stream.writer().unwrap().limit(self.nbytes)
}
fn sync(&self) {
self.dma_stream.sync(0..self.nbytes).unwrap();
}
pub fn nbytes(&self) -> usize {
self.nbytes
}
}
impl HasDaddr for TxBuffer {
fn daddr(&self) -> Daddr {
self.dma_stream.daddr()
}
}
impl Drop for TxBuffer {
fn drop(&mut self) {
TX_BUFFER_POOL
.get()
.unwrap()
.lock_irq_disabled()
.push_back(self.dma_stream.clone());
}
}
/// Buffer for receive packet
#[derive(Debug)]
pub struct RxBuffer {
/// Packet Buffer, length align 8.
buf: BytesMut,
/// Packet len
segment: DmaSegment,
header_len: usize,
packet_len: usize,
}
impl RxBuffer {
pub fn new(len: usize) -> Self {
let len = len.align_up(8);
let buf = BytesMut::zeroed(len);
Self { buf, packet_len: 0 }
pub fn new(header_len: usize) -> Self {
assert!(header_len <= RX_BUFFER_LEN);
let segment = RX_BUFFER_POOL.get().unwrap().alloc_segment().unwrap();
Self {
segment,
header_len,
packet_len: 0,
}
}
pub const fn packet_len(&self) -> usize {
@ -24,56 +91,70 @@ impl RxBuffer {
}
pub fn set_packet_len(&mut self, packet_len: usize) {
assert!(self.header_len + packet_len <= RX_BUFFER_LEN);
self.packet_len = packet_len;
}
pub fn buf(&self) -> &[u8] {
&self.buf
pub fn packet(&self) -> VmReader<'_> {
self.segment
.sync(self.header_len..self.header_len + self.packet_len)
.unwrap();
self.segment
.reader()
.unwrap()
.skip(self.header_len)
.limit(self.packet_len)
}
pub fn buf_mut(&mut self) -> &mut [u8] {
&mut self.buf
pub fn buf(&self) -> VmReader<'_> {
self.segment
.sync(0..self.header_len + self.packet_len)
.unwrap();
self.segment
.reader()
.unwrap()
.limit(self.header_len + self.packet_len)
}
pub const fn buf_len(&self) -> usize {
self.segment.size()
}
}
/// Buffer for transmit packet
#[derive(Debug)]
pub struct TxBuffer {
buf: BytesMut,
impl HasDaddr for RxBuffer {
fn daddr(&self) -> Daddr {
self.segment.daddr()
}
}
impl TxBuffer {
pub fn with_len(buf_len: usize) -> Self {
Self {
buf: BytesMut::zeroed(buf_len),
pub const RX_BUFFER_LEN: usize = 4096;
static RX_BUFFER_POOL: Once<Arc<DmaPool>> = Once::new();
static TX_BUFFER_POOL: Once<SpinLock<LinkedList<DmaStream>>> = Once::new();
fn get_tx_stream_from_pool(nbytes: usize) -> Option<DmaStream> {
let mut pool = TX_BUFFER_POOL.get().unwrap().lock_irq_disabled();
let mut cursor = pool.cursor_front_mut();
while let Some(current) = cursor.current() {
if current.nbytes() >= nbytes {
return cursor.remove_current();
}
cursor.move_next();
}
pub fn new(buf: &[u8]) -> Self {
Self {
buf: BytesMut::from(buf),
}
}
pub fn buf(&self) -> &[u8] {
&self.buf
}
pub fn buf_mut(&mut self) -> &mut [u8] {
&mut self.buf
}
None
}
/// Buffer for event buffer
#[derive(Debug)]
pub struct EventBuffer {
id: u32,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, Default)]
#[allow(non_camel_case_types)]
pub enum EventIDType {
#[default]
VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0,
pub fn init() {
const POOL_INIT_SIZE: usize = 32;
const POOL_HIGH_WATERMARK: usize = 64;
RX_BUFFER_POOL.call_once(|| {
DmaPool::new(
RX_BUFFER_LEN,
POOL_INIT_SIZE,
POOL_HIGH_WATERMARK,
DmaDirection::FromDevice,
false,
)
});
TX_BUFFER_POOL.call_once(|| SpinLock::new(LinkedList::new()));
}

View File

@ -28,7 +28,7 @@
use super::{
error::SocketError,
header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr},
header::{VirtioVsockHdr, VirtioVsockOp, VsockDeviceAddr},
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -74,9 +74,9 @@ pub enum VsockEventType {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct VsockEvent {
/// The source of the event, i.e. the peer who sent it.
pub source: VsockAddr,
pub source: VsockDeviceAddr,
/// The destination of the event, i.e. the CID and port on our side.
pub destination: VsockAddr,
pub destination: VsockDeviceAddr,
/// The peer's buffer status for the connection.
pub buffer_status: VsockBufferStatus,
/// The type of event.
@ -143,7 +143,7 @@ impl VsockEvent {
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectionInfo {
pub dst: VsockAddr,
pub dst: VsockDeviceAddr,
pub src_port: u32,
/// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
/// bytes it has allocated for packet bodies.
@ -166,7 +166,7 @@ pub struct ConnectionInfo {
}
impl ConnectionInfo {
pub fn new(destination: VsockAddr, src_port: u32) -> Self {
pub fn new(destination: VsockDeviceAddr, src_port: u32) -> Self {
Self {
dst: destination,
src_port,

View File

@ -1,15 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec};
use core::{fmt::Debug, hint::spin_loop};
use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
use core::{fmt::Debug, hint::spin_loop, mem::size_of};
use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame};
use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame, vm::VmWriter};
use aster_util::{field_ptr, slot_vec::SlotVec};
use log::debug;
use pod::Pod;
use super::{
buffer::RxBuffer,
buffer::{RxBuffer, RX_BUFFER_LEN},
config::{VirtioVsockConfig, VsockFeatures},
connect::{ConnectionInfo, VsockEvent},
error::SocketError,
@ -18,7 +18,7 @@ use super::{
};
use crate::{
device::{
socket::{handle_recv_irq, register_device},
socket::{buffer::TxBuffer, handle_recv_irq, register_device},
VirtioDeviceError,
},
queue::{QueueError, VirtQueue},
@ -30,9 +30,6 @@ const QUEUE_RECV: u16 = 0;
const QUEUE_SEND: u16 = 1;
const QUEUE_EVENT: u16 = 2;
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than `size_of::<VirtioVsockHdr>()`.
const RX_BUFFER_SIZE: usize = 512;
/// Vsock device driver
pub struct SocketDevice {
config: VirtioVsockConfig,
@ -71,8 +68,8 @@ impl SocketDevice {
// Allocate and add buffers for the RX queue.
let mut rx_buffers = SlotVec::new();
for i in 0..QUEUE_SIZE {
let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE);
let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
let rx_buffer = RxBuffer::new(size_of::<VirtioVsockHdr>());
let token = recv_queue.add_dma_buf(&[], &[&rx_buffer])?;
assert_eq!(i, token);
assert_eq!(rx_buffers.put(rx_buffer) as u16, i);
}
@ -187,7 +184,10 @@ impl SocketDevice {
header: &VirtioVsockHdr,
buffer: &[u8],
) -> Result<(), SocketError> {
let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
debug!("buffer in send_packet_to_tx_queue: {:?}", buffer);
let tx_buffer = TxBuffer::new(header, buffer);
let token = self.send_queue.add_dma_buf(&[&tx_buffer], &[])?;
if self.send_queue.should_notify() {
self.send_queue.notify();
@ -198,9 +198,13 @@ impl SocketDevice {
spin_loop();
}
self.send_queue.pop_used()?;
debug!("buffer in send_packet_to_tx_queue: {:?}", buffer);
// Pop out the buffer, so we can reuse the send queue further
let (pop_token, _) = self.send_queue.pop_used()?;
debug_assert!(pop_token == token);
if pop_token != token {
return Err(SocketError::QueueError(QueueError::WrongToken));
}
debug!("send packet succeeds");
Ok(())
}
@ -223,6 +227,7 @@ impl SocketDevice {
if !connection_info.has_pending_credit_request {
self.credit_request(connection_info)?;
connection_info.has_pending_credit_request = true;
//TODO check if the update needed
}
Err(SocketError::InsufficientBufferSpaceInPeer)
}
@ -261,9 +266,13 @@ impl SocketDevice {
.rx_buffers
.remove(token as usize)
.ok_or(QueueError::WrongToken)?;
rx_buffer.set_packet_len(RX_BUFFER_SIZE);
rx_buffer.set_packet_len(len as usize);
let (header, payload) = read_header_and_body(rx_buffer.buf())?;
let mut buf_reader = rx_buffer.buf();
let mut temp_buffer = vec![0u8; buf_reader.remain()];
buf_reader.read(&mut VmWriter::from(&mut temp_buffer as &mut [u8]));
let (header, payload) = read_header_and_body(&temp_buffer)?;
// The length written should be equal to len(header)+len(packet)
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
debug!("Received packet {:?}. Op {:?}", header, header.op());
@ -285,15 +294,15 @@ impl SocketDevice {
if !self.recv_queue.can_pop() {
return Ok(None);
}
let mut body = RxBuffer::new(RX_BUFFER_SIZE);
let header = self.receive(body.buf_mut())?;
let mut body = vec![0u8; RX_BUFFER_LEN];
let header = self.receive(&mut body)?;
VsockEvent::from_header(&header).and_then(|event| handler(event, body.buf()))
VsockEvent::from_header(&header).and_then(|event| handler(event, &body))
}
/// Add a used rx buffer to recv queue,@index is only to check the correctness
fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> {
let token = self.recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
fn add_rx_buffer(&mut self, rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> {
let token = self.recv_queue.add_dma_buf(&[], &[&rx_buffer])?;
assert_eq!(index, token);
assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none());
if self.recv_queue.should_notify() {

View File

@ -35,7 +35,7 @@ pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::<VirtioVsockHdr>();
/// Socket address.
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct VsockAddr {
pub struct VsockDeviceAddr {
/// Context Identifier.
pub cid: u64,
/// Port number.
@ -93,15 +93,15 @@ impl VirtioVsockHdr {
VirtioVsockOp::try_from(self.op).map_err(|err| err.into())
}
pub fn source(&self) -> VsockAddr {
VsockAddr {
pub fn source(&self) -> VsockDeviceAddr {
VsockDeviceAddr {
cid: self.src_cid,
port: self.src_port,
}
}
pub fn destination(&self) -> VsockAddr {
VsockAddr {
pub fn destination(&self) -> VsockDeviceAddr {
VsockDeviceAddr {
cid: self.dst_cid,
port: self.dst_port,
}

View File

@ -1,6 +1,6 @@
// SPDX-License-Identifier: MPL-2.0
//! This mod is modified from virtio-drivers project.
// ! #![feature(linked_list_cursors)]
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
use aster_frame::sync::SpinLock;
@ -61,6 +61,7 @@ pub fn handle_recv_irq(name: &str) {
pub fn init() {
VSOCK_DEVICE_TABLE.call_once(|| SpinLock::new(BTreeMap::new()));
buffer::init();
}
type VsockDeviceIrqHandlerListRef = Arc<SpinLock<Vec<Arc<dyn VsockDeviceIrqHandler>>>>;

View File

@ -3,6 +3,8 @@
use aster_frame::mm::{DmaCoherent, DmaStream, DmaStreamSlice, HasDaddr};
use aster_network::{DmaSegment, RxBuffer, TxBuffer};
use crate::device;
/// A DMA-capable buffer.
///
/// Any type implements this trait should also implements `HasDaddr` trait,
@ -48,3 +50,15 @@ impl DmaBuf for RxBuffer {
self.buf_len()
}
}
impl DmaBuf for device::socket::buffer::TxBuffer {
fn len(&self) -> usize {
self.nbytes()
}
}
impl DmaBuf for device::socket::buffer::RxBuffer {
fn len(&self) -> usize {
self.buf_len()
}
}

View File

@ -6,6 +6,7 @@
#![allow(dead_code)]
#![feature(trait_alias)]
#![feature(fn_traits)]
#![feature(linked_list_cursors)]
extern crate alloc;