diff --git a/kernel/aster-nix/src/driver/mod.rs b/kernel/aster-nix/src/driver/mod.rs index 69f2ef093..186f981ad 100644 --- a/kernel/aster-nix/src/driver/mod.rs +++ b/kernel/aster-nix/src/driver/mod.rs @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 -use log::{info, debug}; -use alloc::string::ToString; +use aster_virtio::{ + self, + device::socket::{header::VsockAddr, manager::VsockConnectionManager, DEVICE_NAME}, +}; use component::ComponentInitError; -use aster_virtio::{self, device::socket::{header::VsockAddr, device::SocketDevice, manager::VsockConnectionManager, DEVICE_NAME}}; +use log::{debug, info}; pub fn init() { // print all the input device to make sure input crate will compile @@ -14,8 +16,7 @@ pub fn init() { // let _ = socket_device_server_test(); } - -fn socket_device_client_test() -> Result<(),ComponentInitError> { +fn socket_device_client_test() -> Result<(), ComponentInitError> { let host_cid = 2; let guest_cid = 3; let host_port = 1234; @@ -28,28 +29,33 @@ fn socket_device_client_test() -> Result<(),ComponentInitError> { let hello_from_host = "Hello from host"; let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); - assert_eq!(device.lock().guest_cid(),guest_cid); + assert_eq!(device.lock().guest_cid(), guest_cid); let mut socket = VsockConnectionManager::new(device); socket.connect(host_address, guest_port).unwrap(); socket.wait_for_event().unwrap(); // wait for connect response - socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap(); - debug!("The buffer {:?} is sent, start receiving",hello_from_guest.as_bytes()); + socket + .send(host_address, guest_port, hello_from_guest.as_bytes()) + .unwrap(); + debug!( + "The buffer {:?} is sent, start receiving", + hello_from_guest.as_bytes() + ); socket.wait_for_event().unwrap(); // wait for recv let mut buffer = [0u8; 64]; - let event = socket.recv(host_address, guest_port,&mut buffer).unwrap(); + let event = socket.recv(host_address, guest_port, &mut buffer).unwrap(); assert_eq!( &buffer[0..hello_from_host.len()], hello_from_host.as_bytes() ); - socket.force_close(host_address,guest_port).unwrap(); + socket.force_close(host_address, guest_port).unwrap(); - debug!("The final event: {:?}",event); + debug!("The final event: {:?}", event); Ok(()) } -pub fn socket_device_server_test() -> Result<(),ComponentInitError>{ +pub fn socket_device_server_test() -> Result<(), ComponentInitError> { let host_cid = 2; let guest_cid = 3; let host_port = 1234; @@ -62,26 +68,31 @@ pub fn socket_device_server_test() -> Result<(),ComponentInitError>{ let hello_from_host = "Hello from host"; let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap(); - assert_eq!(device.lock().guest_cid(),guest_cid); + assert_eq!(device.lock().guest_cid(), guest_cid); let mut socket = VsockConnectionManager::new(device); socket.listen(4321); socket.wait_for_event().unwrap(); // wait for connect request socket.wait_for_event().unwrap(); // wait for recv let mut buffer = [0u8; 64]; - let event = socket.recv(host_address, guest_port,&mut buffer).unwrap(); + let event = socket.recv(host_address, guest_port, &mut buffer).unwrap(); assert_eq!( &buffer[0..hello_from_host.len()], hello_from_host.as_bytes() ); - debug!("The buffer {:?} is received, start sending {:?}", &buffer[0..hello_from_host.len()],hello_from_guest.as_bytes()); - socket.send(host_address,guest_port,hello_from_guest.as_bytes()).unwrap(); + debug!( + "The buffer {:?} is received, start sending {:?}", + &buffer[0..hello_from_host.len()], + hello_from_guest.as_bytes() + ); + socket + .send(host_address, guest_port, hello_from_guest.as_bytes()) + .unwrap(); - socket.shutdown(host_address,guest_port).unwrap(); + socket.shutdown(host_address, guest_port).unwrap(); let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown - debug!("The final event: {:?}",event); + debug!("The final event: {:?}", event); Ok(()) - } diff --git a/kernel/comps/virtio/src/device/socket/buffer.rs b/kernel/comps/virtio/src/device/socket/buffer.rs index 20f400414..e3041616e 100644 --- a/kernel/comps/virtio/src/device/socket/buffer.rs +++ b/kernel/comps/virtio/src/device/socket/buffer.rs @@ -1,11 +1,11 @@ -//! This module is adapted from network/buffer.rs +// SPDX-License-Identifier: MPL-2.0 + use align_ext::AlignExt; use bytes::BytesMut; use pod::Pod; -use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN; - use super::header::VirtioVsockHdr; +use crate::device::socket::header::VIRTIO_VSOCK_HDR_LEN; /// Buffer for receive packet #[derive(Debug)] @@ -86,7 +86,7 @@ impl TxBuffer { /// Buffer for event buffer #[derive(Debug)] -pub struct EventBuffer{ +pub struct EventBuffer { id: u32, } @@ -96,4 +96,4 @@ pub struct EventBuffer{ pub enum EventIDType { #[default] VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0, -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/device/socket/config.rs b/kernel/comps/virtio/src/device/socket/config.rs index af625d4bb..0172c2dfc 100644 --- a/kernel/comps/virtio/src/device/socket/config.rs +++ b/kernel/comps/virtio/src/device/socket/config.rs @@ -1,13 +1,15 @@ +// SPDX-License-Identifier: MPL-2.0 + use aster_frame::io_mem::IoMem; -use pod::Pod; -use aster_util::{safe_ptr::SafePtr}; +use aster_util::safe_ptr::SafePtr; use bitflags::bitflags; +use pod::Pod; -use crate::transport::{self, VirtioTransport}; +use crate::transport::VirtioTransport; -bitflags!{ +bitflags! { /// Vsock feature bits since v1.2 - /// If no feature bit is set, only stream socket type is supported. + /// If no feature bit is set, only stream socket type is supported. /// If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also been negotiated. pub struct VsockFeatures: u64 { const VIRTIO_VSOCK_F_STREAM = 1 << 0; // stream socket type is supported. @@ -23,7 +25,7 @@ impl VsockFeatures { #[derive(Debug, Clone, Copy, Pod)] #[repr(C)] -pub struct VirtioVsockConfig{ +pub struct VirtioVsockConfig { /// The guest_cid field contains the guest’s context ID, which uniquely identifies /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. /// @@ -41,4 +43,4 @@ impl VirtioVsockConfig { let memory = transport.device_config_memory(); SafePtr::new(memory, 0) } -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/device/socket/connect.rs b/kernel/comps/virtio/src/device/socket/connect.rs index 411197c4c..874cba6e3 100644 --- a/kernel/comps/virtio/src/device/socket/connect.rs +++ b/kernel/comps/virtio/src/device/socket/connect.rs @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: MPL-2.0 + use log::debug; -use super::{header::{VsockAddr, VirtioVsockHdr, VirtioVsockOp}, error::SocketError}; - +use super::{ + error::SocketError, + header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr}, +}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct VsockBufferStatus { @@ -63,7 +67,7 @@ impl VsockEvent { && self.destination.port == connection_info.src_port } - pub fn from_header(header: &VirtioVsockHdr) -> Result { + pub fn from_header(header: &VirtioVsockHdr) -> Result { let op = header.op()?; let buffer_status = VsockBufferStatus { buffer_allocation: header.buf_alloc, @@ -114,7 +118,6 @@ impl VsockEvent { } } - #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct ConnectionInfo { pub dst: VsockAddr, @@ -184,4 +187,4 @@ impl ConnectionInfo { ..Default::default() } } -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/device/socket/device.rs b/kernel/comps/virtio/src/device/socket/device.rs index 39b91ad67..bd34ddf40 100644 --- a/kernel/comps/virtio/src/device/socket/device.rs +++ b/kernel/comps/virtio/src/device/socket/device.rs @@ -1,14 +1,26 @@ -use core::hint::spin_loop; -use core::fmt::Debug; -use alloc::{vec::Vec, boxed::Box, string::ToString, sync::Arc}; -use aster_frame::{offset_of, trap::TrapFrame, sync::SpinLock}; -use aster_util::{slot_vec::SlotVec, field_ptr}; +// SPDX-License-Identifier: MPL-2.0 + +use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec}; +use core::{fmt::Debug, hint::spin_loop}; + +use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame}; +use aster_util::{field_ptr, slot_vec::SlotVec}; use log::debug; use pod::Pod; -use crate::{queue::{VirtQueue, QueueError}, device::{VirtioDeviceError, socket::{register_device, DEVICE_NAME}}, transport::{VirtioTransport}}; - -use super::{buffer::RxBuffer, config::{VirtioVsockConfig, VsockFeatures}, connect::{ConnectionInfo, VsockEvent}, header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN}, error::SocketError, VsockDeviceIrqHandler}; +use super::{ + buffer::RxBuffer, + config::{VirtioVsockConfig, VsockFeatures}, + connect::{ConnectionInfo, VsockEvent}, + error::SocketError, + header::{VirtioVsockHdr, VirtioVsockOp, VIRTIO_VSOCK_HDR_LEN}, + VsockDeviceIrqHandler, +}; +use crate::{ + device::{socket::register_device, VirtioDeviceError}, + queue::{QueueError, VirtQueue}, + transport::VirtioTransport, +}; const QUEUE_SIZE: u16 = 64; const QUEUE_RECV: u16 = 0; @@ -30,39 +42,43 @@ pub struct SocketDevice { rx_buffers: SlotVec, transport: Box, - callbacks: Vec>, + callbacks: Vec>, } impl SocketDevice { pub fn init(mut transport: Box) -> Result<(), VirtioDeviceError> { let virtio_vsock_config = VirtioVsockConfig::new(transport.as_mut()); debug!("virtio_vsock_config = {:?}", virtio_vsock_config); - let guest_cid = - field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low).read().unwrap() as u64 - | (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high).read().unwrap() as u64) << 32; + let guest_cid = field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_low) + .read() + .unwrap() as u64 + | (field_ptr!(&virtio_vsock_config, VirtioVsockConfig, guest_cid_high) + .read() + .unwrap() as u64) + << 32; - let mut recv_queue = VirtQueue::new(QUEUE_RECV,QUEUE_SIZE,transport.as_mut()) + let mut recv_queue = VirtQueue::new(QUEUE_RECV, QUEUE_SIZE, transport.as_mut()) .expect("createing recv queue fails"); - let send_queue = VirtQueue::new(QUEUE_SEND,QUEUE_SIZE,transport.as_mut()) + let send_queue = VirtQueue::new(QUEUE_SEND, QUEUE_SIZE, transport.as_mut()) .expect("creating send queue fails"); - let event_queue = VirtQueue::new(QUEUE_EVENT,QUEUE_SIZE,transport.as_mut()) + let event_queue = VirtQueue::new(QUEUE_EVENT, QUEUE_SIZE, transport.as_mut()) .expect("creating event queue fails"); // Allocate and add buffers for the RX queue. let mut rx_buffers = SlotVec::new(); for i in 0..QUEUE_SIZE { let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE); - let token = recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?; + let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?; assert_eq!(i, token); assert_eq!(rx_buffers.put(rx_buffer) as u16, i); } - + if recv_queue.should_notify() { debug!("notify receive queue"); recv_queue.notify(); } - let mut device = Self{ + let mut device = Self { config: virtio_vsock_config.read().unwrap(), guest_cid, send_queue, @@ -74,13 +90,13 @@ impl SocketDevice { }; // Interrupt handler if vsock device config space changes - fn config_space_change(_: &TrapFrame){ + fn config_space_change(_: &TrapFrame) { debug!("vsock device config space change"); } // Interrupt handler if vsock device receives some packet. // TODO: This will be handled by vsock socket layer. - fn handle_vsock_event(_: &TrapFrame){ + fn handle_vsock_event(_: &TrapFrame) { debug!("Packet received. This will be solved by socket layer"); } @@ -88,7 +104,7 @@ impl SocketDevice { .transport .register_cfg_callback(Box::new(config_space_change)) .unwrap(); - device + device .transport .register_queue_callback(QUEUE_RECV, Box::new(handle_vsock_event), false) .unwrap(); @@ -96,7 +112,7 @@ impl SocketDevice { device.transport.finish_init(); register_device( - super::DEVICE_NAME.to_string(), + super::DEVICE_NAME.to_string(), Arc::new(SpinLock::new(device)), ); @@ -113,7 +129,7 @@ impl SocketDevice { /// This returns as soon as the request is sent; you should wait until `poll` returns a /// [`VsockEventType::Connected`] event indicating that the peer has accepted the connection /// before sending data. - pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Request as u16, ..connection_info.new_header(self.guest_cid) @@ -124,7 +140,7 @@ impl SocketDevice { } /// Accepts the given connection from a peer. - pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Response as u16, ..connection_info.new_header(self.guest_cid) @@ -133,7 +149,7 @@ impl SocketDevice { } /// Requests the peer to send us a credit update for the given connection. - fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::CreditRequest as u16, ..connection_info.new_header(self.guest_cid) @@ -142,7 +158,7 @@ impl SocketDevice { } /// Tells the peer how much buffer space we have to receive data. - pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::CreditUpdate as u16, ..connection_info.new_header(self.guest_cid) @@ -155,7 +171,7 @@ impl SocketDevice { /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// shutdown. - pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Shutdown as u16, ..connection_info.new_header(self.guest_cid) @@ -164,7 +180,7 @@ impl SocketDevice { } /// Forcibly closes the connection without waiting for the peer. - pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(),SocketError> { + pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result<(), SocketError> { let header = VirtioVsockHdr { op: VirtioVsockOp::Rst as u16, ..connection_info.new_header(self.guest_cid) @@ -172,15 +188,17 @@ impl SocketDevice { self.send_packet_to_tx_queue(&header, &[]) } - fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result<(), SocketError> { + fn send_packet_to_tx_queue( + &mut self, + header: &VirtioVsockHdr, + buffer: &[u8], + ) -> Result<(), SocketError> { // let (_token, _len) = self.send_queue.add_notify_wait_pop( // &[header.as_bytes(), buffer], // &mut [], // )?; - let _token = self - .send_queue - .add_buf(&[header.as_bytes(), buffer], &[])?; + let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?; if self.send_queue.should_notify() { self.send_queue.notify(); @@ -189,7 +207,7 @@ impl SocketDevice { // Wait until the buffer is used while !self.send_queue.can_pop() { spin_loop(); - } + } self.send_queue.pop_used()?; @@ -217,13 +235,17 @@ impl SocketDevice { } /// Sends the buffer to the destination. - pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result<(), SocketError> { + pub fn send( + &mut self, + buffer: &[u8], + connection_info: &mut ConnectionInfo, + ) -> Result<(), SocketError> { self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?; let len = buffer.len() as u32; let header = VirtioVsockHdr { op: VirtioVsockOp::Rw as u16, - len: len, + len, ..connection_info.new_header(self.guest_cid) }; connection_info.tx_cnt += len; @@ -231,10 +253,12 @@ impl SocketDevice { } /// Polls the RX virtqueue for the next event, and calls the given handler function to handle it. - pub fn poll(&mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result,SocketError> + pub fn poll( + &mut self, + handler: impl FnOnce(VsockEvent, &[u8]) -> Result, SocketError>, ) -> Result, SocketError> { // Return None if there is no pending packet. - if !self.recv_queue.can_pop(){ + if !self.recv_queue.can_pop() { return Ok(None); } let (token, len) = self.recv_queue.pop_used()?; @@ -250,20 +274,19 @@ impl SocketDevice { buffer.set_packet_len(RX_BUFFER_SIZE); + let head_result = read_header_and_body(buffer.buf()); - let head_result = read_header_and_body(&buffer.buf()); - - let Ok((header,body)) = head_result else { + let Ok((header, body)) = head_result else { let ret = match head_result { Err(e) => Err(e), - _ => Ok(None) //FIXME: this clause is never reached. + _ => Ok(None), //FIXME: this clause is never reached. }; self.add_rx_buffer(buffer, token)?; return ret; }; debug!("Received packet {:?}. Op {:?}", header, header.op()); - debug!("body is {:?}",body); + debug!("body is {:?}", body); let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body)); @@ -271,13 +294,12 @@ impl SocketDevice { self.add_rx_buffer(buffer, token)?; result - } /// Add a used rx buffer to recv queue,@index is only to check the correctness fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> { - let token = self.recv_queue.add_buf(&[], &mut [rx_buffer.buf_mut()])?; - assert_eq!(index,token); + let token = self.recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?; + assert_eq!(index, token); assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none()); if self.recv_queue.should_notify() { self.recv_queue.notify(); @@ -290,11 +312,9 @@ impl SocketDevice { let device_features = VsockFeatures::from_bits_truncate(features); let supported_features = VsockFeatures::support_features(); let vsock_features = device_features & supported_features; - debug!("features negotiated: {:?}",vsock_features); + debug!("features negotiated: {:?}", vsock_features); vsock_features.bits() } - - } impl Debug for SocketDevice { @@ -310,7 +330,7 @@ impl Debug for SocketDevice { } } -fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketError> { +fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]), SocketError> { // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::()`. let header = VirtioVsockHdr::from_bytes(&buffer[..VIRTIO_VSOCK_HDR_LEN]); let body_length = header.len() as usize; @@ -324,4 +344,4 @@ fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8]),SocketE .get(VIRTIO_VSOCK_HDR_LEN..data_end) .ok_or(SocketError::BufferTooShort)?; Ok((header, data)) -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/device/socket/error.rs b/kernel/comps/virtio/src/device/socket/error.rs index 42a96993d..4134d55b0 100644 --- a/kernel/comps/virtio/src/device/socket/error.rs +++ b/kernel/comps/virtio/src/device/socket/error.rs @@ -1,10 +1,10 @@ +// SPDX-License-Identifier: MPL-2.0 + //! This file comes from virtio-drivers project //! This module contains the error from the VirtIO socket driver. use core::{fmt, result}; -use smoltcp::socket::dhcpv4::Socket; - use crate::queue::QueueError; /// The error type of VirtIO socket driver. @@ -43,7 +43,7 @@ pub enum SocketError { } #[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum SocketQueueError { +pub enum SocketQueueError { InvalidArgs, BufferTooSmall, NotReady, @@ -63,7 +63,6 @@ impl From for SocketQueueError { } } - impl From for SocketError { fn from(value: QueueError) -> Self { Self::QueueError(SocketQueueError::from(value)) diff --git a/kernel/comps/virtio/src/device/socket/header.rs b/kernel/comps/virtio/src/device/socket/header.rs index c89e2b3a8..702e372f0 100644 --- a/kernel/comps/virtio/src/device/socket/header.rs +++ b/kernel/comps/virtio/src/device/socket/header.rs @@ -1,5 +1,8 @@ -use pod::Pod; +// SPDX-License-Identifier: MPL-2.0 + use bitflags::bitflags; +use pod::Pod; + use super::error::{self, SocketError}; pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::(); @@ -15,9 +18,9 @@ pub struct VsockAddr { /// VirtioVsock header precedes the payload in each packet. // #[repr(packed)] -#[repr(C,packed)] +#[repr(C, packed)] #[derive(Debug, Clone, Copy, Pod)] -pub struct VirtioVsockHdr{ +pub struct VirtioVsockHdr { pub src_cid: u64, pub dst_cid: u64, pub src_port: u32, @@ -25,7 +28,7 @@ pub struct VirtioVsockHdr{ pub len: u32, pub socket_type: u16, - pub op: u16, //TOASK: why mark Pod and can I mark OpType Pod and replace u16 into OpType. + pub op: u16, pub flags: u32, /// Total receive buffer space for this socket. This includes both free and in-use buffers. pub buf_alloc: u32, @@ -50,19 +53,22 @@ impl Default for VirtioVsockHdr { } } - impl VirtioVsockHdr { /// Returns the length of the data. pub fn len(&self) -> u32 { self.len } + pub fn is_empty(&self) -> bool { + self.len == 0 + } + pub fn op(&self) -> error::Result { self.op.try_into() } pub fn source(&self) -> VsockAddr { - VsockAddr{ + VsockAddr { cid: self.src_cid, port: self.src_port, } @@ -76,7 +82,7 @@ impl VirtioVsockHdr { } pub fn check_data_is_empty(&self) -> error::Result<()> { - if self.len() == 0 { + if self.is_empty() { Ok(()) } else { Err(SocketError::UnexpectedDataInPacket) @@ -87,7 +93,7 @@ impl VirtioVsockHdr { #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] #[repr(u16)] #[allow(non_camel_case_types)] -pub enum VirtioVsockOp{ +pub enum VirtioVsockOp { #[default] Invalid = 0, @@ -123,7 +129,7 @@ impl TryFrom for VirtioVsockOp { _ => return Err(SocketError::UnknownOperation(v)), }; Ok(op) - } + } } bitflags! { @@ -136,7 +142,7 @@ bitflags! { /// The peer will not send any more data. const VIRTIO_VSOCK_SHUTDOWN_SEND = 1 << 1; /// The peer will not send or receive any more data. - const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits; + const VIRTIO_VSOCK_SHUTDOWN_ALL = Self::VIRTIO_VSOCK_SHUTDOWN_RCV.bits | Self::VIRTIO_VSOCK_SHUTDOWN_SEND.bits; } } @@ -148,4 +154,4 @@ pub enum VsockType { Stream = 1, /// seqpacket socket type introduced in virtio-v1.2. SeqPacket = 2, -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/device/socket/manager.rs b/kernel/comps/virtio/src/device/socket/manager.rs index 8da27b61e..30d6ee0ab 100644 --- a/kernel/comps/virtio/src/device/socket/manager.rs +++ b/kernel/comps/virtio/src/device/socket/manager.rs @@ -1,14 +1,17 @@ +// SPDX-License-Identifier: MPL-2.0 +use alloc::{boxed::Box, sync::Arc, vec, vec::Vec}; use core::{cmp::min, hint::spin_loop}; -use alloc::{vec::Vec, boxed::Box, sync::Arc}; use aster_frame::sync::SpinLock; use log::debug; +use super::{ + connect::{ConnectionInfo, DisconnectReason, VsockEvent, VsockEventType}, + device::SocketDevice, + header::VsockAddr, +}; use crate::device::socket::error::SocketError; -use super::{device::SocketDevice, connect::{ConnectionInfo, VsockEvent, VsockEventType, DisconnectReason}, header::VsockAddr}; - - const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; /// TODO: A higher level interface for VirtIO socket (vsock) devices. @@ -72,11 +75,11 @@ impl VsockConnectionManager { /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Connected` event indicating that the peer has accepted the connection /// before sending data. - pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { + pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> { if self.connections.iter().any(|connection| { connection.info.dst == destination && connection.info.src_port == src_port }) { - return Err(SocketError::ConnectionExists.into()); + return Err(SocketError::ConnectionExists); } let new_connection = Connection::new(destination, src_port); @@ -88,14 +91,19 @@ impl VsockConnectionManager { } /// Sends the buffer to the destination. - pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result<(),SocketError> { + pub fn send( + &mut self, + destination: VsockAddr, + src_port: u32, + buffer: &[u8], + ) -> Result<(), SocketError> { let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; self.driver.lock().send(buffer, &mut connection.info) } /// Polls the vsock device to receive data or other updates. - pub fn poll(&mut self) -> Result,SocketError> { + pub fn poll(&mut self) -> Result, SocketError> { let guest_cid = self.driver.lock().guest_cid(); let connections = &mut self.connections; @@ -181,8 +189,13 @@ impl VsockConnectionManager { } /// Reads data received from the given connection. - pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result { - debug!("connections is {:?}",self.connections); + pub fn recv( + &mut self, + peer: VsockAddr, + src_port: u32, + buffer: &mut [u8], + ) -> Result { + debug!("connections is {:?}", self.connections); let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; // Copy from ring buffer @@ -197,11 +210,11 @@ impl VsockConnectionManager { self.connections.swap_remove(connection_index); } - Ok(bytes_read) + Ok(bytes_read) } /// Blocks until we get some event from the vsock device. - pub fn wait_for_event(&mut self) -> Result { + pub fn wait_for_event(&mut self) -> Result { loop { if let Some(event) = self.poll()? { return Ok(event); @@ -216,14 +229,18 @@ impl VsockConnectionManager { /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// shutdown. - pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { + pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result<(), SocketError> { let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; self.driver.lock().shutdown(&connection.info) } /// Forcibly closes the connection without waiting for the peer. - pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result<(),SocketError> { + pub fn force_close( + &mut self, + destination: VsockAddr, + src_port: u32, + ) -> Result<(), SocketError> { let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; self.driver.lock().force_close(&connection.info)?; @@ -263,7 +280,6 @@ fn get_connection_for_event<'a>( .find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) } - #[derive(Debug)] struct Connection { info: ConnectionInfo, @@ -296,12 +312,9 @@ struct RingBuffer { impl RingBuffer { pub fn new(capacity: usize) -> Self { - // TODO: can be optimized. - let mut temp = Vec::with_capacity(capacity); - temp.resize(capacity,0); Self { // FIXME: if the capacity is excessive, elements move will be executed. - buffer: temp.into_boxed_slice(), + buffer: vec![0; capacity].into_boxed_slice(), used: 0, start: 0, } @@ -366,4 +379,4 @@ impl RingBuffer { bytes_read } -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/device/socket/mod.rs b/kernel/comps/virtio/src/device/socket/mod.rs index 316957003..bdc4c82d5 100644 --- a/kernel/comps/virtio/src/device/socket/mod.rs +++ b/kernel/comps/virtio/src/device/socket/mod.rs @@ -1,23 +1,23 @@ -//! This mod is modified from virtio-drivers project. +// SPDX-License-Identifier: MPL-2.0 + +//! This mod is modified from virtio-drivers project. +use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; -use alloc::{sync::Arc, collections::BTreeMap, string::String, vec::Vec}; -use component::{ComponentInitError, init_component}; use aster_frame::sync::SpinLock; -use smoltcp::socket::dhcpv4::Socket; +use component::ComponentInitError; use spin::Once; -use core::fmt::Debug; + use self::device::SocketDevice; pub mod buffer; pub mod config; -pub mod device; -pub mod header; pub mod connect; +pub mod device; pub mod error; +pub mod header; pub mod manager; pub static DEVICE_NAME: &str = "Virtio-Vsock"; -pub type VsockDeviceIrqHandler = dyn Fn() + Send + Sync; - +pub trait VsockDeviceIrqHandler = Fn() + Send + Sync + 'static; pub fn register_device(name: String, device: Arc>) { COMPONENT @@ -30,9 +30,7 @@ pub fn register_device(name: String, device: Arc>) { pub fn get_device(str: &str) -> Option>> { let lock = COMPONENT.get().unwrap().vsock_device_table.lock(); - let Some(device) = lock.get(str) else { - return None; - }; + let device = lock.get(str)?; Some(device.clone()) } @@ -46,14 +44,12 @@ pub fn all_devices() -> Vec<(String, Arc>)> { static COMPONENT: Once = Once::new(); - -pub fn component_init() -> Result<(), ComponentInitError>{ +pub fn component_init() -> Result<(), ComponentInitError> { let a = Component::init()?; COMPONENT.call_once(|| a); Ok(()) } - struct Component { vsock_device_table: SpinLock>>>, } @@ -64,4 +60,4 @@ impl Component { vsock_device_table: SpinLock::new(BTreeMap::new()), }) } -} \ No newline at end of file +} diff --git a/kernel/comps/virtio/src/lib.rs b/kernel/comps/virtio/src/lib.rs index f95754e04..513decc9c 100644 --- a/kernel/comps/virtio/src/lib.rs +++ b/kernel/comps/virtio/src/lib.rs @@ -4,6 +4,7 @@ #![no_std] #![deny(unsafe_code)] #![allow(dead_code)] +#![feature(trait_alias)] #![feature(fn_traits)] extern crate alloc; @@ -13,8 +14,12 @@ use alloc::boxed::Box; use bitflags::bitflags; use component::{init_component, ComponentInitError}; use device::{ - block::device::BlockDevice, console::device::ConsoleDevice, input::device::InputDevice, - network::device::NetworkDevice, socket::{device::SocketDevice, self}, VirtioDeviceType, + block::device::BlockDevice, + console::device::ConsoleDevice, + input::device::InputDevice, + network::device::NetworkDevice, + socket::{self, device::SocketDevice}, + VirtioDeviceType, }; use log::{error, warn}; use transport::{mmio::VIRTIO_MMIO_DRIVER, pci::VIRTIO_PCI_DRIVER, DeviceStatus};