mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-22 08:53:29 +00:00
Implement vsock socket layer
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
83a7937334
commit
ad140cec3c
@ -6,8 +6,6 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
# FIXME: used for test in driver mod
|
||||
component = {path="../libs/comp-sys/component"}
|
||||
aster-frame = { path = "../../framework/aster-frame" }
|
||||
align_ext = { path = "../../framework/libs/align_ext" }
|
||||
pod = { git = "https://github.com/asterinas/pod", rev = "d7dba56" }
|
||||
|
@ -1,98 +1,10 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use aster_virtio::{
|
||||
self,
|
||||
device::socket::{header::VsockAddr, manager::VsockConnectionManager, DEVICE_NAME},
|
||||
};
|
||||
use component::ComponentInitError;
|
||||
use log::{debug, info};
|
||||
use log::info;
|
||||
|
||||
pub fn init() {
|
||||
// print all the input device to make sure input crate will compile
|
||||
for (name, _) in aster_input::all_devices() {
|
||||
info!("Found Input device, name:{}", name);
|
||||
}
|
||||
// let _ = socket_device_client_test();
|
||||
// let _ = socket_device_server_test();
|
||||
}
|
||||
|
||||
fn socket_device_client_test() -> Result<(), ComponentInitError> {
|
||||
let host_cid = 2;
|
||||
let guest_cid = 3;
|
||||
let host_port = 1234;
|
||||
let guest_port = 4321;
|
||||
let host_address = VsockAddr {
|
||||
cid: host_cid,
|
||||
port: host_port,
|
||||
};
|
||||
let hello_from_guest = "Hello from guest";
|
||||
let hello_from_host = "Hello from host";
|
||||
|
||||
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
|
||||
assert_eq!(device.lock().guest_cid(), guest_cid);
|
||||
let mut socket = VsockConnectionManager::new(device);
|
||||
|
||||
socket.connect(host_address, guest_port).unwrap();
|
||||
socket.wait_for_event().unwrap(); // wait for connect response
|
||||
socket
|
||||
.send(host_address, guest_port, hello_from_guest.as_bytes())
|
||||
.unwrap();
|
||||
debug!(
|
||||
"The buffer {:?} is sent, start receiving",
|
||||
hello_from_guest.as_bytes()
|
||||
);
|
||||
socket.wait_for_event().unwrap(); // wait for recv
|
||||
let mut buffer = [0u8; 64];
|
||||
let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
|
||||
assert_eq!(
|
||||
&buffer[0..hello_from_host.len()],
|
||||
hello_from_host.as_bytes()
|
||||
);
|
||||
|
||||
socket.force_close(host_address, guest_port).unwrap();
|
||||
|
||||
debug!("The final event: {:?}", event);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn socket_device_server_test() -> Result<(), ComponentInitError> {
|
||||
let host_cid = 2;
|
||||
let guest_cid = 3;
|
||||
let host_port = 1234;
|
||||
let guest_port = 4321;
|
||||
let host_address = VsockAddr {
|
||||
cid: host_cid,
|
||||
port: host_port,
|
||||
};
|
||||
let hello_from_guest = "Hello from guest";
|
||||
let hello_from_host = "Hello from host";
|
||||
|
||||
let device = aster_virtio::device::socket::get_device(DEVICE_NAME).unwrap();
|
||||
assert_eq!(device.lock().guest_cid(), guest_cid);
|
||||
let mut socket = VsockConnectionManager::new(device);
|
||||
|
||||
socket.listen(4321);
|
||||
socket.wait_for_event().unwrap(); // wait for connect request
|
||||
socket.wait_for_event().unwrap(); // wait for recv
|
||||
let mut buffer = [0u8; 64];
|
||||
let event = socket.recv(host_address, guest_port, &mut buffer).unwrap();
|
||||
assert_eq!(
|
||||
&buffer[0..hello_from_host.len()],
|
||||
hello_from_host.as_bytes()
|
||||
);
|
||||
|
||||
debug!(
|
||||
"The buffer {:?} is received, start sending {:?}",
|
||||
&buffer[0..hello_from_host.len()],
|
||||
hello_from_guest.as_bytes()
|
||||
);
|
||||
socket
|
||||
.send(host_address, guest_port, hello_from_guest.as_bytes())
|
||||
.unwrap();
|
||||
|
||||
socket.shutdown(host_address, guest_port).unwrap();
|
||||
let event = socket.wait_for_event().unwrap(); // wait for rst/shutdown
|
||||
|
||||
debug!("The final event: {:?}", event);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
use spin::Once;
|
||||
|
||||
use self::iface::spawn_background_poll_thread;
|
||||
use self::{iface::spawn_background_poll_thread, socket::vsock};
|
||||
use crate::{
|
||||
net::iface::{Iface, IfaceLoopback, IfaceVirtio},
|
||||
prelude::*,
|
||||
@ -28,6 +28,7 @@ pub fn init() {
|
||||
})
|
||||
}
|
||||
poll_ifaces();
|
||||
vsock::init();
|
||||
}
|
||||
|
||||
/// Lazy init should be called after spawning init thread.
|
||||
|
@ -13,6 +13,7 @@ pub mod ip;
|
||||
pub mod options;
|
||||
pub mod unix;
|
||||
mod util;
|
||||
pub mod vsock;
|
||||
|
||||
/// Operations defined on a socket.
|
||||
pub trait Socket: FileLike + Send + Sync {
|
||||
|
@ -15,6 +15,7 @@ pub enum SocketAddr {
|
||||
Unix(UnixSocketAddr),
|
||||
IPv4(Ipv4Address, PortNum),
|
||||
IPv6,
|
||||
Vsock(u32, u32),
|
||||
}
|
||||
|
||||
impl TryFrom<SocketAddr> for IpEndpoint {
|
||||
|
71
kernel/aster-nix/src/net/socket/vsock/addr.rs
Normal file
71
kernel/aster-nix/src/net/socket/vsock/addr.rs
Normal file
@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use aster_virtio::device::socket::header::VsockAddr;
|
||||
|
||||
use crate::{net::socket::SocketAddr, prelude::*};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct VsockSocketAddr {
|
||||
pub cid: u32,
|
||||
pub port: u32,
|
||||
}
|
||||
|
||||
impl VsockSocketAddr {
|
||||
pub fn new(cid: u32, port: u32) -> Self {
|
||||
Self { cid, port }
|
||||
}
|
||||
|
||||
pub fn any_addr() -> Self {
|
||||
Self {
|
||||
cid: VMADDR_CID_ANY,
|
||||
port: VMADDR_PORT_ANY,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SocketAddr> for VsockSocketAddr {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: SocketAddr) -> Result<Self> {
|
||||
let (cid, port) = if let SocketAddr::Vsock(cid, port) = value {
|
||||
(cid, port)
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "invalid vsock socket addr");
|
||||
};
|
||||
Ok(Self { cid, port })
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VsockSocketAddr> for SocketAddr {
|
||||
fn from(value: VsockSocketAddr) -> Self {
|
||||
SocketAddr::Vsock(value.cid, value.port)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VsockAddr> for VsockSocketAddr {
|
||||
fn from(value: VsockAddr) -> Self {
|
||||
VsockSocketAddr {
|
||||
cid: value.cid as u32,
|
||||
port: value.port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VsockSocketAddr> for VsockAddr {
|
||||
fn from(value: VsockSocketAddr) -> Self {
|
||||
VsockAddr {
|
||||
cid: value.cid as u64,
|
||||
port: value.port,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The vSocket equivalent of INADDR_ANY.
|
||||
pub const VMADDR_CID_ANY: u32 = u32::MAX;
|
||||
/// Use this as the destination CID in an address when referring to the local communication (loopback).
|
||||
/// This was VMADDR_CID_RESERVED
|
||||
pub const VMADDR_CID_LOCAL: u32 = 1;
|
||||
/// Use this as the destination CID in an address when referring to the host (any process other than the hypervisor).
|
||||
pub const VMADDR_CID_HOST: u32 = 2;
|
||||
/// Bind to any available port.
|
||||
pub const VMADDR_PORT_ANY: u32 = u32::MAX;
|
206
kernel/aster-nix/src/net/socket/vsock/common.rs
Normal file
206
kernel/aster-nix/src/net/socket/vsock/common.rs
Normal file
@ -0,0 +1,206 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::collections::BTreeSet;
|
||||
|
||||
use aster_virtio::device::socket::{
|
||||
connect::{VsockEvent, VsockEventType},
|
||||
device::SocketDevice,
|
||||
error::SocketError,
|
||||
get_device, DEVICE_NAME,
|
||||
};
|
||||
|
||||
use super::{
|
||||
addr::VsockSocketAddr,
|
||||
stream::{
|
||||
connected::{Connected, ConnectionID},
|
||||
listen::Listen,
|
||||
},
|
||||
};
|
||||
use crate::{events::IoEvents, prelude::*, return_errno_with_message};
|
||||
|
||||
/// Manage all active sockets
|
||||
pub struct VsockSpace {
|
||||
pub driver: Arc<SpinLock<SocketDevice>>,
|
||||
// (key, value) = (local_addr, connecting)
|
||||
pub connecting_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Connected>>>,
|
||||
// (key, value) = (local_addr, listen)
|
||||
pub listen_sockets: SpinLock<BTreeMap<VsockSocketAddr, Arc<Listen>>>,
|
||||
// (key, value) = (id(local_addr,peer_addr), connected)
|
||||
pub connected_sockets: SpinLock<BTreeMap<ConnectionID, Arc<Connected>>>,
|
||||
// Used ports
|
||||
pub used_ports: SpinLock<BTreeSet<u32>>,
|
||||
}
|
||||
|
||||
impl VsockSpace {
|
||||
/// Create a new global VsockSpace
|
||||
pub fn new() -> Self {
|
||||
let driver = get_device(DEVICE_NAME).unwrap();
|
||||
Self {
|
||||
driver,
|
||||
connecting_sockets: SpinLock::new(BTreeMap::new()),
|
||||
listen_sockets: SpinLock::new(BTreeMap::new()),
|
||||
connected_sockets: SpinLock::new(BTreeMap::new()),
|
||||
used_ports: SpinLock::new(BTreeSet::new()),
|
||||
}
|
||||
}
|
||||
/// Poll for each event from the driver
|
||||
pub fn poll(&self) -> Result<Option<VsockEvent>> {
|
||||
let mut driver = self.driver.lock_irq_disabled();
|
||||
let guest_cid: u32 = driver.guest_cid() as u32;
|
||||
|
||||
// match the socket and store the buffer body (if valid)
|
||||
let result = driver
|
||||
.poll(|event, body| {
|
||||
if !self.is_event_for_socket(&event) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Deal with Received before the buffer are recycled.
|
||||
if let VsockEventType::Received { length } = event.event_type {
|
||||
// Only consider the connected socket and copy body to buffer
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
debug!("Rw matches a connection with id {:?}", connected.id());
|
||||
if !connected.connection_buffer_add(body) {
|
||||
return Err(SocketError::BufferTooShort);
|
||||
}
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
Ok(Some(event))
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::with_message(Errno::EAGAIN, "driver poll failed, please try again")
|
||||
})?;
|
||||
|
||||
let Some(event) = result else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// The socket must be stored in the VsockSpace.
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
connected.update_for_event(&event);
|
||||
}
|
||||
|
||||
// Response to the event
|
||||
match event.event_type {
|
||||
VsockEventType::ConnectionRequest => {
|
||||
// Preparation for listen socket `accept`
|
||||
if let Some(listen) = self
|
||||
.listen_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.destination.into())
|
||||
{
|
||||
let peer = event.source;
|
||||
let connected = Arc::new(Connected::new(peer.into(), listen.addr()));
|
||||
connected.update_for_event(&event);
|
||||
listen.push_incoming(connected).unwrap();
|
||||
} else {
|
||||
return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"Connecion request can only be handled by listening socket"
|
||||
)
|
||||
}
|
||||
}
|
||||
VsockEventType::Connected => {
|
||||
if let Some(connecting) = self
|
||||
.connecting_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.destination.into())
|
||||
{
|
||||
// debug!("match a connecting socket. Peer{:?}; local{:?}",connecting.peer_addr(),connecting.local_addr());
|
||||
connecting.update_for_event(&event);
|
||||
connecting.add_events(IoEvents::IN);
|
||||
}
|
||||
}
|
||||
VsockEventType::Disconnected { reason } => {
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
connected.peer_requested_shutdown();
|
||||
} else {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected");
|
||||
}
|
||||
}
|
||||
VsockEventType::Received { length } => {
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
connected.add_events(IoEvents::IN);
|
||||
} else {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "The socket hasn't connected");
|
||||
}
|
||||
}
|
||||
VsockEventType::CreditRequest => {
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
driver.credit_update(&connected.get_info()).map_err(|_| {
|
||||
Error::with_message(Errno::EINVAL, "can not send credit update")
|
||||
})?;
|
||||
}
|
||||
}
|
||||
VsockEventType::CreditUpdate => {
|
||||
if let Some(connected) = self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.get(&event.into())
|
||||
{
|
||||
connected.update_for_event(&event);
|
||||
} else {
|
||||
return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"CreditUpdate is only valid in connected sockets"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Some(event))
|
||||
}
|
||||
/// Check whether the event is for this socket space
|
||||
fn is_event_for_socket(&self, event: &VsockEvent) -> bool {
|
||||
// debug!("The event is for connection with id {:?}",ConnectionID::from(*event));
|
||||
self.connecting_sockets
|
||||
.lock_irq_disabled()
|
||||
.contains_key(&event.destination.into())
|
||||
|| self
|
||||
.listen_sockets
|
||||
.lock_irq_disabled()
|
||||
.contains_key(&event.destination.into())
|
||||
|| self
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.contains_key(&(*event).into())
|
||||
}
|
||||
/// Alloc an unused port range
|
||||
pub fn alloc_ephemeral_port(&self) -> Result<u32> {
|
||||
let mut used_ports = self.used_ports.lock_irq_disabled();
|
||||
for port in 1024..=u32::MAX {
|
||||
if !used_ports.contains(&port) {
|
||||
used_ports.insert(port);
|
||||
return Ok(port);
|
||||
}
|
||||
}
|
||||
return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port");
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VsockSpace {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
23
kernel/aster-nix/src/net/socket/vsock/mod.rs
Normal file
23
kernel/aster-nix/src/net/socket/vsock/mod.rs
Normal file
@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::Arc;
|
||||
|
||||
use aster_virtio::device::socket::{register_recv_callback, DEVICE_NAME};
|
||||
use common::VsockSpace;
|
||||
use spin::Once;
|
||||
|
||||
pub mod addr;
|
||||
pub mod common;
|
||||
pub mod stream;
|
||||
pub use stream::VsockStreamSocket;
|
||||
|
||||
// init static driver
|
||||
pub static VSOCK_GLOBAL: Once<Arc<VsockSpace>> = Once::new();
|
||||
|
||||
pub fn init() {
|
||||
VSOCK_GLOBAL.call_once(|| Arc::new(VsockSpace::new()));
|
||||
register_recv_callback(DEVICE_NAME, || {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
let _ = vsockspace.poll();
|
||||
})
|
||||
}
|
267
kernel/aster-nix/src/net/socket/vsock/stream/connected.rs
Normal file
267
kernel/aster-nix/src/net/socket/vsock/stream/connected.rs
Normal file
@ -0,0 +1,267 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use core::cmp::min;
|
||||
|
||||
use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent};
|
||||
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::socket::{
|
||||
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||
SendRecvFlags, SockShutdownCmd,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
const PER_CONNECTION_BUFFER_CAPACITY: usize = 4096;
|
||||
|
||||
pub struct Connected {
|
||||
connection: SpinLock<Connection>,
|
||||
id: ConnectionID,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl Connected {
|
||||
pub fn new(peer_addr: VsockSocketAddr, local_addr: VsockSocketAddr) -> Self {
|
||||
Self {
|
||||
connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)),
|
||||
id: ConnectionID::new(local_addr, peer_addr),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
}
|
||||
}
|
||||
pub fn peer_addr(&self) -> VsockSocketAddr {
|
||||
self.id.peer_addr
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> VsockSocketAddr {
|
||||
self.id.local_addr
|
||||
}
|
||||
|
||||
pub fn id(&self) -> ConnectionID {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
let poller = Poller::new();
|
||||
if !self
|
||||
.poll(IoEvents::IN, Some(&poller))
|
||||
.contains(IoEvents::IN)
|
||||
{
|
||||
poller.wait()?;
|
||||
}
|
||||
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
let bytes_read = connection.buffer.drain(buf);
|
||||
|
||||
connection.info.done_forwarding(bytes_read);
|
||||
|
||||
Ok(bytes_read)
|
||||
}
|
||||
|
||||
pub fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
debug_assert!(flags.is_all_supported());
|
||||
let buf_len = buf.len();
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.send(buf, &mut connection.info)
|
||||
.map_err(|e| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
|
||||
Ok(buf_len)
|
||||
}
|
||||
|
||||
pub fn should_close(&self) -> bool {
|
||||
let connection = self.connection.lock_irq_disabled();
|
||||
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
||||
// connection.
|
||||
connection.peer_requested_shutdown && connection.buffer.is_empty()
|
||||
}
|
||||
|
||||
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
let connection = self.connection.lock_irq_disabled();
|
||||
// TODO: deal with cmd
|
||||
if self.should_close() {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.reset(&connection.info)
|
||||
.map_err(|e| Error::with_message(Errno::ENOMEM, "can not send close packet"))?;
|
||||
vsockspace
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.id())
|
||||
.unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub fn update_for_event(&self, event: &VsockEvent) {
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
connection.update_for_event(event)
|
||||
}
|
||||
|
||||
pub fn get_info(&self) -> ConnectionInfo {
|
||||
let connection = self.connection.lock_irq_disabled();
|
||||
connection.info.clone()
|
||||
}
|
||||
|
||||
pub fn connection_buffer_add(&self, bytes: &[u8]) -> bool {
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
self.add_events(IoEvents::IN);
|
||||
connection.add(bytes)
|
||||
}
|
||||
|
||||
pub fn peer_requested_shutdown(&self) {
|
||||
self.connection.lock_irq_disabled().peer_requested_shutdown = true
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
pub fn add_events(&self, events: IoEvents) {
|
||||
self.pollee.add_events(events)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Connected {
|
||||
fn drop(&mut self) {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.local_addr().port);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Connection {
|
||||
info: ConnectionInfo,
|
||||
buffer: RingBuffer,
|
||||
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
|
||||
/// still data in the buffer.
|
||||
pub peer_requested_shutdown: bool,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new(peer: VsockSocketAddr, local_port: u32) -> Self {
|
||||
let mut info = ConnectionInfo::new(peer.into(), local_port);
|
||||
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
|
||||
Self {
|
||||
info,
|
||||
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||
peer_requested_shutdown: false,
|
||||
}
|
||||
}
|
||||
pub fn update_for_event(&mut self, event: &VsockEvent) {
|
||||
self.info.update_for_event(event)
|
||||
}
|
||||
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||
self.buffer.add(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
|
||||
pub struct ConnectionID {
|
||||
pub local_addr: VsockSocketAddr,
|
||||
pub peer_addr: VsockSocketAddr,
|
||||
}
|
||||
impl ConnectionID {
|
||||
pub fn new(local_addr: VsockSocketAddr, peer_addr: VsockSocketAddr) -> Self {
|
||||
Self {
|
||||
local_addr,
|
||||
peer_addr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VsockEvent> for ConnectionID {
|
||||
fn from(event: VsockEvent) -> Self {
|
||||
Self::new(event.destination.into(), event.source.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RingBuffer {
|
||||
buffer: Box<[u8]>,
|
||||
/// The number of bytes currently in the buffer.
|
||||
used: usize,
|
||||
/// The index of the first used byte in the buffer.
|
||||
start: usize,
|
||||
}
|
||||
//TODO: ringbuf
|
||||
impl RingBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
// TODO: can be optimized.
|
||||
let temp = vec![0; capacity];
|
||||
Self {
|
||||
// FIXME: if the capacity is excessive, elements move will be executed.
|
||||
buffer: temp.into_boxed_slice(),
|
||||
used: 0,
|
||||
start: 0,
|
||||
}
|
||||
}
|
||||
/// Returns the number of bytes currently used in the buffer.
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Returns true iff there are currently no bytes in the buffer.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.used == 0
|
||||
}
|
||||
|
||||
/// Returns the number of bytes currently free in the buffer.
|
||||
pub fn available(&self) -> usize {
|
||||
self.buffer.len() - self.used
|
||||
}
|
||||
|
||||
/// Adds the given bytes to the buffer if there is enough capacity for them all.
|
||||
///
|
||||
/// Returns true if they were added, or false if they were not.
|
||||
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||
if bytes.len() > self.available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The index of the first available position in the buffer.
|
||||
let first_available = (self.start + self.used) % self.buffer.len();
|
||||
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
|
||||
// `buffer.len()`.
|
||||
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
|
||||
self.buffer[first_available..first_available + copy_length_before_wraparound]
|
||||
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
|
||||
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
|
||||
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
|
||||
}
|
||||
self.used += bytes.len();
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
|
||||
/// buffer.
|
||||
pub fn drain(&mut self, out: &mut [u8]) -> usize {
|
||||
let bytes_read = min(self.used, out.len());
|
||||
|
||||
// The number of bytes to copy out between `start` and the end of the buffer.
|
||||
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
|
||||
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
|
||||
let read_after_wraparound = bytes_read
|
||||
.checked_sub(read_before_wraparound)
|
||||
.unwrap_or_default();
|
||||
|
||||
out[0..read_before_wraparound]
|
||||
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
|
||||
out[read_before_wraparound..bytes_read]
|
||||
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
|
||||
|
||||
self.used -= bytes_read;
|
||||
self.start = (self.start + bytes_read) % self.buffer.len();
|
||||
|
||||
bytes_read
|
||||
}
|
||||
}
|
95
kernel/aster-nix/src/net/socket/vsock/stream/init.rs
Normal file
95
kernel/aster-nix/src/net/socket/vsock/stream/init.rs
Normal file
@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::socket::vsock::{
|
||||
addr::{VsockSocketAddr, VMADDR_CID_ANY, VMADDR_PORT_ANY},
|
||||
VSOCK_GLOBAL,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
pub struct Init {
|
||||
bind_addr: SpinLock<Option<VsockSocketAddr>>,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl Init {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
bind_addr: SpinLock::new(None),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind(&self, addr: VsockSocketAddr) -> Result<()> {
|
||||
if self.bind_addr.lock().is_some() {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
|
||||
}
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
|
||||
// check correctness of cid
|
||||
let local_cid = vsockspace.driver.lock_irq_disabled().guest_cid();
|
||||
if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid as u32 {
|
||||
return_errno_with_message!(Errno::EADDRNOTAVAIL, "The cid in address is incorrect");
|
||||
}
|
||||
let mut new_addr = addr;
|
||||
new_addr.cid = local_cid as u32;
|
||||
|
||||
// check and assign a port
|
||||
if addr.port == VMADDR_PORT_ANY {
|
||||
if let Ok(port) = vsockspace.alloc_ephemeral_port() {
|
||||
new_addr.port = port;
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port");
|
||||
}
|
||||
} else if vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.contains(&new_addr.port)
|
||||
{
|
||||
return_errno_with_message!(Errno::EADDRNOTAVAIL, "the port in address is occupied");
|
||||
} else {
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.insert(new_addr.port);
|
||||
}
|
||||
|
||||
//TODO: The privileged port isn't checked
|
||||
*self.bind_addr.lock() = Some(new_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_bound(&self) -> bool {
|
||||
self.bind_addr.lock().is_some()
|
||||
}
|
||||
|
||||
pub fn bound_addr(&self) -> Option<VsockSocketAddr> {
|
||||
*self.bind_addr.lock()
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
|
||||
pub fn add_events(&self, events: IoEvents) {
|
||||
self.pollee.add_events(events)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Init {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = *self.bind_addr.lock() {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace.used_ports.lock_irq_disabled().remove(&addr.port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Init {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
75
kernel/aster-nix/src/net/socket/vsock/stream/listen.rs
Normal file
75
kernel/aster-nix/src/net/socket/vsock/stream/listen.rs
Normal file
@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use super::connected::Connected;
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::socket::vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
pub struct Listen {
|
||||
addr: VsockSocketAddr,
|
||||
pollee: Pollee,
|
||||
backlog: usize,
|
||||
incoming_connection: SpinLock<VecDeque<Arc<Connected>>>,
|
||||
}
|
||||
|
||||
impl Listen {
|
||||
pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
backlog,
|
||||
incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn addr(&self) -> VsockSocketAddr {
|
||||
self.addr
|
||||
}
|
||||
pub fn push_incoming(&self, connect: Arc<Connected>) -> Result<()> {
|
||||
let mut incoming_connections = self.incoming_connection.lock_irq_disabled();
|
||||
if incoming_connections.len() >= self.backlog {
|
||||
return_errno_with_message!(Errno::ENOMEM, "Queue in listenging socket is full")
|
||||
}
|
||||
incoming_connections.push_back(connect);
|
||||
self.add_events(IoEvents::IN);
|
||||
Ok(())
|
||||
}
|
||||
pub fn accept(&self) -> Result<Arc<Connected>> {
|
||||
// block waiting connection if no existing connection.
|
||||
let poller = Poller::new();
|
||||
if !self
|
||||
.poll(IoEvents::IN, Some(&poller))
|
||||
.contains(IoEvents::IN)
|
||||
{
|
||||
poller.wait()?;
|
||||
}
|
||||
|
||||
let connection = self
|
||||
.incoming_connection
|
||||
.lock_irq_disabled()
|
||||
.pop_front()
|
||||
.unwrap();
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
pub fn add_events(&self, events: IoEvents) {
|
||||
self.pollee.add_events(events)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Listen {
|
||||
fn drop(&mut self) {
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.addr.port);
|
||||
}
|
||||
}
|
8
kernel/aster-nix/src/net/socket/vsock/stream/mod.rs
Normal file
8
kernel/aster-nix/src/net/socket/vsock/stream/mod.rs
Normal file
@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
pub mod connected;
|
||||
pub mod init;
|
||||
pub mod listen;
|
||||
|
||||
pub mod socket;
|
||||
pub use socket::VsockStreamSocket;
|
291
kernel/aster-nix/src/net/socket/vsock/stream/socket.rs
Normal file
291
kernel/aster-nix/src/net/socket/vsock/stream/socket.rs
Normal file
@ -0,0 +1,291 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use super::{connected::Connected, init::Init, listen::Listen};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
fs::file_handle::FileLike,
|
||||
net::socket::{
|
||||
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||
SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::Poller,
|
||||
};
|
||||
|
||||
pub struct VsockStreamSocket(RwLock<Status>);
|
||||
|
||||
impl VsockStreamSocket {
|
||||
pub(super) fn new_from_init(init: Arc<Init>) -> Self {
|
||||
Self(RwLock::new(Status::Init(init)))
|
||||
}
|
||||
|
||||
pub(super) fn new_from_listen(listen: Arc<Listen>) -> Self {
|
||||
Self(RwLock::new(Status::Listen(listen)))
|
||||
}
|
||||
|
||||
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
|
||||
Self(RwLock::new(Status::Connected(connected)))
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Status {
|
||||
Init(Arc<Init>),
|
||||
Listen(Arc<Listen>),
|
||||
Connected(Arc<Connected>),
|
||||
}
|
||||
|
||||
impl VsockStreamSocket {
|
||||
pub fn new() -> Self {
|
||||
let init = Arc::new(Init::new());
|
||||
Self(RwLock::new(Status::Init(init)))
|
||||
}
|
||||
}
|
||||
|
||||
impl FileLike for VsockStreamSocket {
|
||||
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
self.recvfrom(buf, SendRecvFlags::empty())
|
||||
.map(|(read_size, _)| read_size)
|
||||
}
|
||||
|
||||
fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||
self.sendto(buf, None, SendRecvFlags::empty())
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
Status::Init(init) => init.poll(mask, poller),
|
||||
Status::Listen(listen) => listen.poll(mask, poller),
|
||||
Status::Connected(connect) => connect.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Socket for VsockStreamSocket {
|
||||
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let addr = VsockSocketAddr::try_from(sockaddr)?;
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
Status::Init(init) => init.bind(addr),
|
||||
Status::Listen(_) | Status::Connected(_) => {
|
||||
return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"cannot bind a listening or connected socket"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let init = match &*self.0.read() {
|
||||
Status::Init(init) => init.clone(),
|
||||
Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is listened");
|
||||
}
|
||||
Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is connected");
|
||||
}
|
||||
};
|
||||
let remote_addr = VsockSocketAddr::try_from(sockaddr)?;
|
||||
let local_addr = init.bound_addr();
|
||||
|
||||
if let Some(addr) = local_addr {
|
||||
if addr == remote_addr {
|
||||
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
|
||||
}
|
||||
} else {
|
||||
init.bind(VsockSocketAddr::any_addr())?;
|
||||
}
|
||||
|
||||
let connecting = Arc::new(Connected::new(remote_addr, init.bound_addr().unwrap()));
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.connecting_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(connecting.local_addr(), connecting.clone());
|
||||
|
||||
// Send request
|
||||
vsockspace
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.request(&connecting.get_info())
|
||||
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?;
|
||||
|
||||
// wait for response from driver
|
||||
// TODO: add timeout
|
||||
let poller = Poller::new();
|
||||
if !connecting
|
||||
.poll(IoEvents::IN, Some(&poller))
|
||||
.contains(IoEvents::IN)
|
||||
{
|
||||
poller.wait()?;
|
||||
}
|
||||
|
||||
*self.0.write() = Status::Connected(connecting.clone());
|
||||
// move connecting socket map to connected sockmap
|
||||
vsockspace
|
||||
.connecting_sockets
|
||||
.lock_irq_disabled()
|
||||
.remove(&connecting.local_addr())
|
||||
.unwrap();
|
||||
vsockspace
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(connecting.id(), connecting);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let init = match &*self.0.read() {
|
||||
Status::Init(init) => init.clone(),
|
||||
Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is already listened");
|
||||
}
|
||||
Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EISCONN, "The socket is already connected");
|
||||
}
|
||||
};
|
||||
let addr = init.bound_addr().ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"The socket is not bound",
|
||||
))?;
|
||||
let listen = Arc::new(Listen::new(addr, backlog));
|
||||
*self.0.write() = Status::Listen(listen.clone());
|
||||
|
||||
// push listen socket into vsockspace
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.listen_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(listen.addr(), listen);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let listen = match &*self.0.read() {
|
||||
Status::Listen(listen) => listen.clone(),
|
||||
Status::Init(_) | Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
|
||||
}
|
||||
};
|
||||
let connected = listen.accept()?;
|
||||
let peer_addr = connected.peer_addr();
|
||||
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(connected.id(), connected.clone());
|
||||
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.response(&connected.get_info())
|
||||
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
|
||||
|
||||
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
|
||||
Ok((socket, peer_addr.into()))
|
||||
}
|
||||
|
||||
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
let inner = self.0.read();
|
||||
if let Status::Connected(connected) = &*inner {
|
||||
let result = connected.shutdown(cmd);
|
||||
if result.is_ok() {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&connected.local_addr().port);
|
||||
vsockspace
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.remove(&connected.id());
|
||||
}
|
||||
result
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is not connected.");
|
||||
}
|
||||
}
|
||||
|
||||
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let connected = match &*self.0.read() {
|
||||
Status::Connected(connected) => connected.clone(),
|
||||
Status::Init(_) | Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
}
|
||||
};
|
||||
let read_size = connected.recv(buf)?;
|
||||
let peer_addr = self.peer_addr()?;
|
||||
// If buffer is now empty and the peer requested shutdown, finish shutting down the
|
||||
// connection.
|
||||
if connected.should_close() {
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.reset(&connected.get_info())
|
||||
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send close packet"))?;
|
||||
}
|
||||
Ok((read_size, peer_addr))
|
||||
}
|
||||
|
||||
fn sendto(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
remote: Option<SocketAddr>,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
debug_assert!(remote.is_none());
|
||||
if remote.is_some() {
|
||||
return_errno_with_message!(Errno::EINVAL, "vsock should not provide remote addr");
|
||||
}
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
Status::Connected(connected) => connected.send(buf, flags),
|
||||
Status::Init(_) | Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let inner = self.0.read();
|
||||
let addr = match &*inner {
|
||||
Status::Init(init) => init.bound_addr(),
|
||||
Status::Listen(listen) => Some(listen.addr()),
|
||||
Status::Connected(connected) => Some(connected.local_addr()),
|
||||
};
|
||||
addr.map(Into::<SocketAddr>::into)
|
||||
.ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"The socket does not bind to addr",
|
||||
))
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let inner = self.0.read();
|
||||
if let Status::Connected(connected) = &*inner {
|
||||
Ok(connected.peer_addr().into())
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VsockStreamSocket {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ use crate::{
|
||||
net::socket::{
|
||||
ip::{DatagramSocket, StreamSocket},
|
||||
unix::UnixStreamSocket,
|
||||
vsock::VsockStreamSocket,
|
||||
},
|
||||
prelude::*,
|
||||
util::net::{CSocketAddrFamily, Protocol, SockFlags, SockType, SOCK_TYPE_MASK},
|
||||
@ -35,6 +36,9 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result<SyscallRetur
|
||||
SockType::SOCK_DGRAM,
|
||||
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
|
||||
) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
|
||||
(CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM, _) => {
|
||||
Arc::new(VsockStreamSocket::new()) as Arc<dyn FileLike>
|
||||
}
|
||||
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
|
||||
};
|
||||
let fd = {
|
||||
|
@ -55,6 +55,11 @@ pub fn read_socket_addr_from_user(addr: Vaddr, addr_len: usize) -> Result<Socket
|
||||
let sock_addr_in6: CSocketAddrInet6 = read_val_from_user(addr)?;
|
||||
todo!()
|
||||
}
|
||||
CSocketAddrFamily::AF_VSOCK => {
|
||||
debug_assert!(addr_len >= core::mem::size_of::<CSocketAddrVm>());
|
||||
let sock_addr_vm: CSocketAddrVm = read_val_from_user(addr)?;
|
||||
SocketAddr::Vsock(sock_addr_vm.svm_cid, sock_addr_vm.svm_port)
|
||||
}
|
||||
_ => {
|
||||
return_errno_with_message!(Errno::EAFNOSUPPORT, "cannot support address for the family")
|
||||
}
|
||||
@ -89,6 +94,12 @@ pub fn write_socket_addr_to_user(
|
||||
write_size as i32
|
||||
}
|
||||
SocketAddr::IPv6 => todo!(),
|
||||
SocketAddr::Vsock(cid, port) => {
|
||||
let vm_addr = CSocketAddrVm::new(*cid, *port);
|
||||
let write_size = core::mem::size_of::<CSocketAddrVm>();
|
||||
write_val_to_user(dest, &vm_addr)?;
|
||||
write_size as i32
|
||||
}
|
||||
};
|
||||
if addrlen_ptr != 0 {
|
||||
write_val_to_user(addrlen_ptr, &write_size)?;
|
||||
@ -210,6 +221,34 @@ pub struct CSocketAddrInet6 {
|
||||
sin6_scope_id: u32,
|
||||
}
|
||||
|
||||
/// vm socket address
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, Pod)]
|
||||
pub struct CSocketAddrVm {
|
||||
/// always [SaFamily::AF_VSOCK]
|
||||
svm_family: u16,
|
||||
/// always 0
|
||||
svm_reserved1: u16,
|
||||
/// Port number in host byte order.
|
||||
svm_port: u32,
|
||||
/// Address in host byte order.
|
||||
svm_cid: u32,
|
||||
/// Pad to size of [SockAddr] structure (16 bytes), must be zero-filled
|
||||
svm_zero: [u8; 4],
|
||||
}
|
||||
|
||||
impl CSocketAddrVm {
|
||||
pub fn new(cid: u32, port: u32) -> Self {
|
||||
Self {
|
||||
svm_family: CSocketAddrFamily::AF_VSOCK as _,
|
||||
svm_reserved1: 0,
|
||||
svm_port: port,
|
||||
svm_cid: cid,
|
||||
svm_zero: [0u8; 4],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Address family. The definition is from https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/socket.h.
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, Clone, Copy, TryFromInt, PartialEq, Eq)]
|
||||
|
Reference in New Issue
Block a user