mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-18 12:06:43 +00:00
Implement vsock socket layer
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
83a7937334
commit
ad140cec3c
@ -2,10 +2,6 @@
|
||||
|
||||
use align_ext::AlignExt;
|
||||
use bytes::BytesMut;
|
||||
use pod::Pod;
|
||||
|
||||
use super::header::VirtioVsockHdr;
|
||||
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
|
||||
|
||||
/// Buffer for receive packet
|
||||
#[derive(Debug)]
|
||||
@ -38,22 +34,6 @@ impl RxBuffer {
|
||||
pub fn buf_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.buf
|
||||
}
|
||||
|
||||
/// Packet payload slice, which is inner buffer excluding VirtioVsockHdr.
|
||||
pub fn packet(&self) -> &[u8] {
|
||||
debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len());
|
||||
&self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len]
|
||||
}
|
||||
|
||||
/// Mutable packet payload slice.
|
||||
pub fn packet_mut(&mut self) -> &mut [u8] {
|
||||
debug_assert!(VIRTIO_VSOCK_HDR_LEN + self.packet_len <= self.buf.len());
|
||||
&mut self.buf[VIRTIO_VSOCK_HDR_LEN..VIRTIO_VSOCK_HDR_LEN + self.packet_len]
|
||||
}
|
||||
|
||||
pub fn virtio_vsock_header(&self) -> VirtioVsockHdr {
|
||||
VirtioVsockHdr::from_bytes(&self.buf[..VIRTIO_VSOCK_HDR_LEN])
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer for transmit packet
|
||||
|
@ -8,17 +8,14 @@ use pod::Pod;
|
||||
use crate::transport::VirtioTransport;
|
||||
|
||||
bitflags! {
|
||||
/// Vsock feature bits since v1.2
|
||||
/// If no feature bit is set, only stream socket type is supported.
|
||||
/// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated.
|
||||
pub struct VsockFeatures: u64 {
|
||||
const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported.
|
||||
const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is supported.
|
||||
const VIRTIO_VSOCK_F_SEQPACKET = 1 << 1; //seqpacket socket type is not supported now.
|
||||
}
|
||||
}
|
||||
|
||||
impl VsockFeatures {
|
||||
pub fn support_features() -> Self {
|
||||
pub const fn supported_features() -> Self {
|
||||
VsockFeatures::VIRTIO_VSOCK_F_STREAM
|
||||
}
|
||||
}
|
||||
@ -32,9 +29,7 @@ pub struct VirtioVsockConfig {
|
||||
/// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
|
||||
/// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
|
||||
/// So we need to split the u64 guest_cid into two parts.
|
||||
// read only
|
||||
pub guest_cid_low: u32,
|
||||
// read only
|
||||
pub guest_cid_high: u32,
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,30 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
// Modified from vsock.rs in virtio-drivers project
|
||||
//
|
||||
// MIT License
|
||||
//
|
||||
// Copyright (c) 2022-2023 Ant Group
|
||||
// Copyright (c) 2019-2020 rCore Developers
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
//
|
||||
use log::debug;
|
||||
|
||||
use super::{
|
||||
@ -7,7 +32,7 @@ use super::{
|
||||
header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct VsockBufferStatus {
|
||||
pub buffer_allocation: u32,
|
||||
pub forward_count: u32,
|
||||
@ -24,7 +49,7 @@ pub enum DisconnectReason {
|
||||
}
|
||||
|
||||
/// Details of the type of an event received from a VirtIO socket.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum VsockEventType {
|
||||
/// The peer requests to establish a connection with us.
|
||||
ConnectionRequest,
|
||||
@ -47,7 +72,7 @@ pub enum VsockEventType {
|
||||
}
|
||||
|
||||
/// An event received from a VirtIO socket device.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct VsockEvent {
|
||||
/// The source of the event, i.e. the peer who sent it.
|
||||
pub source: VsockAddr,
|
||||
|
@ -17,7 +17,10 @@ use super::{
|
||||
VsockDeviceIrqHandler,
|
||||
};
|
||||
use crate::{
|
||||
device::{socket::register_device, VirtioDeviceError},
|
||||
device::{
|
||||
socket::{handle_recv_irq, register_device},
|
||||
VirtioDeviceError,
|
||||
},
|
||||
queue::{QueueError, VirtQueue},
|
||||
transport::VirtioTransport,
|
||||
};
|
||||
@ -27,10 +30,10 @@ const QUEUE_RECV: u16 = 0;
|
||||
const QUEUE_SEND: u16 = 1;
|
||||
const QUEUE_EVENT: u16 = 2;
|
||||
|
||||
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
|
||||
/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than `size_of::<VirtioVsockHdr>()`.
|
||||
const RX_BUFFER_SIZE: usize = 512;
|
||||
|
||||
/// Low-level driver for a Virtio socket device.
|
||||
/// Vsock device driver
|
||||
pub struct SocketDevice {
|
||||
config: VirtioVsockConfig,
|
||||
guest_cid: u64,
|
||||
@ -46,6 +49,7 @@ pub struct SocketDevice {
|
||||
}
|
||||
|
||||
impl SocketDevice {
|
||||
/// Create a new vsock device
|
||||
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
|
||||
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
|
||||
debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
|
||||
@ -95,9 +99,8 @@ impl SocketDevice {
|
||||
}
|
||||
|
||||
// Interrupt handler if vsock device receives some packet.
|
||||
// TODO: This will be handled by vsock socket layer.
|
||||
fn handle_vsock_event(_: &TrapFrame) {
|
||||
debug!("Packet received. This will be solved by socket layer");
|
||||
handle_recv_irq(super::DEVICE_NAME);
|
||||
}
|
||||
|
||||
device
|
||||
@ -119,28 +122,23 @@ impl SocketDevice {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the CID which has been assigned to this guest.
|
||||
/// Return the CID which has been assigned to this guest.
|
||||
pub fn guest_cid(&self) -> u64 {
|
||||
self.guest_cid
|
||||
}
|
||||
|
||||
/// Sends a request to connect to the given destination.
|
||||
///
|
||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
||||
/// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection
|
||||
/// before sending data.
|
||||
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
/// Send a connection request
|
||||
pub fn request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::Request as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
};
|
||||
// Sends a header only packet to the TX queue to connect the device to the listening socket
|
||||
// at the given destination.
|
||||
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Accepts the given connection from a peer.
|
||||
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
/// Send a response to peer, if peer start a sending request
|
||||
pub fn response(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::Response as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
@ -148,29 +146,7 @@ impl SocketDevice {
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Requests the peer to send us a credit update for the given connection.
|
||||
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::CreditRequest as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
};
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Tells the peer how much buffer space we have to receive data.
|
||||
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::CreditUpdate as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
};
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Requests to shut down the connection cleanly.
|
||||
///
|
||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
||||
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
|
||||
/// shutdown.
|
||||
/// Send a shutdown request
|
||||
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::Shutdown as u16,
|
||||
@ -179,8 +155,8 @@ impl SocketDevice {
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Forcibly closes the connection without waiting for the peer.
|
||||
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
/// Send a reset request to peer
|
||||
pub fn reset(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::Rst as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
@ -188,16 +164,29 @@ impl SocketDevice {
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Request the peer to send the credit info to us
|
||||
pub fn credit_request(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::CreditRequest as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
};
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
/// Tell the peer our credit info
|
||||
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
|
||||
let header = VirtioVsockHdr {
|
||||
op: VirtioVsockOp::CreditUpdate as u16,
|
||||
..connection_info.new_header(self.guest_cid)
|
||||
};
|
||||
self.send_packet_to_tx_queue(&header, &[])
|
||||
}
|
||||
|
||||
fn send_packet_to_tx_queue(
|
||||
&mut self,
|
||||
header: &VirtioVsockHdr,
|
||||
buffer: &[u8],
|
||||
) -> Result<(), SocketError> {
|
||||
// let (_token, _len) = self.send_queue.add_notify_wait_pop(
|
||||
// &[header.as_bytes(), buffer],
|
||||
// &mut [],
|
||||
// )?;
|
||||
|
||||
let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
|
||||
|
||||
if self.send_queue.should_notify() {
|
||||
@ -211,8 +200,7 @@ impl SocketDevice {
|
||||
|
||||
self.send_queue.pop_used()?;
|
||||
|
||||
// FORDEBUG
|
||||
// debug!("buffer in send_packet_to_tx_queue: {:?}",buffer);
|
||||
debug!("buffer in send_packet_to_tx_queue: {:?}", buffer);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -221,13 +209,19 @@ impl SocketDevice {
|
||||
connection_info: &mut ConnectionInfo,
|
||||
buffer_len: usize,
|
||||
) -> Result<(), SocketError> {
|
||||
debug!("connectin info {:?}", connection_info);
|
||||
debug!(
|
||||
"peer free from peer: {:?}, buffer len : {:?}",
|
||||
connection_info.peer_free(),
|
||||
buffer_len
|
||||
);
|
||||
if connection_info.peer_free() as usize >= buffer_len {
|
||||
Ok(())
|
||||
} else {
|
||||
// Request an update of the cached peer credit, if we haven't already done so, and tell
|
||||
// the caller to try again later.
|
||||
if !connection_info.has_pending_credit_request {
|
||||
self.request_credit(connection_info)?;
|
||||
self.credit_request(connection_info)?;
|
||||
connection_info.has_pending_credit_request = true;
|
||||
}
|
||||
Err(SocketError::InsufficientBufferSpaceInPeer)
|
||||
@ -252,6 +246,36 @@ impl SocketDevice {
|
||||
self.send_packet_to_tx_queue(&header, buffer)
|
||||
}
|
||||
|
||||
/// Receive bytes from peer, returns the header
|
||||
pub fn receive(
|
||||
&mut self,
|
||||
buffer: &mut [u8],
|
||||
// connection_info: &mut ConnectionInfo,
|
||||
) -> Result<VirtioVsockHdr, SocketError> {
|
||||
let (token, len) = self.recv_queue.pop_used()?;
|
||||
debug!(
|
||||
"receive packet in rx_queue: token = {}, len = {}",
|
||||
token, len
|
||||
);
|
||||
let mut rx_buffer = self
|
||||
.rx_buffers
|
||||
.remove(token as usize)
|
||||
.ok_or(QueueError::WrongToken)?;
|
||||
rx_buffer.set_packet_len(RX_BUFFER_SIZE);
|
||||
|
||||
let (header, payload) = read_header_and_body(rx_buffer.buf())?;
|
||||
// The length written should be equal to len(header)+len(packet)
|
||||
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
|
||||
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
||||
debug!("body is {:?}", payload);
|
||||
|
||||
assert!(buffer.len() >= payload.len());
|
||||
buffer[..payload.len()].copy_from_slice(payload);
|
||||
|
||||
self.add_rx_buffer(rx_buffer, token)?;
|
||||
Ok(header)
|
||||
}
|
||||
|
||||
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
|
||||
pub fn poll(
|
||||
&mut self,
|
||||
@ -261,39 +285,10 @@ impl SocketDevice {
|
||||
if !self.recv_queue.can_pop() {
|
||||
return Ok(None);
|
||||
}
|
||||
let (token, len) = self.recv_queue.pop_used()?;
|
||||
let mut body = RxBuffer::new(RX_BUFFER_SIZE);
|
||||
let header = self.receive(body.buf_mut())?;
|
||||
|
||||
let mut buffer = self
|
||||
.rx_buffers
|
||||
.remove(token as usize)
|
||||
.ok_or(QueueError::WrongToken)?;
|
||||
|
||||
let header = buffer.virtio_vsock_header();
|
||||
// The length written should be equal to len(header)+len(packet)
|
||||
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
|
||||
|
||||
buffer.set_packet_len(RX_BUFFER_SIZE);
|
||||
|
||||
let head_result = read_header_and_body(buffer.buf());
|
||||
|
||||
let Ok((header, body)) = head_result else {
|
||||
let ret = match head_result {
|
||||
Err(e) => Err(e),
|
||||
_ => Ok(None), //FIXME: this clause is never reached.
|
||||
};
|
||||
self.add_rx_buffer(buffer, token)?;
|
||||
return ret;
|
||||
};
|
||||
|
||||
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
||||
debug!("body is {:?}", body);
|
||||
|
||||
let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
|
||||
|
||||
// reuse the buffer and give it back to recv_queue.
|
||||
self.add_rx_buffer(buffer, token)?;
|
||||
|
||||
result
|
||||
VsockEvent::from_header(&header).and_then(|event| handler(event, body.buf()))
|
||||
}
|
||||
|
||||
/// Add a used rx buffer to recv queue,@index is only to check the correctness
|
||||
@ -310,7 +305,7 @@ impl SocketDevice {
|
||||
/// Negotiate features for the device specified bits 0~23
|
||||
pub(crate) fn negotiate_features(features: u64) -> u64 {
|
||||
let device_features = VsockFeatures::from_bits_truncate(features);
|
||||
let supported_features = VsockFeatures::support_features();
|
||||
let supported_features = VsockFeatures::supported_features();
|
||||
let vsock_features = device_features & supported_features;
|
||||
debug!("features negotiated: {:?}", vsock_features);
|
||||
vsock_features.bits()
|
||||
|
@ -1,14 +1,36 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
//! This file comes from virtio-drivers project
|
||||
//! This module contains the error from the VirtIO socket driver.
|
||||
|
||||
// Modified from error.rs in virtio-drivers project
|
||||
//
|
||||
// MIT License
|
||||
//
|
||||
// Copyright (c) 2022-2023 Ant Group
|
||||
// Copyright (c) 2019-2020 rCore Developers
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
//
|
||||
use core::{fmt, result};
|
||||
|
||||
use crate::queue::QueueError;
|
||||
|
||||
/// The error type of VirtIO socket driver.
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
#[derive(Debug)]
|
||||
pub enum SocketError {
|
||||
/// There is an existing connection.
|
||||
ConnectionExists,
|
||||
@ -39,33 +61,18 @@ pub enum SocketError {
|
||||
/// Recycled a wrong buffer.
|
||||
RecycledWrongBuffer,
|
||||
/// Queue Error
|
||||
QueueError(SocketQueueError),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum SocketQueueError {
|
||||
InvalidArgs,
|
||||
BufferTooSmall,
|
||||
NotReady,
|
||||
AlreadyUsed,
|
||||
WrongToken,
|
||||
}
|
||||
|
||||
impl From<QueueError> for SocketQueueError {
|
||||
fn from(value: QueueError) -> Self {
|
||||
match value {
|
||||
QueueError::InvalidArgs => Self::InvalidArgs,
|
||||
QueueError::BufferTooSmall => Self::BufferTooSmall,
|
||||
QueueError::NotReady => Self::NotReady,
|
||||
QueueError::AlreadyUsed => Self::AlreadyUsed,
|
||||
QueueError::WrongToken => Self::WrongToken,
|
||||
}
|
||||
}
|
||||
QueueError(QueueError),
|
||||
}
|
||||
|
||||
impl From<QueueError> for SocketError {
|
||||
fn from(value: QueueError) -> Self {
|
||||
Self::QueueError(SocketQueueError::from(value))
|
||||
Self::QueueError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<int_to_c_enum::TryFromIntError> for SocketError {
|
||||
fn from(_e: int_to_c_enum::TryFromIntError) -> Self {
|
||||
Self::InvalidNumber
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,32 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
// Modified from protocol.rs in virtio-drivers project
|
||||
//
|
||||
// MIT License
|
||||
//
|
||||
// Copyright (c) 2022-2023 Ant Group
|
||||
// Copyright (c) 2019-2020 rCore Developers
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
//
|
||||
use bitflags::bitflags;
|
||||
use int_to_c_enum::TryFromInt;
|
||||
use pod::Pod;
|
||||
|
||||
use super::error::{self, SocketError};
|
||||
@ -64,7 +90,7 @@ impl VirtioVsockHdr {
|
||||
}
|
||||
|
||||
pub fn op(&self) -> error::Result<VirtioVsockOp> {
|
||||
self.op.try_into()
|
||||
VirtioVsockOp::try_from(self.op).map_err(|err| err.into())
|
||||
}
|
||||
|
||||
pub fn source(&self) -> VsockAddr {
|
||||
@ -90,7 +116,7 @@ impl VirtioVsockHdr {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, TryFromInt)]
|
||||
#[repr(u16)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum VirtioVsockOp {
|
||||
@ -112,26 +138,6 @@ pub enum VirtioVsockOp {
|
||||
CreditRequest = 7,
|
||||
}
|
||||
|
||||
/// TODO: This could be optimized by upgrading [int_to_c_enum::TryFromIntError] to carrying the invalid int number
|
||||
impl TryFrom<u16> for VirtioVsockOp {
|
||||
type Error = SocketError;
|
||||
|
||||
fn try_from(v: u16) -> Result<Self, Self::Error> {
|
||||
let op = match v {
|
||||
0 => Self::Invalid,
|
||||
1 => Self::Request,
|
||||
2 => Self::Response,
|
||||
3 => Self::Rst,
|
||||
4 => Self::Shutdown,
|
||||
5 => Self::Rw,
|
||||
6 => Self::CreditUpdate,
|
||||
7 => Self::CreditRequest,
|
||||
_ => return Err(SocketError::UnknownOperation(v)),
|
||||
};
|
||||
Ok(op)
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
#[repr(C)]
|
||||
#[derive(Default, Pod)]
|
||||
|
@ -1,359 +0,0 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{boxed::Box, sync::Arc, vec, vec::Vec};
|
||||
use core::{cmp::min, hint::spin_loop};
|
||||
|
||||
use aster_frame::sync::SpinLock;
|
||||
use log::debug;
|
||||
|
||||
use super::{
|
||||
connect::{ConnectionInfo, DisconnectReason, VsockEvent, VsockEventType},
|
||||
device::SocketDevice,
|
||||
header::VsockAddr,
|
||||
};
|
||||
use crate::device::socket::error::SocketError;
|
||||
|
||||
const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
|
||||
|
||||
pub struct VsockConnectionManager {
|
||||
driver: Arc<SpinLock<SocketDevice>>,
|
||||
connections: Vec<Connection>,
|
||||
listening_ports: Vec<u32>,
|
||||
}
|
||||
|
||||
impl VsockConnectionManager {
|
||||
/// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
|
||||
pub fn new(driver: Arc<SpinLock<SocketDevice>>) -> Self {
|
||||
Self {
|
||||
driver,
|
||||
connections: Vec::new(),
|
||||
listening_ports: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the CID which has been assigned to this guest.
|
||||
pub fn guest_cid(&self) -> u64 {
|
||||
self.driver.lock().guest_cid()
|
||||
}
|
||||
|
||||
/// Allows incoming connections on the given port number.
|
||||
pub fn listen(&mut self, port: u32) {
|
||||
if !self.listening_ports.contains(&port) {
|
||||
self.listening_ports.push(port);
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops allowing incoming connections on the given port number.
|
||||
pub fn unlisten(&mut self, port: u32) {
|
||||
self.listening_ports.retain(|p| *p != port)
|
||||
}
|
||||
/// Sends a request to connect to the given destination.
|
||||
///
|
||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
||||
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection
|
||||
/// before sending data.
|
||||
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
|
||||
if self.connections.iter().any(|connection| {
|
||||
connection.info.dst == destination && connection.info.src_port == src_port
|
||||
}) {
|
||||
return Err(SocketError::ConnectionExists);
|
||||
}
|
||||
|
||||
let new_connection = Connection::new(destination, src_port);
|
||||
|
||||
self.driver.lock().connect(&new_connection.info)?;
|
||||
debug!("Connection requested: {:?}", new_connection.info);
|
||||
self.connections.push(new_connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sends the buffer to the destination.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
destination: VsockAddr,
|
||||
src_port: u32,
|
||||
buffer: &[u8],
|
||||
) -> Result<(), SocketError> {
|
||||
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
||||
|
||||
self.driver.lock().send(buffer, &mut connection.info)
|
||||
}
|
||||
|
||||
/// Polls the vsock device to receive data or other updates.
|
||||
pub fn poll(&mut self) -> Result<Option<VsockEvent>, SocketError> {
|
||||
let guest_cid = self.driver.lock().guest_cid();
|
||||
let connections = &mut self.connections;
|
||||
|
||||
let result = self.driver.lock().poll(|event, body| {
|
||||
let connection = get_connection_for_event(connections, &event, guest_cid);
|
||||
|
||||
// Skip events which don't match any connection we know about, unless they are a
|
||||
// connection request.
|
||||
let connection = if let Some((_, connection)) = connection {
|
||||
connection
|
||||
} else if let VsockEventType::ConnectionRequest = event.event_type {
|
||||
// If the requested connection already exists or the CID isn't ours, ignore it.
|
||||
if connection.is_some() || event.destination.cid != guest_cid {
|
||||
return Ok(None);
|
||||
}
|
||||
// Add the new connection to our list, at least for now. It will be removed again
|
||||
// below if we weren't listening on the port.
|
||||
connections.push(Connection::new(event.source, event.destination.port));
|
||||
connections.last_mut().unwrap()
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Update stored connection info.
|
||||
connection.info.update_for_event(&event);
|
||||
|
||||
if let VsockEventType::Received { length } = event.event_type {
|
||||
// Copy to buffer
|
||||
if !connection.buffer.add(body) {
|
||||
return Err(SocketError::OutputBufferTooShort(length));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(event))
|
||||
})?;
|
||||
|
||||
let Some(event) = result else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// The connection must exist because we found it above in the callback.
|
||||
let (connection_index, connection) =
|
||||
get_connection_for_event(connections, &event, guest_cid).unwrap();
|
||||
|
||||
match event.event_type {
|
||||
VsockEventType::ConnectionRequest => {
|
||||
if self.listening_ports.contains(&event.destination.port) {
|
||||
self.driver.lock().accept(&connection.info)?;
|
||||
} else {
|
||||
// Reject the connection request and remove it from our list.
|
||||
self.driver.lock().force_close(&connection.info)?;
|
||||
self.connections.swap_remove(connection_index);
|
||||
|
||||
// No need to pass the request on to the client, as we've already rejected it.
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
VsockEventType::Connected => {}
|
||||
VsockEventType::Disconnected { reason } => {
|
||||
// Wait until client reads all data before removing connection.
|
||||
if connection.buffer.is_empty() {
|
||||
if reason == DisconnectReason::Shutdown {
|
||||
self.driver.lock().force_close(&connection.info)?;
|
||||
}
|
||||
self.connections.swap_remove(connection_index);
|
||||
} else {
|
||||
connection.peer_requested_shutdown = true;
|
||||
}
|
||||
}
|
||||
VsockEventType::Received { .. } => {
|
||||
// Already copied the buffer in the callback above.
|
||||
}
|
||||
VsockEventType::CreditRequest => {
|
||||
// If the peer requested credit, send an update.
|
||||
self.driver.lock().credit_update(&connection.info)?;
|
||||
// No need to pass the request on to the client, we've already handled it.
|
||||
return Ok(None);
|
||||
}
|
||||
VsockEventType::CreditUpdate => {}
|
||||
}
|
||||
|
||||
Ok(Some(event))
|
||||
}
|
||||
|
||||
/// Reads data received from the given connection.
|
||||
pub fn recv(
|
||||
&mut self,
|
||||
peer: VsockAddr,
|
||||
src_port: u32,
|
||||
buffer: &mut [u8],
|
||||
) -> Result<usize, SocketError> {
|
||||
debug!("connections is {:?}", self.connections);
|
||||
let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
|
||||
|
||||
// Copy from ring buffer
|
||||
let bytes_read = connection.buffer.drain(buffer);
|
||||
|
||||
connection.info.done_forwarding(bytes_read);
|
||||
|
||||
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
||||
// connection.
|
||||
if connection.peer_requested_shutdown && connection.buffer.is_empty() {
|
||||
self.driver.lock().force_close(&connection.info)?;
|
||||
self.connections.swap_remove(connection_index);
|
||||
}
|
||||
|
||||
Ok(bytes_read)
|
||||
}
|
||||
|
||||
/// Blocks until we get some event from the vsock device.
|
||||
pub fn wait_for_event(&mut self) -> Result<VsockEvent, SocketError> {
|
||||
loop {
|
||||
if let Some(event) = self.poll()? {
|
||||
return Ok(event);
|
||||
} else {
|
||||
spin_loop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Requests to shut down the connection cleanly.
|
||||
///
|
||||
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
||||
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
|
||||
/// shutdown.
|
||||
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
|
||||
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
||||
|
||||
self.driver.lock().shutdown(&connection.info)
|
||||
}
|
||||
|
||||
/// Forcibly closes the connection without waiting for the peer.
|
||||
pub fn force_close(
|
||||
&mut self,
|
||||
destination: VsockAddr,
|
||||
src_port: u32,
|
||||
) -> Result<(), SocketError> {
|
||||
let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
||||
|
||||
self.driver.lock().force_close(&connection.info)?;
|
||||
|
||||
self.connections.swap_remove(index);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the connection from the given list matching the given peer address and local port, and
|
||||
/// its index.
|
||||
///
|
||||
/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
|
||||
fn get_connection(
|
||||
connections: &mut [Connection],
|
||||
peer: VsockAddr,
|
||||
local_port: u32,
|
||||
) -> core::result::Result<(usize, &mut Connection), SocketError> {
|
||||
connections
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.find(|(_, connection)| {
|
||||
connection.info.dst == peer && connection.info.src_port == local_port
|
||||
})
|
||||
.ok_or(SocketError::NotConnected)
|
||||
}
|
||||
|
||||
/// Returns the connection from the given list matching the event, if any, and its index.
|
||||
fn get_connection_for_event<'a>(
|
||||
connections: &'a mut [Connection],
|
||||
event: &VsockEvent,
|
||||
local_cid: u64,
|
||||
) -> Option<(usize, &'a mut Connection)> {
|
||||
connections
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Connection {
|
||||
info: ConnectionInfo,
|
||||
buffer: RingBuffer,
|
||||
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
|
||||
/// still data in the buffer.
|
||||
peer_requested_shutdown: bool,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
fn new(peer: VsockAddr, local_port: u32) -> Self {
|
||||
let mut info = ConnectionInfo::new(peer, local_port);
|
||||
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
|
||||
Self {
|
||||
info,
|
||||
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||
peer_requested_shutdown: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RingBuffer {
|
||||
buffer: Box<[u8]>,
|
||||
/// The number of bytes currently in the buffer.
|
||||
used: usize,
|
||||
/// The index of the first used byte in the buffer.
|
||||
start: usize,
|
||||
}
|
||||
|
||||
impl RingBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
// FIXME: if the capacity is excessive, elements move will be executed.
|
||||
buffer: vec![0; capacity].into_boxed_slice(),
|
||||
used: 0,
|
||||
start: 0,
|
||||
}
|
||||
}
|
||||
/// Returns the number of bytes currently used in the buffer.
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Returns true iff there are currently no bytes in the buffer.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.used == 0
|
||||
}
|
||||
|
||||
/// Returns the number of bytes currently free in the buffer.
|
||||
pub fn available(&self) -> usize {
|
||||
self.buffer.len() - self.used
|
||||
}
|
||||
|
||||
/// Adds the given bytes to the buffer if there is enough capacity for them all.
|
||||
///
|
||||
/// Returns true if they were added, or false if they were not.
|
||||
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||
if bytes.len() > self.available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The index of the first available position in the buffer.
|
||||
let first_available = (self.start + self.used) % self.buffer.len();
|
||||
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
|
||||
// `buffer.len()`.
|
||||
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
|
||||
self.buffer[first_available..first_available + copy_length_before_wraparound]
|
||||
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
|
||||
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
|
||||
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
|
||||
}
|
||||
self.used += bytes.len();
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
|
||||
/// buffer.
|
||||
pub fn drain(&mut self, out: &mut [u8]) -> usize {
|
||||
let bytes_read = min(self.used, out.len());
|
||||
|
||||
// The number of bytes to copy out between `start` and the end of the buffer.
|
||||
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
|
||||
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
|
||||
let read_after_wraparound = bytes_read
|
||||
.checked_sub(read_before_wraparound)
|
||||
.unwrap_or_default();
|
||||
|
||||
out[0..read_before_wraparound]
|
||||
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
|
||||
out[read_before_wraparound..bytes_read]
|
||||
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
|
||||
|
||||
self.used -= bytes_read;
|
||||
self.start = (self.start + bytes_read) % self.buffer.len();
|
||||
|
||||
bytes_read
|
||||
}
|
||||
}
|
@ -4,7 +4,6 @@
|
||||
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
|
||||
|
||||
use aster_frame::sync::SpinLock;
|
||||
use component::ComponentInitError;
|
||||
use spin::Once;
|
||||
|
||||
use self::device::SocketDevice;
|
||||
@ -14,50 +13,59 @@ pub mod connect;
|
||||
pub mod device;
|
||||
pub mod error;
|
||||
pub mod header;
|
||||
pub mod manager;
|
||||
|
||||
pub static DEVICE_NAME: &str = "Virtio-Vsock";
|
||||
pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static;
|
||||
|
||||
pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
|
||||
COMPONENT
|
||||
VSOCK_DEVICE_TABLE
|
||||
.get()
|
||||
.unwrap()
|
||||
.vsock_device_table
|
||||
.lock()
|
||||
.insert(name, device);
|
||||
.insert(name, (Arc::new(SpinLock::new(Vec::new())), device));
|
||||
}
|
||||
|
||||
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
|
||||
let lock = COMPONENT.get().unwrap().vsock_device_table.lock();
|
||||
let device = lock.get(str)?;
|
||||
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||
let (_, device) = lock.get(str)?;
|
||||
Some(device.clone())
|
||||
}
|
||||
|
||||
pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
|
||||
let vsock_devs = COMPONENT.get().unwrap().vsock_device_table.lock();
|
||||
let vsock_devs = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||
vsock_devs
|
||||
.iter()
|
||||
.map(|(name, device)| (name.clone(), device.clone()))
|
||||
.map(|(name, (_, device))| (name.clone(), device.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
static COMPONENT: Once<Component> = Once::new();
|
||||
|
||||
pub fn component_init() -> Result<(), ComponentInitError> {
|
||||
let a = Component::init()?;
|
||||
COMPONENT.call_once(|| a);
|
||||
Ok(())
|
||||
pub fn register_recv_callback(name: &str, callback: impl VsockDeviceIrqHandler) {
|
||||
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||
let Some((callbacks, _)) = lock.get(name) else {
|
||||
return;
|
||||
};
|
||||
callbacks.lock().push(Arc::new(callback));
|
||||
}
|
||||
|
||||
struct Component {
|
||||
vsock_device_table: SpinLock<BTreeMap<String, Arc<SpinLock<SocketDevice>>>>,
|
||||
}
|
||||
|
||||
impl Component {
|
||||
pub fn init() -> Result<Self, ComponentInitError> {
|
||||
Ok(Self {
|
||||
vsock_device_table: SpinLock::new(BTreeMap::new()),
|
||||
})
|
||||
pub fn handle_recv_irq(name: &str) {
|
||||
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
|
||||
let Some((callbacks, _)) = lock.get(name) else {
|
||||
return;
|
||||
};
|
||||
let callbacks = callbacks.clone();
|
||||
let lock = callbacks.lock();
|
||||
for callback in lock.iter() {
|
||||
callback.call(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init() {
|
||||
VSOCK_DEVICE_TABLE.call_once(|| SpinLock::new(BTreeMap::new()));
|
||||
}
|
||||
|
||||
type VsockDeviceIrqHandlerListRef = Arc<SpinLock<Vec<Arc<dyn VsockDeviceIrqHandler>>>>;
|
||||
type VsockDeviceRef = Arc<SpinLock<SocketDevice>>;
|
||||
|
||||
pub static VSOCK_DEVICE_TABLE: Once<
|
||||
SpinLock<BTreeMap<String, (VsockDeviceIrqHandlerListRef, VsockDeviceRef)>>,
|
||||
> = Once::new();
|
||||
|
@ -35,8 +35,8 @@ mod transport;
|
||||
fn virtio_component_init() -> Result<(), ComponentInitError> {
|
||||
// Find all devices and register them to the corresponding crate
|
||||
transport::init();
|
||||
// For vsock cmponent
|
||||
socket::component_init()?;
|
||||
// For vsock table static init
|
||||
socket::init();
|
||||
while let Some(mut transport) = pop_device_transport() {
|
||||
// Reset device
|
||||
transport.set_device_status(DeviceStatus::empty()).unwrap();
|
||||
|
Reference in New Issue
Block a user