diff --git a/OSDK.toml b/OSDK.toml index cab5e9ae2..2a5439d62 100644 --- a/OSDK.toml +++ b/OSDK.toml @@ -65,6 +65,7 @@ qemu.args = """\ -device virtio-net-pci,netdev=mynet0,disable-legacy=on,disable-modern=off \ -device virtio-keyboard-pci,disable-legacy=on,disable-modern=off \ -device virtio-blk-pci,bus=pcie.0,addr=0x6,drive=x0,disable-legacy=on,disable-modern=off \ + -device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off \ -drive file=fs.img,if=none,format=raw,id=x0 \ -netdev user,id=mynet0,hostfwd=tcp::10027-:22,hostfwd=tcp::54136-:8090 \ -chardev stdio,id=mux,mux=on,logfile=./$(date '+%Y-%m-%dT%H%M%S').log \ @@ -72,4 +73,4 @@ qemu.args = """\ -device virtconsole,chardev=mux \ -monitor chardev:mux \ -serial chardev:mux \ -""" \ No newline at end of file +""" diff --git a/kernel/aster-nix/Cargo.toml b/kernel/aster-nix/Cargo.toml index 679a39350..1858077e4 100644 --- a/kernel/aster-nix/Cargo.toml +++ b/kernel/aster-nix/Cargo.toml @@ -6,6 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +# FIXME: used for test in driver mod +component = {path="../libs/comp-sys/component"} aster-frame = { path = "../../framework/aster-frame" } align_ext = { path = "../../framework/libs/align_ext" } pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" } diff --git a/kernel/aster-nix/src/driver/mod.rs b/kernel/aster-nix/src/driver/mod.rs index 750803b5e..69f2ef093 100644 --- a/kernel/aster-nix/src/driver/mod.rs +++ b/kernel/aster-nix/src/driver/mod.rs @@ -1,10 +1,87 @@ // SPDX-License-Identifier: MPL-2.0 -use log::info; +use log::{info, debug}; +use alloc::string::ToString; +use component::ComponentInitError; +use aster_virtio::{self, device::socket::{header::VsockAddr, device::SocketDevice, manager::VsockConnectionManager, DEVICE_NAME}}; pub fn init() { // print all the input device to make sure input crate will compile for (name, _) in aster_input::all_devices() { info!("Found Input device, name:{}", name); } + // let _ = socket_device_client_test(); + // let _ = socket_device_server_test(); +} + + +fn socket_device_client_test() -> Result<(),ComponentInitError> { + let host_cid = 2; + let guest_cid = 3; + let host_port = 1234; + let guest_port = 4321; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + let hello_from_guest = "Hello from guest"; + let hello_from_host = "Hello from host"; + + let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); + assert_eq!(device.lock().guest_cid(),guest_cid); + let mut socket = VsockConnectionManager::new(device); + + socket.connect(host_address, guest_port).unwrap(); + socket.wait_for_event().unwrap(); // wait for connect response + socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap(); + debug!("The buffer {:?} is sent, start receiving",hello_from_guest.as_bytes()); + socket.wait_for_event().unwrap(); // wait for recv + let mut buffer = [0u8; 64]; + let event = socket.recv(host_address, guest_port,&mut buffer).unwrap(); + assert_eq!( + &buffer[0..hello_from_host.len()], + hello_from_host.as_bytes() + ); + + socket.force_close(host_address,guest_port).unwrap(); + + debug!("The final event: {:?}",event); + Ok(()) +} + +pub fn socket_device_server_test() -> Result<(),ComponentInitError>{ + let host_cid = 2; + let guest_cid = 3; + let host_port = 1234; + let guest_port = 4321; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + let hello_from_guest = "Hello from guest"; + let hello_from_host = "Hello from host"; + + let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); + assert_eq!(device.lock().guest_cid(),guest_cid); + let mut socket = VsockConnectionManager::new(device); + + socket.listen(4321); + socket.wait_for_event().unwrap(); // wait for connect request + socket.wait_for_event().unwrap(); // wait for recv + let mut buffer = [0u8; 64]; + let event = socket.recv(host_address, guest_port,&mut buffer).unwrap(); + assert_eq!( + &buffer[0..hello_from_host.len()], + hello_from_host.as_bytes() + ); + + debug!("The buffer {:?} is received, start sending {:?}", &buffer[0..hello_from_host.len()],hello_from_guest.as_bytes()); + socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap(); + + socket.shutdown(host_address,guest_port).unwrap(); + let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown + + debug!("The final event: {:?}",event); + Ok(()) + } diff --git a/kernel/comps/virtio/src/device/mod.rs b/kernel/comps/virtio/src/device/mod.rs index 3ebf79e6d..2f8640654 100644 --- a/kernel/comps/virtio/src/device/mod.rs +++ b/kernel/comps/virtio/src/device/mod.rs @@ -8,6 +8,7 @@ pub mod block; pub mod console; pub mod input; pub mod network; +pub mod socket; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, TryFromInt)] #[repr(u8)] diff --git a/kernel/comps/virtio/src/device/socket/buffer.rs b/kernel/comps/virtio/src/device/socket/buffer.rs new file mode 100644 index 000000000..20f400414 --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/buffer.rs @@ -0,0 +1,99 @@ +//! This module is adapted from network/buffer.rs +use align_ext::AlignExt; +use bytes::BytesMut; +use pod::Pod; + +use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN; + +use super::header::VirtioVsockHdr; + +/// Buffer for receive packet +#[derive(Debug)] +pub struct RxBuffer { + /// Packet Buffer, length align 8. + buf: BytesMut, + /// Packet len + packet_len: usize, +} + +impl RxBuffer { + pub fn new(len: usize) -> Self { + let len = len.align_up(8); + let buf = BytesMut::zeroed(len); + Self { buf, packet_len: 0 } + } + + pub const fn packet_len(&self) -> usize { + self.packet_len + } + + pub fn set_packet_len(&mut self, packet_len: usize) { + self.packet_len = packet_len; + } + + pub fn buf(&self) -> &[u8] { + &self.buf + } + + pub fn buf_mut(&mut self) -> &mut [u8] { + &mut self.buf + } + + /// Packet payload slice, which is inner buffer excluding VirtioVsockHdr. + pub fn packet(&self) -> &[u8] { + debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len()); + &self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len] + } + + /// Mutable packet payload slice. + pub fn packet_mut(&mut self) -> &mut [u8] { + debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len()); + &mut self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len] + } + + pub fn virtio_vsock_header(&self) -> VirtioVsockHdr { + VirtioVsockHdr::from_bytes(&self.buf[..VIRTIO_VSOCK_HDR_LEN]) + } +} + +/// Buffer for transmit packet +#[derive(Debug)] +pub struct TxBuffer { + buf: BytesMut, +} + +impl TxBuffer { + pub fn with_len(buf_len: usize) -> Self { + Self { + buf: BytesMut::zeroed(buf_len), + } + } + + pub fn new(buf: &[u8]) -> Self { + Self { + buf: BytesMut::from(buf), + } + } + + pub fn buf(&self) -> &[u8] { + &self.buf + } + + pub fn buf_mut(&mut self) -> &mut [u8] { + &mut self.buf + } +} + +/// Buffer for event buffer +#[derive(Debug)] +pub struct EventBuffer{ + id: u32, +} + +#[repr(u32)] +#[derive(Debug, Clone, Copy, Default)] +#[allow(non_camel_case_types)] +pub enum EventIDType { + #[default] + VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0, +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/device/socket/config.rs b/kernel/comps/virtio/src/device/socket/config.rs new file mode 100644 index 000000000..af625d4bb --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/config.rs @@ -0,0 +1,44 @@ +use aster_frame::io_mem::IoMem; +use pod::Pod; +use aster_util::{safe_ptr::SafePtr}; +use bitflags::bitflags; + +use crate::transport::{self, VirtioTransport}; + +bitflags!{ + /// Vsock feature bits since v1.2 + /// If no feature bit is set, only stream socket type is supported. + /// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated. + pub struct VsockFeatures: u64 { + const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported. + const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is supported. + } +} + +impl VsockFeatures { + pub fn support_features() -> Self { + VsockFeatures::VIRTIO_VSOCK_F_STREAM + } +} + +#[derive(Debug, Clone, Copy, Pod)] +#[repr(C)] +pub struct VirtioVsockConfig{ + /// The guest_cid field contains the guest’s context ID, which uniquely identifies + /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + /// + /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space, + /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic. + /// So we need to split the u64 guest_cid into two parts. + // read only + pub guest_cid_low: u32, + // read only + pub guest_cid_high: u32, +} + +impl VirtioVsockConfig { + pub(crate) fn new(transport: &dyn VirtioTransport) -> SafePtr { + let memory = transport.device_config_memory(); + SafePtr::new(memory, 0) + } +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/device/socket/connect.rs b/kernel/comps/virtio/src/device/socket/connect.rs new file mode 100644 index 000000000..411197c4c --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/connect.rs @@ -0,0 +1,187 @@ +use log::debug; + +use super::{header::{VsockAddr, VirtioVsockHdr, VirtioVsockOp}, error::SocketError}; + + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct VsockBufferStatus { + pub buffer_allocation: u32, + pub forward_count: u32, +} + +/// The reason why a vsock connection was closed. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum DisconnectReason { + /// The peer has either closed the connection in response to our shutdown request, or forcibly + /// closed it of its own accord. + Reset, + /// The peer asked to shut down the connection. + Shutdown, +} + +/// Details of the type of an event received from a VirtIO socket. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum VsockEventType { + /// The peer requests to establish a connection with us. + ConnectionRequest, + /// The connection was successfully established. + Connected, + /// The connection was closed. + Disconnected { + /// The reason for the disconnection. + reason: DisconnectReason, + }, + /// Data was received on the connection. + Received { + /// The length of the data in bytes. + length: usize, + }, + /// The peer requests us to send a credit update. + CreditRequest, + /// The peer just sent us a credit update with nothing else. + CreditUpdate, +} + +/// An event received from a VirtIO socket device. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct VsockEvent { + /// The source of the event, i.e. the peer who sent it. + pub source: VsockAddr, + /// The destination of the event, i.e. the CID and port on our side. + pub destination: VsockAddr, + /// The peer's buffer status for the connection. + pub buffer_status: VsockBufferStatus, + /// The type of event. + pub event_type: VsockEventType, +} + +impl VsockEvent { + /// Returns whether the event matches the given connection. + pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool { + self.source == connection_info.dst + && self.destination.cid == guest_cid + && self.destination.port == connection_info.src_port + } + + pub fn from_header(header: &VirtioVsockHdr) -> Result { + let op = header.op()?; + let buffer_status = VsockBufferStatus { + buffer_allocation: header.buf_alloc, + forward_count: header.fwd_cnt, + }; + let source = header.source(); + let destination = header.destination(); + + let event_type = match op { + VirtioVsockOp::Request => { + header.check_data_is_empty()?; + VsockEventType::ConnectionRequest + } + VirtioVsockOp::Response => { + header.check_data_is_empty()?; + VsockEventType::Connected + } + VirtioVsockOp::CreditUpdate => { + header.check_data_is_empty()?; + VsockEventType::CreditUpdate + } + VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => { + header.check_data_is_empty()?; + debug!("Disconnected from the peer"); + let reason = if op == VirtioVsockOp::Rst { + DisconnectReason::Reset + } else { + DisconnectReason::Shutdown + }; + VsockEventType::Disconnected { reason } + } + VirtioVsockOp::Rw => VsockEventType::Received { + length: header.len() as usize, + }, + VirtioVsockOp::CreditRequest => { + header.check_data_is_empty()?; + VsockEventType::CreditRequest + } + VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation), + }; + + Ok(VsockEvent { + source, + destination, + buffer_status, + event_type, + }) + } +} + + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct ConnectionInfo { + pub dst: VsockAddr, + pub src_port: u32, + /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in + /// bytes it has allocated for packet bodies. + peer_buf_alloc: u32, + /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it + /// has finished processing. + peer_fwd_cnt: u32, + /// The number of bytes of packet bodies which we have sent to the peer. + pub tx_cnt: u32, + /// The number of bytes of buffer space we have allocated to receive packet bodies from the + /// peer. + pub buf_alloc: u32, + /// The number of bytes of packet bodies which we have received from the peer and handled. + pub fwd_cnt: u32, + /// Whether we have recently requested credit from the peer. + /// + /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we + /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`. + pub has_pending_credit_request: bool, +} + +impl ConnectionInfo { + pub fn new(destination: VsockAddr, src_port: u32) -> Self { + Self { + dst: destination, + src_port, + ..Default::default() + } + } + + /// Updates this connection info with the peer buffer allocation and forwarded count from the + /// given event. + pub fn update_for_event(&mut self, event: &VsockEvent) { + self.peer_buf_alloc = event.buffer_status.buffer_allocation; + self.peer_fwd_cnt = event.buffer_status.forward_count; + + if let VsockEventType::CreditUpdate = event.event_type { + self.has_pending_credit_request = false; + } + } + + /// Increases the forwarded count recorded for this connection by the given number of bytes. + /// + /// This should be called once received data has been passed to the client, so there is buffer + /// space available for more. + pub fn done_forwarding(&mut self, length: usize) { + self.fwd_cnt += length as u32; + } + + /// Returns the number of bytes of RX buffer space the peer has available to receive packet body + /// data from us. + pub fn peer_free(&self) -> u32 { + self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt) + } + + pub fn new_header(&self, src_cid: u64) -> VirtioVsockHdr { + VirtioVsockHdr { + src_cid, + dst_cid: self.dst.cid, + src_port: self.src_port, + dst_port: self.dst.port, + buf_alloc: self.buf_alloc, + fwd_cnt: self.fwd_cnt, + ..Default::default() + } + } +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/device/socket/device.rs b/kernel/comps/virtio/src/device/socket/device.rs new file mode 100644 index 000000000..39b91ad67 --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/device.rs @@ -0,0 +1,327 @@ +use core::hint::spin_loop; +use core::fmt::Debug; +use alloc::{vec::Vec, boxed::Box, string::ToString, sync::Arc}; +use aster_frame::{offset_of, trap::TrapFrame, sync::SpinLock}; +use aster_util::{slot_vec::SlotVec, field_ptr}; +use log::debug; +use pod::Pod; + +use crate::{queue::{VirtQueue, QueueError}, device::{VirtioDeviceError, socket::{register_device, DEVICE_NAME}}, transport::{VirtioTransport}}; + +use super::{buffer::RxBuffer, config::{VirtioVsockConfig, VsockFeatures}, connect::{ConnectionInfo, VsockEvent}, header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN}, error::SocketError, VsockDeviceIrqHandler}; + +const QUEUE_SIZE: u16 = 64; +const QUEUE_RECV: u16 = 0; +const QUEUE_SEND: u16 = 1; +const QUEUE_EVENT: u16 = 2; + +/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::(). +const RX_BUFFER_SIZE: usize = 512; + +/// Low-level driver for a Virtio socket device. +pub struct SocketDevice { + config: VirtioVsockConfig, + guest_cid: u64, + + /// Virtqueue to receive packets. + send_queue: VirtQueue, + recv_queue: VirtQueue, + event_queue: VirtQueue, + + rx_buffers: SlotVec, + transport: Box, + callbacks: Vec>, +} + +impl SocketDevice { + pub fn init(mut transport: Box) -> Result<(), VirtioDeviceError> { + let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut()); + debug!("virtio_vsock_config = {:?}", virtio_vsock_config); + let guest_cid = + field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low).read().unwrap() as u64 + | (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high).read().unwrap() as u64) << 32; + + let mut recv_queue = VirtQueue::new(QUEUE_RECV,QUEUE_SIZE,transport.as_mut()) + .expect("createing recv queue fails"); + let send_queue = VirtQueue::new(QUEUE_SEND,QUEUE_SIZE,transport.as_mut()) + .expect("creating send queue fails"); + let event_queue = VirtQueue::new(QUEUE_EVENT,QUEUE_SIZE,transport.as_mut()) + .expect("creating event queue fails"); + + // Allocate and add buffers for the RX queue. + let mut rx_buffers = SlotVec::new(); + for i in 0..QUEUE_SIZE { + let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE); + let token = recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?; + assert_eq!(i, token); + assert_eq!(rx_buffers.put(rx_buffer) as u16, i); + } + + if recv_queue.should_notify() { + debug!("notify receive queue"); + recv_queue.notify(); + } + + let mut device = Self{ + config: virtio_vsock_config.read().unwrap(), + guest_cid, + send_queue, + recv_queue, + event_queue, + rx_buffers, + transport, + callbacks: Vec::new(), + }; + + // Interrupt handler if vsock device config space changes + fn config_space_change(_: &TrapFrame){ + debug!("vsock device config space change"); + } + + // Interrupt handler if vsock device receives some packet. + // TODO: This will be handled by vsock socket layer. + fn handle_vsock_event(_: &TrapFrame){ + debug!("Packet received. This will be solved by socket layer"); + } + + device + .transport + .register_cfg_callback(Box::new(config_space_change)) + .unwrap(); + device + .transport + .register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false) + .unwrap(); + + device.transport.finish_init(); + + register_device( + super::DEVICE_NAME.to_string(), + Arc::new(SpinLock::new(device)), + ); + + Ok(()) + } + + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.guest_cid + } + + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Request as u16, + ..connection_info.new_header(self.guest_cid) + }; + // Sends a header only packet to the TX queue to connect the device to the listening socket + // at the given destination. + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Accepts the given connection from a peer. + pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Response as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Requests the peer to send us a credit update for the given connection. + fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditRequest as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Tells the peer how much buffer space we have to receive data. + pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditUpdate as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Requests to shut down the connection cleanly. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Shutdown as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Rst as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result<(), SocketError> { + // let (_token, _len) = self.send_queue.add_notify_wait_pop( + // &[header.as_bytes(), buffer], + // &mut [], + // )?; + + let _token = self + .send_queue + .add_buf(&[header.as_bytes(), buffer], &[])?; + + if self.send_queue.should_notify() { + self.send_queue.notify(); + } + + // Wait until the buffer is used + while !self.send_queue.can_pop() { + spin_loop(); + } + + self.send_queue.pop_used()?; + + // FORDEBUG + // debug!("buffer in send_packet_to_tx_queue: {:?}",buffer); + Ok(()) + } + + fn check_peer_buffer_is_sufficient( + &mut self, + connection_info: &mut ConnectionInfo, + buffer_len: usize, + ) -> Result<(), SocketError> { + if connection_info.peer_free() as usize >= buffer_len { + Ok(()) + } else { + // Request an update of the cached peer credit, if we haven't already done so, and tell + // the caller to try again later. + if !connection_info.has_pending_credit_request { + self.request_credit(connection_info)?; + connection_info.has_pending_credit_request = true; + } + Err(SocketError::InsufficientBufferSpaceInPeer) + } + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result<(), SocketError> { + self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?; + + let len = buffer.len() as u32; + let header = VirtioVsockHdr { + op: VirtioVsockOp::Rw as u16, + len: len, + ..connection_info.new_header(self.guest_cid) + }; + connection_info.tx_cnt += len; + self.send_packet_to_tx_queue(&header, buffer) + } + + /// Polls the RX virtqueue for the next event, and calls the given handler function to handle it. + pub fn poll(&mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result,SocketError> + ) -> Result, SocketError> { + // Return None if there is no pending packet. + if !self.recv_queue.can_pop(){ + return Ok(None); + } + let (token, len) = self.recv_queue.pop_used()?; + + let mut buffer = self + .rx_buffers + .remove(token as usize) + .ok_or(QueueError::WrongToken)?; + + let header = buffer.virtio_vsock_header(); + // The length written should be equal to len(header)+len(packet) + assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32); + + buffer.set_packet_len(RX_BUFFER_SIZE); + + + let head_result = read_header_and_body(&buffer.buf()); + + let Ok((header,body)) = head_result else { + let ret = match head_result { + Err(e) => Err(e), + _ => Ok(None) //FIXME: this clause is never reached. + }; + self.add_rx_buffer(buffer, token)?; + return ret; + }; + + debug!("Received packet {:?}. Op {:?}", header, header.op()); + debug!("body is {:?}",body); + + let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body)); + + // reuse the buffer and give it back to recv_queue. + self.add_rx_buffer(buffer, token)?; + + result + + } + + /// Add a used rx buffer to recv queue,@index is only to check the correctness + fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> { + let token = self.recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?; + assert_eq!(index,token); + assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none()); + if self.recv_queue.should_notify() { + self.recv_queue.notify(); + } + Ok(()) + } + + /// Negotiate features for the device specified bits 0~23 + pub(crate) fn negotiate_features(features: u64) -> u64 { + let device_features = VsockFeatures::from_bits_truncate(features); + let supported_features = VsockFeatures::support_features(); + let vsock_features = device_features & supported_features; + debug!("features negotiated: {:?}",vsock_features); + vsock_features.bits() + } + + +} + +impl Debug for SocketDevice { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SocketDevice") + .field("config", &self.config) + .field("guest_cid", &self.guest_cid) + .field("send_queue", &self.send_queue) + .field("recv_queue", &self.recv_queue) + .field("event_queue", &self.event_queue) + .field("transport", &self.transport) + .finish() + } +} + +fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketError> { + // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::()`. + let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]); + let body_length = header.len() as usize; + + // This could fail if the device returns an unreasonably long body length. + let data_end = VIRTIO_VSOCK_HDR_LEN + .checked_add(body_length) + .ok_or(SocketError::InvalidNumber)?; + // This could fail if the device returns a body length longer than the buffer we gave it. + let data = buffer + .get(VIRTIO_VSOCK_HDR_LEN..data_end) + .ok_or(SocketError::BufferTooShort)?; + Ok((header, data)) +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/device/socket/error.rs b/kernel/comps/virtio/src/device/socket/error.rs new file mode 100644 index 000000000..42a96993d --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/error.rs @@ -0,0 +1,105 @@ +//! This file comes from virtio-drivers project +//! This module contains the error from the VirtIO socket driver. + +use core::{fmt, result}; + +use smoltcp::socket::dhcpv4::Socket; + +use crate::queue::QueueError; + +/// The error type of VirtIO socket driver. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum SocketError { + /// There is an existing connection. + ConnectionExists, + /// Failed to establish the connection. + ConnectionFailed, + /// The device is not connected to any peer. + NotConnected, + /// Peer socket is shutdown. + PeerSocketShutdown, + /// No response received. + NoResponseReceived, + /// The given buffer is shorter than expected. + BufferTooShort, + /// The given buffer for output is shorter than expected. + OutputBufferTooShort(usize), + /// The given buffer has exceeded the maximum buffer size. + BufferTooLong(usize, usize), + /// Unknown operation. + UnknownOperation(u16), + /// Invalid operation, + InvalidOperation, + /// Invalid number. + InvalidNumber, + /// Unexpected data in packet. + UnexpectedDataInPacket, + /// Peer has insufficient buffer space, try again later. + InsufficientBufferSpaceInPeer, + /// Recycled a wrong buffer. + RecycledWrongBuffer, + /// Queue Error + QueueError(SocketQueueError), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum SocketQueueError { + InvalidArgs, + BufferTooSmall, + NotReady, + AlreadyUsed, + WrongToken, +} + +impl From for SocketQueueError { + fn from(value: QueueError) -> Self { + match value { + QueueError::InvalidArgs => Self::InvalidArgs, + QueueError::BufferTooSmall => Self::BufferTooSmall, + QueueError::NotReady => Self::NotReady, + QueueError::AlreadyUsed => Self::AlreadyUsed, + QueueError::WrongToken => Self::WrongToken, + } + } +} + + +impl From for SocketError { + fn from(value: QueueError) -> Self { + Self::QueueError(SocketQueueError::from(value)) + } +} + +impl fmt::Display for SocketError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::ConnectionExists => write!( + f, + "There is an existing connection. Please close the current connection before attempting to connect again."), + Self::ConnectionFailed => write!( + f, "Failed to establish the connection. The packet sent may have an unknown type value" + ), + Self::NotConnected => write!(f, "The device is not connected to any peer. Please connect it to a peer first."), + Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."), + Self::NoResponseReceived => write!(f, "No response received"), + Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"), + Self::BufferTooLong(actual, max) => { + write!(f, "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'") + } + Self::OutputBufferTooShort(expected) => { + write!(f, "The given output buffer is too short. '{expected}' bytes is needed for the output buffer.") + } + Self::UnknownOperation(op) => { + write!(f, "The operation code '{op}' is unknown") + } + Self::InvalidOperation => write!(f, "Invalid operation"), + Self::InvalidNumber => write!(f, "Invalid number"), + Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"), + Self::InsufficientBufferSpaceInPeer => write!(f, "Peer has insufficient buffer space, try again later"), + Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"), + Self::QueueError(_) => write!(f,"Error encounted out of vsock itself!"), + } + } +} + +pub type Result = result::Result; diff --git a/kernel/comps/virtio/src/device/socket/header.rs b/kernel/comps/virtio/src/device/socket/header.rs new file mode 100644 index 000000000..c89e2b3a8 --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/header.rs @@ -0,0 +1,151 @@ +use pod::Pod; +use bitflags::bitflags; +use super::error::{self, SocketError}; + +pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::(); + +/// Socket address. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct VsockAddr { + /// Context Identifier. + pub cid: u64, + /// Port number. + pub port: u32, +} + +/// VirtioVsock header precedes the payload in each packet. +// #[repr(packed)] +#[repr(C,packed)] +#[derive(Debug, Clone, Copy, Pod)] +pub struct VirtioVsockHdr{ + pub src_cid: u64, + pub dst_cid: u64, + pub src_port: u32, + pub dst_port: u32, + + pub len: u32, + pub socket_type: u16, + pub op: u16, //TOASK: why mark Pod and can I mark OpType Pod and replace u16 into OpType. + pub flags: u32, + /// Total receive buffer space for this socket. This includes both free and in-use buffers. + pub buf_alloc: u32, + /// Free-running bytes received counter. + pub fwd_cnt: u32, +} + +impl Default for VirtioVsockHdr { + fn default() -> Self { + Self { + src_cid: 0, + dst_cid: 0, + src_port: 0, + dst_port: 0, + len: 0, + socket_type: VsockType::Stream as u16, + op: 0, + flags: 0, + buf_alloc: 0, + fwd_cnt: 0, + } + } +} + + +impl VirtioVsockHdr { + /// Returns the length of the data. + pub fn len(&self) -> u32 { + self.len + } + + pub fn op(&self) -> error::Result { + self.op.try_into() + } + + pub fn source(&self) -> VsockAddr { + VsockAddr{ + cid: self.src_cid, + port: self.src_port, + } + } + + pub fn destination(&self) -> VsockAddr { + VsockAddr { + cid: self.dst_cid, + port: self.dst_port, + } + } + + pub fn check_data_is_empty(&self) -> error::Result<()> { + if self.len() == 0 { + Ok(()) + } else { + Err(SocketError::UnexpectedDataInPacket) + } + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +#[allow(non_camel_case_types)] +pub enum VirtioVsockOp{ + #[default] + Invalid = 0, + + /* Connect operations */ + Request = 1, + Response = 2, + Rst = 3, + Shutdown = 4, + + /* To send payload */ + Rw = 5, + + /* Tell the peer our credit info */ + CreditUpdate = 6, + /* Request the peer to send the credit info to us */ + CreditRequest = 7, +} + +/// TODO: This could be optimized by upgrading [int_to_c_enum::TryFromIntError] to carrying the invalid int number +impl TryFrom for VirtioVsockOp { + type Error = SocketError; + + fn try_from(v: u16) -> Result { + let op = match v { + 0 => Self::Invalid, + 1 => Self::Request, + 2 => Self::Response, + 3 => Self::Rst, + 4 => Self::Shutdown, + 5 => Self::Rw, + 6 => Self::CreditUpdate, + 7 => Self::CreditRequest, + _ => return Err(SocketError::UnknownOperation(v)), + }; + Ok(op) + } +} + +bitflags! { + #[repr(C)] + #[derive(Default, Pod)] + /// Header flags field type makes sense when connected socket receives VIRTIO_VSOCK_OP_SHUTDOWN. + pub struct ShutdownFlags: u32{ + /// The peer will not receive any more data. + const VIRTIO_VSOCK_SHUTDOWN_RCV = 1 << 0; + /// The peer will not send any more data. + const VIRTIO_VSOCK_SHUTDOWN_SEND = 1 << 1; + /// The peer will not send or receive any more data. + const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits; + } +} + +/// Currently only stream sockets are supported. type is 1 for stream socket types. +#[derive(Copy, Clone, Debug)] +#[repr(u16)] +pub enum VsockType { + /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries. + Stream = 1, + /// seqpacket socket type introduced in virtio-v1.2. + SeqPacket = 2, +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/device/socket/manager.rs b/kernel/comps/virtio/src/device/socket/manager.rs new file mode 100644 index 000000000..8da27b61e --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/manager.rs @@ -0,0 +1,369 @@ +use core::{cmp::min, hint::spin_loop}; + +use alloc::{vec::Vec, boxed::Box, sync::Arc}; +use aster_frame::sync::SpinLock; +use log::debug; + +use crate::device::socket::error::SocketError; + +use super::{device::SocketDevice, connect::{ConnectionInfo, VsockEvent, VsockEventType, DisconnectReason}, header::VsockAddr}; + + +const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; + +/// TODO: A higher level interface for VirtIO socket (vsock) devices. +/// +/// This keeps track of multiple vsock connections. +/// +/// # Example +/// +/// ``` +/// +/// let mut socket = VsockConnectionManager::new(SocketDevice); +/// +/// // Start a thread to call `socket.poll()` and handle events. +/// +/// let remote_address = VsockAddr { cid: 2, port: 4321 }; +/// let local_port = 1234; +/// socket.connect(remote_address, local_port)?; +/// +/// // Wait until `socket.poll()` returns an event indicating that the socket is connected. +/// +/// socket.send(remote_address, local_port, "Hello world".as_bytes())?; +/// +/// socket.shutdown(remote_address, local_port)?; +/// # Ok(()) +/// # } +/// `` +pub struct VsockConnectionManager { + driver: Arc>, + connections: Vec, + listening_ports: Vec, +} + +impl VsockConnectionManager { + /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. + pub fn new(driver: Arc>) -> Self { + Self { + driver, + connections: Vec::new(), + listening_ports: Vec::new(), + } + } + + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.driver.lock().guest_cid() + } + + /// Allows incoming connections on the given port number. + pub fn listen(&mut self, port: u32) { + if !self.listening_ports.contains(&port) { + self.listening_ports.push(port); + } + } + + /// Stops allowing incoming connections on the given port number. + pub fn unlisten(&mut self, port: u32) { + self.listening_ports.retain(|p| *p != port) + } + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Connected` event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { + if self.connections.iter().any(|connection| { + connection.info.dst == destination && connection.info.src_port == src_port + }) { + return Err(SocketError::ConnectionExists.into()); + } + + let new_connection = Connection::new(destination, src_port); + + self.driver.lock().connect(&new_connection.info)?; + debug!("Connection requested: {:?}", new_connection.info); + self.connections.push(new_connection); + Ok(()) + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result<(),SocketError> { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.lock().send(buffer, &mut connection.info) + } + + /// Polls the vsock device to receive data or other updates. + pub fn poll(&mut self) -> Result,SocketError> { + let guest_cid = self.driver.lock().guest_cid(); + let connections = &mut self.connections; + + let result = self.driver.lock().poll(|event, body| { + let connection = get_connection_for_event(connections, &event, guest_cid); + + // Skip events which don't match any connection we know about, unless they are a + // connection request. + let connection = if let Some((_, connection)) = connection { + connection + } else if let VsockEventType::ConnectionRequest = event.event_type { + // If the requested connection already exists or the CID isn't ours, ignore it. + if connection.is_some() || event.destination.cid != guest_cid { + return Ok(None); + } + // Add the new connection to our list, at least for now. It will be removed again + // below if we weren't listening on the port. + connections.push(Connection::new(event.source, event.destination.port)); + connections.last_mut().unwrap() + } else { + return Ok(None); + }; + + // Update stored connection info. + connection.info.update_for_event(&event); + + if let VsockEventType::Received { length } = event.event_type { + // Copy to buffer + if !connection.buffer.add(body) { + return Err(SocketError::OutputBufferTooShort(length)); + } + } + + Ok(Some(event)) + })?; + + let Some(event) = result else { + return Ok(None); + }; + + // The connection must exist because we found it above in the callback. + let (connection_index, connection) = + get_connection_for_event(connections, &event, guest_cid).unwrap(); + + match event.event_type { + VsockEventType::ConnectionRequest => { + if self.listening_ports.contains(&event.destination.port) { + self.driver.lock().accept(&connection.info)?; + } else { + // Reject the connection request and remove it from our list. + self.driver.lock().force_close(&connection.info)?; + self.connections.swap_remove(connection_index); + + // No need to pass the request on to the client, as we've already rejected it. + return Ok(None); + } + } + VsockEventType::Connected => {} + VsockEventType::Disconnected { reason } => { + // Wait until client reads all data before removing connection. + if connection.buffer.is_empty() { + if reason == DisconnectReason::Shutdown { + self.driver.lock().force_close(&connection.info)?; + } + self.connections.swap_remove(connection_index); + } else { + connection.peer_requested_shutdown = true; + } + } + VsockEventType::Received { .. } => { + // Already copied the buffer in the callback above. + } + VsockEventType::CreditRequest => { + // If the peer requested credit, send an update. + self.driver.lock().credit_update(&connection.info)?; + // No need to pass the request on to the client, we've already handled it. + return Ok(None); + } + VsockEventType::CreditUpdate => {} + } + + Ok(Some(event)) + } + + /// Reads data received from the given connection. + pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result { + debug!("connections is {:?}",self.connections); + let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; + + // Copy from ring buffer + let bytes_read = connection.buffer.drain(buffer); + + connection.info.done_forwarding(bytes_read); + + // If buffer is now empty and the peer requested shutdown, finish shutting down the + // connection. + if connection.peer_requested_shutdown && connection.buffer.is_empty() { + self.driver.lock().force_close(&connection.info)?; + self.connections.swap_remove(connection_index); + } + + Ok(bytes_read) + } + + /// Blocks until we get some event from the vsock device. + pub fn wait_for_event(&mut self) -> Result { + loop { + if let Some(event) = self.poll()? { + return Ok(event); + } else { + spin_loop(); + } + } + } + + /// Requests to shut down the connection cleanly. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.lock().shutdown(&connection.info) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { + let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.lock().force_close(&connection.info)?; + + self.connections.swap_remove(index); + Ok(()) + } +} + +/// Returns the connection from the given list matching the given peer address and local port, and +/// its index. +/// +/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list. +fn get_connection( + connections: &mut [Connection], + peer: VsockAddr, + local_port: u32, +) -> core::result::Result<(usize, &mut Connection), SocketError> { + connections + .iter_mut() + .enumerate() + .find(|(_, connection)| { + connection.info.dst == peer && connection.info.src_port == local_port + }) + .ok_or(SocketError::NotConnected) +} + +/// Returns the connection from the given list matching the event, if any, and its index. +fn get_connection_for_event<'a>( + connections: &'a mut [Connection], + event: &VsockEvent, + local_cid: u64, +) -> Option<(usize, &'a mut Connection)> { + connections + .iter_mut() + .enumerate() + .find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) +} + + +#[derive(Debug)] +struct Connection { + info: ConnectionInfo, + buffer: RingBuffer, + /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is + /// still data in the buffer. + peer_requested_shutdown: bool, +} + +impl Connection { + fn new(peer: VsockAddr, local_port: u32) -> Self { + let mut info = ConnectionInfo::new(peer, local_port); + info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); + Self { + info, + buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), + peer_requested_shutdown: false, + } + } +} + +#[derive(Debug)] +struct RingBuffer { + buffer: Box<[u8]>, + /// The number of bytes currently in the buffer. + used: usize, + /// The index of the first used byte in the buffer. + start: usize, +} + +impl RingBuffer { + pub fn new(capacity: usize) -> Self { + // TODO: can be optimized. + let mut temp = Vec::with_capacity(capacity); + temp.resize(capacity,0); + Self { + // FIXME: if the capacity is excessive, elements move will be executed. + buffer: temp.into_boxed_slice(), + used: 0, + start: 0, + } + } + /// Returns the number of bytes currently used in the buffer. + pub fn used(&self) -> usize { + self.used + } + + /// Returns true iff there are currently no bytes in the buffer. + pub fn is_empty(&self) -> bool { + self.used == 0 + } + + /// Returns the number of bytes currently free in the buffer. + pub fn available(&self) -> usize { + self.buffer.len() - self.used + } + + /// Adds the given bytes to the buffer if there is enough capacity for them all. + /// + /// Returns true if they were added, or false if they were not. + pub fn add(&mut self, bytes: &[u8]) -> bool { + if bytes.len() > self.available() { + return false; + } + + // The index of the first available position in the buffer. + let first_available = (self.start + self.used) % self.buffer.len(); + // The number of bytes to copy from `bytes` to `buffer` between `first_available` and + // `buffer.len()`. + let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); + self.buffer[first_available..first_available + copy_length_before_wraparound] + .copy_from_slice(&bytes[0..copy_length_before_wraparound]); + if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { + self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); + } + self.used += bytes.len(); + + true + } + + /// Reads and removes as many bytes as possible from the buffer, up to the length of the given + /// buffer. + pub fn drain(&mut self, out: &mut [u8]) -> usize { + let bytes_read = min(self.used, out.len()); + + // The number of bytes to copy out between `start` and the end of the buffer. + let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); + // The number of bytes to copy out from the beginning of the buffer after wrapping around. + let read_after_wraparound = bytes_read + .checked_sub(read_before_wraparound) + .unwrap_or_default(); + + out[0..read_before_wraparound] + .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); + out[read_before_wraparound..bytes_read] + .copy_from_slice(&self.buffer[0..read_after_wraparound]); + + self.used -= bytes_read; + self.start = (self.start + bytes_read) % self.buffer.len(); + + bytes_read + } +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/device/socket/mod.rs b/kernel/comps/virtio/src/device/socket/mod.rs new file mode 100644 index 000000000..316957003 --- /dev/null +++ b/kernel/comps/virtio/src/device/socket/mod.rs @@ -0,0 +1,67 @@ +//! This mod is modified from virtio-drivers project. + +use alloc::{sync::Arc, collections::BTreeMap, string::String, vec::Vec}; +use component::{ComponentInitError, init_component}; +use aster_frame::sync::SpinLock; +use smoltcp::socket::dhcpv4::Socket; +use spin::Once; +use core::fmt::Debug; +use self::device::SocketDevice; +pub mod buffer; +pub mod config; +pub mod device; +pub mod header; +pub mod connect; +pub mod error; +pub mod manager; + +pub static DEVICE_NAME: &str = "Virtio-Vsock"; +pub type VsockDeviceIrqHandler = dyn Fn() + Send + Sync; + + +pub fn register_device(name: String, device: Arc>) { + COMPONENT + .get() + .unwrap() + .vsock_device_table + .lock() + .insert(name, device); +} + +pub fn get_device(str: &str) -> Option>> { + let lock = COMPONENT.get().unwrap().vsock_device_table.lock(); + let Some(device) = lock.get(str) else { + return None; + }; + Some(device.clone()) +} + +pub fn all_devices() -> Vec<(String, Arc>)> { + let vsock_devs = COMPONENT.get().unwrap().vsock_device_table.lock(); + vsock_devs + .iter() + .map(|(name, device)| (name.clone(), device.clone())) + .collect() +} + +static COMPONENT: Once = Once::new(); + + +pub fn component_init() -> Result<(), ComponentInitError>{ + let a = Component::init()?; + COMPONENT.call_once(|| a); + Ok(()) +} + + +struct Component { + vsock_device_table: SpinLock>>>, +} + +impl Component { + pub fn init() -> Result { + Ok(Self { + vsock_device_table: SpinLock::new(BTreeMap::new()), + }) + } +} \ No newline at end of file diff --git a/kernel/comps/virtio/src/lib.rs b/kernel/comps/virtio/src/lib.rs index 240149423..f95754e04 100644 --- a/kernel/comps/virtio/src/lib.rs +++ b/kernel/comps/virtio/src/lib.rs @@ -14,7 +14,7 @@ use bitflags::bitflags; use component::{init_component, ComponentInitError}; use device::{ block::device::BlockDevice, console::device::ConsoleDevice, input::device::InputDevice, - network::device::NetworkDevice, VirtioDeviceType, + network::device::NetworkDevice, socket::{device::SocketDevice, self}, VirtioDeviceType, }; use log::{error, warn}; use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus}; @@ -30,6 +30,8 @@ mod transport; fn virtio_component_init() -> Result<(), ComponentInitError> { // Find all devices and register them to the corresponding crate transport::init(); + // For vsock cmponent + socket::component_init()?; while let Some(mut transport) = pop_device_transport() { // Reset device transport.set_device_status(DeviceStatus::empty()).unwrap(); @@ -52,6 +54,7 @@ fn virtio_component_init() -> Result<(), ComponentInitError> { VirtioDeviceType::Input => InputDevice::init(transport), VirtioDeviceType::Network => NetworkDevice::init(transport), VirtioDeviceType::Console => ConsoleDevice::init(transport), + VirtioDeviceType::Socket => SocketDevice::init(transport), _ => { warn!("[Virtio]: Found unimplemented device:{:?}", device_type); Ok(()) @@ -86,6 +89,7 @@ fn negotiate_features(transport: &mut Box) { VirtioDeviceType::Block => BlockDevice::negotiate_features(device_specified_features), VirtioDeviceType::Input => InputDevice::negotiate_features(device_specified_features), VirtioDeviceType::Console => ConsoleDevice::negotiate_features(device_specified_features), + VirtioDeviceType::Socket => SocketDevice::negotiate_features(device_specified_features), _ => device_specified_features, }; let mut support_feature = Feature::from_bits_truncate(features); diff --git a/test_vsock/vsock_client.py b/test_vsock/vsock_client.py new file mode 100644 index 000000000..3dae3b8e7 --- /dev/null +++ b/test_vsock/vsock_client.py @@ -0,0 +1,16 @@ +import socket + +client_socket = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) +CID = socket.VMADDR_CID_HOST +PORT = 1234 +vm_cid = 3 +server_port = 4321 +client_socket.bind((CID, PORT)) +client_socket.connect((vm_cid, server_port)) + +client_socket.sendall(b'Hello from host') + +response = client_socket.recv(4096) +print(f'Received: {response.decode()}') + +client_socket.close() diff --git a/test_vsock/vsock_server.py b/test_vsock/vsock_server.py new file mode 100644 index 000000000..7b4baf8d4 --- /dev/null +++ b/test_vsock/vsock_server.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import socket + +CID = socket.VMADDR_CID_HOST +PORT = 1234 + +s = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) +s.bind((CID, PORT)) +s.listen() +(conn, (remote_cid, remote_port)) = s.accept() + +print(f"Connection opened by cid={remote_cid} port={remote_port}") + +while True: + buf = conn.recv(64) + if not buf: + break + + print(f"Received bytes: {buf}") + + conn.send(b'Hello from host') \ No newline at end of file diff --git a/tools/qemu_args.sh b/tools/qemu_args.sh index e940b0f41..11695a793 100755 --- a/tools/qemu_args.sh +++ b/tools/qemu_args.sh @@ -44,6 +44,7 @@ QEMU_ARGS="\ -device virtio-keyboard-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-net-pci,netdev=net01,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-serial-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ + -device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3$IOMMU_DEV_EXTRA \ -device virtconsole,chardev=mux \ $IOMMU_EXTRA_ARGS \ " @@ -59,6 +60,7 @@ MICROVM_QEMU_ARGS="\ -device virtio-net-device,netdev=net01 \ -device virtio-serial-device \ -device virtconsole,chardev=mux \ + -device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3 \ " if [ "$1" = "microvm" ]; then