Fix format and clippy errors

This commit is contained in:
Anmin Liu
2024-03-28 07:47:05 +00:00
committed by Tate, Hongliang Tian
parent 52f808e315
commit be45f0ee72
10 changed files with 194 additions and 139 deletions

View File

@ -1,9 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
use log::{info, debug};
use alloc::string::ToString;
use aster_virtio::{
self,
device::socket::{header::VsockAddr, manager::VsockConnectionManager, DEVICE_NAME},
};
use component::ComponentInitError;
use aster_virtio::{self, device::socket::{header::VsockAddr, device::SocketDevice, manager::VsockConnectionManager, DEVICE_NAME}};
use log::{debug, info};
pub fn init() {
// print all the input device to make sure input crate will compile
@ -14,8 +16,7 @@ pub fn init() {
// let _ = socket_device_server_test();
}
fn socket_device_client_test() -> Result<(),ComponentInitError> {
fn socket_device_client_test() -> Result<(), ComponentInitError> {
let host_cid = 2;
let guest_cid = 3;
let host_port = 1234;
@ -28,28 +29,33 @@ fn socket_device_client_test() -> Result<(),ComponentInitError> {
let hello_from_host = "Hello from host";
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
assert_eq!(device.lock().guest_cid(),guest_cid);
assert_eq!(device.lock().guest_cid(), guest_cid);
let mut socket = VsockConnectionManager::new(device);
socket.connect(host_address, guest_port).unwrap();
socket.wait_for_event().unwrap(); // wait for connect response
socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap();
debug!("The buffer {:?} is sent, start receiving",hello_from_guest.as_bytes());
socket
.send(host_address, guest_port, hello_from_guest.as_bytes())
.unwrap();
debug!(
"The buffer {:?} is sent, start receiving",
hello_from_guest.as_bytes()
);
socket.wait_for_event().unwrap(); // wait for recv
let mut buffer = [0u8; 64];
let event = socket.recv(host_address, guest_port,&mut buffer).unwrap();
let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
assert_eq!(
&buffer[0..hello_from_host.len()],
hello_from_host.as_bytes()
);
socket.force_close(host_address,guest_port).unwrap();
socket.force_close(host_address, guest_port).unwrap();
debug!("The final event: {:?}",event);
debug!("The final event: {:?}", event);
Ok(())
}
pub fn socket_device_server_test() -> Result<(),ComponentInitError>{
pub fn socket_device_server_test() -> Result<(), ComponentInitError> {
let host_cid = 2;
let guest_cid = 3;
let host_port = 1234;
@ -62,26 +68,31 @@ pub fn socket_device_server_test() -> Result<(),ComponentInitError>{
let hello_from_host = "Hello from host";
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
assert_eq!(device.lock().guest_cid(),guest_cid);
assert_eq!(device.lock().guest_cid(), guest_cid);
let mut socket = VsockConnectionManager::new(device);
socket.listen(4321);
socket.wait_for_event().unwrap(); // wait for connect request
socket.wait_for_event().unwrap(); // wait for recv
let mut buffer = [0u8; 64];
let event = socket.recv(host_address, guest_port,&mut buffer).unwrap();
let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
assert_eq!(
&buffer[0..hello_from_host.len()],
hello_from_host.as_bytes()
);
debug!("The buffer {:?} is received, start sending {:?}", &buffer[0..hello_from_host.len()],hello_from_guest.as_bytes());
socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap();
debug!(
"The buffer {:?} is received, start sending {:?}",
&buffer[0..hello_from_host.len()],
hello_from_guest.as_bytes()
);
socket
.send(host_address, guest_port, hello_from_guest.as_bytes())
.unwrap();
socket.shutdown(host_address,guest_port).unwrap();
socket.shutdown(host_address, guest_port).unwrap();
let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown
debug!("The final event: {:?}",event);
debug!("The final event: {:?}", event);
Ok(())
}

View File

@ -1,11 +1,11 @@
//! This module is adapted from network/buffer.rs
// SPDX-License-Identifier: MPL-2.0
use align_ext::AlignExt;
use bytes::BytesMut;
use pod::Pod;
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
use super::header::VirtioVsockHdr;
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
/// Buffer for receive packet
#[derive(Debug)]
@ -86,7 +86,7 @@ impl TxBuffer {
/// Buffer for event buffer
#[derive(Debug)]
pub struct EventBuffer{
pub struct EventBuffer {
id: u32,
}
@ -96,4 +96,4 @@ pub struct EventBuffer{
pub enum EventIDType {
#[default]
VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0,
}
}

View File

@ -1,13 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use aster_frame::io_mem::IoMem;
use pod::Pod;
use aster_util::{safe_ptr::SafePtr};
use aster_util::safe_ptr::SafePtr;
use bitflags::bitflags;
use pod::Pod;
use crate::transport::{self, VirtioTransport};
use crate::transport::VirtioTransport;
bitflags!{
bitflags! {
/// Vsock feature bits since v1.2
/// If no feature bit is set, only stream socket type is supported.
/// If no feature bit is set, only stream socket type is supported.
/// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated.
pub struct VsockFeatures: u64 {
const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported.
@ -23,7 +25,7 @@ impl VsockFeatures {
#[derive(Debug, Clone, Copy, Pod)]
#[repr(C)]
pub struct VirtioVsockConfig{
pub struct VirtioVsockConfig {
/// The guest_cid field contains the guests context ID, which uniquely identifies
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
///
@ -41,4 +43,4 @@ impl VirtioVsockConfig {
let memory = transport.device_config_memory();
SafePtr::new(memory, 0)
}
}
}

View File

@ -1,7 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
use log::debug;
use super::{header::{VsockAddr, VirtioVsockHdr, VirtioVsockOp}, error::SocketError};
use super::{
error::SocketError,
header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr},
};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VsockBufferStatus {
@ -63,7 +67,7 @@ impl VsockEvent {
&& self.destination.port == connection_info.src_port
}
pub fn from_header(header: &VirtioVsockHdr) -> Result<Self,SocketError> {
pub fn from_header(header: &VirtioVsockHdr) -> Result<Self, SocketError> {
let op = header.op()?;
let buffer_status = VsockBufferStatus {
buffer_allocation: header.buf_alloc,
@ -114,7 +118,6 @@ impl VsockEvent {
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectionInfo {
pub dst: VsockAddr,
@ -184,4 +187,4 @@ impl ConnectionInfo {
..Default::default()
}
}
}
}

View File

@ -1,14 +1,26 @@
use core::hint::spin_loop;
use core::fmt::Debug;
use alloc::{vec::Vec, boxed::Box, string::ToString, sync::Arc};
use aster_frame::{offset_of, trap::TrapFrame, sync::SpinLock};
use aster_util::{slot_vec::SlotVec, field_ptr};
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec};
use core::{fmt::Debug, hint::spin_loop};
use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame};
use aster_util::{field_ptr, slot_vec::SlotVec};
use log::debug;
use pod::Pod;
use crate::{queue::{VirtQueue, QueueError}, device::{VirtioDeviceError, socket::{register_device, DEVICE_NAME}}, transport::{VirtioTransport}};
use super::{buffer::RxBuffer, config::{VirtioVsockConfig, VsockFeatures}, connect::{ConnectionInfo, VsockEvent}, header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN}, error::SocketError, VsockDeviceIrqHandler};
use super::{
buffer::RxBuffer,
config::{VirtioVsockConfig, VsockFeatures},
connect::{ConnectionInfo, VsockEvent},
error::SocketError,
header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN},
VsockDeviceIrqHandler,
};
use crate::{
device::{socket::register_device, VirtioDeviceError},
queue::{QueueError, VirtQueue},
transport::VirtioTransport,
};
const QUEUE_SIZE: u16 = 64;
const QUEUE_RECV: u16 = 0;
@ -30,39 +42,43 @@ pub struct SocketDevice {
rx_buffers: SlotVec<RxBuffer>,
transport: Box<dyn VirtioTransport>,
callbacks: Vec<Box<&'static VsockDeviceIrqHandler>>,
callbacks: Vec<Box<dyn VsockDeviceIrqHandler>>,
}
impl SocketDevice {
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
let guest_cid =
field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low).read().unwrap() as u64
| (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high).read().unwrap() as u64) << 32;
let guest_cid = field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low)
.read()
.unwrap() as u64
| (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high)
.read()
.unwrap() as u64)
<< 32;
let mut recv_queue = VirtQueue::new(QUEUE_RECV,QUEUE_SIZE,transport.as_mut())
let mut recv_queue = VirtQueue::new(QUEUE_RECV, QUEUE_SIZE, transport.as_mut())
.expect("createing recv queue fails");
let send_queue = VirtQueue::new(QUEUE_SEND,QUEUE_SIZE,transport.as_mut())
let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut())
.expect("creating send queue fails");
let event_queue = VirtQueue::new(QUEUE_EVENT,QUEUE_SIZE,transport.as_mut())
let event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut())
.expect("creating event queue fails");
// Allocate and add buffers for the RX queue.
let mut rx_buffers = SlotVec::new();
for i in 0..QUEUE_SIZE {
let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE);
let token = recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?;
let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
assert_eq!(i, token);
assert_eq!(rx_buffers.put(rx_buffer) as u16, i);
}
if recv_queue.should_notify() {
debug!("notify receive queue");
recv_queue.notify();
}
let mut device = Self{
let mut device = Self {
config: virtio_vsock_config.read().unwrap(),
guest_cid,
send_queue,
@ -74,13 +90,13 @@ impl SocketDevice {
};
// Interrupt handler if vsock device config space changes
fn config_space_change(_: &TrapFrame){
fn config_space_change(_: &TrapFrame) {
debug!("vsock device config space change");
}
// Interrupt handler if vsock device receives some packet.
// TODO: This will be handled by vsock socket layer.
fn handle_vsock_event(_: &TrapFrame){
fn handle_vsock_event(_: &TrapFrame) {
debug!("Packet received. This will be solved by socket layer");
}
@ -88,7 +104,7 @@ impl SocketDevice {
.transport
.register_cfg_callback(Box::new(config_space_change))
.unwrap();
device
device
.transport
.register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false)
.unwrap();
@ -96,7 +112,7 @@ impl SocketDevice {
device.transport.finish_init();
register_device(
super::DEVICE_NAME.to_string(),
super::DEVICE_NAME.to_string(),
Arc::new(SpinLock::new(device)),
);
@ -113,7 +129,7 @@ impl SocketDevice {
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection
/// before sending data.
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Request as u16,
..connection_info.new_header(self.guest_cid)
@ -124,7 +140,7 @@ impl SocketDevice {
}
/// Accepts the given connection from a peer.
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Response as u16,
..connection_info.new_header(self.guest_cid)
@ -133,7 +149,7 @@ impl SocketDevice {
}
/// Requests the peer to send us a credit update for the given connection.
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditRequest as u16,
..connection_info.new_header(self.guest_cid)
@ -142,7 +158,7 @@ impl SocketDevice {
}
/// Tells the peer how much buffer space we have to receive data.
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditUpdate as u16,
..connection_info.new_header(self.guest_cid)
@ -155,7 +171,7 @@ impl SocketDevice {
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown.
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Shutdown as u16,
..connection_info.new_header(self.guest_cid)
@ -164,7 +180,7 @@ impl SocketDevice {
}
/// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> {
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rst as u16,
..connection_info.new_header(self.guest_cid)
@ -172,15 +188,17 @@ impl SocketDevice {
self.send_packet_to_tx_queue(&header, &[])
}
fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result<(), SocketError> {
fn send_packet_to_tx_queue(
&mut self,
header: &VirtioVsockHdr,
buffer: &[u8],
) -> Result<(), SocketError> {
// let (_token, _len) = self.send_queue.add_notify_wait_pop(
// &[header.as_bytes(), buffer],
// &mut [],
// )?;
let _token = self
.send_queue
.add_buf(&[header.as_bytes(), buffer], &[])?;
let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
if self.send_queue.should_notify() {
self.send_queue.notify();
@ -189,7 +207,7 @@ impl SocketDevice {
// Wait until the buffer is used
while !self.send_queue.can_pop() {
spin_loop();
}
}
self.send_queue.pop_used()?;
@ -217,13 +235,17 @@ impl SocketDevice {
}
/// Sends the buffer to the destination.
pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result<(), SocketError> {
pub fn send(
&mut self,
buffer: &[u8],
connection_info: &mut ConnectionInfo,
) -> Result<(), SocketError> {
self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
let len = buffer.len() as u32;
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rw as u16,
len: len,
len,
..connection_info.new_header(self.guest_cid)
};
connection_info.tx_cnt += len;
@ -231,10 +253,12 @@ impl SocketDevice {
}
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
pub fn poll(&mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>,SocketError>
pub fn poll(
&mut self,
handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>, SocketError>,
) -> Result<Option<VsockEvent>, SocketError> {
// Return None if there is no pending packet.
if !self.recv_queue.can_pop(){
if !self.recv_queue.can_pop() {
return Ok(None);
}
let (token, len) = self.recv_queue.pop_used()?;
@ -250,20 +274,19 @@ impl SocketDevice {
buffer.set_packet_len(RX_BUFFER_SIZE);
let head_result = read_header_and_body(buffer.buf());
let head_result = read_header_and_body(&buffer.buf());
let Ok((header,body)) = head_result else {
let Ok((header, body)) = head_result else {
let ret = match head_result {
Err(e) => Err(e),
_ => Ok(None) //FIXME: this clause is never reached.
_ => Ok(None), //FIXME: this clause is never reached.
};
self.add_rx_buffer(buffer, token)?;
return ret;
};
debug!("Received packet {:?}. Op {:?}", header, header.op());
debug!("body is {:?}",body);
debug!("body is {:?}", body);
let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
@ -271,13 +294,12 @@ impl SocketDevice {
self.add_rx_buffer(buffer, token)?;
result
}
/// Add a used rx buffer to recv queue,@index is only to check the correctness
fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> {
let token = self.recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?;
assert_eq!(index,token);
let token = self.recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
assert_eq!(index, token);
assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none());
if self.recv_queue.should_notify() {
self.recv_queue.notify();
@ -290,11 +312,9 @@ impl SocketDevice {
let device_features = VsockFeatures::from_bits_truncate(features);
let supported_features = VsockFeatures::support_features();
let vsock_features = device_features & supported_features;
debug!("features negotiated: {:?}",vsock_features);
debug!("features negotiated: {:?}", vsock_features);
vsock_features.bits()
}
}
impl Debug for SocketDevice {
@ -310,7 +330,7 @@ impl Debug for SocketDevice {
}
}
fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketError> {
fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]), SocketError> {
// Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]);
let body_length = header.len() as usize;
@ -324,4 +344,4 @@ fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketE
.get(VIRTIO_VSOCK_HDR_LEN..data_end)
.ok_or(SocketError::BufferTooShort)?;
Ok((header, data))
}
}

View File

@ -1,10 +1,10 @@
// SPDX-License-Identifier: MPL-2.0
//! This file comes from virtio-drivers project
//! This module contains the error from the VirtIO socket driver.
use core::{fmt, result};
use smoltcp::socket::dhcpv4::Socket;
use crate::queue::QueueError;
/// The error type of VirtIO socket driver.
@ -43,7 +43,7 @@ pub enum SocketError {
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum SocketQueueError {
pub enum SocketQueueError {
InvalidArgs,
BufferTooSmall,
NotReady,
@ -63,7 +63,6 @@ impl From<QueueError> for SocketQueueError {
}
}
impl From<QueueError> for SocketError {
fn from(value: QueueError) -> Self {
Self::QueueError(SocketQueueError::from(value))

View File

@ -1,5 +1,8 @@
use pod::Pod;
// SPDX-License-Identifier: MPL-2.0
use bitflags::bitflags;
use pod::Pod;
use super::error::{self, SocketError};
pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::<VirtioVsockHdr>();
@ -15,9 +18,9 @@ pub struct VsockAddr {
/// VirtioVsock header precedes the payload in each packet.
// #[repr(packed)]
#[repr(C,packed)]
#[repr(C, packed)]
#[derive(Debug, Clone, Copy, Pod)]
pub struct VirtioVsockHdr{
pub struct VirtioVsockHdr {
pub src_cid: u64,
pub dst_cid: u64,
pub src_port: u32,
@ -25,7 +28,7 @@ pub struct VirtioVsockHdr{
pub len: u32,
pub socket_type: u16,
pub op: u16, //TOASK: why mark Pod and can I mark OpType Pod and replace u16 into OpType.
pub op: u16,
pub flags: u32,
/// Total receive buffer space for this socket. This includes both free and in-use buffers.
pub buf_alloc: u32,
@ -50,19 +53,22 @@ impl Default for VirtioVsockHdr {
}
}
impl VirtioVsockHdr {
/// Returns the length of the data.
pub fn len(&self) -> u32 {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn op(&self) -> error::Result<VirtioVsockOp> {
self.op.try_into()
}
pub fn source(&self) -> VsockAddr {
VsockAddr{
VsockAddr {
cid: self.src_cid,
port: self.src_port,
}
@ -76,7 +82,7 @@ impl VirtioVsockHdr {
}
pub fn check_data_is_empty(&self) -> error::Result<()> {
if self.len() == 0 {
if self.is_empty() {
Ok(())
} else {
Err(SocketError::UnexpectedDataInPacket)
@ -87,7 +93,7 @@ impl VirtioVsockHdr {
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
#[allow(non_camel_case_types)]
pub enum VirtioVsockOp{
pub enum VirtioVsockOp {
#[default]
Invalid = 0,
@ -123,7 +129,7 @@ impl TryFrom<u16> for VirtioVsockOp {
_ => return Err(SocketError::UnknownOperation(v)),
};
Ok(op)
}
}
}
bitflags! {
@ -136,7 +142,7 @@ bitflags! {
/// The peer will not send any more data.
const VIRTIO_VSOCK_SHUTDOWN_SEND = 1 << 1;
/// The peer will not send or receive any more data.
const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits;
const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits;
}
}
@ -148,4 +154,4 @@ pub enum VsockType {
Stream = 1,
/// seqpacket socket type introduced in virtio-v1.2.
SeqPacket = 2,
}
}

View File

@ -1,14 +1,17 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc, vec, vec::Vec};
use core::{cmp::min, hint::spin_loop};
use alloc::{vec::Vec, boxed::Box, sync::Arc};
use aster_frame::sync::SpinLock;
use log::debug;
use super::{
connect::{ConnectionInfo, DisconnectReason, VsockEvent, VsockEventType},
device::SocketDevice,
header::VsockAddr,
};
use crate::device::socket::error::SocketError;
use super::{device::SocketDevice, connect::{ConnectionInfo, VsockEvent, VsockEventType, DisconnectReason}, header::VsockAddr};
const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
/// TODO: A higher level interface for VirtIO socket (vsock) devices.
@ -72,11 +75,11 @@ impl VsockConnectionManager {
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection
/// before sending data.
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> {
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
if self.connections.iter().any(|connection| {
connection.info.dst == destination && connection.info.src_port == src_port
}) {
return Err(SocketError::ConnectionExists.into());
return Err(SocketError::ConnectionExists);
}
let new_connection = Connection::new(destination, src_port);
@ -88,14 +91,19 @@ impl VsockConnectionManager {
}
/// Sends the buffer to the destination.
pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result<(),SocketError> {
pub fn send(
&mut self,
destination: VsockAddr,
src_port: u32,
buffer: &[u8],
) -> Result<(), SocketError> {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().send(buffer, &mut connection.info)
}
/// Polls the vsock device to receive data or other updates.
pub fn poll(&mut self) -> Result<Option<VsockEvent>,SocketError> {
pub fn poll(&mut self) -> Result<Option<VsockEvent>, SocketError> {
let guest_cid = self.driver.lock().guest_cid();
let connections = &mut self.connections;
@ -181,8 +189,13 @@ impl VsockConnectionManager {
}
/// Reads data received from the given connection.
pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize,SocketError> {
debug!("connections is {:?}",self.connections);
pub fn recv(
&mut self,
peer: VsockAddr,
src_port: u32,
buffer: &mut [u8],
) -> Result<usize, SocketError> {
debug!("connections is {:?}", self.connections);
let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
// Copy from ring buffer
@ -197,11 +210,11 @@ impl VsockConnectionManager {
self.connections.swap_remove(connection_index);
}
Ok(bytes_read)
Ok(bytes_read)
}
/// Blocks until we get some event from the vsock device.
pub fn wait_for_event(&mut self) -> Result<VsockEvent,SocketError> {
pub fn wait_for_event(&mut self) -> Result<VsockEvent, SocketError> {
loop {
if let Some(event) = self.poll()? {
return Ok(event);
@ -216,14 +229,18 @@ impl VsockConnectionManager {
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown.
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> {
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().shutdown(&connection.info)
}
/// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> {
pub fn force_close(
&mut self,
destination: VsockAddr,
src_port: u32,
) -> Result<(), SocketError> {
let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().force_close(&connection.info)?;
@ -263,7 +280,6 @@ fn get_connection_for_event<'a>(
.find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
}
#[derive(Debug)]
struct Connection {
info: ConnectionInfo,
@ -296,12 +312,9 @@ struct RingBuffer {
impl RingBuffer {
pub fn new(capacity: usize) -> Self {
// TODO: can be optimized.
let mut temp = Vec::with_capacity(capacity);
temp.resize(capacity,0);
Self {
// FIXME: if the capacity is excessive, elements move will be executed.
buffer: temp.into_boxed_slice(),
buffer: vec![0; capacity].into_boxed_slice(),
used: 0,
start: 0,
}
@ -366,4 +379,4 @@ impl RingBuffer {
bytes_read
}
}
}

View File

@ -1,23 +1,23 @@
//! This mod is modified from virtio-drivers project.
// SPDX-License-Identifier: MPL-2.0
//! This mod is modified from virtio-drivers project.
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
use alloc::{sync::Arc, collections::BTreeMap, string::String, vec::Vec};
use component::{ComponentInitError, init_component};
use aster_frame::sync::SpinLock;
use smoltcp::socket::dhcpv4::Socket;
use component::ComponentInitError;
use spin::Once;
use core::fmt::Debug;
use self::device::SocketDevice;
pub mod buffer;
pub mod config;
pub mod device;
pub mod header;
pub mod connect;
pub mod device;
pub mod error;
pub mod header;
pub mod manager;
pub static DEVICE_NAME: &str = "Virtio-Vsock";
pub type VsockDeviceIrqHandler = dyn Fn() + Send + Sync;
pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static;
pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
COMPONENT
@ -30,9 +30,7 @@ pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
let lock = COMPONENT.get().unwrap().vsock_device_table.lock();
let Some(device) = lock.get(str) else {
return None;
};
let device = lock.get(str)?;
Some(device.clone())
}
@ -46,14 +44,12 @@ pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
static COMPONENT: Once<Component> = Once::new();
pub fn component_init() -> Result<(), ComponentInitError>{
pub fn component_init() -> Result<(), ComponentInitError> {
let a = Component::init()?;
COMPONENT.call_once(|| a);
Ok(())
}
struct Component {
vsock_device_table: SpinLock<BTreeMap<String, Arc<SpinLock<SocketDevice>>>>,
}
@ -64,4 +60,4 @@ impl Component {
vsock_device_table: SpinLock::new(BTreeMap::new()),
})
}
}
}

View File

@ -4,6 +4,7 @@
#![no_std]
#![deny(unsafe_code)]
#![allow(dead_code)]
#![feature(trait_alias)]
#![feature(fn_traits)]
extern crate alloc;
@ -13,8 +14,12 @@ use alloc::boxed::Box;
use bitflags::bitflags;
use component::{init_component, ComponentInitError};
use device::{
block::device::BlockDevice, console::device::ConsoleDevice, input::device::InputDevice,
network::device::NetworkDevice, socket::{device::SocketDevice, self}, VirtioDeviceType,
block::device::BlockDevice,
console::device::ConsoleDevice,
input::device::InputDevice,
network::device::NetworkDevice,
socket::{self, device::SocketDevice},
VirtioDeviceType,
};
use log::{error, warn};
use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus};