Support calling from inside via vsock

This commit is contained in:
Anmin Liu
2024-05-06 15:00:48 +00:00
committed by Tate, Hongliang Tian
parent 48f69c25a9
commit 60dd17fdd3
24 changed files with 582 additions and 558 deletions

View File

@ -13,13 +13,12 @@ use crate::{
SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
},
prelude::*,
process::signal::{Pollee, Poller},
process::signal::Poller,
};
pub struct VsockStreamSocket {
status: RwLock<Status>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
pub enum Status {
@ -34,14 +33,12 @@ impl VsockStreamSocket {
Self {
status: RwLock::new(Status::Init(init)),
is_nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(IoEvents::empty()),
}
}
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
Self {
status: RwLock::new(Status::Connected(connected)),
is_nonblocking: AtomicBool::new(false),
pollee: Pollee::new(IoEvents::empty()),
}
}
fn is_nonblocking(&self) -> bool {
@ -52,11 +49,44 @@ impl VsockStreamSocket {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
// TODO: Support timeout
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => {
if let Err(e) = result {
debug!("The result of cond() is Error: {:?}", e);
}
return result;
}
};
let events = match &*self.status.read() {
Status::Init(init) => init.poll(mask, Some(&poller)),
Status::Listen(listen) => listen.poll(mask, Some(&poller)),
Status::Connected(connected) => connected.poll(mask, Some(&poller)),
};
debug!("events: {:?}", events);
if !events.is_empty() {
continue;
}
poller.wait()?;
}
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen = match &*self.status.read() {
Status::Listen(listen) => listen.clone(),
Status::Init(_) | Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
}
};
@ -68,17 +98,12 @@ impl VsockStreamSocket {
VSOCK_GLOBAL
.get()
.unwrap()
.connected_sockets
.lock_irq_disabled()
.insert(connected.id(), connected.clone());
.insert_connected_socket(connected.id(), connected.clone());
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.response(&connected.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
.response(&connected.get_info())?;
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
Ok((socket, peer_addr.into()))
@ -123,7 +148,11 @@ impl FileLike for VsockStreamSocket {
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
match &*self.status.read() {
Status::Init(init) => init.poll(mask, poller),
Status::Listen(listen) => listen.poll(mask, poller),
Status::Connected(connected) => connected.poll(mask, poller),
}
}
fn status_flags(&self) -> StatusFlags {
@ -142,30 +171,6 @@ impl FileLike for VsockStreamSocket {
}
Ok(())
}
fn register_observer(
&self,
observer: Weak<dyn crate::events::Observer<IoEvents>>,
mask: IoEvents,
) -> Result<()> {
match &*self.status.read() {
Status::Init(init) => init.register_observer(&self.pollee, observer, mask),
Status::Listen(listen) => listen.register_observer(&self.pollee, observer, mask),
Status::Connected(connected) => {
connected.register_observer(&self.pollee, observer, mask)
}
}
}
fn unregister_observer(
&self,
observer: &Weak<dyn crate::events::Observer<IoEvents>>,
) -> Result<Weak<dyn crate::events::Observer<IoEvents>>> {
match &*self.status.read() {
Status::Init(init) => init.unregister_observer(&self.pollee, observer),
Status::Listen(listen) => listen.unregister_observer(&self.pollee, observer),
Status::Connected(connected) => connected.unregister_observer(&self.pollee, observer),
}
}
}
impl Socket for VsockStreamSocket {
@ -189,15 +194,14 @@ impl Socket for VsockStreamSocket {
let init = match &*self.status.read() {
Status::Init(init) => init.clone(),
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is listened");
return_errno_with_message!(Errno::EINVAL, "the socket is listened");
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is connected");
return_errno_with_message!(Errno::EINVAL, "the socket is connected");
}
};
let remote_addr = VsockSocketAddr::try_from(sockaddr)?;
let local_addr = init.bound_addr();
if let Some(addr) = local_addr {
if addr == remote_addr {
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
@ -208,18 +212,10 @@ impl Socket for VsockStreamSocket {
let connecting = Arc::new(Connecting::new(remote_addr, init.bound_addr().unwrap()));
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace
.connecting_sockets
.lock_irq_disabled()
.insert(connecting.local_addr(), connecting.clone());
vsockspace.insert_connecting_socket(connecting.local_addr(), connecting.clone());
// Send request
vsockspace
.driver
.lock_irq_disabled()
.request(&connecting.info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send connect packet"))?;
vsockspace.request(&connecting.info()).unwrap();
// wait for response from driver
// TODO: add timeout
let poller = Poller::new();
@ -229,19 +225,14 @@ impl Socket for VsockStreamSocket {
{
poller.wait()?;
}
vsockspace
.connecting_sockets
.lock_irq_disabled()
.remove(&connecting.local_addr())
.unwrap();
vsockspace
.remove_connecting_socket(&connecting.local_addr())
.unwrap();
let connected = Arc::new(Connected::from_connecting(connecting));
*self.status.write() = Status::Connected(connected.clone());
// move connecting socket map to connected sockmap
vsockspace
.connected_sockets
.lock_irq_disabled()
.insert(connected.id(), connected);
vsockspace.insert_connected_socket(connected.id(), connected);
Ok(())
}
@ -250,15 +241,15 @@ impl Socket for VsockStreamSocket {
let init = match &*self.status.read() {
Status::Init(init) => init.clone(),
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is already listened");
return_errno_with_message!(Errno::EINVAL, "the socket is already listened");
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "The socket is already connected");
return_errno_with_message!(Errno::EISCONN, "the socket is already connected");
}
};
let addr = init.bound_addr().ok_or(Error::with_message(
Errno::EINVAL,
"The socket is not bound",
"the socket is not bound",
))?;
let listen = Arc::new(Listen::new(addr, backlog));
*self.status.write() = Status::Listen(listen.clone());
@ -267,9 +258,7 @@ impl Socket for VsockStreamSocket {
VSOCK_GLOBAL
.get()
.unwrap()
.listen_sockets
.lock_irq_disabled()
.insert(listen.addr(), listen);
.insert_listen_socket(listen.addr(), listen);
Ok(())
}
@ -278,7 +267,7 @@ impl Socket for VsockStreamSocket {
if self.is_nonblocking() {
self.try_accept()
} else {
wait_events(self, IoEvents::IN, || self.try_accept())
self.wait_events(IoEvents::IN, || self.try_accept())
}
}
@ -286,7 +275,7 @@ impl Socket for VsockStreamSocket {
match &*self.status.read() {
Status::Connected(connected) => connected.shutdown(cmd),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
}
}
@ -297,7 +286,7 @@ impl Socket for VsockStreamSocket {
if self.is_nonblocking() {
self.try_recvfrom(buf, flags)
} else {
wait_events(self, IoEvents::IN, || self.try_recvfrom(buf, flags))
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
}
}
@ -315,7 +304,7 @@ impl Socket for VsockStreamSocket {
match &*inner {
Status::Connected(connected) => connected.send(buf, flags),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
}
}
@ -343,37 +332,3 @@ impl Socket for VsockStreamSocket {
}
}
}
// TODO: Support timeout
fn wait_events<F, R>(socket: &VsockStreamSocket, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => {
if let Err(e) = result {
debug!("The result of cond() is Error: {:?}", e);
}
return result;
}
};
let events = match &*socket.status.read() {
Status::Init(init) => init.poll(mask, Some(&poller)),
Status::Listen(listen) => listen.poll(mask, Some(&poller)),
Status::Connected(connected) => connected.poll(mask, Some(&poller)),
};
debug!("events: {:?}", events);
if !events.is_empty() {
continue;
}
poller.wait()?;
debug!("pass the poller wait");
}
}