Implement vsock driver

This commit is contained in:
Anmin Liu
2024-03-28 02:55:16 +00:00
committed by Tate, Hongliang Tian
parent 39c2e17f75
commit 52f808e315
16 changed files with 1477 additions and 3 deletions

View File

@ -65,6 +65,7 @@ qemu.args = """\
-device virtio-net-pci,netdev=mynet0,disable-legacy=on,disable-modern=off \ -device virtio-net-pci,netdev=mynet0,disable-legacy=on,disable-modern=off \
-device virtio-keyboard-pci,disable-legacy=on,disable-modern=off \ -device virtio-keyboard-pci,disable-legacy=on,disable-modern=off \
-device virtio-blk-pci,bus=pcie.0,addr=0x6,drive=x0,disable-legacy=on,disable-modern=off \ -device virtio-blk-pci,bus=pcie.0,addr=0x6,drive=x0,disable-legacy=on,disable-modern=off \
-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off \
-drive file=fs.img,if=none,format=raw,id=x0 \ -drive file=fs.img,if=none,format=raw,id=x0 \
-netdev user,id=mynet0,hostfwd=tcp::10027-:22,hostfwd=tcp::54136-:8090 \ -netdev user,id=mynet0,hostfwd=tcp::10027-:22,hostfwd=tcp::54136-:8090 \
-chardev stdio,id=mux,mux=on,logfile=./$(date '+%Y-%m-%dT%H%M%S').log \ -chardev stdio,id=mux,mux=on,logfile=./$(date '+%Y-%m-%dT%H%M%S').log \
@ -72,4 +73,4 @@ qemu.args = """\
-device virtconsole,chardev=mux \ -device virtconsole,chardev=mux \
-monitor chardev:mux \ -monitor chardev:mux \
-serial chardev:mux \ -serial chardev:mux \
""" """

View File

@ -6,6 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
# FIXME: used for test in driver mod
component = {path="../libs/comp-sys/component"}
aster-frame = { path = "../../framework/aster-frame" } aster-frame = { path = "../../framework/aster-frame" }
align_ext = { path = "../../framework/libs/align_ext" } align_ext = { path = "../../framework/libs/align_ext" }
pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" } pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" }

View File

@ -1,10 +1,87 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use log::info; use log::{info, debug};
use alloc::string::ToString;
use component::ComponentInitError;
use aster_virtio::{self, device::socket::{header::VsockAddr, device::SocketDevice, manager::VsockConnectionManager, DEVICE_NAME}};
pub fn init() { pub fn init() {
// print all the input device to make sure input crate will compile // print all the input device to make sure input crate will compile
for (name, _) in aster_input::all_devices() { for (name, _) in aster_input::all_devices() {
info!("Found Input device, name:{}", name); info!("Found Input device, name:{}", name);
} }
// let _ = socket_device_client_test();
// let _ = socket_device_server_test();
}
fn socket_device_client_test() -> Result<(),ComponentInitError> {
let host_cid = 2;
let guest_cid = 3;
let host_port = 1234;
let guest_port = 4321;
let host_address = VsockAddr {
cid: host_cid,
port: host_port,
};
let hello_from_guest = "Hello from guest";
let hello_from_host = "Hello from host";
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
assert_eq!(device.lock().guest_cid(),guest_cid);
let mut socket = VsockConnectionManager::new(device);
socket.connect(host_address, guest_port).unwrap();
socket.wait_for_event().unwrap(); // wait for connect response
socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap();
debug!("The buffer {:?} is sent, start receiving",hello_from_guest.as_bytes());
socket.wait_for_event().unwrap(); // wait for recv
let mut buffer = [0u8; 64];
let event = socket.recv(host_address, guest_port,&mut buffer).unwrap();
assert_eq!(
&buffer[0..hello_from_host.len()],
hello_from_host.as_bytes()
);
socket.force_close(host_address,guest_port).unwrap();
debug!("The final event: {:?}",event);
Ok(())
}
pub fn socket_device_server_test() -> Result<(),ComponentInitError>{
let host_cid = 2;
let guest_cid = 3;
let host_port = 1234;
let guest_port = 4321;
let host_address = VsockAddr {
cid: host_cid,
port: host_port,
};
let hello_from_guest = "Hello from guest";
let hello_from_host = "Hello from host";
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
assert_eq!(device.lock().guest_cid(),guest_cid);
let mut socket = VsockConnectionManager::new(device);
socket.listen(4321);
socket.wait_for_event().unwrap(); // wait for connect request
socket.wait_for_event().unwrap(); // wait for recv
let mut buffer = [0u8; 64];
let event = socket.recv(host_address, guest_port,&mut buffer).unwrap();
assert_eq!(
&buffer[0..hello_from_host.len()],
hello_from_host.as_bytes()
);
debug!("The buffer {:?} is received, start sending {:?}", &buffer[0..hello_from_host.len()],hello_from_guest.as_bytes());
socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap();
socket.shutdown(host_address,guest_port).unwrap();
let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown
debug!("The final event: {:?}",event);
Ok(())
} }

View File

@ -8,6 +8,7 @@ pub mod block;
pub mod console; pub mod console;
pub mod input; pub mod input;
pub mod network; pub mod network;
pub mod socket;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, TryFromInt)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, TryFromInt)]
#[repr(u8)] #[repr(u8)]

View File

@ -0,0 +1,99 @@
//! This module is adapted from network/buffer.rs
use align_ext::AlignExt;
use bytes::BytesMut;
use pod::Pod;
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
use super::header::VirtioVsockHdr;
/// Buffer for receive packet
#[derive(Debug)]
pub struct RxBuffer {
/// Packet Buffer, length align 8.
buf: BytesMut,
/// Packet len
packet_len: usize,
}
impl RxBuffer {
pub fn new(len: usize) -> Self {
let len = len.align_up(8);
let buf = BytesMut::zeroed(len);
Self { buf, packet_len: 0 }
}
pub const fn packet_len(&self) -> usize {
self.packet_len
}
pub fn set_packet_len(&mut self, packet_len: usize) {
self.packet_len = packet_len;
}
pub fn buf(&self) -> &[u8] {
&self.buf
}
pub fn buf_mut(&mut self) -> &mut [u8] {
&mut self.buf
}
/// Packet payload slice, which is inner buffer excluding VirtioVsockHdr.
pub fn packet(&self) -> &[u8] {
debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len());
&self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len]
}
/// Mutable packet payload slice.
pub fn packet_mut(&mut self) -> &mut [u8] {
debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len());
&mut self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len]
}
pub fn virtio_vsock_header(&self) -> VirtioVsockHdr {
VirtioVsockHdr::from_bytes(&self.buf[..VIRTIO_VSOCK_HDR_LEN])
}
}
/// Buffer for transmit packet
#[derive(Debug)]
pub struct TxBuffer {
buf: BytesMut,
}
impl TxBuffer {
pub fn with_len(buf_len: usize) -> Self {
Self {
buf: BytesMut::zeroed(buf_len),
}
}
pub fn new(buf: &[u8]) -> Self {
Self {
buf: BytesMut::from(buf),
}
}
pub fn buf(&self) -> &[u8] {
&self.buf
}
pub fn buf_mut(&mut self) -> &mut [u8] {
&mut self.buf
}
}
/// Buffer for event buffer
#[derive(Debug)]
pub struct EventBuffer{
id: u32,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, Default)]
#[allow(non_camel_case_types)]
pub enum EventIDType {
#[default]
VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0,
}

View File

@ -0,0 +1,44 @@
use aster_frame::io_mem::IoMem;
use pod::Pod;
use aster_util::{safe_ptr::SafePtr};
use bitflags::bitflags;
use crate::transport::{self, VirtioTransport};
bitflags!{
/// Vsock feature bits since v1.2
/// If no feature bit is set, only stream socket type is supported.
/// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated.
pub struct VsockFeatures: u64 {
const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported.
const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is supported.
}
}
impl VsockFeatures {
pub fn support_features() -> Self {
VsockFeatures::VIRTIO_VSOCK_F_STREAM
}
}
#[derive(Debug, Clone, Copy, Pod)]
#[repr(C)]
pub struct VirtioVsockConfig{
/// The guest_cid field contains the guests context ID, which uniquely identifies
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
///
/// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
/// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
/// So we need to split the u64 guest_cid into two parts.
// read only
pub guest_cid_low: u32,
// read only
pub guest_cid_high: u32,
}
impl VirtioVsockConfig {
pub(crate) fn new(transport: &dyn VirtioTransport) -> SafePtr<Self, IoMem> {
let memory = transport.device_config_memory();
SafePtr::new(memory, 0)
}
}

View File

