mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-14 07:46:48 +00:00
Implement vsock socket layer
This commit is contained in:
parent
83a7937334
commit
ad140cec3c
16
Makefile
16
Makefile
@ -33,7 +33,9 @@ CARGO_OSDK_ARGS += --init-args="/opt/syscall_test/run_syscall_test.sh"
|
|||||||
else ifeq ($(AUTO_TEST), regression)
|
else ifeq ($(AUTO_TEST), regression)
|
||||||
CARGO_OSDK_ARGS += --init-args="/regression/run_regression_test.sh"
|
CARGO_OSDK_ARGS += --init-args="/regression/run_regression_test.sh"
|
||||||
else ifeq ($(AUTO_TEST), boot)
|
else ifeq ($(AUTO_TEST), boot)
|
||||||
CARGO_OSDK_ARGS += --init-args="/regression/boot_hello.sh"
|
CARGO_OSDK_ARGS += --init_args="/regression/boot_hello.sh"
|
||||||
|
else ifeq ($(AUTO_TEST), vsock)
|
||||||
|
CARGO_OSDK_ARGS += --init_args="/regression/run_vsock_test.sh"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(RELEASE_LTO), 1)
|
ifeq ($(RELEASE_LTO), 1)
|
||||||
@ -68,16 +70,6 @@ ifeq ($(ENABLE_KVM), 1)
|
|||||||
CARGO_OSDK_ARGS += --qemu-args="--enable-kvm"
|
CARGO_OSDK_ARGS += --qemu-args="--enable-kvm"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(VSOCK),1)
|
|
||||||
ifeq ($(QEMU_MACHINE), microvm)
|
|
||||||
CARGO_OSDK_ARGS += --qumu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3"
|
|
||||||
else ifeq ($(EMULATE_IOMMU), 1)
|
|
||||||
CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off,iommu_platform=on,ats=on"
|
|
||||||
else
|
|
||||||
CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off"
|
|
||||||
endif
|
|
||||||
endif
|
|
||||||
|
|
||||||
# Pass make variables to all subdirectory makes
|
# Pass make variables to all subdirectory makes
|
||||||
export
|
export
|
||||||
|
|
||||||
@ -152,6 +144,8 @@ else ifeq ($(AUTO_TEST), regression)
|
|||||||
@tail --lines 100 qemu.log | grep -q "^All regression tests passed." || (echo "Regression test failed" && exit 1)
|
@tail --lines 100 qemu.log | grep -q "^All regression tests passed." || (echo "Regression test failed" && exit 1)
|
||||||
else ifeq ($(AUTO_TEST), boot)
|
else ifeq ($(AUTO_TEST), boot)
|
||||||
@tail --lines 100 qemu.log | grep -q "^Successfully booted." || (echo "Boot test failed" && exit 1)
|
@tail --lines 100 qemu.log | grep -q "^Successfully booted." || (echo "Boot test failed" && exit 1)
|
||||||
|
else ifeq ($(AUTO_TEST), vsock)
|
||||||
|
@tail --lines 100 qemu.log | grep -q "^Vsock test passed." || (echo "Vsock test failed" && exit 1)
|
||||||
endif
|
endif
|
||||||
|
|
||||||
gdb_server: initramfs $(CARGO_OSDK)
|
gdb_server: initramfs $(CARGO_OSDK)
|
||||||
|
@ -6,8 +6,6 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# FIXME: used for test in driver mod
|
|
||||||
component = {path="../libs/comp-sys/component"}
|
|
||||||
aster-frame = { path = "../../framework/aster-frame" }
|
aster-frame = { path = "../../framework/aster-frame" }
|
||||||
align_ext = { path = "../../framework/libs/align_ext" }
|
align_ext = { path = "../../framework/libs/align_ext" }
|
||||||
pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" }
|
pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" }
|
||||||
|
@ -1,98 +1,10 @@
|
|||||||
// SPDX-License-Identifier: MPL-2.0
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
use aster_virtio::{
|
use log::info;
|
||||||
self,
|
|
||||||
device::socket::{header::VsockAddr, manager::VsockConnectionManager, DEVICE_NAME},
|
|
||||||
};
|
|
||||||
use component::ComponentInitError;
|
|
||||||
use log::{debug, info};
|
|
||||||
|
|
||||||
pub fn init() {
|
pub fn init() {
|
||||||
// print all the input device to make sure input crate will compile
|
// print all the input device to make sure input crate will compile
|
||||||
for (name, _) in aster_input::all_devices() {
|
for (name, _) in aster_input::all_devices() {
|
||||||
info!("Found Input device, name:{}", name);
|
info!("Found Input device, name:{}", name);
|
||||||
}
|
}
|
||||||
// let _ = socket_device_client_test();
|
|
||||||
// let _ = socket_device_server_test();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn socket_device_client_test() -> Result<(), ComponentInitError> {
|
|
||||||
let host_cid = 2;
|
|
||||||
let guest_cid = 3;
|
|
||||||
let host_port = 1234;
|
|
||||||
let guest_port = 4321;
|
|
||||||
let host_address = VsockAddr {
|
|
||||||
cid: host_cid,
|
|
||||||
port: host_port,
|
|
||||||
};
|
|
||||||
let hello_from_guest = "Hello from guest";
|
|
||||||
let hello_from_host = "Hello from host";
|
|
||||||
|
|
||||||
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
|
|
||||||
assert_eq!(device.lock().guest_cid(), guest_cid);
|
|
||||||
let mut socket = VsockConnectionManager::new(device);
|
|
||||||
|
|
||||||
socket.connect(host_address, guest_port).unwrap();
|
|
||||||
socket.wait_for_event().unwrap(); // wait for connect response
|
|
||||||
socket
|
|
||||||
.send(host_address, guest_port, hello_from_guest.as_bytes())
|
|
||||||
.unwrap();
|
|
||||||
debug!(
|
|
||||||
"The buffer {:?} is sent, start receiving",
|
|
||||||
hello_from_guest.as_bytes()
|
|
||||||
);
|
|
||||||
socket.wait_for_event().unwrap(); // wait for recv
|
|
||||||
let mut buffer = [0u8; 64];
|
|
||||||
let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
&buffer[0..hello_from_host.len()],
|
|
||||||
hello_from_host.as_bytes()
|
|
||||||
);
|
|
||||||
|
|
||||||
socket.force_close(host_address, guest_port).unwrap();
|
|
||||||
|
|
||||||
debug!("The final event: {:?}", event);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn socket_device_server_test() -> Result<(), ComponentInitError> {
|
|
||||||
let host_cid = 2;
|
|
||||||
let guest_cid = 3;
|
|
||||||
let host_port = 1234;
|
|
||||||
let guest_port = 4321;
|
|
||||||
let host_address = VsockAddr {
|
|
||||||
cid: host_cid,
|
|
||||||
port: host_port,
|
|
||||||
};
|
|
||||||
let hello_from_guest = "Hello from guest";
|
|
||||||
let hello_from_host = "Hello from host";
|
|
||||||
|
|
||||||
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
|
|
||||||
assert_eq!(device.lock().guest_cid(), guest_cid);
|
|
||||||
let mut socket = VsockConnectionManager::new(device);
|
|
||||||
|
|
||||||
socket.listen(4321);
|
|
||||||
socket.wait_for_event().unwrap(); // wait for connect request
|
|
||||||
socket.wait_for_event().unwrap(); // wait for recv
|
|
||||||
let mut buffer = [0u8; 64];
|
|
||||||
let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
&buffer[0..hello_from_host.len()],
|
|
||||||
hello_from_host.as_bytes()
|
|
||||||
);
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
"The buffer {:?} is received, start sending {:?}",
|
|
||||||
&buffer[0..hello_from_host.len()],
|
|
||||||
hello_from_guest.as_bytes()
|
|
||||||
);
|
|
||||||
socket
|
|
||||||
.send(host_address, guest_port, hello_from_guest.as_bytes())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
socket.shutdown(host_address, guest_port).unwrap();
|
|
||||||
let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown
|
|
||||||
|
|
||||||
debug!("The final event: {:?}", event);
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
use spin::Once;
|
use spin::Once;
|
||||||
|
|
||||||
use self::iface::spawn_background_poll_thread;
|
use self::{iface::spawn_background_poll_thread, socket::vsock};
|
||||||
use crate::{
|
use crate::{
|
||||||
net::iface::{Iface, IfaceLoopback, IfaceVirtio},
|
net::iface::{Iface, IfaceLoopback, IfaceVirtio},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
@ -28,6 +28,7 @@ pub fn init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
poll_ifaces();
|
poll_ifaces();
|
||||||
|
vsock::init();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lazy init should be called after spawning init thread.
|
/// Lazy init should be called after spawning init thread.
|
||||||
|
@ -13,6 +13,7 @@ pub mod ip;
|
|||||||
pub mod options;
|
pub mod options;
|
||||||
pub mod unix;
|
pub mod unix;
|
||||||
mod util;
|
mod util;
|
||||||
|
pub mod vsock;
|
||||||
|
|
||||||
/// Operations defined on a socket.
|
/// Operations defined on a socket.
|
||||||
pub trait Socket: FileLike + Send + Sync {
|
pub trait Socket: FileLike + Send + Sync {
|
||||||
|
@ -15,6 +15,7 @@ pub enum SocketAddr {
|
|||||||
Unix(UnixSocketAddr),
|
Unix(UnixSocketAddr),
|
||||||
IPv4(Ipv4Address, PortNum),
|
IPv4(Ipv4Address, PortNum),
|
||||||
IPv6,
|
IPv6,
|
||||||
|
Vsock(u32, u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<SocketAddr> for IpEndpoint {
|
impl TryFrom<SocketAddr> for IpEndpoint {
|
||||||
|
71
kernel/aster-nix/src/net/socket/vsock/addr.rs
Normal file
71
kernel/aster-nix/src/net/socket/vsock/addr.rs
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use aster_virtio::device::socket::header::VsockAddr;
|
||||||
|
|
||||||
|
use crate::{net::socket::SocketAddr, prelude::*};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
|
pub struct VsockSocketAddr {
|
||||||
|
pub cid: u32,
|
||||||
|
pub port: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VsockSocketAddr {
|
||||||
|
pub fn new(cid: u32, port: u32) -> Self {
|
||||||
|
Self { cid, port }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn any_addr() -> Self {
|
||||||
|
Self {
|
||||||
|
cid: VMADDR_CID_ANY,
|
||||||
|
port: VMADDR_PORT_ANY,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<SocketAddr> for VsockSocketAddr {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn try_from(value: SocketAddr) -> Result<Self> {
|
||||||
|
let (cid, port) = if let SocketAddr::Vsock(cid, port) = value {
|
||||||
|
(cid, port)
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "invalid vsock socket addr");
|
||||||
|
};
|
||||||
|
Ok(Self { cid, port })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VsockSocketAddr> for SocketAddr {
|
||||||
|
fn from(value: VsockSocketAddr) -> Self {
|
||||||
|
SocketAddr::Vsock(value.cid, value.port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VsockAddr> for VsockSocketAddr {
|
||||||
|
fn from(value: VsockAddr) -> Self {
|
||||||
|
VsockSocketAddr {
|
||||||
|
cid: value.cid as u32,
|
||||||
|
port: value.port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VsockSocketAddr> for VsockAddr {
|
||||||
|
fn from(value: VsockSocketAddr) -> Self {
|
||||||
|
VsockAddr {
|
||||||
|
cid: value.cid as u64,
|
||||||
|
port: value.port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The vSocket equivalent of INADDR_ANY.
|
||||||
|
pub const VMADDR_CID_ANY: u32 = u32::MAX;
|
||||||
|
/// Use this as the destination CID in an address when referring to the local communication (loopback).
|
||||||
|
/// This was VMADDR_CID_RESERVED
|
||||||
|
pub const VMADDR_CID_LOCAL: u32 = 1;
|
||||||
|
/// Use this as the destination CID in an address when referring to the host (any process other than the hypervisor).
|
||||||
|
pub const VMADDR_CID_HOST: u32 = 2;
|
||||||
|
/// Bind to any available port.
|
||||||
|
pub const VMADDR_PORT_ANY: u32 = u32::MAX;
|
206
kernel/aster-nix/src/net/socket/vsock/common.rs
Normal file
206
kernel/aster-nix/src/net/socket/vsock/common.rs
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use alloc::collections::BTreeSet;
|
||||||
|
|
||||||
|
use aster_virtio::device::socket::{
|
||||||
|
connect::{VsockEvent, VsockEventType},
|
||||||
|
device::SocketDevice,
|
||||||
|
error::SocketError,
|
||||||
|
get_device, DEVICE_NAME,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
addr::VsockSocketAddr,
|
||||||
|
stream::{
|
||||||
|
connected::{Connected, ConnectionID},
|
||||||
|
listen::Listen,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use crate::{events::IoEvents, prelude::*, return_errno_with_message};
|
||||||
|
|
||||||
|
/// Manage all active sockets
|
||||||
|
pub struct VsockSpace {
|
||||||
|
pub driver: Arc<SpinLock<SocketDevice>>,
|
||||||
|
// (key, value) = (local_addr, connecting)
|
||||||
|
pub connecting_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Connected>>>,
|
||||||
|
// (key, value) = (local_addr, listen)
|
||||||
|
pub listen_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Listen>>>,
|
||||||
|
// (key, value) = (id(local_addr,peer_addr), connected)
|
||||||
|
pub connected_sockets: SpinLock<BTreeMap<ConnectionID, Arc<Connected>>>,
|
||||||
|
// Used ports
|
||||||
|
pub used_ports: SpinLock<BTreeSet<u32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VsockSpace {
|
||||||
|
/// Create a new global VsockSpace
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let driver = get_device(DEVICE_NAME).unwrap();
|
||||||
|
Self {
|
||||||
|
driver,
|
||||||
|
connecting_sockets: SpinLock::new(BTreeMap::new()),
|
||||||
|
listen_sockets: SpinLock::new(BTreeMap::new()),
|
||||||
|
connected_sockets: SpinLock::new(BTreeMap::new()),
|
||||||
|
used_ports: SpinLock::new(BTreeSet::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Poll for each event from the driver
|
||||||
|
pub fn poll(&self) -> Result<Option<VsockEvent>> {
|
||||||
|
let mut driver = self.driver.lock_irq_disabled();
|
||||||
|
let guest_cid: u32 = driver.guest_cid() as u32;
|
||||||
|
|
||||||
|
// match the socket and store the buffer body (if valid)
|
||||||
|
let result = driver
|
||||||
|
.poll(|event, body| {
|
||||||
|
if !self.is_event_for_socket(&event) {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with Received before the buffer are recycled.
|
||||||
|
if let VsockEventType::Received { length } = event.event_type {
|
||||||
|
// Only consider the connected socket and copy body to buffer
|
||||||
|
if let Some(connected) = self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.into())
|
||||||
|
{
|
||||||
|
debug!("Rw matches a connection with id {:?}", connected.id());
|
||||||
|
if !connected.connection_buffer_add(body) {
|
||||||
|
return Err(SocketError::BufferTooShort);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Some(event))
|
||||||
|
})
|
||||||
|
.map_err(|e| {
|
||||||
|
Error::with_message(Errno::EAGAIN, "driver poll failed, please try again")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let Some(event) = result else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
// The socket must be stored in the VsockSpace.
|
||||||
|
if let Some(connected) = self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.into())
|
||||||
|
{
|
||||||
|
connected.update_for_event(&event);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response to the event
|
||||||
|
match event.event_type {
|
||||||
|
VsockEventType::ConnectionRequest => {
|
||||||
|
// Preparation for listen socket `accept`
|
||||||
|
if let Some(listen) = self
|
||||||
|
.listen_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.destination.into())
|
||||||
|
{
|
||||||
|
let peer = event.source;
|
||||||
|
let connected = Arc::new(Connected::new(peer.into(), listen.addr()));
|
||||||
|
connected.update_for_event(&event);
|
||||||
|
listen.push_incoming(connected).unwrap();
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(
|
||||||
|
Errno::EINVAL,
|
||||||
|
"Connecion request can only be handled by listening socket"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VsockEventType::Connected => {
|
||||||
|
if let Some(connecting) = self
|
||||||
|
.connecting_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.destination.into())
|
||||||
|
{
|
||||||
|
// debug!("match a connecting socket. Peer{:?}; local{:?}",connecting.peer_addr(),connecting.local_addr());
|
||||||
|
connecting.update_for_event(&event);
|
||||||
|
connecting.add_events(IoEvents::IN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VsockEventType::Disconnected { reason } => {
|
||||||
|
if let Some(connected) = self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.into())
|
||||||
|
{
|
||||||
|
connected.peer_requested_shutdown();
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VsockEventType::Received { length } => {
|
||||||
|
if let Some(connected) = self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.into())
|
||||||
|
{
|
||||||
|
connected.add_events(IoEvents::IN);
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VsockEventType::CreditRequest => {
|
||||||
|
if let Some(connected) = self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.into())
|
||||||
|
{
|
||||||
|
driver.credit_update(&connected.get_info()).map_err(|_| {
|
||||||
|
Error::with_message(Errno::EINVAL, "can not send credit update")
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VsockEventType::CreditUpdate => {
|
||||||
|
if let Some(connected) = self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.get(&event.into())
|
||||||
|
{
|
||||||
|
connected.update_for_event(&event);
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(
|
||||||
|
Errno::EINVAL,
|
||||||
|
"CreditUpdate is only valid in connected sockets"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Some(event))
|
||||||
|
}
|
||||||
|
/// Check whether the event is for this socket space
|
||||||
|
fn is_event_for_socket(&self, event: &VsockEvent) -> bool {
|
||||||
|
// debug!("The event is for connection with id {:?}",ConnectionID::from(*event));
|
||||||
|
self.connecting_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.contains_key(&event.destination.into())
|
||||||
|
|| self
|
||||||
|
.listen_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.contains_key(&event.destination.into())
|
||||||
|
|| self
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.contains_key(&(*event).into())
|
||||||
|
}
|
||||||
|
/// Alloc an unused port range
|
||||||
|
pub fn alloc_ephemeral_port(&self) -> Result<u32> {
|
||||||
|
let mut used_ports = self.used_ports.lock_irq_disabled();
|
||||||
|
for port in 1024..=u32::MAX {
|
||||||
|
if !used_ports.contains(&port) {
|
||||||
|
used_ports.insert(port);
|
||||||
|
return Ok(port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for VsockSpace {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
23
kernel/aster-nix/src/net/socket/vsock/mod.rs
Normal file
23
kernel/aster-nix/src/net/socket/vsock/mod.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use alloc::sync::Arc;
|
||||||
|
|
||||||
|
use aster_virtio::device::socket::{register_recv_callback, DEVICE_NAME};
|
||||||
|
use common::VsockSpace;
|
||||||
|
use spin::Once;
|
||||||
|
|
||||||
|
pub mod addr;
|
||||||
|
pub mod common;
|
||||||
|
pub mod stream;
|
||||||
|
pub use stream::VsockStreamSocket;
|
||||||
|
|
||||||
|
// init static driver
|
||||||
|
pub static VSOCK_GLOBAL: Once<Arc<VsockSpace>> = Once::new();
|
||||||
|
|
||||||
|
pub fn init() {
|
||||||
|
VSOCK_GLOBAL.call_once(|| Arc::new(VsockSpace::new()));
|
||||||
|
register_recv_callback(DEVICE_NAME, || {
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
let _ = vsockspace.poll();
|
||||||
|
})
|
||||||
|
}
|
267
kernel/aster-nix/src/net/socket/vsock/stream/connected.rs
Normal file
267
kernel/aster-nix/src/net/socket/vsock/stream/connected.rs
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use alloc::boxed::Box;
|
||||||
|
use core::cmp::min;
|
||||||
|
|
||||||
|
use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
events::IoEvents,
|
||||||
|
net::socket::{
|
||||||
|
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||||
|
SendRecvFlags, SockShutdownCmd,
|
||||||
|
},
|
||||||
|
prelude::*,
|
||||||
|
process::signal::{Pollee, Poller},
|
||||||
|
};
|
||||||
|
|
||||||
|
const PER_CONNECTION_BUFFER_CAPACITY: usize = 4096;
|
||||||
|
|
||||||
|
pub struct Connected {
|
||||||
|
connection: SpinLock<Connection>,
|
||||||
|
id: ConnectionID,
|
||||||
|
pollee: Pollee,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connected {
|
||||||
|
pub fn new(peer_addr: VsockSocketAddr, local_addr: VsockSocketAddr) -> Self {
|
||||||
|
Self {
|
||||||
|
connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)),
|
||||||
|
id: ConnectionID::new(local_addr, peer_addr),
|
||||||
|
pollee: Pollee::new(IoEvents::empty()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn peer_addr(&self) -> VsockSocketAddr {
|
||||||
|
self.id.peer_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn local_addr(&self) -> VsockSocketAddr {
|
||||||
|
self.id.local_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> ConnectionID {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
|
||||||
|
let poller = Poller::new();
|
||||||
|
if !self
|
||||||
|
.poll(IoEvents::IN, Some(&poller))
|
||||||
|
.contains(IoEvents::IN)
|
||||||
|
{
|
||||||
|
poller.wait()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut connection = self.connection.lock_irq_disabled();
|
||||||
|
let bytes_read = connection.buffer.drain(buf);
|
||||||
|
|
||||||
|
connection.info.done_forwarding(bytes_read);
|
||||||
|
|
||||||
|
Ok(bytes_read)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||||
|
let mut connection = self.connection.lock_irq_disabled();
|
||||||
|
debug_assert!(flags.is_all_supported());
|
||||||
|
let buf_len = buf.len();
|
||||||
|
VSOCK_GLOBAL
|
||||||
|
.get()
|
||||||
|
.unwrap()
|
||||||
|
.driver
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.send(buf, &mut connection.info)
|
||||||
|
.map_err(|e| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
|
||||||
|
Ok(buf_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn should_close(&self) -> bool {
|
||||||
|
let connection = self.connection.lock_irq_disabled();
|
||||||
|
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
||||||
|
// connection.
|
||||||
|
connection.peer_requested_shutdown && connection.buffer.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||||
|
let connection = self.connection.lock_irq_disabled();
|
||||||
|
// TODO: deal with cmd
|
||||||
|
if self.should_close() {
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
vsockspace
|
||||||
|
.driver
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.reset(&connection.info)
|
||||||
|
.map_err(|e| Error::with_message(Errno::ENOMEM, "can not send close packet"))?;
|
||||||
|
vsockspace
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.remove(&self.id())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub fn update_for_event(&self, event: &VsockEvent) {
|
||||||
|
let mut connection = self.connection.lock_irq_disabled();
|
||||||
|
connection.update_for_event(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_info(&self) -> ConnectionInfo {
|
||||||
|
let connection = self.connection.lock_irq_disabled();
|
||||||
|
connection.info.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connection_buffer_add(&self, bytes: &[u8]) -> bool {
|
||||||
|
let mut connection = self.connection.lock_irq_disabled();
|
||||||
|
self.add_events(IoEvents::IN);
|
||||||
|
connection.add(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peer_requested_shutdown(&self) {
|
||||||
|
self.connection.lock_irq_disabled().peer_requested_shutdown = true
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||||
|
self.pollee.poll(mask, poller)
|
||||||
|
}
|
||||||
|
pub fn add_events(&self, events: IoEvents) {
|
||||||
|
self.pollee.add_events(events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Connected {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
vsockspace
|
||||||
|
.used_ports
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.remove(&self.local_addr().port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Connection {
|
||||||
|
info: ConnectionInfo,
|
||||||
|
buffer: RingBuffer,
|
||||||
|
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
|
||||||
|
/// still data in the buffer.
|
||||||
|
pub peer_requested_shutdown: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection {
|
||||||
|
pub fn new(peer: VsockSocketAddr, local_port: u32) -> Self {
|
||||||
|
let mut info = ConnectionInfo::new(peer.into(), local_port);
|
||||||
|
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
|
||||||
|
Self {
|
||||||
|
info,
|
||||||
|
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||||
|
peer_requested_shutdown: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn update_for_event(&mut self, event: &VsockEvent) {
|
||||||
|
self.info.update_for_event(event)
|
||||||
|
}
|
||||||
|
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||||
|
self.buffer.add(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
|
||||||
|
pub struct ConnectionID {
|
||||||
|
pub local_addr: VsockSocketAddr,
|
||||||
|
pub peer_addr: VsockSocketAddr,
|
||||||
|
}
|
||||||
|
impl ConnectionID {
|
||||||
|
pub fn new(local_addr: VsockSocketAddr, peer_addr: VsockSocketAddr) -> Self {
|
||||||
|
Self {
|
||||||
|
local_addr,
|
||||||
|
peer_addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VsockEvent> for ConnectionID {
|
||||||
|
fn from(event: VsockEvent) -> Self {
|
||||||
|
Self::new(event.destination.into(), event.source.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RingBuffer {
|
||||||
|
buffer: Box<[u8]>,
|
||||||
|
/// The number of bytes currently in the buffer.
|
||||||
|
used: usize,
|
||||||
|
/// The index of the first used byte in the buffer.
|
||||||
|
start: usize,
|
||||||
|
}
|
||||||
|
//TODO: ringbuf
|
||||||
|
impl RingBuffer {
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
// TODO: can be optimized.
|
||||||
|
let temp = vec![0; capacity];
|
||||||
|
Self {
|
||||||
|
// FIXME: if the capacity is excessive, elements move will be executed.
|
||||||
|
buffer: temp.into_boxed_slice(),
|
||||||
|
used: 0,
|
||||||
|
start: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Returns the number of bytes currently used in the buffer.
|
||||||
|
pub fn used(&self) -> usize {
|
||||||
|
self.used
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true iff there are currently no bytes in the buffer.
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.used == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of bytes currently free in the buffer.
|
||||||
|
pub fn available(&self) -> usize {
|
||||||
|
self.buffer.len() - self.used
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Adds the given bytes to the buffer if there is enough capacity for them all.
|
||||||
|
///
|
||||||
|
/// Returns true if they were added, or false if they were not.
|
||||||
|
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||||
|
if bytes.len() > self.available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The index of the first available position in the buffer.
|
||||||
|
let first_available = (self.start + self.used) % self.buffer.len();
|
||||||
|
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
|
||||||
|
// `buffer.len()`.
|
||||||
|
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
|
||||||
|
self.buffer[first_available..first_available + copy_length_before_wraparound]
|
||||||
|
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
|
||||||
|
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
|
||||||
|
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
|
||||||
|
}
|
||||||
|
self.used += bytes.len();
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
|
||||||
|
/// buffer.
|
||||||
|
pub fn drain(&mut self, out: &mut [u8]) -> usize {
|
||||||
|
let bytes_read = min(self.used, out.len());
|
||||||
|
|
||||||
|
// The number of bytes to copy out between `start` and the end of the buffer.
|
||||||
|
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
|
||||||
|
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
|
||||||
|
let read_after_wraparound = bytes_read
|
||||||
|
.checked_sub(read_before_wraparound)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
out[0..read_before_wraparound]
|
||||||
|
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
|
||||||
|
out[read_before_wraparound..bytes_read]
|
||||||
|
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
|
||||||
|
|
||||||
|
self.used -= bytes_read;
|
||||||
|
self.start = (self.start + bytes_read) % self.buffer.len();
|
||||||
|
|
||||||
|
bytes_read
|
||||||
|
}
|
||||||
|
}
|
95
kernel/aster-nix/src/net/socket/vsock/stream/init.rs
Normal file
95
kernel/aster-nix/src/net/socket/vsock/stream/init.rs
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
events::IoEvents,
|
||||||
|
net::socket::vsock::{
|
||||||
|
addr::{VsockSocketAddr, VMADDR_CID_ANY, VMADDR_PORT_ANY},
|
||||||
|
VSOCK_GLOBAL,
|
||||||
|
},
|
||||||
|
prelude::*,
|
||||||
|
process::signal::{Pollee, Poller},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct Init {
|
||||||
|
bind_addr: SpinLock<Option<VsockSocketAddr>>,
|
||||||
|
pollee: Pollee,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Init {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
bind_addr: SpinLock::new(None),
|
||||||
|
pollee: Pollee::new(IoEvents::empty()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bind(&self, addr: VsockSocketAddr) -> Result<()> {
|
||||||
|
if self.bind_addr.lock().is_some() {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
|
||||||
|
}
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
|
||||||
|
// check correctness of cid
|
||||||
|
let local_cid = vsockspace.driver.lock_irq_disabled().guest_cid();
|
||||||
|
if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid as u32 {
|
||||||
|
return_errno_with_message!(Errno::EADDRNOTAVAIL, "The cid in address is incorrect");
|
||||||
|
}
|
||||||
|
let mut new_addr = addr;
|
||||||
|
new_addr.cid = local_cid as u32;
|
||||||
|
|
||||||
|
// check and assign a port
|
||||||
|
if addr.port == VMADDR_PORT_ANY {
|
||||||
|
if let Ok(port) = vsockspace.alloc_ephemeral_port() {
|
||||||
|
new_addr.port = port;
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port");
|
||||||
|
}
|
||||||
|
} else if vsockspace
|
||||||
|
.used_ports
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.contains(&new_addr.port)
|
||||||
|
{
|
||||||
|
return_errno_with_message!(Errno::EADDRNOTAVAIL, "the port in address is occupied");
|
||||||
|
} else {
|
||||||
|
vsockspace
|
||||||
|
.used_ports
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.insert(new_addr.port);
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: The privileged port isn't checked
|
||||||
|
*self.bind_addr.lock() = Some(new_addr);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_bound(&self) -> bool {
|
||||||
|
self.bind_addr.lock().is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bound_addr(&self) -> Option<VsockSocketAddr> {
|
||||||
|
*self.bind_addr.lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||||
|
self.pollee.poll(mask, poller)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_events(&self, events: IoEvents) {
|
||||||
|
self.pollee.add_events(events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Init {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(addr) = *self.bind_addr.lock() {
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
vsockspace.used_ports.lock_irq_disabled().remove(&addr.port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Init {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
75
kernel/aster-nix/src/net/socket/vsock/stream/listen.rs
Normal file
75
kernel/aster-nix/src/net/socket/vsock/stream/listen.rs
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use super::connected::Connected;
|
||||||
|
use crate::{
|
||||||
|
events::IoEvents,
|
||||||
|
net::socket::vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||||
|
prelude::*,
|
||||||
|
process::signal::{Pollee, Poller},
|
||||||
|
};
|
||||||
|
pub struct Listen {
|
||||||
|
addr: VsockSocketAddr,
|
||||||
|
pollee: Pollee,
|
||||||
|
backlog: usize,
|
||||||
|
incoming_connection: SpinLock<VecDeque<Arc<Connected>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Listen {
|
||||||
|
pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
addr,
|
||||||
|
pollee: Pollee::new(IoEvents::empty()),
|
||||||
|
backlog,
|
||||||
|
incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn addr(&self) -> VsockSocketAddr {
|
||||||
|
self.addr
|
||||||
|
}
|
||||||
|
pub fn push_incoming(&self, connect: Arc<Connected>) -> Result<()> {
|
||||||
|
let mut incoming_connections = self.incoming_connection.lock_irq_disabled();
|
||||||
|
if incoming_connections.len() >= self.backlog {
|
||||||
|
return_errno_with_message!(Errno::ENOMEM, "Queue in listenging socket is full")
|
||||||
|
}
|
||||||
|
incoming_connections.push_back(connect);
|
||||||
|
self.add_events(IoEvents::IN);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub fn accept(&self) -> Result<Arc<Connected>> {
|
||||||
|
// block waiting connection if no existing connection.
|
||||||
|
let poller = Poller::new();
|
||||||
|
if !self
|
||||||
|
.poll(IoEvents::IN, Some(&poller))
|
||||||
|
.contains(IoEvents::IN)
|
||||||
|
{
|
||||||
|
poller.wait()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let connection = self
|
||||||
|
.incoming_connection
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.pop_front()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(connection)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||||
|
self.pollee.poll(mask, poller)
|
||||||
|
}
|
||||||
|
pub fn add_events(&self, events: IoEvents) {
|
||||||
|
self.pollee.add_events(events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Listen {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
VSOCK_GLOBAL
|
||||||
|
.get()
|
||||||
|
.unwrap()
|
||||||
|
.used_ports
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.remove(&self.addr.port);
|
||||||
|
}
|
||||||
|
}
|
8
kernel/aster-nix/src/net/socket/vsock/stream/mod.rs
Normal file
8
kernel/aster-nix/src/net/socket/vsock/stream/mod.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
pub mod connected;
|
||||||
|
pub mod init;
|
||||||
|
pub mod listen;
|
||||||
|
|
||||||
|
pub mod socket;
|
||||||
|
pub use socket::VsockStreamSocket;
|
291
kernel/aster-nix/src/net/socket/vsock/stream/socket.rs
Normal file
291
kernel/aster-nix/src/net/socket/vsock/stream/socket.rs
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
use super::{connected::Connected, init::Init, listen::Listen};
|
||||||
|
use crate::{
|
||||||
|
events::IoEvents,
|
||||||
|
fs::file_handle::FileLike,
|
||||||
|
net::socket::{
|
||||||
|
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||||
|
SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
|
||||||
|
},
|
||||||
|
prelude::*,
|
||||||
|
process::signal::Poller,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct VsockStreamSocket(RwLock<Status>);
|
||||||
|
|
||||||
|
impl VsockStreamSocket {
|
||||||
|
pub(super) fn new_from_init(init: Arc<Init>) -> Self {
|
||||||
|
Self(RwLock::new(Status::Init(init)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn new_from_listen(listen: Arc<Listen>) -> Self {
|
||||||
|
Self(RwLock::new(Status::Listen(listen)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
|
||||||
|
Self(RwLock::new(Status::Connected(connected)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Status {
|
||||||
|
Init(Arc<Init>),
|
||||||
|
Listen(Arc<Listen>),
|
||||||
|
Connected(Arc<Connected>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VsockStreamSocket {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let init = Arc::new(Init::new());
|
||||||
|
Self(RwLock::new(Status::Init(init)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileLike for VsockStreamSocket {
|
||||||
|
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
|
||||||
|
Some(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||||
|
self.recvfrom(buf, SendRecvFlags::empty())
|
||||||
|
.map(|(read_size, _)| read_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||||
|
self.sendto(buf, None, SendRecvFlags::empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||||
|
let inner = self.0.read();
|
||||||
|
match &*inner {
|
||||||
|
Status::Init(init) => init.poll(mask, poller),
|
||||||
|
Status::Listen(listen) => listen.poll(mask, poller),
|
||||||
|
Status::Connected(connect) => connect.poll(mask, poller),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socket for VsockStreamSocket {
|
||||||
|
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||||
|
let addr = VsockSocketAddr::try_from(sockaddr)?;
|
||||||
|
let inner = self.0.read();
|
||||||
|
match &*inner {
|
||||||
|
Status::Init(init) => init.bind(addr),
|
||||||
|
Status::Listen(_) | Status::Connected(_) => {
|
||||||
|
return_errno_with_message!(
|
||||||
|
Errno::EINVAL,
|
||||||
|
"cannot bind a listening or connected socket"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||||
|
let init = match &*self.0.read() {
|
||||||
|
Status::Init(init) => init.clone(),
|
||||||
|
Status::Listen(_) => {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "The socket is listened");
|
||||||
|
}
|
||||||
|
Status::Connected(_) => {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "The socket is connected");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let remote_addr = VsockSocketAddr::try_from(sockaddr)?;
|
||||||
|
let local_addr = init.bound_addr();
|
||||||
|
|
||||||
|
if let Some(addr) = local_addr {
|
||||||
|
if addr == remote_addr {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
init.bind(VsockSocketAddr::any_addr())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let connecting = Arc::new(Connected::new(remote_addr, init.bound_addr().unwrap()));
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
vsockspace
|
||||||
|
.connecting_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.insert(connecting.local_addr(), connecting.clone());
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
vsockspace
|
||||||
|
.driver
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.request(&connecting.get_info())
|
||||||
|
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?;
|
||||||
|
|
||||||
|
// wait for response from driver
|
||||||
|
// TODO: add timeout
|
||||||
|
let poller = Poller::new();
|
||||||
|
if !connecting
|
||||||
|
.poll(IoEvents::IN, Some(&poller))
|
||||||
|
.contains(IoEvents::IN)
|
||||||
|
{
|
||||||
|
poller.wait()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.0.write() = Status::Connected(connecting.clone());
|
||||||
|
// move connecting socket map to connected sockmap
|
||||||
|
vsockspace
|
||||||
|
.connecting_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.remove(&connecting.local_addr())
|
||||||
|
.unwrap();
|
||||||
|
vsockspace
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.insert(connecting.id(), connecting);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn listen(&self, backlog: usize) -> Result<()> {
|
||||||
|
let init = match &*self.0.read() {
|
||||||
|
Status::Init(init) => init.clone(),
|
||||||
|
Status::Listen(_) => {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "The socket is already listened");
|
||||||
|
}
|
||||||
|
Status::Connected(_) => {
|
||||||
|
return_errno_with_message!(Errno::EISCONN, "The socket is already connected");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let addr = init.bound_addr().ok_or(Error::with_message(
|
||||||
|
Errno::EINVAL,
|
||||||
|
"The socket is not bound",
|
||||||
|
))?;
|
||||||
|
let listen = Arc::new(Listen::new(addr, backlog));
|
||||||
|
*self.0.write() = Status::Listen(listen.clone());
|
||||||
|
|
||||||
|
// push listen socket into vsockspace
|
||||||
|
VSOCK_GLOBAL
|
||||||
|
.get()
|
||||||
|
.unwrap()
|
||||||
|
.listen_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.insert(listen.addr(), listen);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||||
|
let listen = match &*self.0.read() {
|
||||||
|
Status::Listen(listen) => listen.clone(),
|
||||||
|
Status::Init(_) | Status::Connected(_) => {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let connected = listen.accept()?;
|
||||||
|
let peer_addr = connected.peer_addr();
|
||||||
|
|
||||||
|
VSOCK_GLOBAL
|
||||||
|
.get()
|
||||||
|
.unwrap()
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.insert(connected.id(), connected.clone());
|
||||||
|
|
||||||
|
VSOCK_GLOBAL
|
||||||
|
.get()
|
||||||
|
.unwrap()
|
||||||
|
.driver
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.response(&connected.get_info())
|
||||||
|
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
|
||||||
|
|
||||||
|
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
|
||||||
|
Ok((socket, peer_addr.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||||
|
let inner = self.0.read();
|
||||||
|
if let Status::Connected(connected) = &*inner {
|
||||||
|
let result = connected.shutdown(cmd);
|
||||||
|
if result.is_ok() {
|
||||||
|
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||||
|
vsockspace
|
||||||
|
.used_ports
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.remove(&connected.local_addr().port);
|
||||||
|
vsockspace
|
||||||
|
.connected_sockets
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.remove(&connected.id());
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "The socket is not connected.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||||
|
let connected = match &*self.0.read() {
|
||||||
|
Status::Connected(connected) => connected.clone(),
|
||||||
|
Status::Init(_) | Status::Listen(_) => {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let read_size = connected.recv(buf)?;
|
||||||
|
let peer_addr = self.peer_addr()?;
|
||||||
|
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
||||||
|
// connection.
|
||||||
|
if connected.should_close() {
|
||||||
|
VSOCK_GLOBAL
|
||||||
|
.get()
|
||||||
|
.unwrap()
|
||||||
|
.driver
|
||||||
|
.lock_irq_disabled()
|
||||||
|
.reset(&connected.get_info())
|
||||||
|
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send close packet"))?;
|
||||||
|
}
|
||||||
|
Ok((read_size, peer_addr))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sendto(
|
||||||
|
&self,
|
||||||
|
buf: &[u8],
|
||||||
|
remote: Option<SocketAddr>,
|
||||||
|
flags: SendRecvFlags,
|
||||||
|
) -> Result<usize> {
|
||||||
|
debug_assert!(remote.is_none());
|
||||||
|
if remote.is_some() {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "vsock should not provide remote addr");
|
||||||
|
}
|
||||||
|
let inner = self.0.read();
|
||||||
|
match &*inner {
|
||||||
|
Status::Connected(connected) => connected.send(buf, flags),
|
||||||
|
Status::Init(_) | Status::Listen(_) => {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn addr(&self) -> Result<SocketAddr> {
|
||||||
|
let inner = self.0.read();
|
||||||
|
let addr = match &*inner {
|
||||||
|
Status::Init(init) => init.bound_addr(),
|
||||||
|
Status::Listen(listen) => Some(listen.addr()),
|
||||||
|
Status::Connected(connected) => Some(connected.local_addr()),
|
||||||
|
};
|
||||||
|
addr.map(Into::<SocketAddr>::into)
|
||||||
|
.ok_or(Error::with_message(
|
||||||
|
Errno::EINVAL,
|
||||||
|
"The socket does not bind to addr",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||||
|
let inner = self.0.read();
|
||||||
|
if let Status::Connected(connected) = &*inner {
|
||||||
|
Ok(connected.peer_addr().into())
|
||||||
|
} else {
|
||||||
|
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for VsockStreamSocket {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@ use crate::{
|
|||||||
net::socket::{
|
net::socket::{
|
||||||
ip::{DatagramSocket, StreamSocket},
|
ip::{DatagramSocket, StreamSocket},
|
||||||
unix::UnixStreamSocket,
|
unix::UnixStreamSocket,
|
||||||
|
vsock::VsockStreamSocket,
|
||||||
},
|
},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
util::net::{CSocketAddrFamily, Protocol, SockFlags, SockType, SOCK_TYPE_MASK},
|
util::net::{CSocketAddrFamily, Protocol, SockFlags, SockType, SOCK_TYPE_MASK},
|
||||||
@ -35,6 +36,9 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result<SyscallRetur
|
|||||||
SockType::SOCK_DGRAM,
|
SockType::SOCK_DGRAM,
|
||||||
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
|
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
|
||||||
) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
|
) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
|
||||||
|
(CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM, _) => {
|
||||||
|
Arc::new(VsockStreamSocket::new()) as Arc<dyn FileLike>
|
||||||
|
}
|
||||||
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
|
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
|
||||||
};
|
};
|
||||||
let fd = {
|
let fd = {
|
||||||
|
@ -55,6 +55,11 @@ pub fn read_socket_addr_from_user(addr: Vaddr, addr_len: usize) -> Result<Socket
|
|||||||
let sock_addr_in6: CSocketAddrInet6 = read_val_from_user(addr)?;
|
let sock_addr_in6: CSocketAddrInet6 = read_val_from_user(addr)?;
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
CSocketAddrFamily::AF_VSOCK => {
|
||||||
|
debug_assert!(addr_len >= core::mem::size_of::<CSocketAddrVm>());
|
||||||
|
let sock_addr_vm: CSocketAddrVm = read_val_from_user(addr)?;
|
||||||
|
SocketAddr::Vsock(sock_addr_vm.svm_cid, sock_addr_vm.svm_port)
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
return_errno_with_message!(Errno::EAFNOSUPPORT, "cannot support address for the family")
|
return_errno_with_message!(Errno::EAFNOSUPPORT, "cannot support address for the family")
|
||||||
}
|
}
|
||||||
@ -89,6 +94,12 @@ pub fn write_socket_addr_to_user(
|
|||||||
write_size as i32
|
write_size as i32
|
||||||
}
|
}
|
||||||
SocketAddr::IPv6 => todo!(),
|
SocketAddr::IPv6 => todo!(),
|
||||||
|
SocketAddr::Vsock(cid, port) => {
|
||||||
|
let vm_addr = CSocketAddrVm::new(*cid, *port);
|
||||||
|
let write_size = core::mem::size_of::<CSocketAddrVm>();
|
||||||
|
write_val_to_user(dest, &vm_addr)?;
|
||||||
|
write_size as i32
|
||||||
|
}
|
||||||
};
|
};
|
||||||
if addrlen_ptr != 0 {
|
if addrlen_ptr != 0 {
|
||||||
write_val_to_user(addrlen_ptr, &write_size)?;
|
write_val_to_user(addrlen_ptr, &write_size)?;
|
||||||
@ -210,6 +221,34 @@ pub struct CSocketAddrInet6 {
|
|||||||
sin6_scope_id: u32,
|
sin6_scope_id: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// vm socket address
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Clone, Copy, Pod)]
|
||||||
|
pub struct CSocketAddrVm {
|
||||||
|
/// always [SaFamily::AF_VSOCK]
|
||||||
|
svm_family: u16,
|
||||||
|
/// always 0
|
||||||
|
svm_reserved1: u16,
|
||||||
|
/// Port number in host byte order.
|
||||||
|
svm_port: u32,
|
||||||
|
/// Address in host byte order.
|
||||||
|
svm_cid: u32,
|
||||||
|
/// Pad to size of [SockAddr] structure (16 bytes), must be zero-filled
|
||||||
|
svm_zero: [u8; 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CSocketAddrVm {
|
||||||
|
pub fn new(cid: u32, port: u32) -> Self {
|
||||||
|
Self {
|
||||||
|
svm_family: CSocketAddrFamily::AF_VSOCK as _,
|
||||||
|
svm_reserved1: 0,
|
||||||
|
svm_port: port,
|
||||||
|
svm_cid: cid,
|
||||||
|
svm_zero: [0u8; 4],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Address family. The definition is from https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/socket.h.
|
/// Address family. The definition is from https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/socket.h.
|
||||||
#[repr(i32)]
|
#[repr(i32)]
|
||||||
#[derive(Debug, Clone, Copy, TryFromInt, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, TryFromInt, PartialEq, Eq)]
|
||||||
|
@ -2,10 +2,6 @@
|
|||||||
|
|
||||||
use align_ext::AlignExt;
|
use align_ext::AlignExt;
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use pod::Pod;
|
|
||||||
|
|
||||||
use super::header::VirtioVsockHdr;
|
|
||||||
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
|
|
||||||
|
|
||||||
/// Buffer for receive packet
|
/// Buffer for receive packet
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -38,22 +34,6 @@ impl RxBuffer {
|
|||||||
pub fn buf_mut(&mut self) -> &mut [u8] {
|
pub fn buf_mut(&mut self) -> &mut [u8] {
|
||||||
&mut self.buf
|
&mut self.buf
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Packet payload slice, which is inner buffer excluding VirtioVsockHdr.
|
|
||||||
pub fn packet(&self) -> &[u8] {
|
|
||||||
debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len());
|
|
||||||
&self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len]
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Mutable packet payload slice.
|
|
||||||
pub fn packet_mut(&mut self) -> &mut [u8] {
|
|
||||||
debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len());
|
|
||||||
&mut self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len]
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn virtio_vsock_header(&self) -> VirtioVsockHdr {
|
|
||||||
VirtioVsockHdr::from_bytes(&self.buf[..VIRTIO_VSOCK_HDR_LEN])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Buffer for transmit packet
|
/// Buffer for transmit packet
|
||||||
|
@ -8,17 +8,14 @@ use pod::Pod;
|
|||||||
use crate::transport::VirtioTransport;
|
use crate::transport::VirtioTransport;
|
||||||
|
|
||||||
bitflags! {
|
bitflags! {
|
||||||
/// Vsock feature bits since v1.2
|
|
||||||
/// If no feature bit is set, only stream socket type is supported.
|
|
||||||
/// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated.
|
|
||||||
pub struct VsockFeatures: u64 {
|
pub struct VsockFeatures: u64 {
|
||||||
const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported.
|
const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported.
|
||||||
const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is supported.
|
const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is not supported now.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VsockFeatures {
|
impl VsockFeatures {
|
||||||
pub fn support_features() -> Self {
|
pub const fn supported_features() -> Self {
|
||||||
VsockFeatures::VIRTIO_VSOCK_F_STREAM
|
VsockFeatures::VIRTIO_VSOCK_F_STREAM
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -32,9 +29,7 @@ pub struct VirtioVsockConfig {
|
|||||||
/// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
|
/// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
|
||||||
/// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
|
/// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
|
||||||
/// So we need to split the u64 guest_cid into two parts.
|
/// So we need to split the u64 guest_cid into two parts.
|
||||||
// read only
|
|
||||||
pub guest_cid_low: u32,
|
pub guest_cid_low: u32,
|
||||||
// read only
|
|
||||||
pub guest_cid_high: u32,
|
pub guest_cid_high: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,30 @@
|
|||||||
// SPDX-License-Identifier: MPL-2.0
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
// Modified from vsock.rs in virtio-drivers project
|
||||||
|
//
|
||||||
|
// MIT License
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Ant Group
|
||||||
|
// Copyright (c) 2019-2020 rCore Developers
|
||||||
|
//
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to deal
|
||||||
|
// in the Software without restriction, including without limitation the rights
|
||||||
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
// copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
//
|
||||||
|
// The above copyright notice and this permission notice shall be included in all
|
||||||
|
// copies or substantial portions of the Software.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
// SOFTWARE.
|
||||||
|
//
|
||||||
use log::debug;
|
use log::debug;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
@ -7,7 +32,7 @@ use super::{
|
|||||||
header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr},
|
header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||||
pub struct VsockBufferStatus {
|
pub struct VsockBufferStatus {
|
||||||
pub buffer_allocation: u32,
|
pub buffer_allocation: u32,
|
||||||
pub forward_count: u32,
|
pub forward_count: u32,
|
||||||
@ -24,7 +49,7 @@ pub enum DisconnectReason {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Details of the type of an event received from a VirtIO socket.
|
/// Details of the type of an event received from a VirtIO socket.
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||||
pub enum VsockEventType {
|
pub enum VsockEventType {
|
||||||
/// The peer requests to establish a connection with us.
|
/// The peer requests to establish a connection with us.
|
||||||
ConnectionRequest,
|
ConnectionRequest,
|
||||||
@ -47,7 +72,7 @@ pub enum VsockEventType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// An event received from a VirtIO socket device.
|
/// An event received from a VirtIO socket device.
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||||
pub struct VsockEvent {
|
pub struct VsockEvent {
|
||||||
/// The source of the event, i.e. the peer who sent it.
|
/// The source of the event, i.e. the peer who sent it.
|
||||||
pub source: VsockAddr,
|
pub source: VsockAddr,
|
||||||
|
@ -17,7 +17,10 @@ use super::{
|
|||||||
VsockDeviceIrqHandler,
|
VsockDeviceIrqHandler,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
device::{socket::register_device, VirtioDeviceError},
|
device::{
|
||||||
|
socket::{handle_recv_irq, register_device},
|
||||||
|
VirtioDeviceError,
|
||||||
|
},
|
||||||
queue::{QueueError, VirtQueue},
|
queue::{QueueError, VirtQueue},
|
||||||
transport::VirtioTransport,
|
transport::VirtioTransport,
|
||||||
};
|
};
|
||||||
@ -27,10 +30,10 @@ const QUEUE_RECV: u16 = 0;
|
|||||||
const QUEUE_SEND: u16 = 1;
|
const QUEUE_SEND: u16 = 1;
|
||||||
const QUEUE_EVENT: u16 = 2;
|
const QUEUE_EVENT: u16 = 2;
|
||||||
|
|
||||||
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
|
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than `size_of::<VirtioVsockHdr>()`.
|
||||||
const RX_BUFFER_SIZE: usize = 512;
|
const RX_BUFFER_SIZE: usize = 512;
|
||||||
|
|
||||||
/// Low-level driver for a Virtio socket device.
|
/// Vsock device driver
|
||||||
pub struct SocketDevice {
|
pub struct SocketDevice {
|
||||||
config: VirtioVsockConfig,
|
config: VirtioVsockConfig,
|
||||||
guest_cid: u64,
|
guest_cid: u64,
|
||||||
@ -46,6 +49,7 @@ pub struct SocketDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SocketDevice {
|
impl SocketDevice {
|
||||||
|
/// Create a new vsock device
|
||||||
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
|
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
|
||||||
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
|
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
|
||||||
debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
|
debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
|
||||||
@ -95,9 +99,8 @@ impl SocketDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Interrupt handler if vsock device receives some packet.
|
// Interrupt handler if vsock device receives some packet.
|
||||||
// TODO: This will be handled by vsock socket layer.
|
|
||||||
fn handle_vsock_event(_: &TrapFrame) {
|
fn handle_vsock_event(_: &TrapFrame) {
|
||||||
debug!("Packet received. This will be solved by socket layer");
|
handle_recv_irq(super::DEVICE_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
device
|
device
|
||||||
@ -119,28 +122,23 @@ impl SocketDevice {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the CID which has been assigned to this guest.
|
/// Return the CID which has been assigned to this guest.
|
||||||
pub fn guest_cid(&self) -> u64 {
|
pub fn guest_cid(&self) -> u64 {
|
||||||
self.guest_cid
|
self.guest_cid
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a request to connect to the given destination.
|
/// Send a connection request
|
||||||
///
|
pub fn request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
|
||||||
/// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection
|
|
||||||
/// before sending data.
|
|
||||||
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
||||||
let header = VirtioVsockHdr {
|
let header = VirtioVsockHdr {
|
||||||
op: VirtioVsockOp::Request as u16,
|
op: VirtioVsockOp::Request as u16,
|
||||||
..connection_info.new_header(self.guest_cid)
|
..connection_info.new_header(self.guest_cid)
|
||||||
};
|
};
|
||||||
// Sends a header only packet to the TX queue to connect the device to the listening socket
|
|
||||||
// at the given destination.
|
|
||||||
self.send_packet_to_tx_queue(&header, &[])
|
self.send_packet_to_tx_queue(&header, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Accepts the given connection from a peer.
|
/// Send a response to peer, if peer start a sending request
|
||||||
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
pub fn response(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||||
let header = VirtioVsockHdr {
|
let header = VirtioVsockHdr {
|
||||||
op: VirtioVsockOp::Response as u16,
|
op: VirtioVsockOp::Response as u16,
|
||||||
..connection_info.new_header(self.guest_cid)
|
..connection_info.new_header(self.guest_cid)
|
||||||
@ -148,29 +146,7 @@ impl SocketDevice {
|
|||||||
self.send_packet_to_tx_queue(&header, &[])
|
self.send_packet_to_tx_queue(&header, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Requests the peer to send us a credit update for the given connection.
|
/// Send a shutdown request
|
||||||
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
||||||
let header = VirtioVsockHdr {
|
|
||||||
op: VirtioVsockOp::CreditRequest as u16,
|
|
||||||
..connection_info.new_header(self.guest_cid)
|
|
||||||
};
|
|
||||||
self.send_packet_to_tx_queue(&header, &[])
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tells the peer how much buffer space we have to receive data.
|
|
||||||
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
|
||||||
let header = VirtioVsockHdr {
|
|
||||||
op: VirtioVsockOp::CreditUpdate as u16,
|
|
||||||
..connection_info.new_header(self.guest_cid)
|
|
||||||
};
|
|
||||||
self.send_packet_to_tx_queue(&header, &[])
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Requests to shut down the connection cleanly.
|
|
||||||
///
|
|
||||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
|
||||||
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
|
|
||||||
/// shutdown.
|
|
||||||
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||||
let header = VirtioVsockHdr {
|
let header = VirtioVsockHdr {
|
||||||
op: VirtioVsockOp::Shutdown as u16,
|
op: VirtioVsockOp::Shutdown as u16,
|
||||||
@ -179,8 +155,8 @@ impl SocketDevice {
|
|||||||
self.send_packet_to_tx_queue(&header, &[])
|
self.send_packet_to_tx_queue(&header, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forcibly closes the connection without waiting for the peer.
|
/// Send a reset request to peer
|
||||||
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
pub fn reset(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||||
let header = VirtioVsockHdr {
|
let header = VirtioVsockHdr {
|
||||||
op: VirtioVsockOp::Rst as u16,
|
op: VirtioVsockOp::Rst as u16,
|
||||||
..connection_info.new_header(self.guest_cid)
|
..connection_info.new_header(self.guest_cid)
|
||||||
@ -188,16 +164,29 @@ impl SocketDevice {
|
|||||||
self.send_packet_to_tx_queue(&header, &[])
|
self.send_packet_to_tx_queue(&header, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Request the peer to send the credit info to us
|
||||||
|
pub fn credit_request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||||
|
let header = VirtioVsockHdr {
|
||||||
|
op: VirtioVsockOp::CreditRequest as u16,
|
||||||
|
..connection_info.new_header(self.guest_cid)
|
||||||
|
};
|
||||||
|
self.send_packet_to_tx_queue(&header, &[])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tell the peer our credit info
|
||||||
|
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||||
|
let header = VirtioVsockHdr {
|
||||||
|
op: VirtioVsockOp::CreditUpdate as u16,
|
||||||
|
..connection_info.new_header(self.guest_cid)
|
||||||
|
};
|
||||||
|
self.send_packet_to_tx_queue(&header, &[])
|
||||||
|
}
|
||||||
|
|
||||||
fn send_packet_to_tx_queue(
|
fn send_packet_to_tx_queue(
|
||||||
&mut self,
|
&mut self,
|
||||||
header: &VirtioVsockHdr,
|
header: &VirtioVsockHdr,
|
||||||
buffer: &[u8],
|
buffer: &[u8],
|
||||||
) -> Result<(), SocketError> {
|
) -> Result<(), SocketError> {
|
||||||
// let (_token, _len) = self.send_queue.add_notify_wait_pop(
|
|
||||||
// &[header.as_bytes(), buffer],
|
|
||||||
// &mut [],
|
|
||||||
// )?;
|
|
||||||
|
|
||||||
let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
|
let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
|
||||||
|
|
||||||
if self.send_queue.should_notify() {
|
if self.send_queue.should_notify() {
|
||||||
@ -211,8 +200,7 @@ impl SocketDevice {
|
|||||||
|
|
||||||
self.send_queue.pop_used()?;
|
self.send_queue.pop_used()?;
|
||||||
|
|
||||||
// FORDEBUG
|
debug!("buffer in send_packet_to_tx_queue: {:?}", buffer);
|
||||||
// debug!("buffer in send_packet_to_tx_queue: {:?}",buffer);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,13 +209,19 @@ impl SocketDevice {
|
|||||||
connection_info: &mut ConnectionInfo,
|
connection_info: &mut ConnectionInfo,
|
||||||
buffer_len: usize,
|
buffer_len: usize,
|
||||||
) -> Result<(), SocketError> {
|
) -> Result<(), SocketError> {
|
||||||
|
debug!("connectin info {:?}", connection_info);
|
||||||
|
debug!(
|
||||||
|
"peer free from peer: {:?}, buffer len : {:?}",
|
||||||
|
connection_info.peer_free(),
|
||||||
|
buffer_len
|
||||||
|
);
|
||||||
if connection_info.peer_free() as usize >= buffer_len {
|
if connection_info.peer_free() as usize >= buffer_len {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
// Request an update of the cached peer credit, if we haven't already done so, and tell
|
// Request an update of the cached peer credit, if we haven't already done so, and tell
|
||||||
// the caller to try again later.
|
// the caller to try again later.
|
||||||
if !connection_info.has_pending_credit_request {
|
if !connection_info.has_pending_credit_request {
|
||||||
self.request_credit(connection_info)?;
|
self.credit_request(connection_info)?;
|
||||||
connection_info.has_pending_credit_request = true;
|
connection_info.has_pending_credit_request = true;
|
||||||
}
|
}
|
||||||
Err(SocketError::InsufficientBufferSpaceInPeer)
|
Err(SocketError::InsufficientBufferSpaceInPeer)
|
||||||
@ -252,6 +246,36 @@ impl SocketDevice {
|
|||||||
self.send_packet_to_tx_queue(&header, buffer)
|
self.send_packet_to_tx_queue(&header, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Receive bytes from peer, returns the header
|
||||||
|
pub fn receive(
|
||||||
|
&mut self,
|
||||||
|
buffer: &mut [u8],
|
||||||
|
// connection_info: &mut ConnectionInfo,
|
||||||
|
) -> Result<VirtioVsockHdr, SocketError> {
|
||||||
|
let (token, len) = self.recv_queue.pop_used()?;
|
||||||
|
debug!(
|
||||||
|
"receive packet in rx_queue: token = {}, len = {}",
|
||||||
|
token, len
|
||||||
|
);
|
||||||
|
let mut rx_buffer = self
|
||||||
|
.rx_buffers
|
||||||
|
.remove(token as usize)
|
||||||
|
.ok_or(QueueError::WrongToken)?;
|
||||||
|
rx_buffer.set_packet_len(RX_BUFFER_SIZE);
|
||||||
|
|
||||||
|
let (header, payload) = read_header_and_body(rx_buffer.buf())?;
|
||||||
|
// The length written should be equal to len(header)+len(packet)
|
||||||
|
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
|
||||||
|
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
||||||
|
debug!("body is {:?}", payload);
|
||||||
|
|
||||||
|
assert!(buffer.len() >= payload.len());
|
||||||
|
buffer[..payload.len()].copy_from_slice(payload);
|
||||||
|
|
||||||
|
self.add_rx_buffer(rx_buffer, token)?;
|
||||||
|
Ok(header)
|
||||||
|
}
|
||||||
|
|
||||||
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
|
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
|
||||||
pub fn poll(
|
pub fn poll(
|
||||||
&mut self,
|
&mut self,
|
||||||
@ -261,39 +285,10 @@ impl SocketDevice {
|
|||||||
if !self.recv_queue.can_pop() {
|
if !self.recv_queue.can_pop() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
let (token, len) = self.recv_queue.pop_used()?;
|
let mut body = RxBuffer::new(RX_BUFFER_SIZE);
|
||||||
|
let header = self.receive(body.buf_mut())?;
|
||||||
|
|
||||||
let mut buffer = self
|
VsockEvent::from_header(&header).and_then(|event| handler(event, body.buf()))
|
||||||
.rx_buffers
|
|
||||||
.remove(token as usize)
|
|
||||||
.ok_or(QueueError::WrongToken)?;
|
|
||||||
|
|
||||||
let header = buffer.virtio_vsock_header();
|
|
||||||
// The length written should be equal to len(header)+len(packet)
|
|
||||||
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
|
|
||||||
|
|
||||||
buffer.set_packet_len(RX_BUFFER_SIZE);
|
|
||||||
|
|
||||||
let head_result = read_header_and_body(buffer.buf());
|
|
||||||
|
|
||||||
let Ok((header, body)) = head_result else {
|
|
||||||
let ret = match head_result {
|
|
||||||
Err(e) => Err(e),
|
|
||||||
_ => Ok(None), //FIXME: this clause is never reached.
|
|
||||||
};
|
|
||||||
self.add_rx_buffer(buffer, token)?;
|
|
||||||
return ret;
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
|
||||||
debug!("body is {:?}", body);
|
|
||||||
|
|
||||||
let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
|
|
||||||
|
|
||||||
// reuse the buffer and give it back to recv_queue.
|
|
||||||
self.add_rx_buffer(buffer, token)?;
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a used rx buffer to recv queue,@index is only to check the correctness
|
/// Add a used rx buffer to recv queue,@index is only to check the correctness
|
||||||
@ -310,7 +305,7 @@ impl SocketDevice {
|
|||||||
/// Negotiate features for the device specified bits 0~23
|
/// Negotiate features for the device specified bits 0~23
|
||||||
pub(crate) fn negotiate_features(features: u64) -> u64 {
|
pub(crate) fn negotiate_features(features: u64) -> u64 {
|
||||||
let device_features = VsockFeatures::from_bits_truncate(features);
|
let device_features = VsockFeatures::from_bits_truncate(features);
|
||||||
let supported_features = VsockFeatures::support_features();
|
let supported_features = VsockFeatures::supported_features();
|
||||||
let vsock_features = device_features & supported_features;
|
let vsock_features = device_features & supported_features;
|
||||||
debug!("features negotiated: {:?}", vsock_features);
|
debug!("features negotiated: {:?}", vsock_features);
|
||||||
vsock_features.bits()
|
vsock_features.bits()
|
||||||
|
@ -1,14 +1,36 @@
|
|||||||
// SPDX-License-Identifier: MPL-2.0
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
//! This file comes from virtio-drivers project
|
// Modified from error.rs in virtio-drivers project
|
||||||
//! This module contains the error from the VirtIO socket driver.
|
//
|
||||||
|
// MIT License
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Ant Group
|
||||||
|
// Copyright (c) 2019-2020 rCore Developers
|
||||||
|
//
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to deal
|
||||||
|
// in the Software without restriction, including without limitation the rights
|
||||||
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
// copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
//
|
||||||
|
// The above copyright notice and this permission notice shall be included in all
|
||||||
|
// copies or substantial portions of the Software.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
// SOFTWARE.
|
||||||
|
//
|
||||||
use core::{fmt, result};
|
use core::{fmt, result};
|
||||||
|
|
||||||
use crate::queue::QueueError;
|
use crate::queue::QueueError;
|
||||||
|
|
||||||
/// The error type of VirtIO socket driver.
|
/// The error type of VirtIO socket driver.
|
||||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
#[derive(Debug)]
|
||||||
pub enum SocketError {
|
pub enum SocketError {
|
||||||
/// There is an existing connection.
|
/// There is an existing connection.
|
||||||
ConnectionExists,
|
ConnectionExists,
|
||||||
@ -39,33 +61,18 @@ pub enum SocketError {
|
|||||||
/// Recycled a wrong buffer.
|
/// Recycled a wrong buffer.
|
||||||
RecycledWrongBuffer,
|
RecycledWrongBuffer,
|
||||||
/// Queue Error
|
/// Queue Error
|
||||||
QueueError(SocketQueueError),
|
QueueError(QueueError),
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
|
||||||
pub enum SocketQueueError {
|
|
||||||
InvalidArgs,
|
|
||||||
BufferTooSmall,
|
|
||||||
NotReady,
|
|
||||||
AlreadyUsed,
|
|
||||||
WrongToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<QueueError> for SocketQueueError {
|
|
||||||
fn from(value: QueueError) -> Self {
|
|
||||||
match value {
|
|
||||||
QueueError::InvalidArgs => Self::InvalidArgs,
|
|
||||||
QueueError::BufferTooSmall => Self::BufferTooSmall,
|
|
||||||
QueueError::NotReady => Self::NotReady,
|
|
||||||
QueueError::AlreadyUsed => Self::AlreadyUsed,
|
|
||||||
QueueError::WrongToken => Self::WrongToken,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<QueueError> for SocketError {
|
impl From<QueueError> for SocketError {
|
||||||
fn from(value: QueueError) -> Self {
|
fn from(value: QueueError) -> Self {
|
||||||
Self::QueueError(SocketQueueError::from(value))
|
Self::QueueError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<int_to_c_enum::TryFromIntError> for SocketError {
|
||||||
|
fn from(_e: int_to_c_enum::TryFromIntError) -> Self {
|
||||||
|
Self::InvalidNumber
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,32 @@
|
|||||||
// SPDX-License-Identifier: MPL-2.0
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
// Modified from protocol.rs in virtio-drivers project
|
||||||
|
//
|
||||||
|
// MIT License
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Ant Group
|
||||||
|
// Copyright (c) 2019-2020 rCore Developers
|
||||||
|
//
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to deal
|
||||||
|
// in the Software without restriction, including without limitation the rights
|
||||||
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
// copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
//
|
||||||
|
// The above copyright notice and this permission notice shall be included in all
|
||||||
|
// copies or substantial portions of the Software.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
// SOFTWARE.
|
||||||
|
//
|
||||||
use bitflags::bitflags;
|
use bitflags::bitflags;
|
||||||
|
use int_to_c_enum::TryFromInt;
|
||||||
use pod::Pod;
|
use pod::Pod;
|
||||||
|
|
||||||
use super::error::{self, SocketError};
|
use super::error::{self, SocketError};
|
||||||
@ -64,7 +90,7 @@ impl VirtioVsockHdr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn op(&self) -> error::Result<VirtioVsockOp> {
|
pub fn op(&self) -> error::Result<VirtioVsockOp> {
|
||||||
self.op.try_into()
|
VirtioVsockOp::try_from(self.op).map_err(|err| err.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn source(&self) -> VsockAddr {
|
pub fn source(&self) -> VsockAddr {
|
||||||
@ -90,7 +116,7 @@ impl VirtioVsockHdr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, TryFromInt)]
|
||||||
#[repr(u16)]
|
#[repr(u16)]
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
pub enum VirtioVsockOp {
|
pub enum VirtioVsockOp {
|
||||||
@ -112,26 +138,6 @@ pub enum VirtioVsockOp {
|
|||||||
CreditRequest = 7,
|
CreditRequest = 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TODO: This could be optimized by upgrading [int_to_c_enum::TryFromIntError] to carrying the invalid int number
|
|
||||||
impl TryFrom<u16> for VirtioVsockOp {
|
|
||||||
type Error = SocketError;
|
|
||||||
|
|
||||||
fn try_from(v: u16) -> Result<Self, Self::Error> {
|
|
||||||
let op = match v {
|
|
||||||
0 => Self::Invalid,
|
|
||||||
1 => Self::Request,
|
|
||||||
2 => Self::Response,
|
|
||||||
3 => Self::Rst,
|
|
||||||
4 => Self::Shutdown,
|
|
||||||
5 => Self::Rw,
|
|
||||||
6 => Self::CreditUpdate,
|
|
||||||
7 => Self::CreditRequest,
|
|
||||||
_ => return Err(SocketError::UnknownOperation(v)),
|
|
||||||
};
|
|
||||||
Ok(op)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bitflags! {
|
bitflags! {
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
#[derive(Default, Pod)]
|
#[derive(Default, Pod)]
|
||||||
|
@ -1,359 +0,0 @@
|
|||||||
// SPDX-License-Identifier: MPL-2.0
|
|
||||||
|
|
||||||
use alloc::{boxed::Box, sync::Arc, vec, vec::Vec};
|
|
||||||
use core::{cmp::min, hint::spin_loop};
|
|
||||||
|
|
||||||
use aster_frame::sync::SpinLock;
|
|
||||||
use log::debug;
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
connect::{ConnectionInfo, DisconnectReason, VsockEvent, VsockEventType},
|
|
||||||
device::SocketDevice,
|
|
||||||
header::VsockAddr,
|
|
||||||
};
|
|
||||||
use crate::device::socket::error::SocketError;
|
|
||||||
|
|
||||||
const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
|
|
||||||
|
|
||||||
pub struct VsockConnectionManager {
|
|
||||||
driver: Arc<SpinLock<SocketDevice>>,
|
|
||||||
connections: Vec<Connection>,
|
|
||||||
listening_ports: Vec<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VsockConnectionManager {
|
|
||||||
/// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
|
|
||||||
pub fn new(driver: Arc<SpinLock<SocketDevice>>) -> Self {
|
|
||||||
Self {
|
|
||||||
driver,
|
|
||||||
connections: Vec::new(),
|
|
||||||
listening_ports: Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the CID which has been assigned to this guest.
|
|
||||||
pub fn guest_cid(&self) -> u64 {
|
|
||||||
self.driver.lock().guest_cid()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Allows incoming connections on the given port number.
|
|
||||||
pub fn listen(&mut self, port: u32) {
|
|
||||||
if !self.listening_ports.contains(&port) {
|
|
||||||
self.listening_ports.push(port);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stops allowing incoming connections on the given port number.
|
|
||||||
pub fn unlisten(&mut self, port: u32) {
|
|
||||||
self.listening_ports.retain(|p| *p != port)
|
|
||||||
}
|
|
||||||
/// Sends a request to connect to the given destination.
|
|
||||||
///
|
|
||||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
|
||||||
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection
|
|
||||||
/// before sending data.
|
|
||||||
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
|
|
||||||
if self.connections.iter().any(|connection| {
|
|
||||||
connection.info.dst == destination && connection.info.src_port == src_port
|
|
||||||
}) {
|
|
||||||
return Err(SocketError::ConnectionExists);
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_connection = Connection::new(destination, src_port);
|
|
||||||
|
|
||||||
self.driver.lock().connect(&new_connection.info)?;
|
|
||||||
debug!("Connection requested: {:?}", new_connection.info);
|
|
||||||
self.connections.push(new_connection);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends the buffer to the destination.
|
|
||||||
pub fn send(
|
|
||||||
&mut self,
|
|
||||||
destination: VsockAddr,
|
|
||||||
src_port: u32,
|
|
||||||
buffer: &[u8],
|
|
||||||
) -> Result<(), SocketError> {
|
|
||||||
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
|
||||||
|
|
||||||
self.driver.lock().send(buffer, &mut connection.info)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Polls the vsock device to receive data or other updates.
|
|
||||||
pub fn poll(&mut self) -> Result<Option<VsockEvent>, SocketError> {
|
|
||||||
let guest_cid = self.driver.lock().guest_cid();
|
|
||||||
let connections = &mut self.connections;
|
|
||||||
|
|
||||||
let result = self.driver.lock().poll(|event, body| {
|
|
||||||
let connection = get_connection_for_event(connections, &event, guest_cid);
|
|
||||||
|
|
||||||
// Skip events which don't match any connection we know about, unless they are a
|
|
||||||
// connection request.
|
|
||||||
let connection = if let Some((_, connection)) = connection {
|
|
||||||
connection
|
|
||||||
} else if let VsockEventType::ConnectionRequest = event.event_type {
|
|
||||||
// If the requested connection already exists or the CID isn't ours, ignore it.
|
|
||||||
if connection.is_some() || event.destination.cid != guest_cid {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
// Add the new connection to our list, at least for now. It will be removed again
|
|
||||||
// below if we weren't listening on the port.
|
|
||||||
connections.push(Connection::new(event.source, event.destination.port));
|
|
||||||
connections.last_mut().unwrap()
|
|
||||||
} else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Update stored connection info.
|
|
||||||
connection.info.update_for_event(&event);
|
|
||||||
|
|
||||||
if let VsockEventType::Received { length } = event.event_type {
|
|
||||||
// Copy to buffer
|
|
||||||
if !connection.buffer.add(body) {
|
|
||||||
return Err(SocketError::OutputBufferTooShort(length));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(event))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let Some(event) = result else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
|
|
||||||
// The connection must exist because we found it above in the callback.
|
|
||||||
let (connection_index, connection) =
|
|
||||||
get_connection_for_event(connections, &event, guest_cid).unwrap();
|
|
||||||
|
|
||||||
match event.event_type {
|
|
||||||
VsockEventType::ConnectionRequest => {
|
|
||||||
if self.listening_ports.contains(&event.destination.port) {
|
|
||||||
self.driver.lock().accept(&connection.info)?;
|
|
||||||
} else {
|
|
||||||
// Reject the connection request and remove it from our list.
|
|
||||||
self.driver.lock().force_close(&connection.info)?;
|
|
||||||
self.connections.swap_remove(connection_index);
|
|
||||||
|
|
||||||
// No need to pass the request on to the client, as we've already rejected it.
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
VsockEventType::Connected => {}
|
|
||||||
VsockEventType::Disconnected { reason } => {
|
|
||||||
// Wait until client reads all data before removing connection.
|
|
||||||
if connection.buffer.is_empty() {
|
|
||||||
if reason == DisconnectReason::Shutdown {
|
|
||||||
self.driver.lock().force_close(&connection.info)?;
|
|
||||||
}
|
|
||||||
self.connections.swap_remove(connection_index);
|
|
||||||
} else {
|
|
||||||
connection.peer_requested_shutdown = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
VsockEventType::Received { .. } => {
|
|
||||||
// Already copied the buffer in the callback above.
|
|
||||||
}
|
|
||||||
VsockEventType::CreditRequest => {
|
|
||||||
// If the peer requested credit, send an update.
|
|
||||||
self.driver.lock().credit_update(&connection.info)?;
|
|
||||||
// No need to pass the request on to the client, we've already handled it.
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
VsockEventType::CreditUpdate => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(event))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Reads data received from the given connection.
|
|
||||||
pub fn recv(
|
|
||||||
&mut self,
|
|
||||||
peer: VsockAddr,
|
|
||||||
src_port: u32,
|
|
||||||
buffer: &mut [u8],
|
|
||||||
) -> Result<usize, SocketError> {
|
|
||||||
debug!("connections is {:?}", self.connections);
|
|
||||||
let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
|
|
||||||
|
|
||||||
// Copy from ring buffer
|
|
||||||
let bytes_read = connection.buffer.drain(buffer);
|
|
||||||
|
|
||||||
connection.info.done_forwarding(bytes_read);
|
|
||||||
|
|
||||||
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
|
||||||
// connection.
|
|
||||||
if connection.peer_requested_shutdown && connection.buffer.is_empty() {
|
|
||||||
self.driver.lock().force_close(&connection.info)?;
|
|
||||||
self.connections.swap_remove(connection_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(bytes_read)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Blocks until we get some event from the vsock device.
|
|
||||||
pub fn wait_for_event(&mut self) -> Result<VsockEvent, SocketError> {
|
|
||||||
loop {
|
|
||||||
if let Some(event) = self.poll()? {
|
|
||||||
return Ok(event);
|
|
||||||
} else {
|
|
||||||
spin_loop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Requests to shut down the connection cleanly.
|
|
||||||
///
|
|
||||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
|
||||||
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
|
|
||||||
/// shutdown.
|
|
||||||
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
|
|
||||||
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
|
||||||
|
|
||||||
self.driver.lock().shutdown(&connection.info)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Forcibly closes the connection without waiting for the peer.
|
|
||||||
pub fn force_close(
|
|
||||||
&mut self,
|
|
||||||
destination: VsockAddr,
|
|
||||||
src_port: u32,
|
|
||||||
) -> Result<(), SocketError> {
|
|
||||||
let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
|
||||||
|
|
||||||
self.driver.lock().force_close(&connection.info)?;
|
|
||||||
|
|
||||||
self.connections.swap_remove(index);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the connection from the given list matching the given peer address and local port, and
|
|
||||||
/// its index.
|
|
||||||
///
|
|
||||||
/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
|
|
||||||
fn get_connection(
|
|
||||||
connections: &mut [Connection],
|
|
||||||
peer: VsockAddr,
|
|
||||||
local_port: u32,
|
|
||||||
) -> core::result::Result<(usize, &mut Connection), SocketError> {
|
|
||||||
connections
|
|
||||||
.iter_mut()
|
|
||||||
.enumerate()
|
|
||||||
.find(|(_, connection)| {
|
|
||||||
connection.info.dst == peer && connection.info.src_port == local_port
|
|
||||||
})
|
|
||||||
.ok_or(SocketError::NotConnected)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the connection from the given list matching the event, if any, and its index.
|
|
||||||
fn get_connection_for_event<'a>(
|
|
||||||
connections: &'a mut [Connection],
|
|
||||||
event: &VsockEvent,
|
|
||||||
local_cid: u64,
|
|
||||||
) -> Option<(usize, &'a mut Connection)> {
|
|
||||||
connections
|
|
||||||
.iter_mut()
|
|
||||||
.enumerate()
|
|
||||||
.find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Connection {
|
|
||||||
info: ConnectionInfo,
|
|
||||||
buffer: RingBuffer,
|
|
||||||
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
|
|
||||||
/// still data in the buffer.
|
|
||||||
peer_requested_shutdown: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Connection {
|
|
||||||
fn new(peer: VsockAddr, local_port: u32) -> Self {
|
|
||||||
let mut info = ConnectionInfo::new(peer, local_port);
|
|
||||||
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
|
|
||||||
Self {
|
|
||||||
info,
|
|
||||||
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
|
|
||||||
peer_requested_shutdown: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct RingBuffer {
|
|
||||||
buffer: Box<[u8]>,
|
|
||||||
/// The number of bytes currently in the buffer.
|
|
||||||
used: usize,
|
|
||||||
/// The index of the first used byte in the buffer.
|
|
||||||
start: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RingBuffer {
|
|
||||||
pub fn new(capacity: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
// FIXME: if the capacity is excessive, elements move will be executed.
|
|
||||||
buffer: vec![0; capacity].into_boxed_slice(),
|
|
||||||
used: 0,
|
|
||||||
start: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// Returns the number of bytes currently used in the buffer.
|
|
||||||
pub fn used(&self) -> usize {
|
|
||||||
self.used
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true iff there are currently no bytes in the buffer.
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.used == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of bytes currently free in the buffer.
|
|
||||||
pub fn available(&self) -> usize {
|
|
||||||
self.buffer.len() - self.used
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Adds the given bytes to the buffer if there is enough capacity for them all.
|
|
||||||
///
|
|
||||||
/// Returns true if they were added, or false if they were not.
|
|
||||||
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
|
||||||
if bytes.len() > self.available() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The index of the first available position in the buffer.
|
|
||||||
let first_available = (self.start + self.used) % self.buffer.len();
|
|
||||||
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
|
|
||||||
// `buffer.len()`.
|
|
||||||
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
|
|
||||||
self.buffer[first_available..first_available + copy_length_before_wraparound]
|
|
||||||
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
|
|
||||||
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
|
|
||||||
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
|
|
||||||
}
|
|
||||||
self.used += bytes.len();
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
|
|
||||||
/// buffer.
|
|
||||||
pub fn drain(&mut self, out: &mut [u8]) -> usize {
|
|
||||||
let bytes_read = min(self.used, out.len());
|
|
||||||
|
|
||||||
// The number of bytes to copy out between `start` and the end of the buffer.
|
|
||||||
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
|
|
||||||
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
|
|
||||||
let read_after_wraparound = bytes_read
|
|
||||||
.checked_sub(read_before_wraparound)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
out[0..read_before_wraparound]
|
|
||||||
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
|
|
||||||
out[read_before_wraparound..bytes_read]
|
|
||||||
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
|
|
||||||
|
|
||||||
self.used -= bytes_read;
|
|
||||||
self.start = (self.start + bytes_read) % self.buffer.len();
|
|
||||||
|
|
||||||
bytes_read
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,7 +4,6 @@
|
|||||||
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
|
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
|
||||||
|
|
||||||
use aster_frame::sync::SpinLock;
|
use aster_frame::sync::SpinLock;
|
||||||
use component::ComponentInitError;
|
|
||||||
use spin::Once;
|
use spin::Once;
|
||||||
|
|
||||||
use self::device::SocketDevice;
|
use self::device::SocketDevice;
|
||||||
@ -14,50 +13,59 @@ pub mod connect;
|
|||||||
pub mod device;
|
pub mod device;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod header;
|
pub mod header;
|
||||||
pub mod manager;
|
|
||||||
|
|
||||||
pub static DEVICE_NAME: &str = "Virtio-Vsock";
|
pub static DEVICE_NAME: &str = "Virtio-Vsock";
|
||||||
pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static;
|
pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static;
|
||||||
|
|
||||||
pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
|
pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
|
||||||
COMPONENT
|
VSOCK_DEVICE_TABLE
|
||||||
.get()
|
.get()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.vsock_device_table
|
|
||||||
.lock()
|
.lock()
|
||||||
.insert(name, device);
|
.insert(name, (Arc::new(SpinLock::new(Vec::new())), device));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
|
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
|
||||||
let lock = COMPONENT.get().unwrap().vsock_device_table.lock();
|
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||||
let device = lock.get(str)?;
|
let (_, device) = lock.get(str)?;
|
||||||
Some(device.clone())
|
Some(device.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
|
pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
|
||||||
let vsock_devs = COMPONENT.get().unwrap().vsock_device_table.lock();
|
let vsock_devs = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||||
vsock_devs
|
vsock_devs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, device)| (name.clone(), device.clone()))
|
.map(|(name, (_, device))| (name.clone(), device.clone()))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
static COMPONENT: Once<Component> = Once::new();
|
pub fn register_recv_callback(name: &str, callback: impl VsockDeviceIrqHandler) {
|
||||||
|
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||||
pub fn component_init() -> Result<(), ComponentInitError> {
|
let Some((callbacks, _)) = lock.get(name) else {
|
||||||
let a = Component::init()?;
|
return;
|
||||||
COMPONENT.call_once(|| a);
|
};
|
||||||
Ok(())
|
callbacks.lock().push(Arc::new(callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Component {
|
pub fn handle_recv_irq(name: &str) {
|
||||||
vsock_device_table: SpinLock<BTreeMap<String, Arc<SpinLock<SocketDevice>>>>,
|
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||||
}
|
let Some((callbacks, _)) = lock.get(name) else {
|
||||||
|
return;
|
||||||
impl Component {
|
};
|
||||||
pub fn init() -> Result<Self, ComponentInitError> {
|
let callbacks = callbacks.clone();
|
||||||
Ok(Self {
|
let lock = callbacks.lock();
|
||||||
vsock_device_table: SpinLock::new(BTreeMap::new()),
|
for callback in lock.iter() {
|
||||||
})
|
callback.call(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn init() {
|
||||||
|
VSOCK_DEVICE_TABLE.call_once(|| SpinLock::new(BTreeMap::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
type VsockDeviceIrqHandlerListRef = Arc<SpinLock<Vec<Arc<dyn VsockDeviceIrqHandler>>>>;
|
||||||
|
type VsockDeviceRef = Arc<SpinLock<SocketDevice>>;
|
||||||
|
|
||||||
|
pub static VSOCK_DEVICE_TABLE: Once<
|
||||||
|
SpinLock<BTreeMap<String, (VsockDeviceIrqHandlerListRef, VsockDeviceRef)>>,
|
||||||
|
> = Once::new();
|
||||||
|
@ -35,8 +35,8 @@ mod transport;
|
|||||||
fn virtio_component_init() -> Result<(), ComponentInitError> {
|
fn virtio_component_init() -> Result<(), ComponentInitError> {
|
||||||
// Find all devices and register them to the corresponding crate
|
// Find all devices and register them to the corresponding crate
|
||||||
transport::init();
|
transport::init();
|
||||||
// For vsock cmponent
|
// For vsock table static init
|
||||||
socket::component_init()?;
|
socket::init();
|
||||||
while let Some(mut transport) = pop_device_transport() {
|
while let Some(mut transport) = pop_device_transport() {
|
||||||
// Reset device
|
// Reset device
|
||||||
transport.set_device_status(DeviceStatus::empty()).unwrap();
|
transport.set_device_status(DeviceStatus::empty()).unwrap();
|
||||||
|
@ -31,6 +31,7 @@ TEST_APPS := \
|
|||||||
pthread \
|
pthread \
|
||||||
pty \
|
pty \
|
||||||
signal_c \
|
signal_c \
|
||||||
|
vsock \
|
||||||
|
|
||||||
# The C head and source files of all the apps, excluding the downloaded mongoose files
|
# The C head and source files of all the apps, excluding the downloaded mongoose files
|
||||||
C_SOURCES := $(shell find . -type f \( -name "*.c" -or -name "*.h" \) ! -name "mongoose.c" ! -name "mongoose.h")
|
C_SOURCES := $(shell find . -type f \( -name "*.c" -or -name "*.h" \) ! -name "mongoose.c" ! -name "mongoose.h")
|
||||||
|
13
regression/apps/scripts/run_vsock_test.sh
Normal file
13
regression/apps/scripts/run_vsock_test.sh
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
VSOCK_DIR=/regression/vsock
|
||||||
|
cd ${VSOCK_DIR}
|
||||||
|
|
||||||
|
echo "Start vsock test......"
|
||||||
|
# ./vsock_server
|
||||||
|
./vsock_client
|
||||||
|
echo "Vsock test passed."
|
5
regression/apps/vsock/Makefile
Normal file
5
regression/apps/vsock/Makefile
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
include ../test_common.mk
|
||||||
|
|
||||||
|
EXTRA_C_FLAGS :=
|
47
regression/apps/vsock/vsock_client.c
Normal file
47
regression/apps/vsock/vsock_client.c
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <linux/vm_sockets.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define PORT 1234
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
int sock;
|
||||||
|
char *hello = "Hello from client";
|
||||||
|
char buffer[1024] = { 0 };
|
||||||
|
struct sockaddr_vm serv_addr;
|
||||||
|
|
||||||
|
if ((sock = socket(AF_VSOCK, SOCK_STREAM, 0)) < 0) {
|
||||||
|
printf("\n Socket creation error\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("\n Create socket successfully!\n");
|
||||||
|
serv_addr.svm_family = AF_VSOCK;
|
||||||
|
serv_addr.svm_cid = VMADDR_CID_HOST;
|
||||||
|
serv_addr.svm_port = PORT;
|
||||||
|
|
||||||
|
if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) <
|
||||||
|
0) {
|
||||||
|
printf("\nConnection Failed \n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("\n Socket connect successfully!\n");
|
||||||
|
|
||||||
|
// Send message to the server and receive the reply
|
||||||
|
if (send(sock, hello, strlen(hello), 0) < 0) {
|
||||||
|
printf("\nSend Failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("Hello message sent\n");
|
||||||
|
if (read(sock, buffer, 1024) < 0) {
|
||||||
|
printf("\nRead Failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("Server: %s\n", buffer);
|
||||||
|
return 0;
|
||||||
|
}
|
61
regression/apps/vsock/vsock_server.c
Normal file
61
regression/apps/vsock/vsock_server.c
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <linux/vm_sockets.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define CID 3
|
||||||
|
#define PORT 4321
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
int sock, new_sock;
|
||||||
|
char *hello = "Hello from client";
|
||||||
|
char buffer[1024] = { 0 };
|
||||||
|
struct sockaddr_vm serv_addr, client_addr;
|
||||||
|
int addrlen = sizeof(client_addr);
|
||||||
|
|
||||||
|
if ((sock = socket(AF_VSOCK, SOCK_STREAM, 0)) < 0) {
|
||||||
|
printf("\n Socket creation error\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("\nCreate socket successfully\n");
|
||||||
|
serv_addr.svm_family = AF_VSOCK;
|
||||||
|
serv_addr.svm_cid = CID;
|
||||||
|
serv_addr.svm_port = PORT;
|
||||||
|
|
||||||
|
if (bind(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {
|
||||||
|
printf("\nBind Failed \n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("\nBind socket successfully\n");
|
||||||
|
|
||||||
|
if (listen(sock, 3) < 0) {
|
||||||
|
printf("\nListen Failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("\nListen socket successfully\n");
|
||||||
|
|
||||||
|
if ((new_sock = accept(sock, (struct sockaddr *)&client_addr,
|
||||||
|
(socklen_t *)&addrlen)) < 0) {
|
||||||
|
printf("\nAccept Failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("\nAccept socket successfully\n");
|
||||||
|
|
||||||
|
// Send message to the server and receive the reply
|
||||||
|
if (read(new_sock, buffer, 1024) < 0) {
|
||||||
|
printf("\nRead Failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("Client: %s\n", buffer);
|
||||||
|
if (send(new_sock, hello, strlen(hello), 0) < 0) {
|
||||||
|
printf("\nSend Failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf("Hello message sent\n");
|
||||||
|
return 0;
|
||||||
|
}
|
@ -55,3 +55,4 @@ export HOST_PORT=8888
|
|||||||
iperf3 -s -B $HOST_ADDR -p $HOST_PORT -D # Start the server as a daemon
|
iperf3 -s -B $HOST_ADDR -p $HOST_PORT -D # Start the server as a daemon
|
||||||
iperf3 -c $HOST_ADDR -p $HOST_PORT # Start the client
|
iperf3 -c $HOST_ADDR -p $HOST_PORT # Start the client
|
||||||
```
|
```
|
||||||
|
Note that [a variant of iperf3](https://github.com/stefano-garzarella/iperf-vsock) can measure the performance of `vsock`.
|
@ -1,16 +0,0 @@
|
|||||||
import socket
|
|
||||||
|
|
||||||
client_socket = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
|
|
||||||
CID = socket.VMADDR_CID_HOST
|
|
||||||
PORT = 1234
|
|
||||||
vm_cid = 3
|
|
||||||
server_port = 4321
|
|
||||||
client_socket.bind((CID, PORT))
|
|
||||||
client_socket.connect((vm_cid, server_port))
|
|
||||||
|
|
||||||
client_socket.sendall(b'Hello from host')
|
|
||||||
|
|
||||||
response = client_socket.recv(4096)
|
|
||||||
print(f'Received: {response.decode()}')
|
|
||||||
|
|
||||||
client_socket.close()
|
|
@ -1,22 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import socket
|
|
||||||
|
|
||||||
CID = socket.VMADDR_CID_HOST
|
|
||||||
PORT = 1234
|
|
||||||
|
|
||||||
s = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
|
|
||||||
s.bind((CID, PORT))
|
|
||||||
s.listen()
|
|
||||||
(conn, (remote_cid, remote_port)) = s.accept()
|
|
||||||
|
|
||||||
print(f"Connection opened by cid={remote_cid} port={remote_port}")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
buf = conn.recv(64)
|
|
||||||
if not buf:
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"Received bytes: {buf}")
|
|
||||||
|
|
||||||
conn.send(b'Hello from host')
|
|
@ -60,7 +60,7 @@ MICROVM_QEMU_ARGS="\
|
|||||||
-device virtio-net-device,netdev=net01 \
|
-device virtio-net-device,netdev=net01 \
|
||||||
-device virtio-serial-device \
|
-device virtio-serial-device \
|
||||||
-device virtconsole,chardev=mux \
|
-device virtconsole,chardev=mux \
|
||||||
-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3 \
|
-device vhost-vsock-device,guest-cid=3 \
|
||||||
"
|
"
|
||||||
|
|
||||||
if [ "$1" = "microvm" ]; then
|
if [ "$1" = "microvm" ]; then
|
||||||
|
Loading…
x
Reference in New Issue
Block a user