mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-29 04:13:24 +00:00
Support calling from inside via vsock
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
48f69c25a9
commit
60dd17fdd3
@ -1,13 +1,11 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use core::cmp::min;
|
||||
|
||||
use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent};
|
||||
use ringbuf::{ring_buffer::RbBase, HeapRb, Rb};
|
||||
|
||||
use super::connecting::Connecting;
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
events::IoEvents,
|
||||
net::socket::{
|
||||
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||
SendRecvFlags, SockShutdownCmd,
|
||||
@ -54,8 +52,8 @@ impl Connected {
|
||||
|
||||
pub fn try_recv(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
let bytes_read = connection.buffer.drain(buf);
|
||||
|
||||
let bytes_read = connection.buffer.len().min(buf.len());
|
||||
connection.buffer.pop_slice(&mut buf[..bytes_read]);
|
||||
connection.info.done_forwarding(bytes_read);
|
||||
|
||||
match bytes_read {
|
||||
@ -77,10 +75,8 @@ impl Connected {
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.send(buf, &mut connection.info)
|
||||
.map_err(|e| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
|
||||
.send(buf, &mut connection.info)?;
|
||||
|
||||
Ok(buf_len)
|
||||
}
|
||||
|
||||
@ -96,23 +92,13 @@ impl Connected {
|
||||
if self.should_close() {
|
||||
let connection = self.connection.lock_irq_disabled();
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.reset(&connection.info)
|
||||
.map_err(|e| Error::with_message(Errno::ENOMEM, "can not send close packet"))?;
|
||||
vsockspace
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.id());
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.local_addr().port);
|
||||
vsockspace.reset(&connection.info).unwrap();
|
||||
vsockspace.remove_connected_socket(&self.id());
|
||||
vsockspace.recycle_port(&self.local_addr().port);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub fn update_for_event(&self, event: &VsockEvent) {
|
||||
pub fn update_info(&self, event: &VsockEvent) {
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
connection.update_for_event(event)
|
||||
}
|
||||
@ -122,7 +108,7 @@ impl Connected {
|
||||
connection.info.clone()
|
||||
}
|
||||
|
||||
pub fn connection_buffer_add(&self, bytes: &[u8]) -> bool {
|
||||
pub fn add_connection_buffer(&self, bytes: &[u8]) -> bool {
|
||||
let mut connection = self.connection.lock_irq_disabled();
|
||||
connection.add(bytes)
|
||||
}
|
||||
@ -144,42 +130,18 @@ impl Connected {
|
||||
self.pollee.del_events(IoEvents::IN);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_observer(
|
||||
&self,
|
||||
pollee: &Pollee,
|
||||
observer: Weak<dyn Observer<IoEvents>>,
|
||||
mask: IoEvents,
|
||||
) -> Result<()> {
|
||||
pollee.register_observer(observer, mask);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn unregister_observer(
|
||||
&self,
|
||||
pollee: &Pollee,
|
||||
observer: &Weak<dyn Observer<IoEvents>>,
|
||||
) -> Result<Weak<dyn Observer<IoEvents>>> {
|
||||
pollee
|
||||
.unregister_observer(observer)
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Connected {
|
||||
fn drop(&mut self) {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.local_addr().port);
|
||||
vsockspace.recycle_port(&self.local_addr().port);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Connection {
|
||||
info: ConnectionInfo,
|
||||
buffer: RingBuffer,
|
||||
buffer: HeapRb<u8>,
|
||||
/// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
|
||||
/// still data in the buffer.
|
||||
pub peer_requested_shutdown: bool,
|
||||
@ -191,16 +153,15 @@ impl Connection {
|
||||
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
|
||||
Self {
|
||||
info,
|
||||
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||
buffer: HeapRb::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||
peer_requested_shutdown: false,
|
||||
}
|
||||
}
|
||||
pub fn from_info(info: ConnectionInfo) -> Self {
|
||||
let mut info = info.clone();
|
||||
pub fn from_info(mut info: ConnectionInfo) -> Self {
|
||||
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
|
||||
Self {
|
||||
info,
|
||||
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||
buffer: HeapRb::new(PER_CONNECTION_BUFFER_CAPACITY),
|
||||
peer_requested_shutdown: false,
|
||||
}
|
||||
}
|
||||
@ -208,7 +169,11 @@ impl Connection {
|
||||
self.info.update_for_event(event)
|
||||
}
|
||||
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||
self.buffer.add(bytes)
|
||||
if bytes.len() > self.buffer.free_len() {
|
||||
return false;
|
||||
}
|
||||
self.buffer.push_slice(bytes);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,85 +196,3 @@ impl From<VsockEvent> for ConnectionID {
|
||||
Self::new(event.destination.into(), event.source.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RingBuffer {
|
||||
buffer: Box<[u8]>,
|
||||
/// The number of bytes currently in the buffer.
|
||||
used: usize,
|
||||
/// The index of the first used byte in the buffer.
|
||||
start: usize,
|
||||
}
|
||||
//TODO: ringbuf
|
||||
impl RingBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
// TODO: can be optimized.
|
||||
let temp = vec![0; capacity];
|
||||
Self {
|
||||
// FIXME: if the capacity is excessive, elements move will be executed.
|
||||
buffer: temp.into_boxed_slice(),
|
||||
used: 0,
|
||||
start: 0,
|
||||
}
|
||||
}
|
||||
/// Returns the number of bytes currently used in the buffer.
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Returns true iff there are currently no bytes in the buffer.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.used == 0
|
||||
}
|
||||
|
||||
/// Returns the number of bytes currently free in the buffer.
|
||||
pub fn available(&self) -> usize {
|
||||
self.buffer.len() - self.used
|
||||
}
|
||||
|
||||
/// Adds the given bytes to the buffer if there is enough capacity for them all.
|
||||
///
|
||||
/// Returns true if they were added, or false if they were not.
|
||||
pub fn add(&mut self, bytes: &[u8]) -> bool {
|
||||
if bytes.len() > self.available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The index of the first available position in the buffer.
|
||||
let first_available = (self.start + self.used) % self.buffer.len();
|
||||
// The number of bytes to copy from `bytes` to `buffer` between `first_available` and
|
||||
// `buffer.len()`.
|
||||
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
|
||||
self.buffer[first_available..first_available + copy_length_before_wraparound]
|
||||
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
|
||||
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
|
||||
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
|
||||
}
|
||||
self.used += bytes.len();
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Reads and removes as many bytes as possible from the buffer, up to the length of the given
|
||||
/// buffer.
|
||||
pub fn drain(&mut self, out: &mut [u8]) -> usize {
|
||||
let bytes_read = min(self.used, out.len());
|
||||
|
||||
// The number of bytes to copy out between `start` and the end of the buffer.
|
||||
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
|
||||
// The number of bytes to copy out from the beginning of the buffer after wrapping around.
|
||||
let read_after_wraparound = bytes_read
|
||||
.checked_sub(read_before_wraparound)
|
||||
.unwrap_or_default();
|
||||
|
||||
out[0..read_before_wraparound]
|
||||
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
|
||||
out[read_before_wraparound..bytes_read]
|
||||
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
|
||||
|
||||
self.used -= bytes_read;
|
||||
self.start = (self.start + bytes_read) % self.buffer.len();
|
||||
|
||||
bytes_read
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ impl Connecting {
|
||||
pub fn info(&self) -> ConnectionInfo {
|
||||
self.info.lock_irq_disabled().clone()
|
||||
}
|
||||
pub fn update_for_event(&self, event: &VsockEvent) {
|
||||
pub fn update_info(&self, event: &VsockEvent) {
|
||||
self.info.lock_irq_disabled().update_for_event(event)
|
||||
}
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
@ -52,9 +52,6 @@ impl Connecting {
|
||||
impl Drop for Connecting {
|
||||
fn drop(&mut self) {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.local_addr().port);
|
||||
vsockspace.recycle_port(&self.local_addr().port);
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
events::IoEvents,
|
||||
net::socket::vsock::{
|
||||
addr::{VsockSocketAddr, VMADDR_CID_ANY, VMADDR_PORT_ANY},
|
||||
VSOCK_GLOBAL,
|
||||
@ -11,31 +11,31 @@ use crate::{
|
||||
};
|
||||
|
||||
pub struct Init {
|
||||
bind_addr: SpinLock<Option<VsockSocketAddr>>,
|
||||
bound_addr: Mutex<Option<VsockSocketAddr>>,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl Init {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
bind_addr: SpinLock::new(None),
|
||||
bound_addr: Mutex::new(None),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind(&self, addr: VsockSocketAddr) -> Result<()> {
|
||||
if self.bind_addr.lock().is_some() {
|
||||
if self.bound_addr.lock().is_some() {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
|
||||
}
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
|
||||
// check correctness of cid
|
||||
let local_cid = vsockspace.driver.lock_irq_disabled().guest_cid();
|
||||
if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid as u32 {
|
||||
return_errno_with_message!(Errno::EADDRNOTAVAIL, "The cid in address is incorrect");
|
||||
let local_cid = vsockspace.guest_cid();
|
||||
if addr.cid != VMADDR_CID_ANY && addr.cid != local_cid {
|
||||
return_errno_with_message!(Errno::EADDRNOTAVAIL, "the cid in address is incorrect");
|
||||
}
|
||||
let mut new_addr = addr;
|
||||
new_addr.cid = local_cid as u32;
|
||||
new_addr.cid = local_cid;
|
||||
|
||||
// check and assign a port
|
||||
if addr.port == VMADDR_PORT_ANY {
|
||||
@ -44,62 +44,33 @@ impl Init {
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EAGAIN, "cannot find unused high port");
|
||||
}
|
||||
} else if vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.contains(&new_addr.port)
|
||||
{
|
||||
} else if !vsockspace.insert_port(new_addr.port) {
|
||||
return_errno_with_message!(Errno::EADDRNOTAVAIL, "the port in address is occupied");
|
||||
} else {
|
||||
vsockspace
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.insert(new_addr.port);
|
||||
}
|
||||
|
||||
//TODO: The privileged port isn't checked
|
||||
*self.bind_addr.lock() = Some(new_addr);
|
||||
*self.bound_addr.lock() = Some(new_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_bound(&self) -> bool {
|
||||
self.bind_addr.lock().is_some()
|
||||
self.bound_addr.lock().is_some()
|
||||
}
|
||||
|
||||
pub fn bound_addr(&self) -> Option<VsockSocketAddr> {
|
||||
*self.bind_addr.lock()
|
||||
*self.bound_addr.lock()
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
|
||||
pub fn register_observer(
|
||||
&self,
|
||||
pollee: &Pollee,
|
||||
observer: Weak<dyn Observer<IoEvents>>,
|
||||
mask: IoEvents,
|
||||
) -> Result<()> {
|
||||
pollee.register_observer(observer, mask);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn unregister_observer(
|
||||
&self,
|
||||
pollee: &Pollee,
|
||||
observer: &Weak<dyn Observer<IoEvents>>,
|
||||
) -> Result<Weak<dyn Observer<IoEvents>>> {
|
||||
pollee
|
||||
.unregister_observer(observer)
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Init {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = *self.bind_addr.lock() {
|
||||
if let Some(addr) = *self.bound_addr.lock() {
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace.used_ports.lock_irq_disabled().remove(&addr.port);
|
||||
vsockspace.recycle_port(&addr.port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
use super::connected::Connected;
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
events::IoEvents,
|
||||
net::socket::vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
@ -30,7 +30,7 @@ impl Listen {
|
||||
pub fn push_incoming(&self, connect: Arc<Connected>) -> Result<()> {
|
||||
let mut incoming_connections = self.incoming_connection.lock_irq_disabled();
|
||||
if incoming_connections.len() >= self.backlog {
|
||||
return_errno_with_message!(Errno::ENOMEM, "Queue in listenging socket is full")
|
||||
return_errno_with_message!(Errno::ENOMEM, "queue in listenging socket is full")
|
||||
}
|
||||
incoming_connections.push_back(connect);
|
||||
Ok(())
|
||||
@ -59,34 +59,10 @@ impl Listen {
|
||||
self.pollee.del_events(IoEvents::IN);
|
||||
}
|
||||
}
|
||||
pub fn register_observer(
|
||||
&self,
|
||||
pollee: &Pollee,
|
||||
observer: Weak<dyn Observer<IoEvents>>,
|
||||
mask: IoEvents,
|
||||
) -> Result<()> {
|
||||
pollee.register_observer(observer, mask);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn unregister_observer(
|
||||
&self,
|
||||
pollee: &Pollee,
|
||||
observer: &Weak<dyn Observer<IoEvents>>,
|
||||
) -> Result<Weak<dyn Observer<IoEvents>>> {
|
||||
pollee
|
||||
.unregister_observer(observer)
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "fails to unregister observer"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Listen {
|
||||
fn drop(&mut self) {
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.used_ports
|
||||
.lock_irq_disabled()
|
||||
.remove(&self.addr.port);
|
||||
VSOCK_GLOBAL.get().unwrap().recycle_port(&self.addr.port);
|
||||
}
|
||||
}
|
||||
|
@ -13,13 +13,12 @@ use crate::{
|
||||
SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
process::signal::Poller,
|
||||
};
|
||||
|
||||
pub struct VsockStreamSocket {
|
||||
status: RwLock<Status>,
|
||||
is_nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
pub enum Status {
|
||||
@ -34,14 +33,12 @@ impl VsockStreamSocket {
|
||||
Self {
|
||||
status: RwLock::new(Status::Init(init)),
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
}
|
||||
}
|
||||
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
|
||||
Self {
|
||||
status: RwLock::new(Status::Connected(connected)),
|
||||
is_nonblocking: AtomicBool::new(false),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
}
|
||||
}
|
||||
fn is_nonblocking(&self) -> bool {
|
||||
@ -52,11 +49,44 @@ impl VsockStreamSocket {
|
||||
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// TODO: Support timeout
|
||||
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
|
||||
where
|
||||
F: FnMut() -> Result<R>,
|
||||
{
|
||||
let poller = Poller::new();
|
||||
|
||||
loop {
|
||||
match cond() {
|
||||
Err(err) if err.error() == Errno::EAGAIN => (),
|
||||
result => {
|
||||
if let Err(e) = result {
|
||||
debug!("The result of cond() is Error: {:?}", e);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
let events = match &*self.status.read() {
|
||||
Status::Init(init) => init.poll(mask, Some(&poller)),
|
||||
Status::Listen(listen) => listen.poll(mask, Some(&poller)),
|
||||
Status::Connected(connected) => connected.poll(mask, Some(&poller)),
|
||||
};
|
||||
|
||||
debug!("events: {:?}", events);
|
||||
if !events.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
poller.wait()?;
|
||||
}
|
||||
}
|
||||
|
||||
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let listen = match &*self.status.read() {
|
||||
Status::Listen(listen) => listen.clone(),
|
||||
Status::Init(_) | Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
|
||||
}
|
||||
};
|
||||
|
||||
@ -68,17 +98,12 @@ impl VsockStreamSocket {
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(connected.id(), connected.clone());
|
||||
.insert_connected_socket(connected.id(), connected.clone());
|
||||
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.response(&connected.get_info())
|
||||
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
|
||||
.response(&connected.get_info())?;
|
||||
|
||||
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
|
||||
Ok((socket, peer_addr.into()))
|
||||
@ -123,7 +148,11 @@ impl FileLike for VsockStreamSocket {
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
match &*self.status.read() {
|
||||
Status::Init(init) => init.poll(mask, poller),
|
||||
Status::Listen(listen) => listen.poll(mask, poller),
|
||||
Status::Connected(connected) => connected.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
|
||||
fn status_flags(&self) -> StatusFlags {
|
||||
@ -142,30 +171,6 @@ impl FileLike for VsockStreamSocket {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn register_observer(
|
||||
&self,
|
||||
observer: Weak<dyn crate::events::Observer<IoEvents>>,
|
||||
mask: IoEvents,
|
||||
) -> Result<()> {
|
||||
match &*self.status.read() {
|
||||
Status::Init(init) => init.register_observer(&self.pollee, observer, mask),
|
||||
Status::Listen(listen) => listen.register_observer(&self.pollee, observer, mask),
|
||||
Status::Connected(connected) => {
|
||||
connected.register_observer(&self.pollee, observer, mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unregister_observer(
|
||||
&self,
|
||||
observer: &Weak<dyn crate::events::Observer<IoEvents>>,
|
||||
) -> Result<Weak<dyn crate::events::Observer<IoEvents>>> {
|
||||
match &*self.status.read() {
|
||||
Status::Init(init) => init.unregister_observer(&self.pollee, observer),
|
||||
Status::Listen(listen) => listen.unregister_observer(&self.pollee, observer),
|
||||
Status::Connected(connected) => connected.unregister_observer(&self.pollee, observer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Socket for VsockStreamSocket {
|
||||
@ -189,15 +194,14 @@ impl Socket for VsockStreamSocket {
|
||||
let init = match &*self.status.read() {
|
||||
Status::Init(init) => init.clone(),
|
||||
Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is listened");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is listened");
|
||||
}
|
||||
Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is connected");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is connected");
|
||||
}
|
||||
};
|
||||
let remote_addr = VsockSocketAddr::try_from(sockaddr)?;
|
||||
let local_addr = init.bound_addr();
|
||||
|
||||
if let Some(addr) = local_addr {
|
||||
if addr == remote_addr {
|
||||
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
|
||||
@ -208,18 +212,10 @@ impl Socket for VsockStreamSocket {
|
||||
|
||||
let connecting = Arc::new(Connecting::new(remote_addr, init.bound_addr().unwrap()));
|
||||
let vsockspace = VSOCK_GLOBAL.get().unwrap();
|
||||
vsockspace
|
||||
.connecting_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(connecting.local_addr(), connecting.clone());
|
||||
vsockspace.insert_connecting_socket(connecting.local_addr(), connecting.clone());
|
||||
|
||||
// Send request
|
||||
vsockspace
|
||||
.driver
|
||||
.lock_irq_disabled()
|
||||
.request(&connecting.info())
|
||||
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?;
|
||||
|
||||
vsockspace.request(&connecting.info()).unwrap();
|
||||
// wait for response from driver
|
||||
// TODO: add timeout
|
||||
let poller = Poller::new();
|
||||
@ -229,19 +225,14 @@ impl Socket for VsockStreamSocket {
|
||||
{
|
||||
poller.wait()?;
|
||||
}
|
||||
vsockspace
|
||||
.connecting_sockets
|
||||
.lock_irq_disabled()
|
||||
.remove(&connecting.local_addr())
|
||||
.unwrap();
|
||||
|
||||
vsockspace
|
||||
.remove_connecting_socket(&connecting.local_addr())
|
||||
.unwrap();
|
||||
let connected = Arc::new(Connected::from_connecting(connecting));
|
||||
*self.status.write() = Status::Connected(connected.clone());
|
||||
// move connecting socket map to connected sockmap
|
||||
vsockspace
|
||||
.connected_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(connected.id(), connected);
|
||||
vsockspace.insert_connected_socket(connected.id(), connected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -250,15 +241,15 @@ impl Socket for VsockStreamSocket {
|
||||
let init = match &*self.status.read() {
|
||||
Status::Init(init) => init.clone(),
|
||||
Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is already listened");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already listened");
|
||||
}
|
||||
Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EISCONN, "The socket is already connected");
|
||||
return_errno_with_message!(Errno::EISCONN, "the socket is already connected");
|
||||
}
|
||||
};
|
||||
let addr = init.bound_addr().ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"The socket is not bound",
|
||||
"the socket is not bound",
|
||||
))?;
|
||||
let listen = Arc::new(Listen::new(addr, backlog));
|
||||
*self.status.write() = Status::Listen(listen.clone());
|
||||
@ -267,9 +258,7 @@ impl Socket for VsockStreamSocket {
|
||||
VSOCK_GLOBAL
|
||||
.get()
|
||||
.unwrap()
|
||||
.listen_sockets
|
||||
.lock_irq_disabled()
|
||||
.insert(listen.addr(), listen);
|
||||
.insert_listen_socket(listen.addr(), listen);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -278,7 +267,7 @@ impl Socket for VsockStreamSocket {
|
||||
if self.is_nonblocking() {
|
||||
self.try_accept()
|
||||
} else {
|
||||
wait_events(self, IoEvents::IN, || self.try_accept())
|
||||
self.wait_events(IoEvents::IN, || self.try_accept())
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,7 +275,7 @@ impl Socket for VsockStreamSocket {
|
||||
match &*self.status.read() {
|
||||
Status::Connected(connected) => connected.shutdown(cmd),
|
||||
Status::Init(_) | Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -297,7 +286,7 @@ impl Socket for VsockStreamSocket {
|
||||
if self.is_nonblocking() {
|
||||
self.try_recvfrom(buf, flags)
|
||||
} else {
|
||||
wait_events(self, IoEvents::IN, || self.try_recvfrom(buf, flags))
|
||||
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,7 +304,7 @@ impl Socket for VsockStreamSocket {
|
||||
match &*inner {
|
||||
Status::Connected(connected) => connected.send(buf, flags),
|
||||
Status::Init(_) | Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -343,37 +332,3 @@ impl Socket for VsockStreamSocket {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Support timeout
|
||||
fn wait_events<F, R>(socket: &VsockStreamSocket, mask: IoEvents, mut cond: F) -> Result<R>
|
||||
where
|
||||
F: FnMut() -> Result<R>,
|
||||
{
|
||||
let poller = Poller::new();
|
||||
|
||||
loop {
|
||||
match cond() {
|
||||
Err(err) if err.error() == Errno::EAGAIN => (),
|
||||
result => {
|
||||
if let Err(e) = result {
|
||||
debug!("The result of cond() is Error: {:?}", e);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
let events = match &*socket.status.read() {
|
||||
Status::Init(init) => init.poll(mask, Some(&poller)),
|
||||
Status::Listen(listen) => listen.poll(mask, Some(&poller)),
|
||||
Status::Connected(connected) => connected.poll(mask, Some(&poller)),
|
||||
};
|
||||
|
||||
debug!("events: {:?}", events);
|
||||
if !events.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
poller.wait()?;
|
||||
debug!("pass the poller wait");
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user