@ -0,0 +1,187 @@
use log::debug;
use super::{header::{VsockAddr, VirtioVsockHdr, VirtioVsockOp}, error::SocketError};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VsockBufferStatus {
pub buffer_allocation: u32,
pub forward_count: u32,
}
/// The reason why a vsock connection was closed.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DisconnectReason {
/// The peer has either closed the connection in response to our shutdown request, or forcibly
/// closed it of its own accord.
Reset,
/// The peer asked to shut down the connection.
Shutdown,
}
/// Details of the type of an event received from a VirtIO socket.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum VsockEventType {
/// The peer requests to establish a connection with us.
ConnectionRequest,
/// The connection was successfully established.
Connected,
/// The connection was closed.
Disconnected {
/// The reason for the disconnection.
reason: DisconnectReason,
},
/// Data was received on the connection.
Received {
/// The length of the data in bytes.
length: usize,
},
/// The peer requests us to send a credit update.
CreditRequest,
/// The peer just sent us a credit update with nothing else.
CreditUpdate,
}
/// An event received from a VirtIO socket device.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VsockEvent {
/// The source of the event, i.e. the peer who sent it.
pub source: VsockAddr,
/// The destination of the event, i.e. the CID and port on our side.
pub destination: VsockAddr,
/// The peer's buffer status for the connection.
pub buffer_status: VsockBufferStatus,
/// The type of event.
pub event_type: VsockEventType,
}
impl VsockEvent {
/// Returns whether the event matches the given connection.
pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
self.source == connection_info.dst
&& self.destination.cid == guest_cid
&& self.destination.port == connection_info.src_port
}
pub fn from_header(header: &VirtioVsockHdr) -> Result<Self,SocketError> {
let op = header.op()?;
let buffer_status = VsockBufferStatus {
buffer_allocation: header.buf_alloc,
forward_count: header.fwd_cnt,
};
let source = header.source();
let destination = header.destination();
let event_type = match op {
VirtioVsockOp::Request => {
header.check_data_is_empty()?;
VsockEventType::ConnectionRequest
}
VirtioVsockOp::Response => {
header.check_data_is_empty()?;
VsockEventType::Connected
}
VirtioVsockOp::CreditUpdate => {
header.check_data_is_empty()?;
VsockEventType::CreditUpdate
}
VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
header.check_data_is_empty()?;
debug!("Disconnected from the peer");
let reason = if op == VirtioVsockOp::Rst {
DisconnectReason::Reset
} else {
DisconnectReason::Shutdown
};
VsockEventType::Disconnected { reason }
}
VirtioVsockOp::Rw => VsockEventType::Received {
length: header.len() as usize,
},
VirtioVsockOp::CreditRequest => {
header.check_data_is_empty()?;
VsockEventType::CreditRequest
}
VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation),
};
Ok(VsockEvent {
source,
destination,
buffer_status,
event_type,
})
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectionInfo {
pub dst: VsockAddr,
pub src_port: u32,
/// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
/// bytes it has allocated for packet bodies.
peer_buf_alloc: u32,
/// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
/// has finished processing.
peer_fwd_cnt: u32,
/// The number of bytes of packet bodies which we have sent to the peer.
pub tx_cnt: u32,
/// The number of bytes of buffer space we have allocated to receive packet bodies from the
/// peer.
pub buf_alloc: u32,
/// The number of bytes of packet bodies which we have received from the peer and handled.
pub fwd_cnt: u32,
/// Whether we have recently requested credit from the peer.
///
/// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
/// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
pub has_pending_credit_request: bool,
}
impl ConnectionInfo {
pub fn new(destination: VsockAddr, src_port: u32) -> Self {
Self {
dst: destination,
src_port,
..Default::default()
}
}
/// Updates this connection info with the peer buffer allocation and forwarded count from the
/// given event.
pub fn update_for_event(&mut self, event: &VsockEvent) {
self.peer_buf_alloc = event.buffer_status.buffer_allocation;
self.peer_fwd_cnt = event.buffer_status.forward_count;
if let VsockEventType::CreditUpdate = event.event_type {
self.has_pending_credit_request = false;
}
}
/// Increases the forwarded count recorded for this connection by the given number of bytes.
///
/// This should be called once received data has been passed to the client, so there is buffer
/// space available for more.
pub fn done_forwarding(&mut self, length: usize) {
self.fwd_cnt += length as u32;
}
/// Returns the number of bytes of RX buffer space the peer has available to receive packet body
/// data from us.
pub fn peer_free(&self) -> u32 {
self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
}
pub fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
VirtioVsockHdr {
src_cid,
dst_cid: self.dst.cid,
src_port: self.src_port,
dst_port: self.dst.port,
buf_alloc: self.buf_alloc,
fwd_cnt: self.fwd_cnt,
..Default::default()
}
}
}

View File

