diff --git a/.github/workflows/kernel_test.yml b/.github/workflows/kernel_test.yml index c02320a6f..b2699492e 100644 --- a/.github/workflows/kernel_test.yml +++ b/.github/workflows/kernel_test.yml @@ -82,4 +82,29 @@ jobs: - name: Regression Test (Linux EFI Handover Boot Protocol) id: regression_test_linux - run: make run AUTO_TEST=regression ENABLE_KVM=0 BOOT_PROTOCOL=linux-efi-handover64 RELEASE=1 + run: make run AUTO_TEST=regression ENABLE_KVM=0 BOOT_PROTOCOL=multiboot2 RELEASE=1 + + vsock-test: + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - uses: actions/checkout@v4 + + - name: Run Vsock Server on Host + id: host_vsock_server + run: | + sudo modprobe vhost_vsock + sudo apt-get install socat + echo "Run vsock server on host...." + socat -dd vsock-listen:1234 SYSTEM:"sleep 3s; echo 'Hello from host';sleep 2s;kill -INT $$" & + - name: Run Vsock Client and Server on Guest + id: guest_vsock_client_server + run: | + docker run --privileged --network=host --device=/dev/kvm -v ./:/root/asterinas asterinas/asterinas:0.4.2 \ + make run AUTO_TEST=vsock ENABLE_KVM=0 QEMU_MACHINE=microvm RELEASE_MODE=1 & + - name: Run Vsock Client on Host + id: host_vsock_client + run: | + sleep 5m + echo "Run vsock client on host...." + echo "Hello from host" | socat -dd - vsock-connect:3:4321 diff --git a/.github/workflows/vsock_interaction.yml b/.github/workflows/vsock_interaction.yml new file mode 100644 index 000000000..4115d9be7 --- /dev/null +++ b/.github/workflows/vsock_interaction.yml @@ -0,0 +1,28 @@ +name: Vsock Interaction + +on: + pull_request: + push: + branches: + - main + +jobs: + vsock-test: + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - uses: actions/checkout@v4 + + - name: Run Vsock Server on Host + id: host_vsock_server + run: | + sudo modprobe vhost_vsock + sudo apt-get install socat + echo "Run vsock server on host...." + socat -ddd VSOCK-LISTEN:1234,fork SYSTEM:'read cmd; result=\$(eval \"\$cmd\" 2>&1); echo \"\$result\"' & + - name: Run Vsock Client on Guest + id: guest_vsock_client_server + run: | + docker run --privileged --network=host --device=/dev/kvm -v ./:/root/asterinas asterinas/asterinas:0.4.2 \ + make run AUTO_TEST=vsock ENABLE_KVM=0 SCHEME=microvm RELEASE_MODE=1 + \ No newline at end of file diff --git a/Makefile b/Makefile index 065e9948e..9e30ed611 100644 --- a/Makefile +++ b/Makefile @@ -33,15 +33,15 @@ CARGO_OSDK_ARGS += --init-args="/opt/syscall_test/run_syscall_test.sh" else ifeq ($(AUTO_TEST), regression) CARGO_OSDK_ARGS += --init-args="/regression/run_regression_test.sh" else ifeq ($(AUTO_TEST), boot) -CARGO_OSDK_ARGS += --init_args="/regression/boot_hello.sh" +CARGO_OSDK_ARGS += --init-args="/regression/boot_hello.sh" else ifeq ($(AUTO_TEST), vsock) -CARGO_OSDK_ARGS += --init_args="/regression/run_vsock_test.sh" -ifeq ($(QEMU_MACHINE), microvm) -CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-device,guest-cid=3" -else ifeq ($(EMULATE_IOMMU), 1) -CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off,iommu_platform=on,ats=on" +CARGO_OSDK_ARGS += --init-args="/regression/run_vsock_test.sh" +ifeq ($(SCHEME), microvm) +CARGO_OSDK_ARGS += --qemu-args="-device vhost-vsock-device,guest-cid=3" +else ifeq ($(SCHEME), iommu) +CARGO_OSDK_ARGS += --qemu-args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off,iommu_platform=on,ats=on" else -CARGO_OSDK_ARGS += --qemu.args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off" +CARGO_OSDK_ARGS += --qemu-args="-device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3,disable-legacy=on,disable-modern=off" endif endif diff --git a/kernel/aster-nix/src/net/socket/util/socket_addr.rs b/kernel/aster-nix/src/net/socket/util/socket_addr.rs index e7a41f722..72169ca3d 100644 --- a/kernel/aster-nix/src/net/socket/util/socket_addr.rs +++ b/kernel/aster-nix/src/net/socket/util/socket_addr.rs @@ -3,7 +3,7 @@ use crate::{ net::{ iface::{IpAddress, IpEndpoint, Ipv4Address}, - socket::unix::UnixSocketAddr, + socket::{unix::UnixSocketAddr, vsock::addr::VsockSocketAddr}, }, prelude::*, }; @@ -15,7 +15,7 @@ pub enum SocketAddr { Unix(UnixSocketAddr), IPv4(Ipv4Address, PortNum), IPv6, - Vsock(u32, u32), + Vsock(VsockSocketAddr), } impl TryFrom for IpEndpoint { diff --git a/kernel/aster-nix/src/net/socket/vsock/addr.rs b/kernel/aster-nix/src/net/socket/vsock/addr.rs index d6d5d7415..78a74b5a4 100644 --- a/kernel/aster-nix/src/net/socket/vsock/addr.rs +++ b/kernel/aster-nix/src/net/socket/vsock/addr.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_virtio::device::socket::header::VsockAddr; +use aster_virtio::device::socket::header::VsockDeviceAddr; use crate::{net::socket::SocketAddr, prelude::*}; @@ -27,23 +27,21 @@ impl TryFrom for VsockSocketAddr { type Error = Error; fn try_from(value: SocketAddr) -> Result { - let (cid, port) = if let SocketAddr::Vsock(cid, port) = value { - (cid, port) - } else { + let SocketAddr::Vsock(vsock_addr) = value else { return_errno_with_message!(Errno::EINVAL, "invalid vsock socket addr"); }; - Ok(Self { cid, port }) + Ok(vsock_addr) } } impl From for SocketAddr { fn from(value: VsockSocketAddr) -> Self { - SocketAddr::Vsock(value.cid, value.port) + SocketAddr::Vsock(value) } } -impl From for VsockSocketAddr { - fn from(value: VsockAddr) -> Self { +impl From for VsockSocketAddr { + fn from(value: VsockDeviceAddr) -> Self { VsockSocketAddr { cid: value.cid as u32, port: value.port, @@ -51,9 +49,9 @@ impl From for VsockSocketAddr { } } -impl From for VsockAddr { +impl From for VsockDeviceAddr { fn from(value: VsockSocketAddr) -> Self { - VsockAddr { + VsockDeviceAddr { cid: value.cid as u64, port: value.port, } diff --git a/kernel/aster-nix/src/net/socket/vsock/common.rs b/kernel/aster-nix/src/net/socket/vsock/common.rs index f78faefb8..613e7f3d4 100644 --- a/kernel/aster-nix/src/net/socket/vsock/common.rs +++ b/kernel/aster-nix/src/net/socket/vsock/common.rs @@ -3,7 +3,7 @@ use alloc::collections::BTreeSet; use aster_virtio::device::socket::{ - connect::{VsockEvent, VsockEventType}, + connect::{ConnectionInfo, VsockEvent, VsockEventType}, device::SocketDevice, error::SocketError, get_device, DEVICE_NAME, @@ -21,15 +21,15 @@ use crate::{events::IoEvents, prelude::*, return_errno_with_message}; /// Manage all active sockets pub struct VsockSpace { - pub driver: Arc>, + driver: Arc>, // (key, value) = (local_addr, connecting) - pub connecting_sockets: SpinLock>>, + connecting_sockets: SpinLock>>, // (key, value) = (local_addr, listen) - pub listen_sockets: SpinLock>>, + listen_sockets: SpinLock>>, // (key, value) = (id(local_addr,peer_addr), connected) - pub connected_sockets: SpinLock>>, + connected_sockets: RwLock>>, // Used ports - pub used_ports: SpinLock>, + used_ports: SpinLock>, } impl VsockSpace { @@ -40,136 +40,12 @@ impl VsockSpace { driver, connecting_sockets: SpinLock::new(BTreeMap::new()), listen_sockets: SpinLock::new(BTreeMap::new()), - connected_sockets: SpinLock::new(BTreeMap::new()), + connected_sockets: RwLock::new(BTreeMap::new()), used_ports: SpinLock::new(BTreeSet::new()), } } - /// Poll for each event from the driver - pub fn poll(&self) -> Result> { - let mut driver = self.driver.lock_irq_disabled(); - let guest_cid: u32 = driver.guest_cid() as u32; - // match the socket and store the buffer body (if valid) - let result = driver - .poll(|event, body| { - if !self.is_event_for_socket(&event) { - return Ok(None); - } - - // Deal with Received before the buffer are recycled. - if let VsockEventType::Received { length } = event.event_type { - // Only consider the connected socket and copy body to buffer - if let Some(connected) = self - .connected_sockets - .lock_irq_disabled() - .get(&event.into()) - { - debug!("Rw matches a connection with id {:?}", connected.id()); - if !connected.connection_buffer_add(body) { - return Err(SocketError::BufferTooShort); - } - connected.update_io_events(); - } else { - return Ok(None); - } - } - Ok(Some(event)) - }) - .map_err(|e| { - Error::with_message(Errno::EAGAIN, "driver poll failed, please try again") - })?; - - let Some(event) = result else { - return Ok(None); - }; - - // The socket must be stored in the VsockSpace. - if let Some(connected) = self - .connected_sockets - .lock_irq_disabled() - .get(&event.into()) - { - connected.update_for_event(&event); - } - - // Response to the event - match event.event_type { - VsockEventType::ConnectionRequest => { - // Preparation for listen socket `accept` - if let Some(listen) = self - .listen_sockets - .lock_irq_disabled() - .get(&event.destination.into()) - { - let peer = event.source; - let connected = Arc::new(Connected::new(peer.into(), listen.addr())); - connected.update_for_event(&event); - listen.push_incoming(connected).unwrap(); - listen.update_io_events(); - } else { - return_errno_with_message!( - Errno::EINVAL, - "Connecion request can only be handled by listening socket" - ) - } - } - VsockEventType::Connected => { - if let Some(connecting) = self - .connecting_sockets - .lock_irq_disabled() - .get(&event.destination.into()) - { - debug!( - "match a connecting socket. Peer{:?}; local{:?}", - connecting.peer_addr(), - connecting.local_addr() - ); - connecting.update_for_event(&event); - connecting.add_events(IoEvents::IN); - } - } - VsockEventType::Disconnected { reason } => { - if let Some(connected) = self - .connected_sockets - .lock_irq_disabled() - .get(&event.into()) - { - connected.peer_requested_shutdown(); - } else { - return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected"); - } - } - VsockEventType::Received { length } => {} - VsockEventType::CreditRequest => { - if let Some(connected) = self - .connected_sockets - .lock_irq_disabled() - .get(&event.into()) - { - driver.credit_update(&connected.get_info()).map_err(|_| { - Error::with_message(Errno::EINVAL, "can not send credit update") - })?; - } - } - VsockEventType::CreditUpdate => { - if let Some(connected) = self - .connected_sockets - .lock_irq_disabled() - .get(&event.into()) - { - connected.update_for_event(&event); - } else { - return_errno_with_message!( - Errno::EINVAL, - "CreditUpdate is only valid in connected sockets" - ) - } - } - } - Ok(Some(event)) - } /// Check whether the event is for this socket space fn is_event_for_socket(&self, event: &VsockEvent) -> bool { - // debug!("The event is for connection with id {:?}",ConnectionID::from(*event)); self.connecting_sockets .lock_irq_disabled() .contains_key(&event.destination.into()) @@ -179,7 +55,7 @@ impl VsockSpace { .contains_key(&event.destination.into()) || self .connected_sockets - .lock_irq_disabled() + .read_irq_disabled() .contains_key(&(*event).into()) } /// Alloc an unused port range @@ -193,6 +69,213 @@ impl VsockSpace { } return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port"); } + pub fn insert_port(&self, port: u32) -> bool { + let mut used_ports = self.used_ports.lock_irq_disabled(); + used_ports.insert(port) + } + pub fn recycle_port(&self, port: &u32) -> bool { + let mut used_ports = self.used_ports.lock_irq_disabled(); + used_ports.remove(port) + } + + pub fn insert_connected_socket( + &self, + id: ConnectionID, + connected: Arc, + ) -> Option> { + let mut connected_sockets = self.connected_sockets.write_irq_disabled(); + connected_sockets.insert(id, connected) + } + pub fn remove_connected_socket(&self, id: &ConnectionID) -> Option> { + let mut connected_sockets = self.connected_sockets.write_irq_disabled(); + connected_sockets.remove(id) + } + pub fn insert_connecting_socket( + &self, + addr: VsockSocketAddr, + connecting: Arc, + ) -> Option> { + let mut connecting_sockets = self.connecting_sockets.lock_irq_disabled(); + connecting_sockets.insert(addr, connecting) + } + pub fn remove_connecting_socket(&self, addr: &VsockSocketAddr) -> Option> { + let mut connecting_sockets = self.connecting_sockets.lock_irq_disabled(); + connecting_sockets.remove(addr) + } + pub fn insert_listen_socket( + &self, + addr: VsockSocketAddr, + listen: Arc, + ) -> Option> { + let mut listen_sockets = self.listen_sockets.lock_irq_disabled(); + listen_sockets.insert(addr, listen) + } + pub fn remove_listen_socket(&self, addr: &VsockSocketAddr) -> Option> { + let mut listen_sockets = self.listen_sockets.lock_irq_disabled(); + listen_sockets.remove(addr) + } +} + +impl VsockSpace { + pub fn guest_cid(&self) -> u32 { + let driver = self.driver.lock_irq_disabled(); + driver.guest_cid() as u32 + } + + pub fn request(&self, info: &ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .request(info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send connect packet")) + } + + pub fn response(&self, info: &ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .response(info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send response packet")) + } + + pub fn shutdown(&self, info: &ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .shutdown(info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send shutdown packet")) + } + + pub fn reset(&self, info: &ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .reset(info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send reset packet")) + } + + pub fn request_credit(&self, info: &ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .credit_request(info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send credit request packet")) + } + + pub fn update_credit(&self, info: &ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .credit_update(info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send credit update packet")) + } + + pub fn send(&self, buffer: &[u8], info: &mut ConnectionInfo) -> Result<()> { + let mut driver = self.driver.lock_irq_disabled(); + driver + .send(buffer, info) + .map_err(|_| Error::with_message(Errno::EIO, "can not send data packet")) + } + + /// Poll for each event from the driver + pub fn poll(&self) -> Result> { + let mut driver = self.driver.lock_irq_disabled(); + let guest_cid = driver.guest_cid() as u32; + // match the socket and store the buffer body (if valid) + let result = driver + .poll(|event, body| { + if !self.is_event_for_socket(&event) { + return Ok(None); + } + + // Deal with Received before the buffer are recycled. + if let VsockEventType::Received { length } = event.event_type { + // Only consider the connected socket and copy body to buffer + if let Some(connected) = self + .connected_sockets + .read_irq_disabled() + .get(&event.into()) + { + debug!("Rw matches a connection with id {:?}", connected.id()); + if !connected.add_connection_buffer(body) { + return Err(SocketError::BufferTooShort); + } + connected.update_io_events(); + } else { + return Ok(None); + } + } + Ok(Some(event)) + }) + .map_err(|e| Error::with_message(Errno::EIO, "driver poll failed, please try again"))?; + + let Some(event) = result else { + return Ok(None); + }; + debug!("vsock receive event: {:?}", event); + // The socket must be stored in the VsockSpace. + if let Some(connected) = self + .connected_sockets + .read_irq_disabled() + .get(&event.into()) + { + connected.update_info(&event); + } + + // Response to the event + match event.event_type { + VsockEventType::ConnectionRequest => { + // Preparation for listen socket `accept` + let listen_sockets = self.listen_sockets.lock_irq_disabled(); + let Some(listen) = listen_sockets.get(&event.destination.into()) else { + return_errno_with_message!( + Errno::EINVAL, + "connecion request can only be handled by listening socket" + ); + }; + let peer = event.source; + let connected = Arc::new(Connected::new(peer.into(), listen.addr())); + connected.update_info(&event); + listen.push_incoming(connected).unwrap(); + listen.update_io_events(); + } + VsockEventType::Connected => { + let connecting_sockets = self.connecting_sockets.lock_irq_disabled(); + let Some(connecting) = connecting_sockets.get(&event.destination.into()) else { + return_errno_with_message!( + Errno::EINVAL, + "connected event can only be handled by connecting socket" + ); + }; + debug!( + "match a connecting socket. Peer{:?}; local{:?}", + connecting.peer_addr(), + connecting.local_addr() + ); + connecting.update_info(&event); + connecting.add_events(IoEvents::IN); + } + VsockEventType::Disconnected { reason } => { + let connected_sockets = self.connected_sockets.read_irq_disabled(); + let Some(connected) = connected_sockets.get(&event.into()) else { + return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); + }; + connected.peer_requested_shutdown(); + } + VsockEventType::Received { length } => {} + VsockEventType::CreditRequest => { + let connected_sockets = self.connected_sockets.read_irq_disabled(); + let Some(connected) = connected_sockets.get(&event.into()) else { + return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); + }; + driver + .credit_update(&connected.get_info()) + .map_err(|_| Error::with_message(Errno::EIO, "cannot send credit update"))?; + } + VsockEventType::CreditUpdate => { + let connected_sockets = self.connected_sockets.read_irq_disabled(); + let Some(connected) = connected_sockets.get(&event.into()) else { + return_errno_with_message!(Errno::ENOTCONN, "the socket hasn't connected"); + }; + connected.update_info(&event); + } + } + Ok(Some(event)) + } } impl Default for VsockSpace { diff --git a/kernel/aster-nix/src/net/socket/vsock/mod.rs b/kernel/aster-nix/src/net/socket/vsock/mod.rs index a514cf469..ced39c12d 100644 --- a/kernel/aster-nix/src/net/socket/vsock/mod.rs +++ b/kernel/aster-nix/src/net/socket/vsock/mod.rs @@ -9,6 +9,7 @@ use spin::Once; pub mod addr; pub mod common; pub mod stream; +pub use addr::VsockSocketAddr; pub use stream::VsockStreamSocket; // init static driver @@ -19,7 +20,7 @@ pub fn init() { VSOCK_GLOBAL.call_once(|| Arc::new(VsockSpace::new())); register_recv_callback(DEVICE_NAME, || { let vsockspace = VSOCK_GLOBAL.get().unwrap(); - let _ = vsockspace.poll(); + vsockspace.poll().unwrap(); }) } } diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs b/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs index 18a961db0..c03c32b59 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs @@ -1,13 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::boxed::Box; -use core::cmp::min; - use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent}; +use ringbuf::{ring_buffer::RbBase, HeapRb, Rb}; use super::connecting::Connecting; use crate::{ - events::{IoEvents, Observer}, + events::IoEvents, net::socket::{ vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, SendRecvFlags, SockShutdownCmd, @@ -54,8 +52,8 @@ impl Connected { pub fn try_recv(&self, buf: &mut [u8]) -> Result { let mut connection = self.connection.lock_irq_disabled(); - let bytes_read = connection.buffer.drain(buf); - + let bytes_read = connection.buffer.len().min(buf.len()); + connection.buffer.pop_slice(&mut buf[..bytes_read]); connection.info.done_forwarding(bytes_read); match bytes_read { @@ -77,10 +75,8 @@ impl Connected { VSOCK_GLOBAL .get() .unwrap() - .driver - .lock_irq_disabled() - .send(buf, &mut connection.info) - .map_err(|e| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?; + .send(buf, &mut connection.info)?; + Ok(buf_len) } @@ -96,23 +92,13 @@ impl Connected { if self.should_close() { let connection = self.connection.lock_irq_disabled(); let vsockspace = VSOCK_GLOBAL.get().unwrap(); - vsockspace - .driver - .lock_irq_disabled() - .reset(&connection.info) - .map_err(|e| Error::with_message(Errno::ENOMEM, "can not send close packet"))?; - vsockspace - .connected_sockets - .lock_irq_disabled() - .remove(&self.id()); - vsockspace - .used_ports - .lock_irq_disabled() - .remove(&self.local_addr().port); + vsockspace.reset(&connection.info).unwrap(); + vsockspace.remove_connected_socket(&self.id()); + vsockspace.recycle_port(&self.local_addr().port); } Ok(()) } - pub fn update_for_event(&self, event: &VsockEvent) { + pub fn update_info(&self, event: &VsockEvent) { let mut connection = self.connection.lock_irq_disabled(); connection.update_for_event(event) } @@ -122,7 +108,7 @@ impl Connected { connection.info.clone() } - pub fn connection_buffer_add(&self, bytes: &[u8]) -> bool { + pub fn add_connection_buffer(&self, bytes: &[u8]) -> bool { let mut connection = self.connection.lock_irq_disabled(); connection.add(bytes) } @@ -144,42 +130,18 @@ impl Connected { self.pollee.del_events(IoEvents::IN); } } - - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } } impl Drop for Connected { fn drop(&mut self) { let vsockspace = VSOCK_GLOBAL.get().unwrap(); - vsockspace - .used_ports - .lock_irq_disabled() - .remove(&self.local_addr().port); + vsockspace.recycle_port(&self.local_addr().port); } } -#[derive(Debug)] pub struct Connection { info: ConnectionInfo, - buffer: RingBuffer, + buffer: HeapRb, /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is /// still data in the buffer. pub peer_requested_shutdown: bool, @@ -191,16 +153,15 @@ impl Connection { info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); Self { info, - buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), + buffer: HeapRb::new(PER_CONNECTION_BUFFER_CAPACITY), peer_requested_shutdown: false, } } - pub fn from_info(info: ConnectionInfo) -> Self { - let mut info = info.clone(); + pub fn from_info(mut info: ConnectionInfo) -> Self { info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); Self { info, - buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), + buffer: HeapRb::new(PER_CONNECTION_BUFFER_CAPACITY), peer_requested_shutdown: false, } } @@ -208,7 +169,11 @@ impl Connection { self.info.update_for_event(event) } pub fn add(&mut self, bytes: &[u8]) -> bool { - self.buffer.add(bytes) + if bytes.len() > self.buffer.free_len() { + return false; + } + self.buffer.push_slice(bytes); + true } } @@ -231,85 +196,3 @@ impl From for ConnectionID { Self::new(event.destination.into(), event.source.into()) } } - -#[derive(Debug)] -struct RingBuffer { - buffer: Box<[u8]>, - /// The number of bytes currently in the buffer. - used: usize, - /// The index of the first used byte in the buffer. - start: usize, -} -//TODO: ringbuf -impl RingBuffer { - pub fn new(capacity: usize) -> Self { - // TODO: can be optimized. - let temp = vec![0; capacity]; - Self { - // FIXME: if the capacity is excessive, elements move will be executed. - buffer: temp.into_boxed_slice(), - used: 0, - start: 0, - } - } - /// Returns the number of bytes currently used in the buffer. - pub fn used(&self) -> usize { - self.used - } - - /// Returns true iff there are currently no bytes in the buffer. - pub fn is_empty(&self) -> bool { - self.used == 0 - } - - /// Returns the number of bytes currently free in the buffer. - pub fn available(&self) -> usize { - self.buffer.len() - self.used - } - - /// Adds the given bytes to the buffer if there is enough capacity for them all. - /// - /// Returns true if they were added, or false if they were not. - pub fn add(&mut self, bytes: &[u8]) -> bool { - if bytes.len() > self.available() { - return false; - } - - // The index of the first available position in the buffer. - let first_available = (self.start + self.used) % self.buffer.len(); - // The number of bytes to copy from `bytes` to `buffer` between `first_available` and - // `buffer.len()`. - let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); - self.buffer[first_available..first_available + copy_length_before_wraparound] - .copy_from_slice(&bytes[0..copy_length_before_wraparound]); - if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { - self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); - } - self.used += bytes.len(); - - true - } - - /// Reads and removes as many bytes as possible from the buffer, up to the length of the given - /// buffer. - pub fn drain(&mut self, out: &mut [u8]) -> usize { - let bytes_read = min(self.used, out.len()); - - // The number of bytes to copy out between `start` and the end of the buffer. - let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); - // The number of bytes to copy out from the beginning of the buffer after wrapping around. - let read_after_wraparound = bytes_read - .checked_sub(read_before_wraparound) - .unwrap_or_default(); - - out[0..read_before_wraparound] - .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); - out[read_before_wraparound..bytes_read] - .copy_from_slice(&self.buffer[0..read_after_wraparound]); - - self.used -= bytes_read; - self.start = (self.start + bytes_read) % self.buffer.len(); - - bytes_read - } -} diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/connecting.rs b/kernel/aster-nix/src/net/socket/vsock/stream/connecting.rs index b0b2ddb00..ab9c76dea 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/connecting.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/connecting.rs @@ -38,7 +38,7 @@ impl Connecting { pub fn info(&self) -> ConnectionInfo { self.info.lock_irq_disabled().clone() } - pub fn update_for_event(&self, event: &VsockEvent) { + pub fn update_info(&self, event: &VsockEvent) { self.info.lock_irq_disabled().update_for_event(event) } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { @@ -52,9 +52,6 @@ impl Connecting { impl Drop for Connecting { fn drop(&mut self) { let vsockspace = VSOCK_GLOBAL.get().unwrap(); - vsockspace - .used_ports - .lock_irq_disabled() - .remove(&self.local_addr().port); + vsockspace.recycle_port(&self.local_addr().port); } } diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/init.rs b/kernel/aster-nix/src/net/socket/vsock/stream/init.rs index c42bf2508..ee7df3a18 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/init.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use crate::{ - events::{IoEvents, Observer}, + events::IoEvents, net::socket::vsock::{ addr::{VsockSocketAddr, VMADDR_CID_ANY, VMADDR_PORT_ANY}, VSOCK_GLOBAL, @@ -11,31 +11,31 @@ use crate::{ }; pub struct Init { - bind_addr: SpinLock>, + bound_addr: Mutex>, pollee: Pollee, } impl Init { pub fn new() -> Self { Self { - bind_addr: SpinLock::new(None), + bound_addr: Mutex::new(None), pollee: Pollee::new(IoEvents::empty()), } } pub fn bind(&self, addr: VsockSocketAddr) -> Result<()> { - if self.bind_addr.lock().is_some() { + if self.bound_addr.lock().is_some() { return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); } let vsockspace = VSOCK_GLOBAL.get().unwrap(); // check correctness of cid - let local_cid = vsockspace.driver.lock_irq_disabled().guest_cid(); - if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid as u32 { - return_errno_with_message!(Errno::EADDRNOTAVAIL, "The cid in address is incorrect"); + let local_cid = vsockspace.guest_cid(); + if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid { + return_errno_with_message!(Errno::EADDRNOTAVAIL, "the cid in address is incorrect"); } let mut new_addr = addr; - new_addr.cid = local_cid as u32; + new_addr.cid = local_cid; // check and assign a port if addr.port == VMADDR_PORT_ANY { @@ -44,62 +44,33 @@ impl Init { } else { return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port"); } - } else if vsockspace - .used_ports - .lock_irq_disabled() - .contains(&new_addr.port) - { + } else if !vsockspace.insert_port(new_addr.port) { return_errno_with_message!(Errno::EADDRNOTAVAIL, "the port in address is occupied"); - } else { - vsockspace - .used_ports - .lock_irq_disabled() - .insert(new_addr.port); } //TODO: The privileged port isn't checked - *self.bind_addr.lock() = Some(new_addr); + *self.bound_addr.lock() = Some(new_addr); Ok(()) } pub fn is_bound(&self) -> bool { - self.bind_addr.lock().is_some() + self.bound_addr.lock().is_some() } pub fn bound_addr(&self) -> Option { - *self.bind_addr.lock() + *self.bound_addr.lock() } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { self.pollee.poll(mask, poller) } - - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } } impl Drop for Init { fn drop(&mut self) { - if let Some(addr) = *self.bind_addr.lock() { + if let Some(addr) = *self.bound_addr.lock() { let vsockspace = VSOCK_GLOBAL.get().unwrap(); - vsockspace.used_ports.lock_irq_disabled().remove(&addr.port); + vsockspace.recycle_port(&addr.port); } } } diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs b/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs index 37d132ecc..2b3a30bae 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/listen.rs @@ -2,7 +2,7 @@ use super::connected::Connected; use crate::{ - events::{IoEvents, Observer}, + events::IoEvents, net::socket::vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, prelude::*, process::signal::{Pollee, Poller}, @@ -30,7 +30,7 @@ impl Listen { pub fn push_incoming(&self, connect: Arc) -> Result<()> { let mut incoming_connections = self.incoming_connection.lock_irq_disabled(); if incoming_connections.len() >= self.backlog { - return_errno_with_message!(Errno::ENOMEM, "Queue in listenging socket is full") + return_errno_with_message!(Errno::ENOMEM, "queue in listenging socket is full") } incoming_connections.push_back(connect); Ok(()) @@ -59,34 +59,10 @@ impl Listen { self.pollee.del_events(IoEvents::IN); } } - pub fn register_observer( - &self, - pollee: &Pollee, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - pollee.register_observer(observer, mask); - Ok(()) - } - - pub fn unregister_observer( - &self, - pollee: &Pollee, - observer: &Weak>, - ) -> Result>> { - pollee - .unregister_observer(observer) - .ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer")) - } } impl Drop for Listen { fn drop(&mut self) { - VSOCK_GLOBAL - .get() - .unwrap() - .used_ports - .lock_irq_disabled() - .remove(&self.addr.port); + VSOCK_GLOBAL.get().unwrap().recycle_port(&self.addr.port); } } diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs b/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs index 52b3e05b4..abbe0da19 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs @@ -13,13 +13,12 @@ use crate::{ SendRecvFlags, SockShutdownCmd, Socket, SocketAddr, }, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::Poller, }; pub struct VsockStreamSocket { status: RwLock, is_nonblocking: AtomicBool, - pollee: Pollee, } pub enum Status { @@ -34,14 +33,12 @@ impl VsockStreamSocket { Self { status: RwLock::new(Status::Init(init)), is_nonblocking: AtomicBool::new(nonblocking), - pollee: Pollee::new(IoEvents::empty()), } } pub(super) fn new_from_connected(connected: Arc) -> Self { Self { status: RwLock::new(Status::Connected(connected)), is_nonblocking: AtomicBool::new(false), - pollee: Pollee::new(IoEvents::empty()), } } fn is_nonblocking(&self) -> bool { @@ -52,11 +49,44 @@ impl VsockStreamSocket { self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } + // TODO: Support timeout + fn wait_events(&self, mask: IoEvents, mut cond: F) -> Result + where + F: FnMut() -> Result, + { + let poller = Poller::new(); + + loop { + match cond() { + Err(err) if err.error() == Errno::EAGAIN => (), + result => { + if let Err(e) = result { + debug!("The result of cond() is Error: {:?}", e); + } + return result; + } + }; + + let events = match &*self.status.read() { + Status::Init(init) => init.poll(mask, Some(&poller)), + Status::Listen(listen) => listen.poll(mask, Some(&poller)), + Status::Connected(connected) => connected.poll(mask, Some(&poller)), + }; + + debug!("events: {:?}", events); + if !events.is_empty() { + continue; + } + + poller.wait()?; + } + } + fn try_accept(&self) -> Result<(Arc, SocketAddr)> { let listen = match &*self.status.read() { Status::Listen(listen) => listen.clone(), Status::Init(_) | Status::Connected(_) => { - return_errno_with_message!(Errno::EINVAL, "The socket is not listening"); + return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); } }; @@ -68,17 +98,12 @@ impl VsockStreamSocket { VSOCK_GLOBAL .get() .unwrap() - .connected_sockets - .lock_irq_disabled() - .insert(connected.id(), connected.clone()); + .insert_connected_socket(connected.id(), connected.clone()); VSOCK_GLOBAL .get() .unwrap() - .driver - .lock_irq_disabled() - .response(&connected.get_info()) - .map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?; + .response(&connected.get_info())?; let socket = Arc::new(VsockStreamSocket::new_from_connected(connected)); Ok((socket, peer_addr.into())) @@ -123,7 +148,11 @@ impl FileLike for VsockStreamSocket { } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { - self.pollee.poll(mask, poller) + match &*self.status.read() { + Status::Init(init) => init.poll(mask, poller), + Status::Listen(listen) => listen.poll(mask, poller), + Status::Connected(connected) => connected.poll(mask, poller), + } } fn status_flags(&self) -> StatusFlags { @@ -142,30 +171,6 @@ impl FileLike for VsockStreamSocket { } Ok(()) } - fn register_observer( - &self, - observer: Weak>, - mask: IoEvents, - ) -> Result<()> { - match &*self.status.read() { - Status::Init(init) => init.register_observer(&self.pollee, observer, mask), - Status::Listen(listen) => listen.register_observer(&self.pollee, observer, mask), - Status::Connected(connected) => { - connected.register_observer(&self.pollee, observer, mask) - } - } - } - - fn unregister_observer( - &self, - observer: &Weak>, - ) -> Result>> { - match &*self.status.read() { - Status::Init(init) => init.unregister_observer(&self.pollee, observer), - Status::Listen(listen) => listen.unregister_observer(&self.pollee, observer), - Status::Connected(connected) => connected.unregister_observer(&self.pollee, observer), - } - } } impl Socket for VsockStreamSocket { @@ -189,15 +194,14 @@ impl Socket for VsockStreamSocket { let init = match &*self.status.read() { Status::Init(init) => init.clone(), Status::Listen(_) => { - return_errno_with_message!(Errno::EINVAL, "The socket is listened"); + return_errno_with_message!(Errno::EINVAL, "the socket is listened"); } Status::Connected(_) => { - return_errno_with_message!(Errno::EINVAL, "The socket is connected"); + return_errno_with_message!(Errno::EINVAL, "the socket is connected"); } }; let remote_addr = VsockSocketAddr::try_from(sockaddr)?; let local_addr = init.bound_addr(); - if let Some(addr) = local_addr { if addr == remote_addr { return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid"); @@ -208,18 +212,10 @@ impl Socket for VsockStreamSocket { let connecting = Arc::new(Connecting::new(remote_addr, init.bound_addr().unwrap())); let vsockspace = VSOCK_GLOBAL.get().unwrap(); - vsockspace - .connecting_sockets - .lock_irq_disabled() - .insert(connecting.local_addr(), connecting.clone()); + vsockspace.insert_connecting_socket(connecting.local_addr(), connecting.clone()); // Send request - vsockspace - .driver - .lock_irq_disabled() - .request(&connecting.info()) - .map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?; - + vsockspace.request(&connecting.info()).unwrap(); // wait for response from driver // TODO: add timeout let poller = Poller::new(); @@ -229,19 +225,14 @@ impl Socket for VsockStreamSocket { { poller.wait()?; } - vsockspace - .connecting_sockets - .lock_irq_disabled() - .remove(&connecting.local_addr()) - .unwrap(); + vsockspace + .remove_connecting_socket(&connecting.local_addr()) + .unwrap(); let connected = Arc::new(Connected::from_connecting(connecting)); *self.status.write() = Status::Connected(connected.clone()); // move connecting socket map to connected sockmap - vsockspace - .connected_sockets - .lock_irq_disabled() - .insert(connected.id(), connected); + vsockspace.insert_connected_socket(connected.id(), connected); Ok(()) } @@ -250,15 +241,15 @@ impl Socket for VsockStreamSocket { let init = match &*self.status.read() { Status::Init(init) => init.clone(), Status::Listen(_) => { - return_errno_with_message!(Errno::EINVAL, "The socket is already listened"); + return_errno_with_message!(Errno::EINVAL, "the socket is already listened"); } Status::Connected(_) => { - return_errno_with_message!(Errno::EISCONN, "The socket is already connected"); + return_errno_with_message!(Errno::EISCONN, "the socket is already connected"); } }; let addr = init.bound_addr().ok_or(Error::with_message( Errno::EINVAL, - "The socket is not bound", + "the socket is not bound", ))?; let listen = Arc::new(Listen::new(addr, backlog)); *self.status.write() = Status::Listen(listen.clone()); @@ -267,9 +258,7 @@ impl Socket for VsockStreamSocket { VSOCK_GLOBAL .get() .unwrap() - .listen_sockets - .lock_irq_disabled() - .insert(listen.addr(), listen); + .insert_listen_socket(listen.addr(), listen); Ok(()) } @@ -278,7 +267,7 @@ impl Socket for VsockStreamSocket { if self.is_nonblocking() { self.try_accept() } else { - wait_events(self, IoEvents::IN, || self.try_accept()) + self.wait_events(IoEvents::IN, || self.try_accept()) } } @@ -286,7 +275,7 @@ impl Socket for VsockStreamSocket { match &*self.status.read() { Status::Connected(connected) => connected.shutdown(cmd), Status::Init(_) | Status::Listen(_) => { - return_errno_with_message!(Errno::EINVAL, "The socket is not connected"); + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); } } } @@ -297,7 +286,7 @@ impl Socket for VsockStreamSocket { if self.is_nonblocking() { self.try_recvfrom(buf, flags) } else { - wait_events(self, IoEvents::IN, || self.try_recvfrom(buf, flags)) + self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags)) } } @@ -315,7 +304,7 @@ impl Socket for VsockStreamSocket { match &*inner { Status::Connected(connected) => connected.send(buf, flags), Status::Init(_) | Status::Listen(_) => { - return_errno_with_message!(Errno::EINVAL, "The socket is not connected"); + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); } } } @@ -343,37 +332,3 @@ impl Socket for VsockStreamSocket { } } } - -// TODO: Support timeout -fn wait_events(socket: &VsockStreamSocket, mask: IoEvents, mut cond: F) -> Result -where - F: FnMut() -> Result, -{ - let poller = Poller::new(); - - loop { - match cond() { - Err(err) if err.error() == Errno::EAGAIN => (), - result => { - if let Err(e) = result { - debug!("The result of cond() is Error: {:?}", e); - } - return result; - } - }; - - let events = match &*socket.status.read() { - Status::Init(init) => init.poll(mask, Some(&poller)), - Status::Listen(listen) => listen.poll(mask, Some(&poller)), - Status::Connected(connected) => connected.poll(mask, Some(&poller)), - }; - - debug!("events: {:?}", events); - if !events.is_empty() { - continue; - } - - poller.wait()?; - debug!("pass the poller wait"); - } -} diff --git a/kernel/aster-nix/src/util/net/addr.rs b/kernel/aster-nix/src/util/net/addr.rs index 135c0ca78..56a4c1101 100644 --- a/kernel/aster-nix/src/util/net/addr.rs +++ b/kernel/aster-nix/src/util/net/addr.rs @@ -6,7 +6,7 @@ use crate::{ net::{ iface::Ipv4Address, - socket::{unix::UnixSocketAddr, SocketAddr}, + socket::{unix::UnixSocketAddr, vsock::VsockSocketAddr, SocketAddr}, }, prelude::*, util::{read_bytes_from_user, read_val_from_user, write_val_to_user}, @@ -58,7 +58,10 @@ pub fn read_socket_addr_from_user(addr: Vaddr, addr_len: usize) -> Result { debug_assert!(addr_len >= core::mem::size_of::()); let sock_addr_vm: CSocketAddrVm = read_val_from_user(addr)?; - SocketAddr::Vsock(sock_addr_vm.svm_cid, sock_addr_vm.svm_port) + SocketAddr::Vsock(VsockSocketAddr::new( + sock_addr_vm.svm_cid, + sock_addr_vm.svm_port, + )) } _ => { return_errno_with_message!(Errno::EAFNOSUPPORT, "cannot support address for the family") @@ -94,8 +97,8 @@ pub fn write_socket_addr_to_user( write_size as i32 } SocketAddr::IPv6 => todo!(), - SocketAddr::Vsock(cid, port) => { - let vm_addr = CSocketAddrVm::new(*cid, *port); + SocketAddr::Vsock(addr) => { + let vm_addr = CSocketAddrVm::new(addr.cid, addr.port); let write_size = core::mem::size_of::(); write_val_to_user(dest, &vm_addr)?; write_size as i32 diff --git a/kernel/comps/network/src/lib.rs b/kernel/comps/network/src/lib.rs index 1e9bb6d8b..d6721ce43 100644 --- a/kernel/comps/network/src/lib.rs +++ b/kernel/comps/network/src/lib.rs @@ -7,7 +7,7 @@ #![feature(linked_list_cursors)] mod buffer; -mod dma_pool; +pub mod dma_pool; mod driver; extern crate alloc; diff --git a/kernel/comps/virtio/src/device/socket/buffer.rs b/kernel/comps/virtio/src/device/socket/buffer.rs index 00e2e6f5f..7ae03f8ce 100644 --- a/kernel/comps/virtio/src/device/socket/buffer.rs +++ b/kernel/comps/virtio/src/device/socket/buffer.rs @@ -1,22 +1,89 @@ // SPDX-License-Identifier: MPL-2.0 -use align_ext::AlignExt; -use bytes::BytesMut; +use alloc::{collections::LinkedList, sync::Arc}; + +use align_ext::AlignExt; +use aster_frame::{ + sync::SpinLock, + vm::{Daddr, DmaDirection, DmaStream, HasDaddr, VmAllocOptions, VmReader, VmWriter, PAGE_SIZE}, +}; +use aster_network::dma_pool::{DmaPool, DmaSegment}; +use pod::Pod; +use spin::Once; + +pub struct TxBuffer { + dma_stream: DmaStream, + nbytes: usize, +} + +impl TxBuffer { + pub fn new(header: &H, packet: &[u8]) -> Self { + let header = header.as_bytes(); + let nbytes = header.len() + packet.len(); + + let dma_stream = if let Some(stream) = get_tx_stream_from_pool(nbytes) { + stream + } else { + let segment = { + let nframes = (nbytes.align_up(PAGE_SIZE)) / PAGE_SIZE; + VmAllocOptions::new(nframes).alloc_contiguous().unwrap() + }; + DmaStream::map(segment, DmaDirection::ToDevice, false).unwrap() + }; + + let mut writer = dma_stream.writer().unwrap(); + writer.write(&mut VmReader::from(header)); + writer.write(&mut VmReader::from(packet)); + + let tx_buffer = Self { dma_stream, nbytes }; + tx_buffer.sync(); + tx_buffer + } + + pub fn writer(&self) -> VmWriter<'_> { + self.dma_stream.writer().unwrap().limit(self.nbytes) + } + + fn sync(&self) { + self.dma_stream.sync(0..self.nbytes).unwrap(); + } + + pub fn nbytes(&self) -> usize { + self.nbytes + } +} + +impl HasDaddr for TxBuffer { + fn daddr(&self) -> Daddr { + self.dma_stream.daddr() + } +} + +impl Drop for TxBuffer { + fn drop(&mut self) { + TX_BUFFER_POOL + .get() + .unwrap() + .lock_irq_disabled() + .push_back(self.dma_stream.clone()); + } +} -/// Buffer for receive packet -#[derive(Debug)] pub struct RxBuffer { - /// Packet Buffer, length align 8. - buf: BytesMut, - /// Packet len + segment: DmaSegment, + header_len: usize, packet_len: usize, } impl RxBuffer { - pub fn new(len: usize) -> Self { - let len = len.align_up(8); - let buf = BytesMut::zeroed(len); - Self { buf, packet_len: 0 } + pub fn new(header_len: usize) -> Self { + assert!(header_len <= RX_BUFFER_LEN); + let segment = RX_BUFFER_POOL.get().unwrap().alloc_segment().unwrap(); + Self { + segment, + header_len, + packet_len: 0, + } } pub const fn packet_len(&self) -> usize { @@ -24,56 +91,70 @@ impl RxBuffer { } pub fn set_packet_len(&mut self, packet_len: usize) { + assert!(self.header_len + packet_len <= RX_BUFFER_LEN); self.packet_len = packet_len; } - pub fn buf(&self) -> &[u8] { - &self.buf + pub fn packet(&self) -> VmReader<'_> { + self.segment + .sync(self.header_len..self.header_len + self.packet_len) + .unwrap(); + self.segment + .reader() + .unwrap() + .skip(self.header_len) + .limit(self.packet_len) } - pub fn buf_mut(&mut self) -> &mut [u8] { - &mut self.buf + pub fn buf(&self) -> VmReader<'_> { + self.segment + .sync(0..self.header_len + self.packet_len) + .unwrap(); + self.segment + .reader() + .unwrap() + .limit(self.header_len + self.packet_len) + } + + pub const fn buf_len(&self) -> usize { + self.segment.size() } } -/// Buffer for transmit packet -#[derive(Debug)] -pub struct TxBuffer { - buf: BytesMut, +impl HasDaddr for RxBuffer { + fn daddr(&self) -> Daddr { + self.segment.daddr() + } } -impl TxBuffer { - pub fn with_len(buf_len: usize) -> Self { - Self { - buf: BytesMut::zeroed(buf_len), +pub const RX_BUFFER_LEN: usize = 4096; +static RX_BUFFER_POOL: Once> = Once::new(); +static TX_BUFFER_POOL: Once>> = Once::new(); + +fn get_tx_stream_from_pool(nbytes: usize) -> Option { + let mut pool = TX_BUFFER_POOL.get().unwrap().lock_irq_disabled(); + let mut cursor = pool.cursor_front_mut(); + while let Some(current) = cursor.current() { + if current.nbytes() >= nbytes { + return cursor.remove_current(); } + cursor.move_next(); } - pub fn new(buf: &[u8]) -> Self { - Self { - buf: BytesMut::from(buf), - } - } - - pub fn buf(&self) -> &[u8] { - &self.buf - } - - pub fn buf_mut(&mut self) -> &mut [u8] { - &mut self.buf - } + None } -/// Buffer for event buffer -#[derive(Debug)] -pub struct EventBuffer { - id: u32, -} - -#[repr(u32)] -#[derive(Debug, Clone, Copy, Default)] -#[allow(non_camel_case_types)] -pub enum EventIDType { - #[default] - VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0, +pub fn init() { + const POOL_INIT_SIZE: usize = 32; + const POOL_HIGH_WATERMARK: usize = 64; + RX_BUFFER_POOL.call_once(|| { + DmaPool::new( + RX_BUFFER_LEN, + POOL_INIT_SIZE, + POOL_HIGH_WATERMARK, + DmaDirection::FromDevice, + false, + ) + }); + TX_BUFFER_POOL.call_once(|| SpinLock::new(LinkedList::new())); } diff --git a/kernel/comps/virtio/src/device/socket/connect.rs b/kernel/comps/virtio/src/device/socket/connect.rs index de27345a8..6fa581bd1 100644 --- a/kernel/comps/virtio/src/device/socket/connect.rs +++ b/kernel/comps/virtio/src/device/socket/connect.rs @@ -28,7 +28,7 @@ use super::{ error::SocketError, - header::{VirtioVsockHdr, VirtioVsockOp, VsockAddr}, + header::{VirtioVsockHdr, VirtioVsockOp, VsockDeviceAddr}, }; #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -74,9 +74,9 @@ pub enum VsockEventType { #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct VsockEvent { /// The source of the event, i.e. the peer who sent it. - pub source: VsockAddr, + pub source: VsockDeviceAddr, /// The destination of the event, i.e. the CID and port on our side. - pub destination: VsockAddr, + pub destination: VsockDeviceAddr, /// The peer's buffer status for the connection. pub buffer_status: VsockBufferStatus, /// The type of event. @@ -143,7 +143,7 @@ impl VsockEvent { #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct ConnectionInfo { - pub dst: VsockAddr, + pub dst: VsockDeviceAddr, pub src_port: u32, /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in /// bytes it has allocated for packet bodies. @@ -166,7 +166,7 @@ pub struct ConnectionInfo { } impl ConnectionInfo { - pub fn new(destination: VsockAddr, src_port: u32) -> Self { + pub fn new(destination: VsockDeviceAddr, src_port: u32) -> Self { Self { dst: destination, src_port, diff --git a/kernel/comps/virtio/src/device/socket/device.rs b/kernel/comps/virtio/src/device/socket/device.rs index 517e0af70..36b8f64e9 100644 --- a/kernel/comps/virtio/src/device/socket/device.rs +++ b/kernel/comps/virtio/src/device/socket/device.rs @@ -1,15 +1,15 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec}; -use core::{fmt::Debug, hint::spin_loop}; +use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec}; +use core::{fmt::Debug, hint::spin_loop, mem::size_of}; -use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame}; +use aster_frame::{offset_of, sync::SpinLock, trap::TrapFrame, vm::VmWriter}; use aster_util::{field_ptr, slot_vec::SlotVec}; use log::debug; use pod::Pod; use super::{ - buffer::RxBuffer, + buffer::{RxBuffer, RX_BUFFER_LEN}, config::{VirtioVsockConfig, VsockFeatures}, connect::{ConnectionInfo, VsockEvent}, error::SocketError, @@ -18,7 +18,7 @@ use super::{ }; use crate::{ device::{ - socket::{handle_recv_irq, register_device}, + socket::{buffer::TxBuffer, handle_recv_irq, register_device}, VirtioDeviceError, }, queue::{QueueError, VirtQueue}, @@ -30,9 +30,6 @@ const QUEUE_RECV: u16 = 0; const QUEUE_SEND: u16 = 1; const QUEUE_EVENT: u16 = 2; -/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than `size_of::()`. -const RX_BUFFER_SIZE: usize = 512; - /// Vsock device driver pub struct SocketDevice { config: VirtioVsockConfig, @@ -71,8 +68,8 @@ impl SocketDevice { // Allocate and add buffers for the RX queue. let mut rx_buffers = SlotVec::new(); for i in 0..QUEUE_SIZE { - let mut rx_buffer = RxBuffer::new(RX_BUFFER_SIZE); - let token = recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?; + let rx_buffer = RxBuffer::new(size_of::()); + let token = recv_queue.add_dma_buf(&[], &[&rx_buffer])?; assert_eq!(i, token); assert_eq!(rx_buffers.put(rx_buffer) as u16, i); } @@ -187,7 +184,10 @@ impl SocketDevice { header: &VirtioVsockHdr, buffer: &[u8], ) -> Result<(), SocketError> { - let _token = self.send_queue.add_buf(&[header.as_bytes(), buffer], &[])?; + debug!("buffer in send_packet_to_tx_queue: {:?}", buffer); + let tx_buffer = TxBuffer::new(header, buffer); + + let token = self.send_queue.add_dma_buf(&[&tx_buffer], &[])?; if self.send_queue.should_notify() { self.send_queue.notify(); @@ -198,9 +198,13 @@ impl SocketDevice { spin_loop(); } - self.send_queue.pop_used()?; - - debug!("buffer in send_packet_to_tx_queue: {:?}", buffer); + // Pop out the buffer, so we can reuse the send queue further + let (pop_token, _) = self.send_queue.pop_used()?; + debug_assert!(pop_token == token); + if pop_token != token { + return Err(SocketError::QueueError(QueueError::WrongToken)); + } + debug!("send packet succeeds"); Ok(()) } @@ -223,6 +227,7 @@ impl SocketDevice { if !connection_info.has_pending_credit_request { self.credit_request(connection_info)?; connection_info.has_pending_credit_request = true; + //TODO check if the update needed } Err(SocketError::InsufficientBufferSpaceInPeer) } @@ -261,9 +266,13 @@ impl SocketDevice { .rx_buffers .remove(token as usize) .ok_or(QueueError::WrongToken)?; - rx_buffer.set_packet_len(RX_BUFFER_SIZE); + rx_buffer.set_packet_len(len as usize); - let (header, payload) = read_header_and_body(rx_buffer.buf())?; + let mut buf_reader = rx_buffer.buf(); + let mut temp_buffer = vec![0u8; buf_reader.remain()]; + buf_reader.read(&mut VmWriter::from(&mut temp_buffer as &mut [u8])); + + let (header, payload) = read_header_and_body(&temp_buffer)?; // The length written should be equal to len(header)+len(packet) assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32); debug!("Received packet {:?}. Op {:?}", header, header.op()); @@ -285,15 +294,15 @@ impl SocketDevice { if !self.recv_queue.can_pop() { return Ok(None); } - let mut body = RxBuffer::new(RX_BUFFER_SIZE); - let header = self.receive(body.buf_mut())?; + let mut body = vec![0u8; RX_BUFFER_LEN]; + let header = self.receive(&mut body)?; - VsockEvent::from_header(&header).and_then(|event| handler(event, body.buf())) + VsockEvent::from_header(&header).and_then(|event| handler(event, &body)) } /// Add a used rx buffer to recv queue,@index is only to check the correctness - fn add_rx_buffer(&mut self, mut rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> { - let token = self.recv_queue.add_buf(&[], &[rx_buffer.buf_mut()])?; + fn add_rx_buffer(&mut self, rx_buffer: RxBuffer, index: u16) -> Result<(), SocketError> { + let token = self.recv_queue.add_dma_buf(&[], &[&rx_buffer])?; assert_eq!(index, token); assert!(self.rx_buffers.put_at(token as usize, rx_buffer).is_none()); if self.recv_queue.should_notify() { diff --git a/kernel/comps/virtio/src/device/socket/header.rs b/kernel/comps/virtio/src/device/socket/header.rs index 96d7ae7c0..d66f37a6d 100644 --- a/kernel/comps/virtio/src/device/socket/header.rs +++ b/kernel/comps/virtio/src/device/socket/header.rs @@ -35,7 +35,7 @@ pub const VIRTIO_VSOCK_HDR_LEN: usize = core::mem::size_of::(); /// Socket address. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] -pub struct VsockAddr { +pub struct VsockDeviceAddr { /// Context Identifier. pub cid: u64, /// Port number. @@ -93,15 +93,15 @@ impl VirtioVsockHdr { VirtioVsockOp::try_from(self.op).map_err(|err| err.into()) } - pub fn source(&self) -> VsockAddr { - VsockAddr { + pub fn source(&self) -> VsockDeviceAddr { + VsockDeviceAddr { cid: self.src_cid, port: self.src_port, } } - pub fn destination(&self) -> VsockAddr { - VsockAddr { + pub fn destination(&self) -> VsockDeviceAddr { + VsockDeviceAddr { cid: self.dst_cid, port: self.dst_port, } diff --git a/kernel/comps/virtio/src/device/socket/mod.rs b/kernel/comps/virtio/src/device/socket/mod.rs index 953c5a47a..0350195ca 100644 --- a/kernel/comps/virtio/src/device/socket/mod.rs +++ b/kernel/comps/virtio/src/device/socket/mod.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -//! This mod is modified from virtio-drivers project. +// ! #![feature(linked_list_cursors)] use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; use aster_frame::sync::SpinLock; @@ -61,6 +61,7 @@ pub fn handle_recv_irq(name: &str) { pub fn init() { VSOCK_DEVICE_TABLE.call_once(|| SpinLock::new(BTreeMap::new())); + buffer::init(); } type VsockDeviceIrqHandlerListRef = Arc>>>; diff --git a/kernel/comps/virtio/src/dma_buf.rs b/kernel/comps/virtio/src/dma_buf.rs index c24c90b78..08850d2ad 100644 --- a/kernel/comps/virtio/src/dma_buf.rs +++ b/kernel/comps/virtio/src/dma_buf.rs @@ -3,6 +3,8 @@ use aster_frame::mm::{DmaCoherent, DmaStream, DmaStreamSlice, HasDaddr}; use aster_network::{DmaSegment, RxBuffer, TxBuffer}; +use crate::device; + /// A DMA-capable buffer. /// /// Any type implements this trait should also implements `HasDaddr` trait, @@ -48,3 +50,15 @@ impl DmaBuf for RxBuffer { self.buf_len() } } + +impl DmaBuf for device::socket::buffer::TxBuffer { + fn len(&self) -> usize { + self.nbytes() + } +} + +impl DmaBuf for device::socket::buffer::RxBuffer { + fn len(&self) -> usize { + self.buf_len() + } +} diff --git a/kernel/comps/virtio/src/lib.rs b/kernel/comps/virtio/src/lib.rs index f83393092..f59c5e267 100644 --- a/kernel/comps/virtio/src/lib.rs +++ b/kernel/comps/virtio/src/lib.rs @@ -6,6 +6,7 @@ #![allow(dead_code)] #![feature(trait_alias)] #![feature(fn_traits)] +#![feature(linked_list_cursors)] extern crate alloc; diff --git a/regression/apps/scripts/run_vsock_test.sh b/regression/apps/scripts/run_vsock_test.sh index a07f3dc41..87cb0321a 100644 --- a/regression/apps/scripts/run_vsock_test.sh +++ b/regression/apps/scripts/run_vsock_test.sh @@ -9,5 +9,5 @@ cd ${VSOCK_DIR} echo "Start vsock test......" ./vsock_client -./vsock_server +# ./vsock_server echo "Vsock test passed." diff --git a/regression/apps/vsock/vsock_client.c b/regression/apps/vsock/vsock_client.c index f6cbddd29..0bcbaf02e 100644 --- a/regression/apps/vsock/vsock_client.c +++ b/regression/apps/vsock/vsock_client.c @@ -12,7 +12,7 @@ int main() { int sock; - char *hello = "Hello from asterinas"; + char *hello = "echo 'Hello from host'\n"; char buffer[1024] = { 0 }; struct sockaddr_vm serv_addr; diff --git a/tools/qemu_args.sh b/tools/qemu_args.sh index 89b5dbb41..e940b0f41 100755 --- a/tools/qemu_args.sh +++ b/tools/qemu_args.sh @@ -44,7 +44,6 @@ QEMU_ARGS="\ -device virtio-keyboard-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-net-pci,netdev=net01,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ -device virtio-serial-pci,disable-legacy=on,disable-modern=off$IOMMU_DEV_EXTRA \ - -device vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3$IOMMU_DEV_EXTRA \ -device virtconsole,chardev=mux \ $IOMMU_EXTRA_ARGS \ " @@ -60,7 +59,6 @@ MICROVM_QEMU_ARGS="\ -device virtio-net-device,netdev=net01 \ -device virtio-serial-device \ -device virtconsole,chardev=mux \ - -device vhost-vsock-device,guest-cid=3 \ " if [ "$1" = "microvm" ]; then