diff --git a/Makefile b/Makefile index 09d0b757f..d78f91ea9 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,9 @@ CARGO_OSDK_ARGS += --init-args="/opt/syscall_test/run_syscall_test.sh" else ifeq ($(AUTO_TEST), regression) CARGO_OSDK_ARGS += --init-args="/regression/run_regression_test.sh" else ifeq ($(AUTO_TEST), boot) -CARGO_OSDK_ARGS += --init-args="/regression/boot_hello.sh" +CARGO_OSDK_ARGS += --init_args="/regression/boot_hello.sh" +else ifeq ($(AUTO_TEST), vsock) +CARGO_OSDK_ARGS += --init_args="/regression/run_vsock_test.sh" endif ifeq ($(RELEASE_LTO), 1) @@ -68,16 +70,6 @@ ifeq ($(ENABLE_KVM), 1) CARGO_OSDK_ARGS += --qemu-args="--enable-kvm" endif -ifeq ($(VSOCK),1) -ifeq ($(QEMU_MACHINE), microvm) -CARGO_OSDK_ARGS += --qumu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3" -else ifeq ($(EMULATE_IOMMU), 1) -CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off,iommu_platform=on,ats=on" -else -CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off" -endif -endif - # Pass make variables to all subdirectory makes export @@ -152,6 +144,8 @@ else ifeq ($(AUTO_TEST), regression) @tail --lines 100 qemu.log | grep -q "^All regression tests passed." || (echo "Regression test failed" && exit 1) else ifeq ($(AUTO_TEST), boot) @tail --lines 100 qemu.log | grep -q "^Successfully booted." || (echo "Boot test failed" && exit 1) +else ifeq ($(AUTO_TEST), vsock) + @tail --lines 100 qemu.log | grep -q "^Vsock test passed." || (echo "Vsock test failed" && exit 1) endif gdb_server: initramfs $(CARGO_OSDK) diff --git a/kernel/aster-nix/Cargo.toml b/kernel/aster-nix/Cargo.toml index 1858077e4..679a39350 100644 --- a/kernel/aster-nix/Cargo.toml +++ b/kernel/aster-nix/Cargo.toml @@ -6,8 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -# FIXME: used for test in driver mod -component = {path="../libs/comp-sys/component"} aster-frame = { path = "../../framework/aster-frame" } align_ext = { path = "../../framework/libs/align_ext" } pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" } diff --git a/kernel/aster-nix/src/driver/mod.rs b/kernel/aster-nix/src/driver/mod.rs index 186f981ad..750803b5e 100644 --- a/kernel/aster-nix/src/driver/mod.rs +++ b/kernel/aster-nix/src/driver/mod.rs @@ -1,98 +1,10 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_virtio::{ - self, - device::socket::{header::VsockAddr, manager::VsockConnectionManager, DEVICE_NAME}, -}; -use component::ComponentInitError; -use log::{debug, info}; +use log::info; pub fn init() { // print all the input device to make sure input crate will compile for (name, _) in aster_input::all_devices() { info!("Found Input device, name:{}", name); } - // let _ = socket_device_client_test(); - // let _ = socket_device_server_test(); -} - -fn socket_device_client_test() -> Result<(), ComponentInitError> { - let host_cid = 2; - let guest_cid = 3; - let host_port = 1234; - let guest_port = 4321; - let host_address = VsockAddr { - cid: host_cid, - port: host_port, - }; - let hello_from_guest = "Hello from guest"; - let hello_from_host = "Hello from host"; - - let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); - assert_eq!(device.lock().guest_cid(), guest_cid); - let mut socket = VsockConnectionManager::new(device); - - socket.connect(host_address, guest_port).unwrap(); - socket.wait_for_event().unwrap(); // wait for connect response - socket - .send(host_address, guest_port, hello_from_guest.as_bytes()) - .unwrap(); - debug!( - "The buffer {:?} is sent, start receiving", - hello_from_guest.as_bytes() - ); - socket.wait_for_event().unwrap(); // wait for recv - let mut buffer = [0u8; 64]; - let event = socket.recv(host_address, guest_port, &mut buffer).unwrap(); - assert_eq!( - &buffer[0..hello_from_host.len()], - hello_from_host.as_bytes() - ); - - socket.force_close(host_address, guest_port).unwrap(); - - debug!("The final event: {:?}", event); - Ok(()) -} - -pub fn socket_device_server_test() -> Result<(), ComponentInitError> { - let host_cid = 2; - let guest_cid = 3; - let host_port = 1234; - let guest_port = 4321; - let host_address = VsockAddr { - cid: host_cid, - port: host_port, - }; - let hello_from_guest = "Hello from guest"; - let hello_from_host = "Hello from host"; - - let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); - assert_eq!(device.lock().guest_cid(), guest_cid); - let mut socket = VsockConnectionManager::new(device); - - socket.listen(4321); - socket.wait_for_event().unwrap(); // wait for connect request - socket.wait_for_event().unwrap(); // wait for recv - let mut buffer = [0u8; 64]; - let event = socket.recv(host_address, guest_port, &mut buffer).unwrap(); - assert_eq!( - &buffer[0..hello_from_host.len()], - hello_from_host.as_bytes() - ); - - debug!( - "The buffer {:?} is received, start sending {:?}", - &buffer[0..hello_from_host.len()], - hello_from_guest.as_bytes() - ); - socket - .send(host_address, guest_port, hello_from_guest.as_bytes()) - .unwrap(); - - socket.shutdown(host_address, guest_port).unwrap(); - let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown - - debug!("The final event: {:?}", event); - Ok(()) } diff --git a/kernel/aster-nix/src/net/mod.rs b/kernel/aster-nix/src/net/mod.rs index 72054353b..6a7d4b665 100644 --- a/kernel/aster-nix/src/net/mod.rs +++ b/kernel/aster-nix/src/net/mod.rs @@ -2,7 +2,7 @@ use spin::Once; -use self::iface::spawn_background_poll_thread; +use self::{iface::spawn_background_poll_thread, socket::vsock}; use crate::{ net::iface::{Iface, IfaceLoopback, IfaceVirtio}, prelude::*, @@ -28,6 +28,7 @@ pub fn init() { }) } poll_ifaces(); + vsock::init(); } /// Lazy init should be called after spawning init thread. diff --git a/kernel/aster-nix/src/net/socket/mod.rs b/kernel/aster-nix/src/net/socket/mod.rs index 64f7a3a09..f2d875bb8 100644 --- a/kernel/aster-nix/src/net/socket/mod.rs +++ b/kernel/aster-nix/src/net/socket/mod.rs @@ -13,6 +13,7 @@ pub mod ip; pub mod options; pub mod unix; mod util; +pub mod vsock; /// Operations defined on a socket. pub trait Socket: FileLike + Send + Sync { diff --git a/kernel/aster-nix/src/net/socket/util/socket_addr.rs b/kernel/aster-nix/src/net/socket/util/socket_addr.rs index 8373b45ba..e7a41f722 100644 --- a/kernel/aster-nix/src/net/socket/util/socket_addr.rs +++ b/kernel/aster-nix/src/net/socket/util/socket_addr.rs @@ -15,6 +15,7 @@ pub enum SocketAddr { Unix(UnixSocketAddr), IPv4(Ipv4Address, PortNum), IPv6, + Vsock(u32, u32), } impl TryFrom for IpEndpoint { diff --git a/kernel/aster-nix/src/net/socket/vsock/addr.rs b/kernel/aster-nix/src/net/socket/vsock/addr.rs new file mode 100644 index 000000000..d6d5d7415 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/addr.rs @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MPL-2.0 + +use aster_virtio::device::socket::header::VsockAddr; + +use crate::{net::socket::SocketAddr, prelude::*}; + +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct VsockSocketAddr { + pub cid: u32, + pub port: u32, +} + +impl VsockSocketAddr { + pub fn new(cid: u32, port: u32) -> Self { + Self { cid, port } + } + + pub fn any_addr() -> Self { + Self { + cid: VMADDR_CID_ANY, + port: VMADDR_PORT_ANY, + } + } +} + +impl TryFrom for VsockSocketAddr { + type Error = Error; + + fn try_from(value: SocketAddr) -> Result { + let (cid, port) = if let SocketAddr::Vsock(cid, port) = value { + (cid, port) + } else { + return_errno_with_message!(Errno::EINVAL, "invalid vsock socket addr"); + }; + Ok(Self { cid, port }) + } +} + +impl From for SocketAddr { + fn from(value: VsockSocketAddr) -> Self { + SocketAddr::Vsock(value.cid, value.port) + } +} + +impl From for VsockSocketAddr { + fn from(value: VsockAddr) -> Self { + VsockSocketAddr { + cid: value.cid as u32, + port: value.port, + } + } +} + +impl From for VsockAddr { + fn from(value: VsockSocketAddr) -> Self { + VsockAddr { + cid: value.cid as u64, + port: value.port, + } + } +} + +/// The vSocket equivalent of INADDR_ANY. +pub const VMADDR_CID_ANY: u32 = u32::MAX; +/// Use this as the destination CID in an address when referring to the local communication (loopback). +/// This was VMADDR_CID_RESERVED +pub const VMADDR_CID_LOCAL: u32 = 1; +/// Use this as the destination CID in an address when referring to the host (any process other than the hypervisor). +pub const VMADDR_CID_HOST: u32 = 2; +/// Bind to any available port. +pub const VMADDR_PORT_ANY: u32 = u32::MAX; diff --git a/kernel/aster-nix/src/net/socket/vsock/common.rs b/kernel/aster-nix/src/net/socket/vsock/common.rs new file mode 100644 index 000000000..be6a327f0 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/common.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::collections::BTreeSet; + +use aster_virtio::device::socket::{ + connect::{VsockEvent, VsockEventType}, + device::SocketDevice, + error::SocketError, + get_device, DEVICE_NAME, +}; + +use super::{ + addr::VsockSocketAddr, + stream::{ + connected::{Connected, ConnectionID}, + listen::Listen, + }, +}; +use crate::{events::IoEvents, prelude::*, return_errno_with_message}; + +/// Manage all active sockets +pub struct VsockSpace { + pub driver: Arc>, + // (key, value) = (local_addr, connecting) + pub connecting_sockets: SpinLock>>, + // (key, value) = (local_addr, listen) + pub listen_sockets: SpinLock>>, + // (key, value) = (id(local_addr,peer_addr), connected) + pub connected_sockets: SpinLock>>, + // Used ports + pub used_ports: SpinLock>, +} + +impl VsockSpace { + /// Create a new global VsockSpace + pub fn new() -> Self { + let driver = get_device(DEVICE_NAME).unwrap(); + Self { + driver, + connecting_sockets: SpinLock::new(BTreeMap::new()), + listen_sockets: SpinLock::new(BTreeMap::new()), + connected_sockets: SpinLock::new(BTreeMap::new()), + used_ports: SpinLock::new(BTreeSet::new()), + } + } + /// Poll for each event from the driver + pub fn poll(&self) -> Result> { + let mut driver = self.driver.lock_irq_disabled(); + let guest_cid: u32 = driver.guest_cid() as u32; + + // match the socket and store the buffer body (if valid) + let result = driver + .poll(|event, body| { + if !self.is_event_for_socket(&event) { + return Ok(None); + } + + // Deal with Received before the buffer are recycled. + if let VsockEventType::Received { length } = event.event_type { + // Only consider the connected socket and copy body to buffer + if let Some(connected) = self + .connected_sockets + .lock_irq_disabled() + .get(&event.into()) + { + debug!("Rw matches a connection with id {:?}", connected.id()); + if !connected.connection_buffer_add(body) { + return Err(SocketError::BufferTooShort); + } + } else { + return Ok(None); + } + } + Ok(Some(event)) + }) + .map_err(|e| { + Error::with_message(Errno::EAGAIN, "driver poll failed, please try again") + })?; + + let Some(event) = result else { + return Ok(None); + }; + + // The socket must be stored in the VsockSpace. + if let Some(connected) = self + .connected_sockets + .lock_irq_disabled() + .get(&event.into()) + { + connected.update_for_event(&event); + } + + // Response to the event + match event.event_type { + VsockEventType::ConnectionRequest => { + // Preparation for listen socket `accept` + if let Some(listen) = self + .listen_sockets + .lock_irq_disabled() + .get(&event.destination.into()) + { + let peer = event.source; + let connected = Arc::new(Connected::new(peer.into(), listen.addr())); + connected.update_for_event(&event); + listen.push_incoming(connected).unwrap(); + } else { + return_errno_with_message!( + Errno::EINVAL, + "Connecion request can only be handled by listening socket" + ) + } + } + VsockEventType::Connected => { + if let Some(connecting) = self + .connecting_sockets + .lock_irq_disabled() + .get(&event.destination.into()) + { + // debug!("match a connecting socket. Peer{:?}; local{:?}",connecting.peer_addr(),connecting.local_addr()); + connecting.update_for_event(&event); + connecting.add_events(IoEvents::IN); + } + } + VsockEventType::Disconnected { reason } => { + if let Some(connected) = self + .connected_sockets + .lock_irq_disabled() + .get(&event.into()) + { + connected.peer_requested_shutdown(); + } else { + return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected"); + } + } + VsockEventType::Received { length } => { + if let Some(connected) = self + .connected_sockets + .lock_irq_disabled() + .get(&event.into()) + { + connected.add_events(IoEvents::IN); + } else { + return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected"); + } + } + VsockEventType::CreditRequest => { + if let Some(connected) = self + .connected_sockets + .lock_irq_disabled() + .get(&event.into()) + { + driver.credit_update(&connected.get_info()).map_err(|_| { + Error::with_message(Errno::EINVAL, "can not send credit update") + })?; + } + } + VsockEventType::CreditUpdate => { + if let Some(connected) = self + .connected_sockets + .lock_irq_disabled() + .get(&event.into()) + { + connected.update_for_event(&event); + } else { + return_errno_with_message!( + Errno::EINVAL, + "CreditUpdate is only valid in connected sockets" + ) + } + } + } + Ok(Some(event)) + } + /// Check whether the event is for this socket space + fn is_event_for_socket(&self, event: &VsockEvent) -> bool { + // debug!("The event is for connection with id {:?}",ConnectionID::from(*event)); + self.connecting_sockets + .lock_irq_disabled() + .contains_key(&event.destination.into()) + || self + .listen_sockets + .lock_irq_disabled() + .contains_key(&event.destination.into()) + || self + .connected_sockets + .lock_irq_disabled() + .contains_key(&(*event).into()) + } + /// Alloc an unused port range + pub fn alloc_ephemeral_port(&self) -> Result { + let mut used_ports = self.used_ports.lock_irq_disabled(); + for port in 1024..=u32::MAX { + if !used_ports.contains(&port) { + used_ports.insert(port); + return Ok(port); + } + } + return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port"); + } +} + +impl Default for VsockSpace { + fn default() -> Self { + Self::new() + } +} diff --git a/kernel/aster-nix/src/net/socket/vsock/mod.rs b/kernel/aster-nix/src/net/socket/vsock/mod.rs new file mode 100644 index 000000000..3a03761e0 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/mod.rs @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::sync::Arc; + +use aster_virtio::device::socket::{register_recv_callback, DEVICE_NAME}; +use common::VsockSpace; +use spin::Once; + +pub mod addr; +pub mod common; +pub mod stream; +pub use stream::VsockStreamSocket; + +// init static driver +pub static VSOCK_GLOBAL: Once> = Once::new(); + +pub fn init() { + VSOCK_GLOBAL.call_once(|| Arc::new(VsockSpace::new())); + register_recv_callback(DEVICE_NAME, || { + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + let _ = vsockspace.poll(); + }) +} diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs b/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs new file mode 100644 index 000000000..c8b211962 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::boxed::Box; +use core::cmp::min; + +use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent}; + +use crate::{ + events::IoEvents, + net::socket::{ + vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, + SendRecvFlags, SockShutdownCmd, + }, + prelude::*, + process::signal::{Pollee, Poller}, +}; + +const PER_CONNECTION_BUFFER_CAPACITY: usize = 4096; + +pub struct Connected { + connection: SpinLock, + id: ConnectionID, + pollee: Pollee, +} + +impl Connected { + pub fn new(peer_addr: VsockSocketAddr, local_addr: VsockSocketAddr) -> Self { + Self { + connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)), + id: ConnectionID::new(local_addr, peer_addr), + pollee: Pollee::new(IoEvents::empty()), + } + } + pub fn peer_addr(&self) -> VsockSocketAddr { + self.id.peer_addr + } + + pub fn local_addr(&self) -> VsockSocketAddr { + self.id.local_addr + } + + pub fn id(&self) -> ConnectionID { + self.id + } + + pub fn recv(&self, buf: &mut [u8]) -> Result { + let poller = Poller::new(); + if !self + .poll(IoEvents::IN, Some(&poller)) + .contains(IoEvents::IN) + { + poller.wait()?; + } + + let mut connection = self.connection.lock_irq_disabled(); + let bytes_read = connection.buffer.drain(buf); + + connection.info.done_forwarding(bytes_read); + + Ok(bytes_read) + } + + pub fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + let mut connection = self.connection.lock_irq_disabled(); + debug_assert!(flags.is_all_supported()); + let buf_len = buf.len(); + VSOCK_GLOBAL + .get() + .unwrap() + .driver + .lock_irq_disabled() + .send(buf, &mut connection.info) + .map_err(|e| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?; + Ok(buf_len) + } + + pub fn should_close(&self) -> bool { + let connection = self.connection.lock_irq_disabled(); + // If buffer is now empty and the peer requested shutdown, finish shutting down the + // connection. + connection.peer_requested_shutdown && connection.buffer.is_empty() + } + + pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { + let connection = self.connection.lock_irq_disabled(); + // TODO: deal with cmd + if self.should_close() { + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + vsockspace + .driver + .lock_irq_disabled() + .reset(&connection.info) + .map_err(|e| Error::with_message(Errno::ENOMEM, "can not send close packet"))?; + vsockspace + .connected_sockets + .lock_irq_disabled() + .remove(&self.id()) + .unwrap(); + } + Ok(()) + } + pub fn update_for_event(&self, event: &VsockEvent) { + let mut connection = self.connection.lock_irq_disabled(); + connection.update_for_event(event) + } + + pub fn get_info(&self) -> ConnectionInfo { + let connection = self.connection.lock_irq_disabled(); + connection.info.clone() + } + + pub fn connection_buffer_add(&self, bytes: &[u8]) -> bool { + let mut connection = self.connection.lock_irq_disabled(); + self.add_events(IoEvents::IN); + connection.add(bytes) + } + + pub fn peer_requested_shutdown(&self) { + self.connection.lock_irq_disabled().peer_requested_shutdown = true + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.pollee.poll(mask, poller) + } + pub fn add_events(&self, events: IoEvents) { + self.pollee.add_events(events) + } +} + +impl Drop for Connected { + fn drop(&mut self) { + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + vsockspace + .used_ports + .lock_irq_disabled() + .remove(&self.local_addr().port); + } +} + +#[derive(Debug)] +pub struct Connection { + info: ConnectionInfo, + buffer: RingBuffer, + /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is + /// still data in the buffer. + pub peer_requested_shutdown: bool, +} + +impl Connection { + pub fn new(peer: VsockSocketAddr, local_port: u32) -> Self { + let mut info = ConnectionInfo::new(peer.into(), local_port); + info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); + Self { + info, + buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), + peer_requested_shutdown: false, + } + } + pub fn update_for_event(&mut self, event: &VsockEvent) { + self.info.update_for_event(event) + } + pub fn add(&mut self, bytes: &[u8]) -> bool { + self.buffer.add(bytes) + } +} + +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub struct ConnectionID { + pub local_addr: VsockSocketAddr, + pub peer_addr: VsockSocketAddr, +} +impl ConnectionID { + pub fn new(local_addr: VsockSocketAddr, peer_addr: VsockSocketAddr) -> Self { + Self { + local_addr, + peer_addr, + } + } +} + +impl From for ConnectionID { + fn from(event: VsockEvent) -> Self { + Self::new(event.destination.into(), event.source.into()) + } +} + +#[derive(Debug)] +struct RingBuffer { + buffer: Box<[u8]>, + /// The number of bytes currently in the buffer. + used: usize, + /// The index of the first used byte in the buffer. + start: usize, +} +//TODO: ringbuf +impl RingBuffer { + pub fn new(capacity: usize) -> Self { + // TODO: can be optimized. + let temp = vec![0; capacity]; + Self { + // FIXME: if the capacity is excessive, elements move will be executed. + buffer: temp.into_boxed_slice(), + used: 0, + start: 0, + } + } + /// Returns the number of bytes currently used in the buffer. + pub fn used(&self) -> usize { + self.used + } + + /// Returns true iff there are currently no bytes in the buffer. + pub fn is_empty(&self) -> bool { + self.used == 0 + } + + /// Returns the number of bytes currently free in the buffer. + pub fn available(&self) -> usize { + self.buffer.len() - self.used + } + + /// Adds the given bytes to the buffer if there is enough capacity for them all. + /// + /// Returns true if they were added, or false if they were not. + pub fn add(&mut self, bytes: &[u8]) -> bool { + if bytes.len() > self.available() { + return false; + } + + // The index of the first available position in the buffer. + let first_available = (self.start + self.used) % self.buffer.len(); + // The number of bytes to copy from `bytes` to `buffer` between `first_available` and + // `buffer.len()`. + let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); + self.buffer[first_available..first_available + copy_length_before_wraparound] + .copy_from_slice(&bytes[0..copy_length_before_wraparound]); + if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { + self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); + } + self.used += bytes.len(); + + true + } + + /// Reads and removes as many bytes as possible from the buffer, up to the length of the given + /// buffer. + pub fn drain(&mut self, out: &mut [u8]) -> usize { + let bytes_read = min(self.used, out.len()); + + // The number of bytes to copy out between `start` and the end of the buffer. + let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); + // The number of bytes to copy out from the beginning of the buffer after wrapping around. + let read_after_wraparound = bytes_read + .checked_sub(read_before_wraparound) + .unwrap_or_default(); + + out[0..read_before_wraparound] + .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); + out[read_before_wraparound..bytes_read] + .copy_from_slice(&self.buffer[0..read_after_wraparound]); + + self.used -= bytes_read; + self.start = (self.start + bytes_read) % self.buffer.len(); + + bytes_read + } +} diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/init.rs b/kernel/aster-nix/src/net/socket/vsock/stream/init.rs new file mode 100644 index 000000000..47ccfc106 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/stream/init.rs @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{ + events::IoEvents, + net::socket::vsock::{ + addr::{VsockSocketAddr, VMADDR_CID_ANY, VMADDR_PORT_ANY}, + VSOCK_GLOBAL, + }, + prelude::*, + process::signal::{Pollee, Poller}, +}; + +pub struct Init { + bind_addr: SpinLock>, + pollee: Pollee, +} + +impl Init { + pub fn new() -> Self { + Self { + bind_addr: SpinLock::new(None), + pollee: Pollee::new(IoEvents::empty()), + } + } + + pub fn bind(&self, addr: VsockSocketAddr) -> Result<()> { + if self.bind_addr.lock().is_some() { + return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); + } + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + + // check correctness of cid + let local_cid = vsockspace.driver.lock_irq_disabled().guest_cid(); + if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid as u32 { + return_errno_with_message!(Errno::EADDRNOTAVAIL, "The cid in address is incorrect"); + } + let mut new_addr = addr; + new_addr.cid = local_cid as u32; + + // check and assign a port + if addr.port == VMADDR_PORT_ANY { + if let Ok(port) = vsockspace.alloc_ephemeral_port() { + new_addr.port = port; + } else { + return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port"); + } + } else if vsockspace + .used_ports + .lock_irq_disabled() + .contains(&new_addr.port) + { + return_errno_with_message!(Errno::EADDRNOTAVAIL, "the port in address is occupied"); + } else { + vsockspace + .used_ports + .lock_irq_disabled() + .insert(new_addr.port); + } + + //TODO: The privileged port isn't checked + *self.bind_addr.lock() = Some(new_addr); + Ok(()) + } + + pub fn is_bound(&self) -> bool { + self.bind_addr.lock().is_some() + } + + pub fn bound_addr(&self) -> Option { + *self.bind_addr.lock() + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.pollee.poll(mask, poller) + } + + pub fn add_events(&self, events: IoEvents) { + self.pollee.add_events(events) + } +} + +impl Drop for Init { + fn drop(&mut self) { + if let Some(addr) = *self.bind_addr.lock() { + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + vsockspace.used_ports.lock_irq_disabled().remove(&addr.port); + } + } +} + +impl Default for Init { + fn default() -> Self { + Self::new() + } +} diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs b/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs new file mode 100644 index 000000000..4a572f939 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::connected::Connected; +use crate::{ + events::IoEvents, + net::socket::vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, + prelude::*, + process::signal::{Pollee, Poller}, +}; +pub struct Listen { + addr: VsockSocketAddr, + pollee: Pollee, + backlog: usize, + incoming_connection: SpinLock>>, +} + +impl Listen { + pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self { + Self { + addr, + pollee: Pollee::new(IoEvents::empty()), + backlog, + incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)), + } + } + + pub fn addr(&self) -> VsockSocketAddr { + self.addr + } + pub fn push_incoming(&self, connect: Arc) -> Result<()> { + let mut incoming_connections = self.incoming_connection.lock_irq_disabled(); + if incoming_connections.len() >= self.backlog { + return_errno_with_message!(Errno::ENOMEM, "Queue in listenging socket is full") + } + incoming_connections.push_back(connect); + self.add_events(IoEvents::IN); + Ok(()) + } + pub fn accept(&self) -> Result> { + // block waiting connection if no existing connection. + let poller = Poller::new(); + if !self + .poll(IoEvents::IN, Some(&poller)) + .contains(IoEvents::IN) + { + poller.wait()?; + } + + let connection = self + .incoming_connection + .lock_irq_disabled() + .pop_front() + .unwrap(); + + Ok(connection) + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.pollee.poll(mask, poller) + } + pub fn add_events(&self, events: IoEvents) { + self.pollee.add_events(events) + } +} + +impl Drop for Listen { + fn drop(&mut self) { + VSOCK_GLOBAL + .get() + .unwrap() + .used_ports + .lock_irq_disabled() + .remove(&self.addr.port); + } +} diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/mod.rs b/kernel/aster-nix/src/net/socket/vsock/stream/mod.rs new file mode 100644 index 000000000..5fc3337a2 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/stream/mod.rs @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MPL-2.0 + +pub mod connected; +pub mod init; +pub mod listen; + +pub mod socket; +pub use socket::VsockStreamSocket; diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs b/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs new file mode 100644 index 000000000..2af956d03 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::{connected::Connected, init::Init, listen::Listen}; +use crate::{ + events::IoEvents, + fs::file_handle::FileLike, + net::socket::{ + vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, + SendRecvFlags, SockShutdownCmd, Socket, SocketAddr, + }, + prelude::*, + process::signal::Poller, +}; + +pub struct VsockStreamSocket(RwLock); + +impl VsockStreamSocket { + pub(super) fn new_from_init(init: Arc) -> Self { + Self(RwLock::new(Status::Init(init))) + } + + pub(super) fn new_from_listen(listen: Arc) -> Self { + Self(RwLock::new(Status::Listen(listen))) + } + + pub(super) fn new_from_connected(connected: Arc) -> Self { + Self(RwLock::new(Status::Connected(connected))) + } +} + +pub enum Status { + Init(Arc), + Listen(Arc), + Connected(Arc), +} + +impl VsockStreamSocket { + pub fn new() -> Self { + let init = Arc::new(Init::new()); + Self(RwLock::new(Status::Init(init))) + } +} + +impl FileLike for VsockStreamSocket { + fn as_socket(self: Arc) -> Option> { + Some(self) + } + + fn read(&self, buf: &mut [u8]) -> Result { + self.recvfrom(buf, SendRecvFlags::empty()) + .map(|(read_size, _)| read_size) + } + + fn write(&self, buf: &[u8]) -> Result { + self.sendto(buf, None, SendRecvFlags::empty()) + } + + fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + let inner = self.0.read(); + match &*inner { + Status::Init(init) => init.poll(mask, poller), + Status::Listen(listen) => listen.poll(mask, poller), + Status::Connected(connect) => connect.poll(mask, poller), + } + } +} + +impl Socket for VsockStreamSocket { + fn bind(&self, sockaddr: SocketAddr) -> Result<()> { + let addr = VsockSocketAddr::try_from(sockaddr)?; + let inner = self.0.read(); + match &*inner { + Status::Init(init) => init.bind(addr), + Status::Listen(_) | Status::Connected(_) => { + return_errno_with_message!( + Errno::EINVAL, + "cannot bind a listening or connected socket" + ) + } + } + } + + fn connect(&self, sockaddr: SocketAddr) -> Result<()> { + let init = match &*self.0.read() { + Status::Init(init) => init.clone(), + Status::Listen(_) => { + return_errno_with_message!(Errno::EINVAL, "The socket is listened"); + } + Status::Connected(_) => { + return_errno_with_message!(Errno::EINVAL, "The socket is connected"); + } + }; + let remote_addr = VsockSocketAddr::try_from(sockaddr)?; + let local_addr = init.bound_addr(); + + if let Some(addr) = local_addr { + if addr == remote_addr { + return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid"); + } + } else { + init.bind(VsockSocketAddr::any_addr())?; + } + + let connecting = Arc::new(Connected::new(remote_addr, init.bound_addr().unwrap())); + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + vsockspace + .connecting_sockets + .lock_irq_disabled() + .insert(connecting.local_addr(), connecting.clone()); + + // Send request + vsockspace + .driver + .lock_irq_disabled() + .request(&connecting.get_info()) + .map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?; + + // wait for response from driver + // TODO: add timeout + let poller = Poller::new(); + if !connecting + .poll(IoEvents::IN, Some(&poller)) + .contains(IoEvents::IN) + { + poller.wait()?; + } + + *self.0.write() = Status::Connected(connecting.clone()); + // move connecting socket map to connected sockmap + vsockspace + .connecting_sockets + .lock_irq_disabled() + .remove(&connecting.local_addr()) + .unwrap(); + vsockspace + .connected_sockets + .lock_irq_disabled() + .insert(connecting.id(), connecting); + + Ok(()) + } + + fn listen(&self, backlog: usize) -> Result<()> { + let init = match &*self.0.read() { + Status::Init(init) => init.clone(), + Status::Listen(_) => { + return_errno_with_message!(Errno::EINVAL, "The socket is already listened"); + } + Status::Connected(_) => { + return_errno_with_message!(Errno::EISCONN, "The socket is already connected"); + } + }; + let addr = init.bound_addr().ok_or(Error::with_message( + Errno::EINVAL, + "The socket is not bound", + ))?; + let listen = Arc::new(Listen::new(addr, backlog)); + *self.0.write() = Status::Listen(listen.clone()); + + // push listen socket into vsockspace + VSOCK_GLOBAL + .get() + .unwrap() + .listen_sockets + .lock_irq_disabled() + .insert(listen.addr(), listen); + + Ok(()) + } + + fn accept(&self) -> Result<(Arc, SocketAddr)> { + let listen = match &*self.0.read() { + Status::Listen(listen) => listen.clone(), + Status::Init(_) | Status::Connected(_) => { + return_errno_with_message!(Errno::EINVAL, "The socket is not listening"); + } + }; + let connected = listen.accept()?; + let peer_addr = connected.peer_addr(); + + VSOCK_GLOBAL + .get() + .unwrap() + .connected_sockets + .lock_irq_disabled() + .insert(connected.id(), connected.clone()); + + VSOCK_GLOBAL + .get() + .unwrap() + .driver + .lock_irq_disabled() + .response(&connected.get_info()) + .map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?; + + let socket = Arc::new(VsockStreamSocket::new_from_connected(connected)); + Ok((socket, peer_addr.into())) + } + + fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { + let inner = self.0.read(); + if let Status::Connected(connected) = &*inner { + let result = connected.shutdown(cmd); + if result.is_ok() { + let vsockspace = VSOCK_GLOBAL.get().unwrap(); + vsockspace + .used_ports + .lock_irq_disabled() + .remove(&connected.local_addr().port); + vsockspace + .connected_sockets + .lock_irq_disabled() + .remove(&connected.id()); + } + result + } else { + return_errno_with_message!(Errno::EINVAL, "The socket is not connected."); + } + } + + fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + let connected = match &*self.0.read() { + Status::Connected(connected) => connected.clone(), + Status::Init(_) | Status::Listen(_) => { + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); + } + }; + let read_size = connected.recv(buf)?; + let peer_addr = self.peer_addr()?; + // If buffer is now empty and the peer requested shutdown, finish shutting down the + // connection. + if connected.should_close() { + VSOCK_GLOBAL + .get() + .unwrap() + .driver + .lock_irq_disabled() + .reset(&connected.get_info()) + .map_err(|e| Error::with_message(Errno::EAGAIN, "can not send close packet"))?; + } + Ok((read_size, peer_addr)) + } + + fn sendto( + &self, + buf: &[u8], + remote: Option, + flags: SendRecvFlags, + ) -> Result { + debug_assert!(remote.is_none()); + if remote.is_some() { + return_errno_with_message!(Errno::EINVAL, "vsock should not provide remote addr"); + } + let inner = self.0.read(); + match &*inner { + Status::Connected(connected) => connected.send(buf, flags), + Status::Init(_) | Status::Listen(_) => { + return_errno_with_message!(Errno::EINVAL, "The socket is not connected"); + } + } + } + + fn addr(&self) -> Result { + let inner = self.0.read(); + let addr = match &*inner { + Status::Init(init) => init.bound_addr(), + Status::Listen(listen) => Some(listen.addr()), + Status::Connected(connected) => Some(connected.local_addr()), + }; + addr.map(Into::::into) + .ok_or(Error::with_message( + Errno::EINVAL, + "The socket does not bind to addr", + )) + } + + fn peer_addr(&self) -> Result { + let inner = self.0.read(); + if let Status::Connected(connected) = &*inner { + Ok(connected.peer_addr().into()) + } else { + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); + } + } +} + +impl Default for VsockStreamSocket { + fn default() -> Self { + Self::new() + } +} diff --git a/kernel/aster-nix/src/syscall/socket.rs b/kernel/aster-nix/src/syscall/socket.rs index 716917506..747d072c2 100644 --- a/kernel/aster-nix/src/syscall/socket.rs +++ b/kernel/aster-nix/src/syscall/socket.rs @@ -6,6 +6,7 @@ use crate::{ net::socket::{ ip::{DatagramSocket, StreamSocket}, unix::UnixStreamSocket, + vsock::VsockStreamSocket, }, prelude::*, util::net::{CSocketAddrFamily, Protocol, SockFlags, SockType, SOCK_TYPE_MASK}, @@ -35,6 +36,9 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result DatagramSocket::new(nonblocking) as Arc, + (CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM, _) => { + Arc::new(VsockStreamSocket::new()) as Arc + } _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"), }; let fd = { diff --git a/kernel/aster-nix/src/util/net/addr.rs b/kernel/aster-nix/src/util/net/addr.rs index 79c0bab5d..135c0ca78 100644 --- a/kernel/aster-nix/src/util/net/addr.rs +++ b/kernel/aster-nix/src/util/net/addr.rs @@ -55,6 +55,11 @@ pub fn read_socket_addr_from_user(addr: Vaddr, addr_len: usize) -> Result { + debug_assert!(addr_len >= core::mem::size_of::()); + let sock_addr_vm: CSocketAddrVm = read_val_from_user(addr)?; + SocketAddr::Vsock(sock_addr_vm.svm_cid, sock_addr_vm.svm_port) + } _ => { return_errno_with_message!(Errno::EAFNOSUPPORT, "cannot support address for the family") } @@ -89,6 +94,12 @@ pub fn write_socket_addr_to_user( write_size as i32 } SocketAddr::IPv6 => todo!(), + SocketAddr::Vsock(cid, port) => { + let vm_addr = CSocketAddrVm::new(*cid, *port); + let write_size = core::mem::size_of::(); + write_val_to_user(dest, &vm_addr)?; + write_size as i32 + } }; if addrlen_ptr != 0 { write_val_to_user(addrlen_ptr, &write_size)?; @@ -210,6 +221,34 @@ pub struct CSocketAddrInet6 { sin6_scope_id: u32, } +/// vm socket address +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod)] +pub struct CSocketAddrVm { + /// always [SaFamily::AF_VSOCK] + svm_family: u16, + /// always 0 + svm_reserved1: u16, + /// Port number in host byte order. + svm_port: u32, + /// Address in host byte order. + svm_cid: u32, + /// Pad to size of [SockAddr] structure (16 bytes), must be zero-filled + svm_zero: [u8; 4], +} + +impl CSocketAddrVm { + pub fn new(cid: u32, port: u32) -> Self { + Self { + svm_family: CSocketAddrFamily::AF_VSOCK as _, + svm_reserved1: 0, + svm_port: port, + svm_cid: cid, + svm_zero: [0u8; 4], + } + } +} + /// Address family. The definition is from https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/socket.h. #[repr(i32)] #[derive(Debug, Clone, Copy, TryFromInt, PartialEq, Eq)] diff --git a/kernel/comps/virtio/src/device/socket/buffer.rs b/kernel/comps/virtio/src/device/socket/buffer.rs index e3041616e..00e2e6f5f 100644 --- a/kernel/comps/virtio/src/device/socket/buffer.rs +++ b/kernel/comps/virtio/src/device/socket/buffer.rs @@ -2,10 +2,6 @@ use align_ext::AlignExt; use bytes::BytesMut; -use pod::Pod; - -use super::header::VirtioVsockHdr; -use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN; /// Buffer for receive packet #[derive(Debug)] @@ -38,22 +34,6 @@ impl RxBuffer { pub fn buf_mut(&mut self) -> &mut [u8] { &mut self.buf } - - /// Packet payload slice, which is inner buffer excluding VirtioVsockHdr. - pub fn packet(&self) -> &[u8] { - debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len()); - &self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len] - } - - /// Mutable packet payload slice. - pub fn packet_mut(&mut self) -> &mut [u8] { - debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len()); - &mut self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len] - } - - pub fn virtio_vsock_header(&self) -> VirtioVsockHdr { - VirtioVsockHdr::from_bytes(&self.buf[..VIRTIO_VSOCK_HDR_LEN]) - } } /// Buffer for transmit packet diff --git a/kernel/comps/virtio/src/device/socket/config.rs b/kernel/comps/virtio/src/device/socket/config.rs index 0172c2dfc..b931aef48 100644 --- a/kernel/comps/virtio/src/device/socket/config.rs +++ b/kernel/comps/virtio/src/device/socket/config.rs @@ -8,17 +8,14 @@ use pod::Pod; use crate::transport::VirtioTransport; bitflags! { - /// Vsock feature bits since v1.2 - /// If no feature bit is set, only stream socket type is supported. - /// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated. pub struct VsockFeatures: u64 { const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported. - const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is supported. + const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is not supported now. } } impl VsockFeatures { - pub fn support_features() -> Self { + pub const fn supported_features() -> Self { VsockFeatures::VIRTIO_VSOCK_F_STREAM } } @@ -32,9 +29,7 @@ pub struct VirtioVsockConfig { /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space, /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic. /// So we need to split the u64 guest_cid into two parts. - // read only pub guest_cid_low: u32, - // read only pub guest_cid_high: u32, } diff --git a/kernel/comps/virtio/src/device/socket/connect.rs b/kernel/comps/virtio/src/device/socket/connect.rs index 874cba6e3..c924489af 100644 --- a/kernel/comps/virtio/src/device/socket/connect.rs +++ b/kernel/comps/virtio/src/device/socket/connect.rs @@ -1,5 +1,30 @@ // SPDX-License-Identifier: MPL-2.0 +// Modified from vsock.rs in virtio-drivers project +// +// MIT License +// +// Copyright (c) 2022-2023 Ant Group +// Copyright (c) 2019-2020 rCore Developers +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// use log::debug; use super::{ @@ -7,7 +32,7 @@ use super::{ header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr}, }; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct VsockBufferStatus { pub buffer_allocation: u32, pub forward_count: u32, @@ -24,7 +49,7 @@ pub enum DisconnectReason { } /// Details of the type of an event received from a VirtIO socket. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum VsockEventType { /// The peer requests to establish a connection with us. ConnectionRequest, @@ -47,7 +72,7 @@ pub enum VsockEventType { } /// An event received from a VirtIO socket device. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct VsockEvent { /// The source of the event, i.e. the peer who sent it. pub source: VsockAddr, diff --git a/kernel/comps/virtio/src/device/socket/device.rs b/kernel/comps/virtio/src/device/socket/device.rs index bd34ddf40..a24ba3df8 100644 --- a/kernel/comps/virtio/src/device/socket/device.rs +++ b/kernel/comps/virtio/src/device/socket/device.rs @@ -17,7 +17,10 @@ use super::{ VsockDeviceIrqHandler, }; use crate::{ - device::{socket::register_device, VirtioDeviceError}, + device::{ + socket::{handle_recv_irq, register_device}, + VirtioDeviceError, + }, queue::{QueueError, VirtQueue}, transport::VirtioTransport, }; @@ -27,10 +30,10 @@ const QUEUE_RECV: u16 = 0; const QUEUE_SEND: u16 = 1; const QUEUE_EVENT: u16 = 2; -/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::(). +/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than `size_of::()`. const RX_BUFFER_SIZE: usize = 512; -/// Low-level driver for a Virtio socket device. +/// Vsock device driver pub struct SocketDevice { config: VirtioVsockConfig, guest_cid: u64, @@ -46,6 +49,7 @@ pub struct SocketDevice { } impl SocketDevice { + /// Create a new vsock device pub fn init(mut transport: Box) -> Result<(), VirtioDeviceError> { let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut()); debug!("virtio_vsock_config = {:?}", virtio_vsock_config); @@ -95,9 +99,8 @@ impl SocketDevice { } // Interrupt handler if vsock device receives some packet. - // TODO: This will be handled by vsock socket layer. fn handle_vsock_event(_: &TrapFrame) { - debug!("Packet received. This will be solved by socket layer"); + handle_recv_irq(super::DEVICE_NAME); } device @@ -119,28 +122,23 @@ impl SocketDevice { Ok(()) } - /// Returns the CID which has been assigned to this guest. + /// Return the CID which has been assigned to this guest. pub fn guest_cid(&self) -> u64 { self.guest_cid } - /// Sends a request to connect to the given destination. - /// - /// This returns as soon as the request is sent; you should wait until `poll` returns a - /// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection - /// before sending data. - pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { + /// Send a connection request + pub fn request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Request as u16, ..connection_info.new_header(self.guest_cid) }; - // Sends a header only packet to the TX queue to connect the device to the listening socket - // at the given destination. + self.send_packet_to_tx_queue(&header, &[]) } - /// Accepts the given connection from a peer. - pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { + /// Send a response to peer, if peer start a sending request + pub fn response(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Response as u16, ..connection_info.new_header(self.guest_cid) @@ -148,29 +146,7 @@ impl SocketDevice { self.send_packet_to_tx_queue(&header, &[]) } - /// Requests the peer to send us a credit update for the given connection. - fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { - let header = VirtioVsockHdr { - op: VirtioVsockOp::CreditRequest as u16, - ..connection_info.new_header(self.guest_cid) - }; - self.send_packet_to_tx_queue(&header, &[]) - } - - /// Tells the peer how much buffer space we have to receive data. - pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { - let header = VirtioVsockHdr { - op: VirtioVsockOp::CreditUpdate as u16, - ..connection_info.new_header(self.guest_cid) - }; - self.send_packet_to_tx_queue(&header, &[]) - } - - /// Requests to shut down the connection cleanly. - /// - /// This returns as soon as the request is sent; you should wait until `poll` returns a - /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the - /// shutdown. + /// Send a shutdown request pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Shutdown as u16, @@ -179,8 +155,8 @@ impl SocketDevice { self.send_packet_to_tx_queue(&header, &[]) } - /// Forcibly closes the connection without waiting for the peer. - pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { + /// Send a reset request to peer + pub fn reset(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Rst as u16, ..connection_info.new_header(self.guest_cid) @@ -188,16 +164,29 @@ impl SocketDevice { self.send_packet_to_tx_queue(&header, &[]) } + /// Request the peer to send the credit info to us + pub fn credit_request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditRequest as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Tell the peer our credit info + pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditUpdate as u16, + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + fn send_packet_to_tx_queue( &mut self, header: &VirtioVsockHdr, buffer: &[u8], ) -> Result<(), SocketError> { - // let (_token, _len) = self.send_queue.add_notify_wait_pop( - // &[header.as_bytes(), buffer], - // &mut [], - // )?; - let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?; if self.send_queue.should_notify() { @@ -211,8 +200,7 @@ impl SocketDevice { self.send_queue.pop_used()?; - // FORDEBUG - // debug!("buffer in send_packet_to_tx_queue: {:?}",buffer); + debug!("buffer in send_packet_to_tx_queue: {:?}", buffer); Ok(()) } @@ -221,13 +209,19 @@ impl SocketDevice { connection_info: &mut ConnectionInfo, buffer_len: usize, ) -> Result<(), SocketError> { + debug!("connectin info {:?}", connection_info); + debug!( + "peer free from peer: {:?}, buffer len : {:?}", + connection_info.peer_free(), + buffer_len + ); if connection_info.peer_free() as usize >= buffer_len { Ok(()) } else { // Request an update of the cached peer credit, if we haven't already done so, and tell // the caller to try again later. if !connection_info.has_pending_credit_request { - self.request_credit(connection_info)?; + self.credit_request(connection_info)?; connection_info.has_pending_credit_request = true; } Err(SocketError::InsufficientBufferSpaceInPeer) @@ -252,6 +246,36 @@ impl SocketDevice { self.send_packet_to_tx_queue(&header, buffer) } + /// Receive bytes from peer, returns the header + pub fn receive( + &mut self, + buffer: &mut [u8], + // connection_info: &mut ConnectionInfo, + ) -> Result { + let (token, len) = self.recv_queue.pop_used()?; + debug!( + "receive packet in rx_queue: token = {}, len = {}", + token, len + ); + let mut rx_buffer = self + .rx_buffers + .remove(token as usize) + .ok_or(QueueError::WrongToken)?; + rx_buffer.set_packet_len(RX_BUFFER_SIZE); + + let (header, payload) = read_header_and_body(rx_buffer.buf())?; + // The length written should be equal to len(header)+len(packet) + assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32); + debug!("Received packet {:?}. Op {:?}", header, header.op()); + debug!("body is {:?}", payload); + + assert!(buffer.len() >= payload.len()); + buffer[..payload.len()].copy_from_slice(payload); + + self.add_rx_buffer(rx_buffer, token)?; + Ok(header) + } + /// Polls the RX virtqueue for the next event, and calls the given handler function to handle it. pub fn poll( &mut self, @@ -261,39 +285,10 @@ impl SocketDevice { if !self.recv_queue.can_pop() { return Ok(None); } - let (token, len) = self.recv_queue.pop_used()?; + let mut body = RxBuffer::new(RX_BUFFER_SIZE); + let header = self.receive(body.buf_mut())?; - let mut buffer = self - .rx_buffers - .remove(token as usize) - .ok_or(QueueError::WrongToken)?; - - let header = buffer.virtio_vsock_header(); - // The length written should be equal to len(header)+len(packet) - assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32); - - buffer.set_packet_len(RX_BUFFER_SIZE); - - let head_result = read_header_and_body(buffer.buf()); - - let Ok((header, body)) = head_result else { - let ret = match head_result { - Err(e) => Err(e), - _ => Ok(None), //FIXME: this clause is never reached. - }; - self.add_rx_buffer(buffer, token)?; - return ret; - }; - - debug!("Received packet {:?}. Op {:?}", header, header.op()); - debug!("body is {:?}", body); - - let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body)); - - // reuse the buffer and give it back to recv_queue. - self.add_rx_buffer(buffer, token)?; - - result + VsockEvent::from_header(&header).and_then(|event| handler(event, body.buf())) } /// Add a used rx buffer to recv queue,@index is only to check the correctness @@ -310,7 +305,7 @@ impl SocketDevice { /// Negotiate features for the device specified bits 0~23 pub(crate) fn negotiate_features(features: u64) -> u64 { let device_features = VsockFeatures::from_bits_truncate(features); - let supported_features = VsockFeatures::support_features(); + let supported_features = VsockFeatures::supported_features(); let vsock_features = device_features & supported_features; debug!("features negotiated: {:?}", vsock_features); vsock_features.bits() diff --git a/kernel/comps/virtio/src/device/socket/error.rs b/kernel/comps/virtio/src/device/socket/error.rs index 4134d55b0..25b427e06 100644 --- a/kernel/comps/virtio/src/device/socket/error.rs +++ b/kernel/comps/virtio/src/device/socket/error.rs @@ -1,14 +1,36 @@ // SPDX-License-Identifier: MPL-2.0 -//! This file comes from virtio-drivers project -//! This module contains the error from the VirtIO socket driver. - +// Modified from error.rs in virtio-drivers project +// +// MIT License +// +// Copyright (c) 2022-2023 Ant Group +// Copyright (c) 2019-2020 rCore Developers +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// use core::{fmt, result}; use crate::queue::QueueError; /// The error type of VirtIO socket driver. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Debug)] pub enum SocketError { /// There is an existing connection. ConnectionExists, @@ -39,33 +61,18 @@ pub enum SocketError { /// Recycled a wrong buffer. RecycledWrongBuffer, /// Queue Error - QueueError(SocketQueueError), -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum SocketQueueError { - InvalidArgs, - BufferTooSmall, - NotReady, - AlreadyUsed, - WrongToken, -} - -impl From for SocketQueueError { - fn from(value: QueueError) -> Self { - match value { - QueueError::InvalidArgs => Self::InvalidArgs, - QueueError::BufferTooSmall => Self::BufferTooSmall, - QueueError::NotReady => Self::NotReady, - QueueError::AlreadyUsed => Self::AlreadyUsed, - QueueError::WrongToken => Self::WrongToken, - } - } + QueueError(QueueError), } impl From for SocketError { fn from(value: QueueError) -> Self { - Self::QueueError(SocketQueueError::from(value)) + Self::QueueError(value) + } +} + +impl From for SocketError { + fn from(_e: int_to_c_enum::TryFromIntError) -> Self { + Self::InvalidNumber } } diff --git a/kernel/comps/virtio/src/device/socket/header.rs b/kernel/comps/virtio/src/device/socket/header.rs index 702e372f0..96d7ae7c0 100644 --- a/kernel/comps/virtio/src/device/socket/header.rs +++ b/kernel/comps/virtio/src/device/socket/header.rs @@ -1,6 +1,32 @@ // SPDX-License-Identifier: MPL-2.0 +// Modified from protocol.rs in virtio-drivers project +// +// MIT License +// +// Copyright (c) 2022-2023 Ant Group +// Copyright (c) 2019-2020 rCore Developers +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// use bitflags::bitflags; +use int_to_c_enum::TryFromInt; use pod::Pod; use super::error::{self, SocketError}; @@ -64,7 +90,7 @@ impl VirtioVsockHdr { } pub fn op(&self) -> error::Result { - self.op.try_into() + VirtioVsockOp::try_from(self.op).map_err(|err| err.into()) } pub fn source(&self) -> VsockAddr { @@ -90,7 +116,7 @@ impl VirtioVsockHdr { } } -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, TryFromInt)] #[repr(u16)] #[allow(non_camel_case_types)] pub enum VirtioVsockOp { @@ -112,26 +138,6 @@ pub enum VirtioVsockOp { CreditRequest = 7, } -/// TODO: This could be optimized by upgrading [int_to_c_enum::TryFromIntError] to carrying the invalid int number -impl TryFrom for VirtioVsockOp { - type Error = SocketError; - - fn try_from(v: u16) -> Result { - let op = match v { - 0 => Self::Invalid, - 1 => Self::Request, - 2 => Self::Response, - 3 => Self::Rst, - 4 => Self::Shutdown, - 5 => Self::Rw, - 6 => Self::CreditUpdate, - 7 => Self::CreditRequest, - _ => return Err(SocketError::UnknownOperation(v)), - }; - Ok(op) - } -} - bitflags! { #[repr(C)] #[derive(Default, Pod)] diff --git a/kernel/comps/virtio/src/device/socket/manager.rs b/kernel/comps/virtio/src/device/socket/manager.rs deleted file mode 100644 index a4a58bf23..000000000 --- a/kernel/comps/virtio/src/device/socket/manager.rs +++ /dev/null @@ -1,359 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -use alloc::{boxed::Box, sync::Arc, vec, vec::Vec}; -use core::{cmp::min, hint::spin_loop}; - -use aster_frame::sync::SpinLock; -use log::debug; - -use super::{ - connect::{ConnectionInfo, DisconnectReason, VsockEvent, VsockEventType}, - device::SocketDevice, - header::VsockAddr, -}; -use crate::device::socket::error::SocketError; - -const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; - -pub struct VsockConnectionManager { - driver: Arc>, - connections: Vec, - listening_ports: Vec, -} - -impl VsockConnectionManager { - /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. - pub fn new(driver: Arc>) -> Self { - Self { - driver, - connections: Vec::new(), - listening_ports: Vec::new(), - } - } - - /// Returns the CID which has been assigned to this guest. - pub fn guest_cid(&self) -> u64 { - self.driver.lock().guest_cid() - } - - /// Allows incoming connections on the given port number. - pub fn listen(&mut self, port: u32) { - if !self.listening_ports.contains(&port) { - self.listening_ports.push(port); - } - } - - /// Stops allowing incoming connections on the given port number. - pub fn unlisten(&mut self, port: u32) { - self.listening_ports.retain(|p| *p != port) - } - /// Sends a request to connect to the given destination. - /// - /// This returns as soon as the request is sent; you should wait until `poll` returns a - /// `VsockEventType::Connected` event indicating that the peer has accepted the connection - /// before sending data. - pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> { - if self.connections.iter().any(|connection| { - connection.info.dst == destination && connection.info.src_port == src_port - }) { - return Err(SocketError::ConnectionExists); - } - - let new_connection = Connection::new(destination, src_port); - - self.driver.lock().connect(&new_connection.info)?; - debug!("Connection requested: {:?}", new_connection.info); - self.connections.push(new_connection); - Ok(()) - } - - /// Sends the buffer to the destination. - pub fn send( - &mut self, - destination: VsockAddr, - src_port: u32, - buffer: &[u8], - ) -> Result<(), SocketError> { - let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; - - self.driver.lock().send(buffer, &mut connection.info) - } - - /// Polls the vsock device to receive data or other updates. - pub fn poll(&mut self) -> Result, SocketError> { - let guest_cid = self.driver.lock().guest_cid(); - let connections = &mut self.connections; - - let result = self.driver.lock().poll(|event, body| { - let connection = get_connection_for_event(connections, &event, guest_cid); - - // Skip events which don't match any connection we know about, unless they are a - // connection request. - let connection = if let Some((_, connection)) = connection { - connection - } else if let VsockEventType::ConnectionRequest = event.event_type { - // If the requested connection already exists or the CID isn't ours, ignore it. - if connection.is_some() || event.destination.cid != guest_cid { - return Ok(None); - } - // Add the new connection to our list, at least for now. It will be removed again - // below if we weren't listening on the port. - connections.push(Connection::new(event.source, event.destination.port)); - connections.last_mut().unwrap() - } else { - return Ok(None); - }; - - // Update stored connection info. - connection.info.update_for_event(&event); - - if let VsockEventType::Received { length } = event.event_type { - // Copy to buffer - if !connection.buffer.add(body) { - return Err(SocketError::OutputBufferTooShort(length)); - } - } - - Ok(Some(event)) - })?; - - let Some(event) = result else { - return Ok(None); - }; - - // The connection must exist because we found it above in the callback. - let (connection_index, connection) = - get_connection_for_event(connections, &event, guest_cid).unwrap(); - - match event.event_type { - VsockEventType::ConnectionRequest => { - if self.listening_ports.contains(&event.destination.port) { - self.driver.lock().accept(&connection.info)?; - } else { - // Reject the connection request and remove it from our list. - self.driver.lock().force_close(&connection.info)?; - self.connections.swap_remove(connection_index); - - // No need to pass the request on to the client, as we've already rejected it. - return Ok(None); - } - } - VsockEventType::Connected => {} - VsockEventType::Disconnected { reason } => { - // Wait until client reads all data before removing connection. - if connection.buffer.is_empty() { - if reason == DisconnectReason::Shutdown { - self.driver.lock().force_close(&connection.info)?; - } - self.connections.swap_remove(connection_index); - } else { - connection.peer_requested_shutdown = true; - } - } - VsockEventType::Received { .. } => { - // Already copied the buffer in the callback above. - } - VsockEventType::CreditRequest => { - // If the peer requested credit, send an update. - self.driver.lock().credit_update(&connection.info)?; - // No need to pass the request on to the client, we've already handled it. - return Ok(None); - } - VsockEventType::CreditUpdate => {} - } - - Ok(Some(event)) - } - - /// Reads data received from the given connection. - pub fn recv( - &mut self, - peer: VsockAddr, - src_port: u32, - buffer: &mut [u8], - ) -> Result { - debug!("connections is {:?}", self.connections); - let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; - - // Copy from ring buffer - let bytes_read = connection.buffer.drain(buffer); - - connection.info.done_forwarding(bytes_read); - - // If buffer is now empty and the peer requested shutdown, finish shutting down the - // connection. - if connection.peer_requested_shutdown && connection.buffer.is_empty() { - self.driver.lock().force_close(&connection.info)?; - self.connections.swap_remove(connection_index); - } - - Ok(bytes_read) - } - - /// Blocks until we get some event from the vsock device. - pub fn wait_for_event(&mut self) -> Result { - loop { - if let Some(event) = self.poll()? { - return Ok(event); - } else { - spin_loop(); - } - } - } - - /// Requests to shut down the connection cleanly. - /// - /// This returns as soon as the request is sent; you should wait until `poll` returns a - /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the - /// shutdown. - pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> { - let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; - - self.driver.lock().shutdown(&connection.info) - } - - /// Forcibly closes the connection without waiting for the peer. - pub fn force_close( - &mut self, - destination: VsockAddr, - src_port: u32, - ) -> Result<(), SocketError> { - let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; - - self.driver.lock().force_close(&connection.info)?; - - self.connections.swap_remove(index); - Ok(()) - } -} - -/// Returns the connection from the given list matching the given peer address and local port, and -/// its index. -/// -/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list. -fn get_connection( - connections: &mut [Connection], - peer: VsockAddr, - local_port: u32, -) -> core::result::Result<(usize, &mut Connection), SocketError> { - connections - .iter_mut() - .enumerate() - .find(|(_, connection)| { - connection.info.dst == peer && connection.info.src_port == local_port - }) - .ok_or(SocketError::NotConnected) -} - -/// Returns the connection from the given list matching the event, if any, and its index. -fn get_connection_for_event<'a>( - connections: &'a mut [Connection], - event: &VsockEvent, - local_cid: u64, -) -> Option<(usize, &'a mut Connection)> { - connections - .iter_mut() - .enumerate() - .find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) -} - -#[derive(Debug)] -struct Connection { - info: ConnectionInfo, - buffer: RingBuffer, - /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is - /// still data in the buffer. - peer_requested_shutdown: bool, -} - -impl Connection { - fn new(peer: VsockAddr, local_port: u32) -> Self { - let mut info = ConnectionInfo::new(peer, local_port); - info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); - Self { - info, - buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), - peer_requested_shutdown: false, - } - } -} - -#[derive(Debug)] -struct RingBuffer { - buffer: Box<[u8]>, - /// The number of bytes currently in the buffer. - used: usize, - /// The index of the first used byte in the buffer. - start: usize, -} - -impl RingBuffer { - pub fn new(capacity: usize) -> Self { - Self { - // FIXME: if the capacity is excessive, elements move will be executed. - buffer: vec![0; capacity].into_boxed_slice(), - used: 0, - start: 0, - } - } - /// Returns the number of bytes currently used in the buffer. - pub fn used(&self) -> usize { - self.used - } - - /// Returns true iff there are currently no bytes in the buffer. - pub fn is_empty(&self) -> bool { - self.used == 0 - } - - /// Returns the number of bytes currently free in the buffer. - pub fn available(&self) -> usize { - self.buffer.len() - self.used - } - - /// Adds the given bytes to the buffer if there is enough capacity for them all. - /// - /// Returns true if they were added, or false if they were not. - pub fn add(&mut self, bytes: &[u8]) -> bool { - if bytes.len() > self.available() { - return false; - } - - // The index of the first available position in the buffer. - let first_available = (self.start + self.used) % self.buffer.len(); - // The number of bytes to copy from `bytes` to `buffer` between `first_available` and - // `buffer.len()`. - let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); - self.buffer[first_available..first_available + copy_length_before_wraparound] - .copy_from_slice(&bytes[0..copy_length_before_wraparound]); - if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { - self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); - } - self.used += bytes.len(); - - true - } - - /// Reads and removes as many bytes as possible from the buffer, up to the length of the given - /// buffer. - pub fn drain(&mut self, out: &mut [u8]) -> usize { - let bytes_read = min(self.used, out.len()); - - // The number of bytes to copy out between `start` and the end of the buffer. - let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); - // The number of bytes to copy out from the beginning of the buffer after wrapping around. - let read_after_wraparound = bytes_read - .checked_sub(read_before_wraparound) - .unwrap_or_default(); - - out[0..read_before_wraparound] - .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); - out[read_before_wraparound..bytes_read] - .copy_from_slice(&self.buffer[0..read_after_wraparound]); - - self.used -= bytes_read; - self.start = (self.start + bytes_read) % self.buffer.len(); - - bytes_read - } -} diff --git a/kernel/comps/virtio/src/device/socket/mod.rs b/kernel/comps/virtio/src/device/socket/mod.rs index bdc4c82d5..953c5a47a 100644 --- a/kernel/comps/virtio/src/device/socket/mod.rs +++ b/kernel/comps/virtio/src/device/socket/mod.rs @@ -4,7 +4,6 @@ use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; use aster_frame::sync::SpinLock; -use component::ComponentInitError; use spin::Once; use self::device::SocketDevice; @@ -14,50 +13,59 @@ pub mod connect; pub mod device; pub mod error; pub mod header; -pub mod manager; pub static DEVICE_NAME: &str = "Virtio-Vsock"; pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static; pub fn register_device(name: String, device: Arc>) { - COMPONENT + VSOCK_DEVICE_TABLE .get() .unwrap() - .vsock_device_table .lock() - .insert(name, device); + .insert(name, (Arc::new(SpinLock::new(Vec::new())), device)); } pub fn get_device(str: &str) -> Option>> { - let lock = COMPONENT.get().unwrap().vsock_device_table.lock(); - let device = lock.get(str)?; + let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock(); + let (_, device) = lock.get(str)?; Some(device.clone()) } pub fn all_devices() -> Vec<(String, Arc>)> { - let vsock_devs = COMPONENT.get().unwrap().vsock_device_table.lock(); + let vsock_devs = VSOCK_DEVICE_TABLE.get().unwrap().lock(); vsock_devs .iter() - .map(|(name, device)| (name.clone(), device.clone())) + .map(|(name, (_, device))| (name.clone(), device.clone())) .collect() } -static COMPONENT: Once = Once::new(); - -pub fn component_init() -> Result<(), ComponentInitError> { - let a = Component::init()?; - COMPONENT.call_once(|| a); - Ok(()) +pub fn register_recv_callback(name: &str, callback: impl VsockDeviceIrqHandler) { + let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock(); + let Some((callbacks, _)) = lock.get(name) else { + return; + }; + callbacks.lock().push(Arc::new(callback)); } -struct Component { - vsock_device_table: SpinLock>>>, -} - -impl Component { - pub fn init() -> Result { - Ok(Self { - vsock_device_table: SpinLock::new(BTreeMap::new()), - }) +pub fn handle_recv_irq(name: &str) { + let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock(); + let Some((callbacks, _)) = lock.get(name) else { + return; + }; + let callbacks = callbacks.clone(); + let lock = callbacks.lock(); + for callback in lock.iter() { + callback.call(()) } } + +pub fn init() { + VSOCK_DEVICE_TABLE.call_once(|| SpinLock::new(BTreeMap::new())); +} + +type VsockDeviceIrqHandlerListRef = Arc>>>; +type VsockDeviceRef = Arc>; + +pub static VSOCK_DEVICE_TABLE: Once< + SpinLock>, +> = Once::new(); diff --git a/kernel/comps/virtio/src/lib.rs b/kernel/comps/virtio/src/lib.rs index 513decc9c..f83393092 100644 --- a/kernel/comps/virtio/src/lib.rs +++ b/kernel/comps/virtio/src/lib.rs @@ -35,8 +35,8 @@ mod transport; fn virtio_component_init() -> Result<(), ComponentInitError> { // Find all devices and register them to the corresponding crate transport::init(); - // For vsock cmponent - socket::component_init()?; + // For vsock table static init + socket::init(); while let Some(mut transport) = pop_device_transport() { // Reset device transport.set_device_status(DeviceStatus::empty()).unwrap(); diff --git a/regression/apps/Makefile b/regression/apps/Makefile index dfb60041e..00c2a439b 100644 --- a/regression/apps/Makefile +++ b/regression/apps/Makefile @@ -31,6 +31,7 @@ TEST_APPS := \ pthread \ pty \ signal_c \ + vsock \ # The C head and source files of all the apps, excluding the downloaded mongoose files C_SOURCES := $(shell find . -type f \( -name "*.c" -or -name "*.h" \) ! -name "mongoose.c" ! -name "mongoose.h") diff --git a/regression/apps/scripts/run_vsock_test.sh b/regression/apps/scripts/run_vsock_test.sh new file mode 100644 index 000000000..ce2d1c56c --- /dev/null +++ b/regression/apps/scripts/run_vsock_test.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +# SPDX-License-Identifier: MPL-2.0 + +set -e + +VSOCK_DIR=/regression/vsock +cd ${VSOCK_DIR} + +echo "Start vsock test......" +# ./vsock_server +./vsock_client +echo "Vsock test passed." diff --git a/regression/apps/vsock/Makefile b/regression/apps/vsock/Makefile new file mode 100644 index 000000000..d428635d2 --- /dev/null +++ b/regression/apps/vsock/Makefile @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MPL-2.0 + +include ../test_common.mk + +EXTRA_C_FLAGS := diff --git a/regression/apps/vsock/vsock_client.c b/regression/apps/vsock/vsock_client.c new file mode 100644 index 000000000..1b260be3b --- /dev/null +++ b/regression/apps/vsock/vsock_client.c @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MPL-2.0 + +#include +#include +#include +#include +#include +#include + +#define PORT 1234 + +int main() +{ + int sock; + char *hello = "Hello from client"; + char buffer[1024] = { 0 }; + struct sockaddr_vm serv_addr; + + if ((sock = socket(AF_VSOCK, SOCK_STREAM, 0)) < 0) { + printf("\n Socket creation error\n"); + return -1; + } + printf("\n Create socket successfully!\n"); + serv_addr.svm_family = AF_VSOCK; + serv_addr.svm_cid = VMADDR_CID_HOST; + serv_addr.svm_port = PORT; + + if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < + 0) { + printf("\nConnection Failed \n"); + return -1; + } + printf("\n Socket connect successfully!\n"); + + // Send message to the server and receive the reply + if (send(sock, hello, strlen(hello), 0) < 0) { + printf("\nSend Failed\n"); + return -1; + } + printf("Hello message sent\n"); + if (read(sock, buffer, 1024) < 0) { + printf("\nRead Failed\n"); + return -1; + } + printf("Server: %s\n", buffer); + return 0; +} \ No newline at end of file diff --git a/regression/apps/vsock/vsock_server.c b/regression/apps/vsock/vsock_server.c new file mode 100644 index 000000000..73aaf861d --- /dev/null +++ b/regression/apps/vsock/vsock_server.c @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MPL-2.0 + +#include +#include +#include +#include +#include +#include + +#define CID 3 +#define PORT 4321 + +int main() +{ + int sock, new_sock; + char *hello = "Hello from client"; + char buffer[1024] = { 0 }; + struct sockaddr_vm serv_addr, client_addr; + int addrlen = sizeof(client_addr); + + if ((sock = socket(AF_VSOCK, SOCK_STREAM, 0)) < 0) { + printf("\n Socket creation error\n"); + return -1; + } + printf("\nCreate socket successfully\n"); + serv_addr.svm_family = AF_VSOCK; + serv_addr.svm_cid = CID; + serv_addr.svm_port = PORT; + + if (bind(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { + printf("\nBind Failed \n"); + return -1; + } + printf("\nBind socket successfully\n"); + + if (listen(sock, 3) < 0) { + printf("\nListen Failed\n"); + return -1; + } + printf("\nListen socket successfully\n"); + + if ((new_sock = accept(sock, (struct sockaddr *)&client_addr, + (socklen_t *)&addrlen)) < 0) { + printf("\nAccept Failed\n"); + return -1; + } + printf("\nAccept socket successfully\n"); + + // Send message to the server and receive the reply + if (read(new_sock, buffer, 1024) < 0) { + printf("\nRead Failed\n"); + return -1; + } + printf("Client: %s\n", buffer); + if (send(new_sock, hello, strlen(hello), 0) < 0) { + printf("\nSend Failed\n"); + return -1; + } + printf("Hello message sent\n"); + return 0; +} \ No newline at end of file diff --git a/regression/benchmark/README.md b/regression/benchmark/README.md index 808e69de8..f033e8696 100644 --- a/regression/benchmark/README.md +++ b/regression/benchmark/README.md @@ -55,3 +55,4 @@ export HOST_PORT=8888 iperf3 -s -B $HOST_ADDR -p $HOST_PORT -D # Start the server as a daemon iperf3 -c $HOST_ADDR -p $HOST_PORT # Start the client ``` +Note that [a variant of iperf3](https://github.com/stefano-garzarella/iperf-vsock) can measure the performance of `vsock`. \ No newline at end of file diff --git a/test_vsock/vsock_client.py b/test_vsock/vsock_client.py deleted file mode 100644 index 3dae3b8e7..000000000 --- a/test_vsock/vsock_client.py +++ /dev/null @@ -1,16 +0,0 @@ -import socket - -client_socket = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) -CID = socket.VMADDR_CID_HOST -PORT = 1234 -vm_cid = 3 -server_port = 4321 -client_socket.bind((CID, PORT)) -client_socket.connect((vm_cid, server_port)) - -client_socket.sendall(b'Hello from host') - -response = client_socket.recv(4096) -print(f'Received: {response.decode()}') - -client_socket.close() diff --git a/test_vsock/vsock_server.py b/test_vsock/vsock_server.py deleted file mode 100644 index 7b4baf8d4..000000000 --- a/test_vsock/vsock_server.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 - -import socket - -CID = socket.VMADDR_CID_HOST -PORT = 1234 - -s = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) -s.bind((CID, PORT)) -s.listen() -(conn, (remote_cid, remote_port)) = s.accept() - -print(f"Connection opened by cid={remote_cid} port={remote_port}") - -while True: - buf = conn.recv(64) - if not buf: - break - - print(f"Received bytes: {buf}") - - conn.send(b'Hello from host') \ No newline at end of file diff --git a/tools/qemu_args.sh b/tools/qemu_args.sh index 11695a793..89b5dbb41 100755 --- a/tools/qemu_args.sh +++ b/tools/qemu_args.sh @@ -60,7 +60,7 @@ MICROVM_QEMU_ARGS="\ -device virtio-net-device,netdev=net01 \ -device virtio-serial-device \ -device virtconsole,chardev=mux \ - -device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3 \ + -device vhost-vsock-device,guest-cid=3 \ " if [ "$1" = "microvm" ]; then