@ -0,0 +1,327 @@
use core::hint::spin_loop;
use core::fmt::Debug;
use alloc::{vec::Vec, boxed::Box, string::ToString, sync::Arc};
use aster_frame::{offset_of, trap::TrapFrame, sync::SpinLock};
use aster_util::{slot_vec::SlotVec, field_ptr};
use log::debug;
use pod::Pod;
use crate::{queue::{VirtQueue, QueueError}, device::{VirtioDeviceError, socket::{register_device, DEVICE_NAME}}, transport::{VirtioTransport}};
use super::{buffer::RxBuffer, config::{VirtioVsockConfig, VsockFeatures}, connect::{ConnectionInfo, VsockEvent}, header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN}, error::SocketError, VsockDeviceIrqHandler};
const QUEUE_SIZE: u16 = 64;
const QUEUE_RECV: u16 = 0;
const QUEUE_SEND: u16 = 1;
const QUEUE_EVENT: u16 = 2;
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
const RX_BUFFER_SIZE: usize = 512;
/// Low-level driver for a Virtio socket device.
pub struct SocketDevice {
config: VirtioVsockConfig,
guest_cid: u64,
/// Virtqueue to receive packets.
send_queue: VirtQueue,
recv_queue: VirtQueue,
event_queue: VirtQueue,
rx_buffers: SlotVec<RxBuffer>,
transport: Box<dyn VirtioTransport>,
callbacks: Vec<Box<&'static VsockDeviceIrqHandler>>,
}
impl SocketDevice {
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
let guest_cid =
field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low).read().unwrap() as u64
| (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high).read().unwrap() as u64) << 32;
let mut recv_queue = VirtQueue::new(QUEUE_RECV,QUEUE_SIZE,transport.as_mut())
.expect("createing recv queue fails");
let send_queue = VirtQueue::new(QUEUE_SEND,QUEUE_SIZE,transport.as_mut())
.expect("creating send queue fails");
let event_queue = VirtQueue::new(QUEUE_EVENT,QUEUE_SIZE,transport.as_mut())
.expect("creating event queue fails");
// Allocate and add buffers for the RX queue.
let mut rx_buffers = SlotVec::new();
for i in 0..QUEUE_SIZE {
let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE);
let token = recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?;
assert_eq!(i, token);
assert_eq!(rx_buffers.put(rx_buffer) as u16, i);
}
if recv_queue.should_notify() {
debug!("notify receive queue");
recv_queue.notify();
}
let mut device = Self{
config: virtio_vsock_config.read().unwrap(),
guest_cid,
send_queue,
recv_queue,
event_queue,
rx_buffers,
transport,
callbacks: Vec::new(),
};
// Interrupt handler if vsock device config space changes
fn config_space_change(_: &TrapFrame){
debug!("vsock device config space change");
}
// Interrupt handler if vsock device receives some packet.
// TODO: This will be handled by vsock socket layer.
fn handle_vsock_event(_: &TrapFrame){
debug!("Packet received. This will be solved by socket layer");
}
device
.transport
.register_cfg_callback(Box::new(config_space_change))
.unwrap();
device
.transport
.register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false)
.unwrap();
device.transport.finish_init();
register_device(
super::DEVICE_NAME.to_string(),
Arc::new(SpinLock::new(device)),
);
Ok(())
}
/// Returns the CID which has been assigned to this guest.
pub fn guest_cid(&self) -> u64 {
self.guest_cid
}
/// Sends a request to connect to the given destination.
///
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection
/// before sending data.
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Request as u16,
..connection_info.new_header(self.guest_cid)
};
// Sends a header only packet to the TX queue to connect the device to the listening socket
// at the given destination.
self.send_packet_to_tx_queue(&header, &[])
}
/// Accepts the given connection from a peer.
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Response as u16,
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
/// Requests the peer to send us a credit update for the given connection.
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditRequest as u16,
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
/// Tells the peer how much buffer space we have to receive data.
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditUpdate as u16,
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
/// Requests to shut down the connection cleanly.
///
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown.
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Shutdown as u16,
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
/// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rst as u16,
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result<(), SocketError> {
// let (_token, _len) = self.send_queue.add_notify_wait_pop(
// &[header.as_bytes(), buffer],
// &mut [],
// )?;
let _token = self
.send_queue
.add_buf(&[header.as_bytes(), buffer], &[])?;
if self.send_queue.should_notify() {
self.send_queue.notify();
}
// Wait until the buffer is used
while !self.send_queue.can_pop() {
spin_loop();
}
self.send_queue.pop_used()?;
// FORDEBUG
// debug!("buffer in send_packet_to_tx_queue: {:?}",buffer);
Ok(())
}
fn check_peer_buffer_is_sufficient(
&mut self,
connection_info: &mut ConnectionInfo,
buffer_len: usize,
) -> Result<(), SocketError> {
if connection_info.peer_free() as usize >= buffer_len {
Ok(())
} else {
// Request an update of the cached peer credit, if we haven't already done so, and tell
// the caller to try again later.
if !connection_info.has_pending_credit_request {
self.request_credit(connection_info)?;
connection_info.has_pending_credit_request = true;
}
Err(SocketError::InsufficientBufferSpaceInPeer)
}
}
/// Sends the buffer to the destination.
pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result<(), SocketError> {
self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
let len = buffer.len() as u32;
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rw as u16,
len: len,
..connection_info.new_header(self.guest_cid)
};
connection_info.tx_cnt += len;
self.send_packet_to_tx_queue(&header, buffer)
}
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
pub fn poll(&mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>,SocketError>
) -> Result<Option<VsockEvent>, SocketError> {
// Return None if there is no pending packet.
if !self.recv_queue.can_pop(){
return Ok(None);
}
let (token, len) = self.recv_queue.pop_used()?;
let mut buffer = self
.rx_buffers
.remove(token as usize)
.ok_or(QueueError::WrongToken)?;
let header = buffer.virtio_vsock_header();
// The length written should be equal to len(header)+len(packet)
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
buffer.set_packet_len(RX_BUFFER_SIZE);
let head_result = read_header_and_body(&buffer.buf());
let Ok((header,body)) = head_result else {
let ret = match head_result {
Err(e) => Err(e),
_ => Ok(None) //FIXME: this clause is never reached.
};
self.add_rx_buffer(buffer, token)?;
return ret;
};
debug!("Received packet {:?}. Op {:?}", header, header.op());
debug!("body is {:?}",body);
let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
// reuse the buffer and give it back to recv_queue.
self.add_rx_buffer(buffer, token)?;
result
}
/// Add a used rx buffer to recv queue,@index is only to check the correctness
fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> {
let token = self.recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?;
assert_eq!(index,token);
assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none());
if self.recv_queue.should_notify() {
self.recv_queue.notify();
}
Ok(())
}
/// Negotiate features for the device specified bits 0~23
pub(crate) fn negotiate_features(features: u64) -> u64 {
let device_features = VsockFeatures::from_bits_truncate(features);
let supported_features = VsockFeatures::support_features();
let vsock_features = device_features & supported_features;
debug!("features negotiated: {:?}",vsock_features);
vsock_features.bits()
}
}
impl Debug for SocketDevice {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SocketDevice")
.field("config", &self.config)
.field("guest_cid", &self.guest_cid)
.field("send_queue", &self.send_queue)
.field("recv_queue", &self.recv_queue)
.field("event_queue", &self.event_queue)
.field("transport", &self.transport)
.finish()
}
}
fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketError> {
// Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]);
let body_length = header.len() as usize;
// This could fail if the device returns an unreasonably long body length.
let data_end = VIRTIO_VSOCK_HDR_LEN
.checked_add(body_length)
.ok_or(SocketError::InvalidNumber)?;
// This could fail if the device returns a body length longer than the buffer we gave it.
let data = buffer
.get(VIRTIO_VSOCK_HDR_LEN..data_end)
.ok_or(SocketError::BufferTooShort)?;
Ok((header, data))
}

View File

