Fix format and clippy errors

This commit is contained in:
Anmin Liu
2024-03-28 07:47:05 +00:00
committed by Tate, Hongliang Tian
parent 52f808e315
commit be45f0ee72
10 changed files with 194 additions and 139 deletions

View File

@ -1,9 +1,11 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use log::{info, debug}; use aster_virtio::{
use alloc::string::ToString; self,
device::socket::{header::VsockAddr, manager::VsockConnectionManager, DEVICE_NAME},
};
use component::ComponentInitError; use component::ComponentInitError;
use aster_virtio::{self, device::socket::{header::VsockAddr, device::SocketDevice, manager::VsockConnectionManager, DEVICE_NAME}}; use log::{debug, info};
pub fn init() { pub fn init() {
// print all the input device to make sure input crate will compile // print all the input device to make sure input crate will compile
@ -14,8 +16,7 @@ pub fn init() {
// let _ = socket_device_server_test(); // let _ = socket_device_server_test();
} }
fn socket_device_client_test() -> Result<(), ComponentInitError> {
fn socket_device_client_test() -> Result<(),ComponentInitError> {
let host_cid = 2; let host_cid = 2;
let guest_cid = 3; let guest_cid = 3;
let host_port = 1234; let host_port = 1234;
@ -28,28 +29,33 @@ fn socket_device_client_test() -> Result<(),ComponentInitError> {
let hello_from_host = "Hello from host"; let hello_from_host = "Hello from host";
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
assert_eq!(device.lock().guest_cid(),guest_cid); assert_eq!(device.lock().guest_cid(), guest_cid);
let mut socket = VsockConnectionManager::new(device); let mut socket = VsockConnectionManager::new(device);
socket.connect(host_address, guest_port).unwrap(); socket.connect(host_address, guest_port).unwrap();
socket.wait_for_event().unwrap(); // wait for connect response socket.wait_for_event().unwrap(); // wait for connect response
socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap(); socket
debug!("The buffer {:?} is sent, start receiving",hello_from_guest.as_bytes()); .send(host_address, guest_port, hello_from_guest.as_bytes())
.unwrap();
debug!(
"The buffer {:?} is sent, start receiving",
hello_from_guest.as_bytes()
);
socket.wait_for_event().unwrap(); // wait for recv socket.wait_for_event().unwrap(); // wait for recv
let mut buffer = [0u8; 64]; let mut buffer = [0u8; 64];
let event = socket.recv(host_address, guest_port,&mut buffer).unwrap(); let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
assert_eq!( assert_eq!(
&buffer[0..hello_from_host.len()], &buffer[0..hello_from_host.len()],
hello_from_host.as_bytes() hello_from_host.as_bytes()
); );
socket.force_close(host_address,guest_port).unwrap(); socket.force_close(host_address, guest_port).unwrap();
debug!("The final event: {:?}",event); debug!("The final event: {:?}", event);
Ok(()) Ok(())
} }
pub fn socket_device_server_test() -> Result<(),ComponentInitError>{ pub fn socket_device_server_test() -> Result<(), ComponentInitError> {
let host_cid = 2; let host_cid = 2;
let guest_cid = 3; let guest_cid = 3;
let host_port = 1234; let host_port = 1234;
@ -62,26 +68,31 @@ pub fn socket_device_server_test() -> Result<(),ComponentInitError>{
let hello_from_host = "Hello from host"; let hello_from_host = "Hello from host";
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
assert_eq!(device.lock().guest_cid(),guest_cid); assert_eq!(device.lock().guest_cid(), guest_cid);
let mut socket = VsockConnectionManager::new(device); let mut socket = VsockConnectionManager::new(device);
socket.listen(4321); socket.listen(4321);
socket.wait_for_event().unwrap(); // wait for connect request socket.wait_for_event().unwrap(); // wait for connect request
socket.wait_for_event().unwrap(); // wait for recv socket.wait_for_event().unwrap(); // wait for recv
let mut buffer = [0u8; 64]; let mut buffer = [0u8; 64];
let event = socket.recv(host_address, guest_port,&mut buffer).unwrap(); let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
assert_eq!( assert_eq!(
&buffer[0..hello_from_host.len()], &buffer[0..hello_from_host.len()],
hello_from_host.as_bytes() hello_from_host.as_bytes()
); );
debug!("The buffer {:?} is received, start sending {:?}", &buffer[0..hello_from_host.len()],hello_from_guest.as_bytes()); debug!(
socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap(); "The buffer {:?} is received, start sending {:?}",
&buffer[0..hello_from_host.len()],
hello_from_guest.as_bytes()
);
socket
.send(host_address, guest_port, hello_from_guest.as_bytes())
.unwrap();
socket.shutdown(host_address,guest_port).unwrap(); socket.shutdown(host_address, guest_port).unwrap();
let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown
debug!("The final event: {:?}",event); debug!("The final event: {:?}", event);
Ok(()) Ok(())
} }

View File

@ -1,11 +1,11 @@
//! This module is adapted from network/buffer.rs // SPDX-License-Identifier: MPL-2.0
use align_ext::AlignExt; use align_ext::AlignExt;
use bytes::BytesMut; use bytes::BytesMut;
use pod::Pod; use pod::Pod;
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
use super::header::VirtioVsockHdr; use super::header::VirtioVsockHdr;
use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN;
/// Buffer for receive packet /// Buffer for receive packet
#[derive(Debug)] #[derive(Debug)]
@ -86,7 +86,7 @@ impl TxBuffer {
/// Buffer for event buffer /// Buffer for event buffer
#[derive(Debug)] #[derive(Debug)]
pub struct EventBuffer{ pub struct EventBuffer {
id: u32, id: u32,
} }
@ -96,4 +96,4 @@ pub struct EventBuffer{
pub enum EventIDType { pub enum EventIDType {
#[default] #[default]
VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0, VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0,
} }

View File

@ -1,13 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use aster_frame::io_mem::IoMem; use aster_frame::io_mem::IoMem;
use pod::Pod; use aster_util::safe_ptr::SafePtr;
use aster_util::{safe_ptr::SafePtr};
use bitflags::bitflags; use bitflags::bitflags;
use pod::Pod;
use crate::transport::{self, VirtioTransport}; use crate::transport::VirtioTransport;
bitflags!{ bitflags! {
/// Vsock feature bits since v1.2 /// Vsock feature bits since v1.2
/// If no feature bit is set, only stream socket type is supported. /// If no feature bit is set, only stream socket type is supported.
/// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated. /// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated.
pub struct VsockFeatures: u64 { pub struct VsockFeatures: u64 {
const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported. const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported.
@ -23,7 +25,7 @@ impl VsockFeatures {
#[derive(Debug, Clone, Copy, Pod)] #[derive(Debug, Clone, Copy, Pod)]
#[repr(C)] #[repr(C)]
pub struct VirtioVsockConfig{ pub struct VirtioVsockConfig {
/// The guest_cid field contains the guests context ID, which uniquely identifies /// The guest_cid field contains the guests context ID, which uniquely identifies
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
/// ///
@ -41,4 +43,4 @@ impl VirtioVsockConfig {
let memory = transport.device_config_memory(); let memory = transport.device_config_memory();
SafePtr::new(memory, 0) SafePtr::new(memory, 0)
} }
} }

View File

@ -1,7 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
use log::debug; use log::debug;
use super::{header::{VsockAddr, VirtioVsockHdr, VirtioVsockOp}, error::SocketError}; use super::{
error::SocketError,
header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr},
};
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct VsockBufferStatus { pub struct VsockBufferStatus {
@ -63,7 +67,7 @@ impl VsockEvent {
&& self.destination.port == connection_info.src_port && self.destination.port == connection_info.src_port
} }
pub fn from_header(header: &VirtioVsockHdr) -> Result<Self,SocketError> { pub fn from_header(header: &VirtioVsockHdr) -> Result<Self, SocketError> {
let op = header.op()?; let op = header.op()?;
let buffer_status = VsockBufferStatus { let buffer_status = VsockBufferStatus {
buffer_allocation: header.buf_alloc, buffer_allocation: header.buf_alloc,
@ -114,7 +118,6 @@ impl VsockEvent {
} }
} }
#[derive(Clone, Debug, Default, PartialEq, Eq)] #[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectionInfo { pub struct ConnectionInfo {
pub dst: VsockAddr, pub dst: VsockAddr,
@ -184,4 +187,4 @@ impl ConnectionInfo {
..Default::default() ..Default::default()
} }
} }
} }

View File

@ -1,14 +1,26 @@
use core::hint::spin_loop; // SPDX-License-Identifier: MPL-2.0
use core::fmt::Debug;
use alloc::{vec::Vec, boxed::Box, string::ToString, sync::Arc}; use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec};
use aster_frame::{offset_of, trap::TrapFrame, sync::SpinLock}; use core::{fmt::Debug, hint::spin_loop};
use aster_util::{slot_vec::SlotVec, field_ptr};
use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame};
use aster_util::{field_ptr, slot_vec::SlotVec};
use log::debug; use log::debug;
use pod::Pod; use pod::Pod;
use crate::{queue::{VirtQueue, QueueError}, device::{VirtioDeviceError, socket::{register_device, DEVICE_NAME}}, transport::{VirtioTransport}}; use super::{
buffer::RxBuffer,
use super::{buffer::RxBuffer, config::{VirtioVsockConfig, VsockFeatures}, connect::{ConnectionInfo, VsockEvent}, header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN}, error::SocketError, VsockDeviceIrqHandler}; config::{VirtioVsockConfig, VsockFeatures},
connect::{ConnectionInfo, VsockEvent},
error::SocketError,
header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN},
VsockDeviceIrqHandler,
};
use crate::{
device::{socket::register_device, VirtioDeviceError},
queue::{QueueError, VirtQueue},
transport::VirtioTransport,
};
const QUEUE_SIZE: u16 = 64; const QUEUE_SIZE: u16 = 64;
const QUEUE_RECV: u16 = 0; const QUEUE_RECV: u16 = 0;
@ -30,39 +42,43 @@ pub struct SocketDevice {
rx_buffers: SlotVec<RxBuffer>, rx_buffers: SlotVec<RxBuffer>,
transport: Box<dyn VirtioTransport>, transport: Box<dyn VirtioTransport>,
callbacks: Vec<Box<&'static VsockDeviceIrqHandler>>, callbacks: Vec<Box<dyn VsockDeviceIrqHandler>>,
} }
impl SocketDevice { impl SocketDevice {
pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> { pub fn init(mut transport: Box<dyn VirtioTransport>) -> Result<(), VirtioDeviceError> {
let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut()); let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut());
debug!("virtio_vsock_config = {:?}", virtio_vsock_config); debug!("virtio_vsock_config = {:?}", virtio_vsock_config);
let guest_cid = let guest_cid = field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low)
field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low).read().unwrap() as u64 .read()
| (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high).read().unwrap() as u64) << 32; .unwrap() as u64
| (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high)
.read()
.unwrap() as u64)
<< 32;
let mut recv_queue = VirtQueue::new(QUEUE_RECV,QUEUE_SIZE,transport.as_mut()) let mut recv_queue = VirtQueue::new(QUEUE_RECV, QUEUE_SIZE, transport.as_mut())
.expect("createing recv queue fails"); .expect("createing recv queue fails");
let send_queue = VirtQueue::new(QUEUE_SEND,QUEUE_SIZE,transport.as_mut()) let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut())
.expect("creating send queue fails"); .expect("creating send queue fails");
let event_queue = VirtQueue::new(QUEUE_EVENT,QUEUE_SIZE,transport.as_mut()) let event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut())
.expect("creating event queue fails"); .expect("creating event queue fails");
// Allocate and add buffers for the RX queue. // Allocate and add buffers for the RX queue.
let mut rx_buffers = SlotVec::new(); let mut rx_buffers = SlotVec::new();
for i in 0..QUEUE_SIZE { for i in 0..QUEUE_SIZE {
let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE); let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE);
let token = recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?; let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
assert_eq!(i, token); assert_eq!(i, token);
assert_eq!(rx_buffers.put(rx_buffer) as u16, i); assert_eq!(rx_buffers.put(rx_buffer) as u16, i);
} }
if recv_queue.should_notify() { if recv_queue.should_notify() {
debug!("notify receive queue"); debug!("notify receive queue");
recv_queue.notify(); recv_queue.notify();
} }
let mut device = Self{ let mut device = Self {
config: virtio_vsock_config.read().unwrap(), config: virtio_vsock_config.read().unwrap(),
guest_cid, guest_cid,
send_queue, send_queue,
@ -74,13 +90,13 @@ impl SocketDevice {
}; };
// Interrupt handler if vsock device config space changes // Interrupt handler if vsock device config space changes
fn config_space_change(_: &TrapFrame){ fn config_space_change(_: &TrapFrame) {
debug!("vsock device config space change"); debug!("vsock device config space change");
} }
// Interrupt handler if vsock device receives some packet. // Interrupt handler if vsock device receives some packet.
// TODO: This will be handled by vsock socket layer. // TODO: This will be handled by vsock socket layer.
fn handle_vsock_event(_: &TrapFrame){ fn handle_vsock_event(_: &TrapFrame) {
debug!("Packet received. This will be solved by socket layer"); debug!("Packet received. This will be solved by socket layer");
} }
@ -88,7 +104,7 @@ impl SocketDevice {
.transport .transport
.register_cfg_callback(Box::new(config_space_change)) .register_cfg_callback(Box::new(config_space_change))
.unwrap(); .unwrap();
device device
.transport .transport
.register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false) .register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false)
.unwrap(); .unwrap();
@ -96,7 +112,7 @@ impl SocketDevice {
device.transport.finish_init(); device.transport.finish_init();
register_device( register_device(
super::DEVICE_NAME.to_string(), super::DEVICE_NAME.to_string(),
Arc::new(SpinLock::new(device)), Arc::new(SpinLock::new(device)),
); );
@ -113,7 +129,7 @@ impl SocketDevice {
/// This returns as soon as the request is sent; you should wait until `poll` returns a /// This returns as soon as the request is sent; you should wait until `poll` returns a
/// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection /// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection
/// before sending data. /// before sending data.
pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::Request as u16, op: VirtioVsockOp::Request as u16,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
@ -124,7 +140,7 @@ impl SocketDevice {
} }
/// Accepts the given connection from a peer. /// Accepts the given connection from a peer.
pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::Response as u16, op: VirtioVsockOp::Response as u16,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
@ -133,7 +149,7 @@ impl SocketDevice {
} }
/// Requests the peer to send us a credit update for the given connection. /// Requests the peer to send us a credit update for the given connection.
fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditRequest as u16, op: VirtioVsockOp::CreditRequest as u16,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
@ -142,7 +158,7 @@ impl SocketDevice {
} }
/// Tells the peer how much buffer space we have to receive data. /// Tells the peer how much buffer space we have to receive data.
pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditUpdate as u16, op: VirtioVsockOp::CreditUpdate as u16,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
@ -155,7 +171,7 @@ impl SocketDevice {
/// This returns as soon as the request is sent; you should wait until `poll` returns a /// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown. /// shutdown.
pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::Shutdown as u16, op: VirtioVsockOp::Shutdown as u16,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
@ -164,7 +180,7 @@ impl SocketDevice {
} }
/// Forcibly closes the connection without waiting for the peer. /// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> {
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::Rst as u16, op: VirtioVsockOp::Rst as u16,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
@ -172,15 +188,17 @@ impl SocketDevice {
self.send_packet_to_tx_queue(&header, &[]) self.send_packet_to_tx_queue(&header, &[])
} }
fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result<(), SocketError> { fn send_packet_to_tx_queue(
&mut self,
header: &VirtioVsockHdr,
buffer: &[u8],
) -> Result<(), SocketError> {
// let (_token, _len) = self.send_queue.add_notify_wait_pop( // let (_token, _len) = self.send_queue.add_notify_wait_pop(
// &[header.as_bytes(), buffer], // &[header.as_bytes(), buffer],
// &mut [], // &mut [],
// )?; // )?;
let _token = self let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?;
.send_queue
.add_buf(&[header.as_bytes(), buffer], &[])?;
if self.send_queue.should_notify() { if self.send_queue.should_notify() {
self.send_queue.notify(); self.send_queue.notify();
@ -189,7 +207,7 @@ impl SocketDevice {
// Wait until the buffer is used // Wait until the buffer is used
while !self.send_queue.can_pop() { while !self.send_queue.can_pop() {
spin_loop(); spin_loop();
} }
self.send_queue.pop_used()?; self.send_queue.pop_used()?;
@ -217,13 +235,17 @@ impl SocketDevice {
} }
/// Sends the buffer to the destination. /// Sends the buffer to the destination.
pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result<(), SocketError> { pub fn send(
&mut self,
buffer: &[u8],
connection_info: &mut ConnectionInfo,
) -> Result<(), SocketError> {
self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?; self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
let len = buffer.len() as u32; let len = buffer.len() as u32;
let header = VirtioVsockHdr { let header = VirtioVsockHdr {
op: VirtioVsockOp::Rw as u16, op: VirtioVsockOp::Rw as u16,
len: len, len,
..connection_info.new_header(self.guest_cid) ..connection_info.new_header(self.guest_cid)
}; };
connection_info.tx_cnt += len; connection_info.tx_cnt += len;
@ -231,10 +253,12 @@ impl SocketDevice {
} }
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it. /// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
pub fn poll(&mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>,SocketError> pub fn poll(
&mut self,
handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>, SocketError>,
) -> Result<Option<VsockEvent>, SocketError> { ) -> Result<Option<VsockEvent>, SocketError> {
// Return None if there is no pending packet. // Return None if there is no pending packet.
if !self.recv_queue.can_pop(){ if !self.recv_queue.can_pop() {
return Ok(None); return Ok(None);
} }
let (token, len) = self.recv_queue.pop_used()?; let (token, len) = self.recv_queue.pop_used()?;
@ -250,20 +274,19 @@ impl SocketDevice {
buffer.set_packet_len(RX_BUFFER_SIZE); buffer.set_packet_len(RX_BUFFER_SIZE);
let head_result = read_header_and_body(buffer.buf());
let head_result = read_header_and_body(&buffer.buf()); let Ok((header, body)) = head_result else {
let Ok((header,body)) = head_result else {
let ret = match head_result { let ret = match head_result {
Err(e) => Err(e), Err(e) => Err(e),
_ => Ok(None) //FIXME: this clause is never reached. _ => Ok(None), //FIXME: this clause is never reached.
}; };
self.add_rx_buffer(buffer, token)?; self.add_rx_buffer(buffer, token)?;
return ret; return ret;
}; };
debug!("Received packet {:?}. Op {:?}", header, header.op()); debug!("Received packet {:?}. Op {:?}", header, header.op());
debug!("body is {:?}",body); debug!("body is {:?}", body);
let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body)); let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
@ -271,13 +294,12 @@ impl SocketDevice {
self.add_rx_buffer(buffer, token)?; self.add_rx_buffer(buffer, token)?;
result result
} }
/// Add a used rx buffer to recv queue,@index is only to check the correctness /// Add a used rx buffer to recv queue,@index is only to check the correctness
fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> { fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> {
let token = self.recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?; let token = self.recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?;
assert_eq!(index,token); assert_eq!(index, token);
assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none()); assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none());
if self.recv_queue.should_notify() { if self.recv_queue.should_notify() {
self.recv_queue.notify(); self.recv_queue.notify();
@ -290,11 +312,9 @@ impl SocketDevice {
let device_features = VsockFeatures::from_bits_truncate(features); let device_features = VsockFeatures::from_bits_truncate(features);
let supported_features = VsockFeatures::support_features(); let supported_features = VsockFeatures::support_features();
let vsock_features = device_features & supported_features; let vsock_features = device_features & supported_features;
debug!("features negotiated: {:?}",vsock_features); debug!("features negotiated: {:?}", vsock_features);
vsock_features.bits() vsock_features.bits()
} }
} }
impl Debug for SocketDevice { impl Debug for SocketDevice {
@ -310,7 +330,7 @@ impl Debug for SocketDevice {
} }
} }
fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketError> { fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]), SocketError> {
// Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`. // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]); let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]);
let body_length = header.len() as usize; let body_length = header.len() as usize;
@ -324,4 +344,4 @@ fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketE
.get(VIRTIO_VSOCK_HDR_LEN..data_end) .get(VIRTIO_VSOCK_HDR_LEN..data_end)
.ok_or(SocketError::BufferTooShort)?; .ok_or(SocketError::BufferTooShort)?;
Ok((header, data)) Ok((header, data))
} }

View File

@ -1,10 +1,10 @@
// SPDX-License-Identifier: MPL-2.0
//! This file comes from virtio-drivers project //! This file comes from virtio-drivers project
//! This module contains the error from the VirtIO socket driver. //! This module contains the error from the VirtIO socket driver.
use core::{fmt, result}; use core::{fmt, result};
use smoltcp::socket::dhcpv4::Socket;
use crate::queue::QueueError; use crate::queue::QueueError;
/// The error type of VirtIO socket driver. /// The error type of VirtIO socket driver.
@ -43,7 +43,7 @@ pub enum SocketError {
} }
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum SocketQueueError { pub enum SocketQueueError {
InvalidArgs, InvalidArgs,
BufferTooSmall, BufferTooSmall,
NotReady, NotReady,
@ -63,7 +63,6 @@ impl From<QueueError> for SocketQueueError {
} }
} }
impl From<QueueError> for SocketError { impl From<QueueError> for SocketError {
fn from(value: QueueError) -> Self { fn from(value: QueueError) -> Self {
Self::QueueError(SocketQueueError::from(value)) Self::QueueError(SocketQueueError::from(value))

View File

@ -1,5 +1,8 @@
use pod::Pod; // SPDX-License-Identifier: MPL-2.0
use bitflags::bitflags; use bitflags::bitflags;
use pod::Pod;
use super::error::{self, SocketError}; use super::error::{self, SocketError};
pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::<VirtioVsockHdr>(); pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::<VirtioVsockHdr>();
@ -15,9 +18,9 @@ pub struct VsockAddr {
/// VirtioVsock header precedes the payload in each packet. /// VirtioVsock header precedes the payload in each packet.
// #[repr(packed)] // #[repr(packed)]
#[repr(C,packed)] #[repr(C, packed)]
#[derive(Debug, Clone, Copy, Pod)] #[derive(Debug, Clone, Copy, Pod)]
pub struct VirtioVsockHdr{ pub struct VirtioVsockHdr {
pub src_cid: u64, pub src_cid: u64,
pub dst_cid: u64, pub dst_cid: u64,
pub src_port: u32, pub src_port: u32,
@ -25,7 +28,7 @@ pub struct VirtioVsockHdr{
pub len: u32, pub len: u32,
pub socket_type: u16, pub socket_type: u16,
pub op: u16, //TOASK: why mark Pod and can I mark OpType Pod and replace u16 into OpType. pub op: u16,
pub flags: u32, pub flags: u32,
/// Total receive buffer space for this socket. This includes both free and in-use buffers. /// Total receive buffer space for this socket. This includes both free and in-use buffers.
pub buf_alloc: u32, pub buf_alloc: u32,
@ -50,19 +53,22 @@ impl Default for VirtioVsockHdr {
} }
} }
impl VirtioVsockHdr { impl VirtioVsockHdr {
/// Returns the length of the data. /// Returns the length of the data.
pub fn len(&self) -> u32 { pub fn len(&self) -> u32 {
self.len self.len
} }
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn op(&self) -> error::Result<VirtioVsockOp> { pub fn op(&self) -> error::Result<VirtioVsockOp> {
self.op.try_into() self.op.try_into()
} }
pub fn source(&self) -> VsockAddr { pub fn source(&self) -> VsockAddr {
VsockAddr{ VsockAddr {
cid: self.src_cid, cid: self.src_cid,
port: self.src_port, port: self.src_port,
} }
@ -76,7 +82,7 @@ impl VirtioVsockHdr {
} }
pub fn check_data_is_empty(&self) -> error::Result<()> { pub fn check_data_is_empty(&self) -> error::Result<()> {
if self.len() == 0 { if self.is_empty() {
Ok(()) Ok(())
} else { } else {
Err(SocketError::UnexpectedDataInPacket) Err(SocketError::UnexpectedDataInPacket)
@ -87,7 +93,7 @@ impl VirtioVsockHdr {
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)] #[repr(u16)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub enum VirtioVsockOp{ pub enum VirtioVsockOp {
#[default] #[default]
Invalid = 0, Invalid = 0,
@ -123,7 +129,7 @@ impl TryFrom<u16> for VirtioVsockOp {
_ => return Err(SocketError::UnknownOperation(v)), _ => return Err(SocketError::UnknownOperation(v)),
}; };
Ok(op) Ok(op)
} }
} }
bitflags! { bitflags! {
@ -136,7 +142,7 @@ bitflags! {
/// The peer will not send any more data. /// The peer will not send any more data.
const VIRTIO_VSOCK_SHUTDOWN_SEND = 1 << 1; const VIRTIO_VSOCK_SHUTDOWN_SEND = 1 << 1;
/// The peer will not send or receive any more data. /// The peer will not send or receive any more data.
const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits; const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits;
} }
} }
@ -148,4 +154,4 @@ pub enum VsockType {
Stream = 1, Stream = 1,
/// seqpacket socket type introduced in virtio-v1.2. /// seqpacket socket type introduced in virtio-v1.2.
SeqPacket = 2, SeqPacket = 2,
} }

View File

@ -1,14 +1,17 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc, vec, vec::Vec};
use core::{cmp::min, hint::spin_loop}; use core::{cmp::min, hint::spin_loop};
use alloc::{vec::Vec, boxed::Box, sync::Arc};
use aster_frame::sync::SpinLock; use aster_frame::sync::SpinLock;
use log::debug; use log::debug;
use super::{
connect::{ConnectionInfo, DisconnectReason, VsockEvent, VsockEventType},
device::SocketDevice,
header::VsockAddr,
};
use crate::device::socket::error::SocketError; use crate::device::socket::error::SocketError;
use super::{device::SocketDevice, connect::{ConnectionInfo, VsockEvent, VsockEventType, DisconnectReason}, header::VsockAddr};
const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
/// TODO: A higher level interface for VirtIO socket (vsock) devices. /// TODO: A higher level interface for VirtIO socket (vsock) devices.
@ -72,11 +75,11 @@ impl VsockConnectionManager {
/// This returns as soon as the request is sent; you should wait until `poll` returns a /// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
/// before sending data. /// before sending data.
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
if self.connections.iter().any(|connection| { if self.connections.iter().any(|connection| {
connection.info.dst == destination && connection.info.src_port == src_port connection.info.dst == destination && connection.info.src_port == src_port
}) { }) {
return Err(SocketError::ConnectionExists.into()); return Err(SocketError::ConnectionExists);
} }
let new_connection = Connection::new(destination, src_port); let new_connection = Connection::new(destination, src_port);
@ -88,14 +91,19 @@ impl VsockConnectionManager {
} }
/// Sends the buffer to the destination. /// Sends the buffer to the destination.
pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result<(),SocketError> { pub fn send(
&mut self,
destination: VsockAddr,
src_port: u32,
buffer: &[u8],
) -> Result<(), SocketError> {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().send(buffer, &mut connection.info) self.driver.lock().send(buffer, &mut connection.info)
} }
/// Polls the vsock device to receive data or other updates. /// Polls the vsock device to receive data or other updates.
pub fn poll(&mut self) -> Result<Option<VsockEvent>,SocketError> { pub fn poll(&mut self) -> Result<Option<VsockEvent>, SocketError> {
let guest_cid = self.driver.lock().guest_cid(); let guest_cid = self.driver.lock().guest_cid();
let connections = &mut self.connections; let connections = &mut self.connections;
@ -181,8 +189,13 @@ impl VsockConnectionManager {
} }
/// Reads data received from the given connection. /// Reads data received from the given connection.
pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize,SocketError> { pub fn recv(
debug!("connections is {:?}",self.connections); &mut self,
peer: VsockAddr,
src_port: u32,
buffer: &mut [u8],
) -> Result<usize, SocketError> {
debug!("connections is {:?}", self.connections);
let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
// Copy from ring buffer // Copy from ring buffer
@ -197,11 +210,11 @@ impl VsockConnectionManager {
self.connections.swap_remove(connection_index); self.connections.swap_remove(connection_index);
} }
Ok(bytes_read) Ok(bytes_read)
} }
/// Blocks until we get some event from the vsock device. /// Blocks until we get some event from the vsock device.
pub fn wait_for_event(&mut self) -> Result<VsockEvent,SocketError> { pub fn wait_for_event(&mut self) -> Result<VsockEvent, SocketError> {
loop { loop {
if let Some(event) = self.poll()? { if let Some(event) = self.poll()? {
return Ok(event); return Ok(event);
@ -216,14 +229,18 @@ impl VsockConnectionManager {
/// This returns as soon as the request is sent; you should wait until `poll` returns a /// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown. /// shutdown.
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().shutdown(&connection.info) self.driver.lock().shutdown(&connection.info)
} }
/// Forcibly closes the connection without waiting for the peer. /// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { pub fn force_close(
&mut self,
destination: VsockAddr,
src_port: u32,
) -> Result<(), SocketError> {
let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.lock().force_close(&connection.info)?; self.driver.lock().force_close(&connection.info)?;
@ -263,7 +280,6 @@ fn get_connection_for_event<'a>(
.find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
} }
#[derive(Debug)] #[derive(Debug)]
struct Connection { struct Connection {
info: ConnectionInfo, info: ConnectionInfo,
@ -296,12 +312,9 @@ struct RingBuffer {
impl RingBuffer { impl RingBuffer {
pub fn new(capacity: usize) -> Self { pub fn new(capacity: usize) -> Self {
// TODO: can be optimized.
let mut temp = Vec::with_capacity(capacity);
temp.resize(capacity,0);
Self { Self {
// FIXME: if the capacity is excessive, elements move will be executed. // FIXME: if the capacity is excessive, elements move will be executed.
buffer: temp.into_boxed_slice(), buffer: vec![0; capacity].into_boxed_slice(),
used: 0, used: 0,
start: 0, start: 0,
} }
@ -366,4 +379,4 @@ impl RingBuffer {
bytes_read bytes_read
} }
} }

View File

@ -1,23 +1,23 @@
//! This mod is modified from virtio-drivers project. // SPDX-License-Identifier: MPL-2.0
//! This mod is modified from virtio-drivers project.
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
use alloc::{sync::Arc, collections::BTreeMap, string::String, vec::Vec};
use component::{ComponentInitError, init_component};
use aster_frame::sync::SpinLock; use aster_frame::sync::SpinLock;
use smoltcp::socket::dhcpv4::Socket; use component::ComponentInitError;
use spin::Once; use spin::Once;
use core::fmt::Debug;
use self::device::SocketDevice; use self::device::SocketDevice;
pub mod buffer; pub mod buffer;
pub mod config; pub mod config;
pub mod device;
pub mod header;
pub mod connect; pub mod connect;
pub mod device;
pub mod error; pub mod error;
pub mod header;
pub mod manager; pub mod manager;
pub static DEVICE_NAME: &str = "Virtio-Vsock"; pub static DEVICE_NAME: &str = "Virtio-Vsock";
pub type VsockDeviceIrqHandler = dyn Fn() + Send + Sync; pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static;
pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) { pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
COMPONENT COMPONENT
@ -30,9 +30,7 @@ pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> { pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
let lock = COMPONENT.get().unwrap().vsock_device_table.lock(); let lock = COMPONENT.get().unwrap().vsock_device_table.lock();
let Some(device) = lock.get(str) else { let device = lock.get(str)?;
return None;
};
Some(device.clone()) Some(device.clone())
} }
@ -46,14 +44,12 @@ pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
static COMPONENT: Once<Component> = Once::new(); static COMPONENT: Once<Component> = Once::new();
pub fn component_init() -> Result<(), ComponentInitError> {
pub fn component_init() -> Result<(), ComponentInitError>{
let a = Component::init()?; let a = Component::init()?;
COMPONENT.call_once(|| a); COMPONENT.call_once(|| a);
Ok(()) Ok(())
} }
struct Component { struct Component {
vsock_device_table: SpinLock<BTreeMap<String, Arc<SpinLock<SocketDevice>>>>, vsock_device_table: SpinLock<BTreeMap<String, Arc<SpinLock<SocketDevice>>>>,
} }
@ -64,4 +60,4 @@ impl Component {
vsock_device_table: SpinLock::new(BTreeMap::new()), vsock_device_table: SpinLock::new(BTreeMap::new()),
}) })
} }
} }

View File

@ -4,6 +4,7 @@
#![no_std] #![no_std]
#![deny(unsafe_code)] #![deny(unsafe_code)]
#![allow(dead_code)] #![allow(dead_code)]
#![feature(trait_alias)]
#![feature(fn_traits)] #![feature(fn_traits)]
extern crate alloc; extern crate alloc;
@ -13,8 +14,12 @@ use alloc::boxed::Box;
use bitflags::bitflags; use bitflags::bitflags;
use component::{init_component, ComponentInitError}; use component::{init_component, ComponentInitError};
use device::{ use device::{
block::device::BlockDevice, console::device::ConsoleDevice, input::device::InputDevice, block::device::BlockDevice,
network::device::NetworkDevice, socket::{device::SocketDevice, self}, VirtioDeviceType, console::device::ConsoleDevice,
input::device::InputDevice,
network::device::NetworkDevice,
socket::{self, device::SocketDevice},
VirtioDeviceType,
}; };
use log::{error, warn}; use log::{error, warn};
use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus}; use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus};