Implement vsock socket layer

This commit is contained in:
Anmin Liu
2024-04-07 15:08:18 +00:00
committed by Tate, Hongliang Tian
parent 83a7937334
commit ad140cec3c
34 changed files with 1421 additions and 688 deletions

View File

@ -0,0 +1,267 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::boxed::Box;
use core::cmp::min;
use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent};
use crate::{
events::IoEvents,
net::socket::{
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
SendRecvFlags, SockShutdownCmd,
},
prelude::*,
process::signal::{Pollee, Poller},
};
const PER_CONNECTION_BUFFER_CAPACITY: usize = 4096;
pub struct Connected {
connection: SpinLock<Connection>,
id: ConnectionID,
pollee: Pollee,
}
impl Connected {
pub fn new(peer_addr: VsockSocketAddr, local_addr: VsockSocketAddr) -> Self {
Self {
connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)),
id: ConnectionID::new(local_addr, peer_addr),
pollee: Pollee::new(IoEvents::empty()),
}
}
pub fn peer_addr(&self) -> VsockSocketAddr {
self.id.peer_addr
}
pub fn local_addr(&self) -> VsockSocketAddr {
self.id.local_addr
}
pub fn id(&self) -> ConnectionID {
self.id
}
pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
let poller = Poller::new();
if !self
.poll(IoEvents::IN, Some(&poller))
.contains(IoEvents::IN)
{
poller.wait()?;
}
let mut connection = self.connection.lock_irq_disabled();
let bytes_read = connection.buffer.drain(buf);
connection.info.done_forwarding(bytes_read);
Ok(bytes_read)
}
pub fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let mut connection = self.connection.lock_irq_disabled();
debug_assert!(flags.is_all_supported());
let buf_len = buf.len();
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.send(buf, &mut connection.info)
.map_err(|e| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
Ok(buf_len)
}
pub fn should_close(&self) -> bool {
let connection = self.connection.lock_irq_disabled();
// If buffer is now empty and the peer requested shutdown, finish shutting down the
// connection.
connection.peer_requested_shutdown && connection.buffer.is_empty()
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let connection = self.connection.lock_irq_disabled();
// TODO: deal with cmd
if self.should_close() {
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace
.driver
.lock_irq_disabled()
.reset(&connection.info)
.map_err(|e| Error::with_message(Errno::ENOMEM, "can not send close packet"))?;
vsockspace
.connected_sockets
.lock_irq_disabled()
.remove(&self.id())
.unwrap();
}
Ok(())
}
pub fn update_for_event(&self, event: &VsockEvent) {
let mut connection = self.connection.lock_irq_disabled();
connection.update_for_event(event)
}
pub fn get_info(&self) -> ConnectionInfo {
let connection = self.connection.lock_irq_disabled();
connection.info.clone()
}
pub fn connection_buffer_add(&self, bytes: &[u8]) -> bool {
let mut connection = self.connection.lock_irq_disabled();
self.add_events(IoEvents::IN);
connection.add(bytes)
}
pub fn peer_requested_shutdown(&self) {
self.connection.lock_irq_disabled().peer_requested_shutdown = true
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn add_events(&self, events: IoEvents) {
self.pollee.add_events(events)
}
}
impl Drop for Connected {
fn drop(&mut self) {
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace
.used_ports
.lock_irq_disabled()
.remove(&self.local_addr().port);
}
}
#[derive(Debug)]
pub struct Connection {
info: ConnectionInfo,
buffer: RingBuffer,
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
/// still data in the buffer.
pub peer_requested_shutdown: bool,
}
impl Connection {
pub fn new(peer: VsockSocketAddr, local_port: u32) -> Self {
let mut info = ConnectionInfo::new(peer.into(), local_port);
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
Self {
info,
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
peer_requested_shutdown: false,
}
}
pub fn update_for_event(&mut self, event: &VsockEvent) {
self.info.update_for_event(event)
}
pub fn add(&mut self, bytes: &[u8]) -> bool {
self.buffer.add(bytes)
}
}
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub struct ConnectionID {
pub local_addr: VsockSocketAddr,
pub peer_addr: VsockSocketAddr,
}
impl ConnectionID {
pub fn new(local_addr: VsockSocketAddr, peer_addr: VsockSocketAddr) -> Self {
Self {
local_addr,
peer_addr,
}
}
}
impl From<VsockEvent> for ConnectionID {
fn from(event: VsockEvent) -> Self {
Self::new(event.destination.into(), event.source.into())
}
}
#[derive(Debug)]
struct RingBuffer {
buffer: Box<[u8]>,
/// The number of bytes currently in the buffer.
used: usize,
/// The index of the first used byte in the buffer.
start: usize,
}
//TODO: ringbuf
impl RingBuffer {
pub fn new(capacity: usize) -> Self {
// TODO: can be optimized.
let temp = vec![0; capacity];
Self {
// FIXME: if the capacity is excessive, elements move will be executed.
buffer: temp.into_boxed_slice(),
used: 0,
start: 0,
}
}
/// Returns the number of bytes currently used in the buffer.
pub fn used(&self) -> usize {
self.used
}
/// Returns true iff there are currently no bytes in the buffer.
pub fn is_empty(&self) -> bool {
self.used == 0
}
/// Returns the number of bytes currently free in the buffer.
pub fn available(&self) -> usize {
self.buffer.len() - self.used
}
/// Adds the given bytes to the buffer if there is enough capacity for them all.
///
/// Returns true if they were added, or false if they were not.
pub fn add(&mut self, bytes: &[u8]) -> bool {
if bytes.len() > self.available() {
return false;
}
// The index of the first available position in the buffer.
let first_available = (self.start + self.used) % self.buffer.len();
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
// `buffer.len()`.
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
self.buffer[first_available..first_available + copy_length_before_wraparound]
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
}
self.used += bytes.len();
true
}
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
/// buffer.
pub fn drain(&mut self, out: &mut [u8]) -> usize {
let bytes_read = min(self.used, out.len());
// The number of bytes to copy out between `start` and the end of the buffer.
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
let read_after_wraparound = bytes_read
.checked_sub(read_before_wraparound)
.unwrap_or_default();
out[0..read_before_wraparound]
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
out[read_before_wraparound..bytes_read]
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
self.used -= bytes_read;
self.start = (self.start + bytes_read) % self.buffer.len();
bytes_read
}
}

View File

@ -0,0 +1,95 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
events::IoEvents,
net::socket::vsock::{
addr::{VsockSocketAddr, VMADDR_CID_ANY, VMADDR_PORT_ANY},
VSOCK_GLOBAL,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct Init {
bind_addr: SpinLock<Option<VsockSocketAddr>>,
pollee: Pollee,
}
impl Init {
pub fn new() -> Self {
Self {
bind_addr: SpinLock::new(None),
pollee: Pollee::new(IoEvents::empty()),
}
}
pub fn bind(&self, addr: VsockSocketAddr) -> Result<()> {
if self.bind_addr.lock().is_some() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
}
let vsockspace = VSOCK_GLOBAL.get().unwrap();
// check correctness of cid
let local_cid = vsockspace.driver.lock_irq_disabled().guest_cid();
if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid as u32 {
return_errno_with_message!(Errno::EADDRNOTAVAIL, "The cid in address is incorrect");
}
let mut new_addr = addr;
new_addr.cid = local_cid as u32;
// check and assign a port
if addr.port == VMADDR_PORT_ANY {
if let Ok(port) = vsockspace.alloc_ephemeral_port() {
new_addr.port = port;
} else {
return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port");
}
} else if vsockspace
.used_ports
.lock_irq_disabled()
.contains(&new_addr.port)
{
return_errno_with_message!(Errno::EADDRNOTAVAIL, "the port in address is occupied");
} else {
vsockspace
.used_ports
.lock_irq_disabled()
.insert(new_addr.port);
}
//TODO: The privileged port isn't checked
*self.bind_addr.lock() = Some(new_addr);
Ok(())
}
pub fn is_bound(&self) -> bool {
self.bind_addr.lock().is_some()
}
pub fn bound_addr(&self) -> Option<VsockSocketAddr> {
*self.bind_addr.lock()
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn add_events(&self, events: IoEvents) {
self.pollee.add_events(events)
}
}
impl Drop for Init {
fn drop(&mut self) {
if let Some(addr) = *self.bind_addr.lock() {
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace.used_ports.lock_irq_disabled().remove(&addr.port);
}
}
}
impl Default for Init {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,75 @@
// SPDX-License-Identifier: MPL-2.0
use super::connected::Connected;
use crate::{
events::IoEvents,
net::socket::vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct Listen {
addr: VsockSocketAddr,
pollee: Pollee,
backlog: usize,
incoming_connection: SpinLock<VecDeque<Arc<Connected>>>,
}
impl Listen {
pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self {
Self {
addr,
pollee: Pollee::new(IoEvents::empty()),
backlog,
incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)),
}
}
pub fn addr(&self) -> VsockSocketAddr {
self.addr
}
pub fn push_incoming(&self, connect: Arc<Connected>) -> Result<()> {
let mut incoming_connections = self.incoming_connection.lock_irq_disabled();
if incoming_connections.len() >= self.backlog {
return_errno_with_message!(Errno::ENOMEM, "Queue in listenging socket is full")
}
incoming_connections.push_back(connect);
self.add_events(IoEvents::IN);
Ok(())
}
pub fn accept(&self) -> Result<Arc<Connected>> {
// block waiting connection if no existing connection.
let poller = Poller::new();
if !self
.poll(IoEvents::IN, Some(&poller))
.contains(IoEvents::IN)
{
poller.wait()?;
}
let connection = self
.incoming_connection
.lock_irq_disabled()
.pop_front()
.unwrap();
Ok(connection)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn add_events(&self, events: IoEvents) {
self.pollee.add_events(events)
}
}
impl Drop for Listen {
fn drop(&mut self) {
VSOCK_GLOBAL
.get()
.unwrap()
.used_ports
.lock_irq_disabled()
.remove(&self.addr.port);
}
}

View File

@ -0,0 +1,8 @@
// SPDX-License-Identifier: MPL-2.0
pub mod connected;
pub mod init;
pub mod listen;
pub mod socket;
pub use socket::VsockStreamSocket;

View File

@ -0,0 +1,291 @@
// SPDX-License-Identifier: MPL-2.0
use super::{connected::Connected, init::Init, listen::Listen};
use crate::{
events::IoEvents,
fs::file_handle::FileLike,
net::socket::{
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
},
prelude::*,
process::signal::Poller,
};
pub struct VsockStreamSocket(RwLock<Status>);
impl VsockStreamSocket {
pub(super) fn new_from_init(init: Arc<Init>) -> Self {
Self(RwLock::new(Status::Init(init)))
}
pub(super) fn new_from_listen(listen: Arc<Listen>) -> Self {
Self(RwLock::new(Status::Listen(listen)))
}
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
Self(RwLock::new(Status::Connected(connected)))
}
}
pub enum Status {
Init(Arc<Init>),
Listen(Arc<Listen>),
Connected(Arc<Connected>),
}
impl VsockStreamSocket {
pub fn new() -> Self {
let init = Arc::new(Init::new());
Self(RwLock::new(Status::Init(init)))
}
}
impl FileLike for VsockStreamSocket {
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
Some(self)
}
fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.recvfrom(buf, SendRecvFlags::empty())
.map(|(read_size, _)| read_size)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
self.sendto(buf, None, SendRecvFlags::empty())
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let inner = self.0.read();
match &*inner {
Status::Init(init) => init.poll(mask, poller),
Status::Listen(listen) => listen.poll(mask, poller),
Status::Connected(connect) => connect.poll(mask, poller),
}
}
}
impl Socket for VsockStreamSocket {
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
let addr = VsockSocketAddr::try_from(sockaddr)?;
let inner = self.0.read();
match &*inner {
Status::Init(init) => init.bind(addr),
Status::Listen(_) | Status::Connected(_) => {
return_errno_with_message!(
Errno::EINVAL,
"cannot bind a listening or connected socket"
)
}
}
}
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let init = match &*self.0.read() {
Status::Init(init) => init.clone(),
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is listened");
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is connected");
}
};
let remote_addr = VsockSocketAddr::try_from(sockaddr)?;
let local_addr = init.bound_addr();
if let Some(addr) = local_addr {
if addr == remote_addr {
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
}
} else {
init.bind(VsockSocketAddr::any_addr())?;
}
let connecting = Arc::new(Connected::new(remote_addr, init.bound_addr().unwrap()));
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace
.connecting_sockets
.lock_irq_disabled()
.insert(connecting.local_addr(), connecting.clone());
// Send request
vsockspace
.driver
.lock_irq_disabled()
.request(&connecting.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?;
// wait for response from driver
// TODO: add timeout
let poller = Poller::new();
if !connecting
.poll(IoEvents::IN, Some(&poller))
.contains(IoEvents::IN)
{
poller.wait()?;
}
*self.0.write() = Status::Connected(connecting.clone());
// move connecting socket map to connected sockmap
vsockspace
.connecting_sockets
.lock_irq_disabled()
.remove(&connecting.local_addr())
.unwrap();
vsockspace
.connected_sockets
.lock_irq_disabled()
.insert(connecting.id(), connecting);
Ok(())
}
fn listen(&self, backlog: usize) -> Result<()> {
let init = match &*self.0.read() {
Status::Init(init) => init.clone(),
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is already listened");
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "The socket is already connected");
}
};
let addr = init.bound_addr().ok_or(Error::with_message(
Errno::EINVAL,
"The socket is not bound",
))?;
let listen = Arc::new(Listen::new(addr, backlog));
*self.0.write() = Status::Listen(listen.clone());
// push listen socket into vsockspace
VSOCK_GLOBAL
.get()
.unwrap()
.listen_sockets
.lock_irq_disabled()
.insert(listen.addr(), listen);
Ok(())
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen = match &*self.0.read() {
Status::Listen(listen) => listen.clone(),
Status::Init(_) | Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
}
};
let connected = listen.accept()?;
let peer_addr = connected.peer_addr();
VSOCK_GLOBAL
.get()
.unwrap()
.connected_sockets
.lock_irq_disabled()
.insert(connected.id(), connected.clone());
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.response(&connected.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
Ok((socket, peer_addr.into()))
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
let result = connected.shutdown(cmd);
if result.is_ok() {
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace
.used_ports
.lock_irq_disabled()
.remove(&connected.local_addr().port);
vsockspace
.connected_sockets
.lock_irq_disabled()
.remove(&connected.id());
}
result
} else {
return_errno_with_message!(Errno::EINVAL, "The socket is not connected.");
}
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected = match &*self.0.read() {
Status::Connected(connected) => connected.clone(),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
};
let read_size = connected.recv(buf)?;
let peer_addr = self.peer_addr()?;
// If buffer is now empty and the peer requested shutdown, finish shutting down the
// connection.
if connected.should_close() {
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.reset(&connected.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send close packet"))?;
}
Ok((read_size, peer_addr))
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(remote.is_none());
if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "vsock should not provide remote addr");
}
let inner = self.0.read();
match &*inner {
Status::Connected(connected) => connected.send(buf, flags),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
}
}
}
fn addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
let addr = match &*inner {
Status::Init(init) => init.bound_addr(),
Status::Listen(listen) => Some(listen.addr()),
Status::Connected(connected) => Some(connected.local_addr()),
};
addr.map(Into::<SocketAddr>::into)
.ok_or(Error::with_message(
Errno::EINVAL,
"The socket does not bind to addr",
))
}
fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
Ok(connected.peer_addr().into())
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
}
}
impl Default for VsockStreamSocket {
fn default() -> Self {
Self::new()
}
}