@ -0,0 +1,105 @@
//! This file comes from virtio-drivers project
//! This module contains the error from the VirtIO socket driver.
use core::{fmt, result};
use smoltcp::socket::dhcpv4::Socket;
use crate::queue::QueueError;
/// The error type of VirtIO socket driver.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SocketError {
/// There is an existing connection.
ConnectionExists,
/// Failed to establish the connection.
ConnectionFailed,
/// The device is not connected to any peer.
NotConnected,
/// Peer socket is shutdown.
PeerSocketShutdown,
/// No response received.
NoResponseReceived,
/// The given buffer is shorter than expected.
BufferTooShort,
/// The given buffer for output is shorter than expected.
OutputBufferTooShort(usize),
/// The given buffer has exceeded the maximum buffer size.
BufferTooLong(usize, usize),
/// Unknown operation.
UnknownOperation(u16),
/// Invalid operation,
InvalidOperation,
/// Invalid number.
InvalidNumber,
/// Unexpected data in packet.
UnexpectedDataInPacket,
/// Peer has insufficient buffer space, try again later.
InsufficientBufferSpaceInPeer,
/// Recycled a wrong buffer.
RecycledWrongBuffer,
/// Queue Error
QueueError(SocketQueueError),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum SocketQueueError {
InvalidArgs,
BufferTooSmall,
NotReady,
AlreadyUsed,
WrongToken,
}
impl From<QueueError> for SocketQueueError {
fn from(value: QueueError) -> Self {
match value {
QueueError::InvalidArgs => Self::InvalidArgs,
QueueError::BufferTooSmall => Self::BufferTooSmall,
QueueError::NotReady => Self::NotReady,
QueueError::AlreadyUsed => Self::AlreadyUsed,
QueueError::WrongToken => Self::WrongToken,
}
}
}
impl From<QueueError> for SocketError {
fn from(value: QueueError) -> Self {
Self::QueueError(SocketQueueError::from(value))
}
}
impl fmt::Display for SocketError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::ConnectionExists => write!(
f,
"There is an existing connection. Please close the current connection before attempting to connect again."),
Self::ConnectionFailed => write!(
f, "Failed to establish the connection. The packet sent may have an unknown type value"
),
Self::NotConnected => write!(f, "The device is not connected to any peer. Please connect it to a peer first."),
Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."),
Self::NoResponseReceived => write!(f, "No response received"),
Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"),
Self::BufferTooLong(actual, max) => {
write!(f, "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'")
}
Self::OutputBufferTooShort(expected) => {
write!(f, "The given output buffer is too short. '{expected}' bytes is needed for the output buffer.")
}
Self::UnknownOperation(op) => {
write!(f, "The operation code '{op}' is unknown")
}
Self::InvalidOperation => write!(f, "Invalid operation"),
Self::InvalidNumber => write!(f, "Invalid number"),
Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"),
Self::InsufficientBufferSpaceInPeer => write!(f, "Peer has insufficient buffer space, try again later"),
Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"),
Self::QueueError(_) => write!(f,"Error encounted out of vsock itself!"),
}
}
}
pub type Result<T> = result::Result<T, SocketError>;

View File

@ -0,0 +1,151 @@
use pod::Pod;
use bitflags::bitflags;
use super::error::{self, SocketError};
pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::<VirtioVsockHdr>();
/// Socket address.
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct VsockAddr {
/// Context Identifier.
pub cid: u64,
/// Port number.
pub port: u32,
}
/// VirtioVsock header precedes the payload in each packet.
// #[repr(packed)]
#[repr(C,packed)]
#[derive(Debug, Clone, Copy, Pod)]
pub struct VirtioVsockHdr{
pub src_cid: u64,
pub dst_cid: u64,
pub src_port: u32,
pub dst_port: u32,
pub len: u32,
pub socket_type: u16,
pub op: u16, //TOASK: why mark Pod and can I mark OpType Pod and replace u16 into OpType.
pub flags: u32,
/// Total receive buffer space for this socket. This includes both free and in-use buffers.
pub buf_alloc: u32,
/// Free-running bytes received counter.
pub fwd_cnt: u32,
}
impl Default for VirtioVsockHdr {
fn default() -> Self {
Self {
src_cid: 0,
dst_cid: 0,
src_port: 0,
dst_port: 0,
len: 0,
socket_type: VsockType::Stream as u16,
op: 0,
flags: 0,
buf_alloc: 0,
fwd_cnt: 0,
}
}
}
impl VirtioVsockHdr {
/// Returns the length of the data.
pub fn len(&self) -> u32 {
self.len
}
pub fn op(&self) -> error::Result<VirtioVsockOp> {
self.op.try_into()
}
pub fn source(&self) -> VsockAddr {
VsockAddr{
cid: self.src_cid,
port: self.src_port,
}
}
pub fn destination(&self) -> VsockAddr {
VsockAddr {
cid: self.dst_cid,
port: self.dst_port,
}
}
pub fn check_data_is_empty(&self) -> error::Result<()> {
if self.len() == 0 {
Ok(())
} else {
Err(SocketError::UnexpectedDataInPacket)
}
}
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
#[allow(non_camel_case_types)]
pub enum VirtioVsockOp{
#[default]
Invalid = 0,
/* Connect operations */
Request = 1,
Response = 2,
Rst = 3,
Shutdown = 4,
/* To send payload */
Rw = 5,
/* Tell the peer our credit info */
CreditUpdate = 6,
/* Request the peer to send the credit info to us */
CreditRequest = 7,
}
/// TODO: This could be optimized by upgrading [int_to_c_enum::TryFromIntError] to carrying the invalid int number
impl TryFrom<u16> for VirtioVsockOp {
type Error = SocketError;
fn try_from(v: u16) -> Result<Self, Self::Error> {
let op = match v {
0 => Self::Invalid,
1 => Self::Request,
2 => Self::Response,
3 => Self::Rst,
4 => Self::Shutdown,
5 => Self::Rw,
6 => Self::CreditUpdate,
7 => Self::CreditRequest,
_ => return Err(SocketError::UnknownOperation(v)),
};
Ok(op)
}
}
bitflags! {
#[repr(C)]
#[derive(Default, Pod)]
/// Header flags field type makes sense when connected socket receives VIRTIO_VSOCK_OP_SHUTDOWN.
pub struct ShutdownFlags: u32{
/// The peer will not receive any more data.
const VIRTIO_VSOCK_SHUTDOWN_RCV = 1 << 0;
/// The peer will not send any more data.
const VIRTIO_VSOCK_SHUTDOWN_SEND = 1 << 1;
/// The peer will not send or receive any more data.
const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits;
}
}
/// Currently only stream sockets are supported. type is 1 for stream socket types.
#[derive(Copy, Clone, Debug)]
#[repr(u16)]
pub enum VsockType {
/// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries.
Stream = 1,
/// seqpacket socket type introduced in virtio-v1.2.
SeqPacket = 2,
}

View File

@ -0,0 +1,369 @@
use core::{cmp::min, hint::spin_loop};
use alloc::{vec::Vec, boxed::Box, sync::Arc};
use aster_frame::sync::SpinLock;
use log::debug;
use crate::device::socket::error::SocketError;
use super::{device::SocketDevice, connect::{ConnectionInfo, VsockEvent, VsockEventType, DisconnectReason}, header::VsockAddr};
const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
/// TODO: A higher level interface for VirtIO socket (vsock) devices.
///
/// This keeps track of multiple vsock connections.
///
/// # Example
///
/// ```
///
/// let mut socket = VsockConnectionManager::new(SocketDevice);
///
/// // Start a thread to call `socket.poll()` and handle events.
///
/// let remote_address = VsockAddr { cid: 2, port: 4321 };
/// let local_port = 1234;
/// socket.connect(remote_address, local_port)?;
///
/// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
///
/// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
///
/// socket.shutdown(remote_address, local_port)?;
/// # Ok(())
/// # }
/// ``
pub struct VsockConnectionManager {
driver: Arc<SpinLock<SocketDevice>>,
connections: Vec<Connection>,
listening_ports: Vec<u32>,
}
impl VsockConnectionManager {
/// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
pub fn new(driver: Arc<SpinLock<SocketDevice>>) -> Self {
Self {
driver,
connections: Vec::new(),
listening_ports: Vec::new(),
}
}
/// Returns the CID which has been assigned to this guest.
pub fn guest_cid(&self) -> u64 {
self.driver.lock().guest_cid()
}
/// Allows incoming connections on the given port number.
pub fn listen(&mut self, port: u32) {
if !self.listening_ports.contains(&port) {
self.listening_ports.push(port);
}
}
/// Stops allowing incoming connections on the given port number.
pub fn unlisten(&mut self, port: u32) {
self.listening_ports.retain(|p| *p != port)
}
/// Sends a request to connect to the given destination.
///
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection
/// before sending data.
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> {
if self.connections.iter().any(|connection| {
connection.info.dst == destination && connection.info.src_port == src_port
}) {
return Err(SocketError::ConnectionExists.into());
}
let new_connection = Connection::new(destination, src_port);
self.driver.lock().connect(&new_connection.info)?;
debug!("Connection requested: {:?}", new_connection.info);
self.connections.push(new_connection);
Ok(())
}
/// Sends the buffer to the destination.
pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result<(),SocketError> {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().send(buffer, &mut connection.info)
}
/// Polls the vsock device to receive data or other updates.
pub fn poll(&mut self) -> Result<Option<VsockEvent>,SocketError> {
let guest_cid = self.driver.lock().guest_cid();
let connections = &mut self.connections;
let result = self.driver.lock().poll(|event, body| {
let connection = get_connection_for_event(connections, &event, guest_cid);
// Skip events which don't match any connection we know about, unless they are a
// connection request.
let connection = if let Some((_, connection)) = connection {
connection
} else if let VsockEventType::ConnectionRequest = event.event_type {
// If the requested connection already exists or the CID isn't ours, ignore it.
if connection.is_some() || event.destination.cid != guest_cid {
return Ok(None);
}
// Add the new connection to our list, at least for now. It will be removed again
// below if we weren't listening on the port.
connections.push(Connection::new(event.source, event.destination.port));
connections.last_mut().unwrap()
} else {
return Ok(None);
};
// Update stored connection info.
connection.info.update_for_event(&event);
if let VsockEventType::Received { length } = event.event_type {
// Copy to buffer
if !connection.buffer.add(body) {
return Err(SocketError::OutputBufferTooShort(length));
}
}
Ok(Some(event))
})?;
let Some(event) = result else {
return Ok(None);
};
// The connection must exist because we found it above in the callback.
let (connection_index, connection) =
get_connection_for_event(connections, &event, guest_cid).unwrap();
match event.event_type {
VsockEventType::ConnectionRequest => {
if self.listening_ports.contains(&event.destination.port) {
self.driver.lock().accept(&connection.info)?;
} else {
// Reject the connection request and remove it from our list.
self.driver.lock().force_close(&connection.info)?;
self.connections.swap_remove(connection_index);
// No need to pass the request on to the client, as we've already rejected it.
return Ok(None);
}
}
VsockEventType::Connected => {}
VsockEventType::Disconnected { reason } => {
// Wait until client reads all data before removing connection.
if connection.buffer.is_empty() {
if reason == DisconnectReason::Shutdown {
self.driver.lock().force_close(&connection.info)?;
}
self.connections.swap_remove(connection_index);
} else {
connection.peer_requested_shutdown = true;
}
}
VsockEventType::Received { .. } => {
// Already copied the buffer in the callback above.
}
VsockEventType::CreditRequest => {
// If the peer requested credit, send an update.
self.driver.lock().credit_update(&connection.info)?;
// No need to pass the request on to the client, we've already handled it.
return Ok(None);
}
VsockEventType::CreditUpdate => {}
}
Ok(Some(event))
}
/// Reads data received from the given connection.
pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize,SocketError> {
debug!("connections is {:?}",self.connections);
let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
// Copy from ring buffer
let bytes_read = connection.buffer.drain(buffer);
connection.info.done_forwarding(bytes_read);
// If buffer is now empty and the peer requested shutdown, finish shutting down the
// connection.
if connection.peer_requested_shutdown && connection.buffer.is_empty() {
self.driver.lock().force_close(&connection.info)?;
self.connections.swap_remove(connection_index);
}
Ok(bytes_read)
}
/// Blocks until we get some event from the vsock device.
pub fn wait_for_event(&mut self) -> Result<VsockEvent,SocketError> {
loop {
if let Some(event) = self.poll()? {
return Ok(event);
} else {
spin_loop();
}
}
}
/// Requests to shut down the connection cleanly.
///
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown.
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().shutdown(&connection.info)
}
/// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> {
let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().force_close(&connection.info)?;
self.connections.swap_remove(index);
Ok(())
}
}
/// Returns the connection from the given list matching the given peer address and local port, and
/// its index.
///
/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
fn get_connection(
connections: &mut [Connection],
peer: VsockAddr,
local_port: u32,
) -> core::result::Result<(usize, &mut Connection), SocketError> {
connections
.iter_mut()
.enumerate()
.find(|(_, connection)| {
connection.info.dst == peer && connection.info.src_port == local_port
})
.ok_or(SocketError::NotConnected)
}
/// Returns the connection from the given list matching the event, if any, and its index.
fn get_connection_for_event<'a>(
connections: &'a mut [Connection],
event: &VsockEvent,
local_cid: u64,
) -> Option<(usize, &'a mut Connection)> {
connections
.iter_mut()
.enumerate()
.find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
}
#[derive(Debug)]
struct Connection {
info: ConnectionInfo,
buffer: RingBuffer,
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
/// still data in the buffer.
peer_requested_shutdown: bool,
}
impl Connection {
fn new(peer: VsockAddr, local_port: u32) -> Self {
let mut info = ConnectionInfo::new(peer, local_port);
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
Self {
info,
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
peer_requested_shutdown: false,
}
}
}
#[derive(Debug)]
struct RingBuffer {
buffer: Box<[u8]>,
/// The number of bytes currently in the buffer.
used: usize,
/// The index of the first used byte in the buffer.
start: usize,
}
impl RingBuffer {
pub fn new(capacity: usize) -> Self {
// TODO: can be optimized.
let mut temp = Vec::with_capacity(capacity);
temp.resize(capacity,0);
Self {
// FIXME: if the capacity is excessive, elements move will be executed.
buffer: temp.into_boxed_slice(),
used: 0,
start: 0,
}
}
/// Returns the number of bytes currently used in the buffer.
pub fn used(&self) -> usize {
self.used
}
/// Returns true iff there are currently no bytes in the buffer.
pub fn is_empty(&self) -> bool {
self.used == 0
}
/// Returns the number of bytes currently free in the buffer.
pub fn available(&self) -> usize {
self.buffer.len() - self.used
}
/// Adds the given bytes to the buffer if there is enough capacity for them all.
///
/// Returns true if they were added, or false if they were not.
pub fn add(&mut self, bytes: &[u8]) -> bool {
if bytes.len() > self.available() {
return false;
}
// The index of the first available position in the buffer.
let first_available = (self.start + self.used) % self.buffer.len();
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
// `buffer.len()`.
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
self.buffer[first_available..first_available + copy_length_before_wraparound]
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
}
self.used += bytes.len();
true
}
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
/// buffer.
pub fn drain(&mut self, out: &mut [u8]) -> usize {
let bytes_read = min(self.used, out.len());
// The number of bytes to copy out between `start` and the end of the buffer.
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
let read_after_wraparound = bytes_read
.checked_sub(read_before_wraparound)
.unwrap_or_default();
out[0..read_before_wraparound]
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
out[read_before_wraparound..bytes_read]
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
self.used -= bytes_read;
self.start = (self.start + bytes_read) % self.buffer.len();
bytes_read
}
}

View File

@ -0,0 +1,67 @@
//! This mod is modified from virtio-drivers project.
use alloc::{sync::Arc, collections::BTreeMap, string::String, vec::Vec};
use component::{ComponentInitError, init_component};
use aster_frame::sync::SpinLock;
use smoltcp::socket::dhcpv4::Socket;
use spin::Once;
use core::fmt::Debug;
use self::device::SocketDevice;
pub mod buffer;
pub mod config;
pub mod device;
pub mod header;
pub mod connect;
pub mod error;
pub mod manager;
pub static DEVICE_NAME: &str = "Virtio-Vsock";
pub type VsockDeviceIrqHandler = dyn Fn() + Send + Sync;
pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
COMPONENT
.get()
.unwrap()
.vsock_device_table
.lock()
.insert(name, device);
}
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
let lock = COMPONENT.get().unwrap().vsock_device_table.lock();
let Some(device) = lock.get(str) else {
return None;
};
Some(device.clone())
}
pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
let vsock_devs = COMPONENT.get().unwrap().vsock_device_table.lock();
vsock_devs
.iter()
.map(|(name, device)| (name.clone(), device.clone()))
.collect()
}
static COMPONENT: Once<Component> = Once::new();
pub fn component_init() -> Result<(), ComponentInitError>{
let a = Component::init()?;
COMPONENT.call_once(|| a);
Ok(())
}
struct Component {
vsock_device_table: SpinLock<BTreeMap<String, Arc<SpinLock<SocketDevice>>>>,
}
impl Component {
pub fn init() -> Result<Self, ComponentInitError> {
Ok(Self {
vsock_device_table: SpinLock::new(BTreeMap::new()),
})
}
}

View File

@ -14,7 +14,7 @@ use bitflags::bitflags;
use component::{init_component, ComponentInitError}; use component::{init_component, ComponentInitError};
use device::{ use device::{
block::device::BlockDevice, console::device::ConsoleDevice, input::device::InputDevice, block::device::BlockDevice, console::device::ConsoleDevice, input::device::InputDevice,
network::device::NetworkDevice, VirtioDeviceType, network::device::NetworkDevice, socket::{device::SocketDevice, self}, VirtioDeviceType,
}; };
use log::{error, warn}; use log::{error, warn};
use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus}; use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus};
@ -30,6 +30,8 @@ mod transport;
fn virtio_component_init() -> Result<(), ComponentInitError> { fn virtio_component_init() -> Result<(), ComponentInitError> {
// Find all devices and register them to the corresponding crate // Find all devices and register them to the corresponding crate
transport::init(); transport::init();
// For vsock cmponent
socket::component_init()?;
while let Some(mut transport) = pop_device_transport() { while let Some(mut transport) = pop_device_transport() {
// Reset device // Reset device
transport.set_device_status(DeviceStatus::empty()).unwrap(); transport.set_device_status(DeviceStatus::empty()).unwrap();
@ -52,6 +54,7 @@ fn virtio_component_init() -> Result<(), ComponentInitError> {
VirtioDeviceType::Input => InputDevice::init(transport), VirtioDeviceType::Input => InputDevice::init(transport),
VirtioDeviceType::Network => NetworkDevice::init(transport), VirtioDeviceType::Network => NetworkDevice::init(transport),
VirtioDeviceType::Console => ConsoleDevice::init(transport), VirtioDeviceType::Console => ConsoleDevice::init(transport),
VirtioDeviceType::Socket => SocketDevice::init(transport),
_ => { _ => {
warn!("[Virtio]: Found unimplemented device:{:?}", device_type); warn!("[Virtio]: Found unimplemented device:{:?}", device_type);
Ok(()) Ok(())
@ -86,6 +89,7 @@ fn negotiate_features(transport: &mut Box<dyn VirtioTransport>) {
VirtioDeviceType::Block => BlockDevice::negotiate_features(device_specified_features), VirtioDeviceType::Block => BlockDevice::negotiate_features(device_specified_features),
VirtioDeviceType::Input => InputDevice::negotiate_features(device_specified_features), VirtioDeviceType::Input => InputDevice::negotiate_features(device_specified_features),
VirtioDeviceType::Console => ConsoleDevice::negotiate_features(device_specified_features), VirtioDeviceType::Console => ConsoleDevice::negotiate_features(device_specified_features),
VirtioDeviceType::Socket => SocketDevice::negotiate_features(device_specified_features),
_ => device_specified_features, _ => device_specified_features,
}; };
let mut support_feature = Feature::from_bits_truncate(features); let mut support_feature = Feature::from_bits_truncate(features);

View File

@ -0,0 +1,16 @@
import socket
client_socket = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
CID = socket.VMADDR_CID_HOST
PORT = 1234
vm_cid = 3
server_port = 4321
client_socket.bind((CID, PORT))
client_socket.connect((vm_cid, server_port))
client_socket.sendall(b'Hello from host')
response = client_socket.recv(4096)
print(f'Received: {response.decode()}')
client_socket.close()

View File

@ -0,0 +1,22 @@
#!/usr/bin/env python3
import socket
CID = socket.VMADDR_CID_HOST
PORT = 1234
s = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
s.bind((CID, PORT))
s.listen()
(conn, (remote_cid, remote_port)) = s.accept()
print(f"Connection opened by cid={remote_cid} port={remote_port}")
while True:
buf = conn.recv(64)
if not buf:
break
print(f"Received bytes: {buf}")
conn.send(b'Hello from host')

View File

@ -44,6 +44,7 @@ QEMU_ARGS="\
-device virtio-keyboard-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-keyboard-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \
-device virtio-net-pci,netdev=net01,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-net-pci,netdev=net01,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \
-device virtio-serial-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-serial-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \
-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3$IOMMU_DEV_EXTRA \
-device virtconsole,chardev=mux \ -device virtconsole,chardev=mux \
$IOMMU_EXTRA_ARGS \ $IOMMU_EXTRA_ARGS \
" "
@ -59,6 +60,7 @@ MICROVM_QEMU_ARGS="\
-device virtio-net-device,netdev=net01 \ -device virtio-net-device,netdev=net01 \
-device virtio-serial-device \ -device virtio-serial-device \
-device virtconsole,chardev=mux \ -device virtconsole,chardev=mux \
-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3 \
" "
if [ "$1" = "microvm" ]; then if [ "$1" = "microvm" ]; then