Support nonblocking option

This commit is contained in:
Anmin Liu
2024-04-21 17:08:21 +00:00
committed by Tate, Hongliang Tian
parent 7f6ef5e12d
commit 48f69c25a9
12 changed files with 301 additions and 159 deletions

View File

@ -1,31 +1,25 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::AtomicBool;
use atomic::Ordering;
use super::{connected::Connected, connecting::Connecting, init::Init, listen::Listen};
use crate::{
events::IoEvents,
fs::file_handle::FileLike,
fs::{file_handle::FileLike, utils::StatusFlags},
net::socket::{
vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL},
SendRecvFlags, SockShutdownCmd, Socket, SocketAddr,
},
prelude::*,
process::signal::Poller,
process::signal::{Pollee, Poller},
};
pub struct VsockStreamSocket(RwLock<Status>);
impl VsockStreamSocket {
pub(super) fn new_from_init(init: Arc<Init>) -> Self {
Self(RwLock::new(Status::Init(init)))
}
pub(super) fn new_from_listen(listen: Arc<Listen>) -> Self {
Self(RwLock::new(Status::Listen(listen)))
}
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
Self(RwLock::new(Status::Connected(connected)))
}
pub struct VsockStreamSocket {
status: RwLock<Status>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
pub enum Status {
@ -35,9 +29,82 @@ pub enum Status {
}
impl VsockStreamSocket {
pub fn new() -> Self {
pub fn new(nonblocking: bool) -> Self {
let init = Arc::new(Init::new());
Self(RwLock::new(Status::Init(init)))
Self {
status: RwLock::new(Status::Init(init)),
is_nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(IoEvents::empty()),
}
}
pub(super) fn new_from_connected(connected: Arc<Connected>) -> Self {
Self {
status: RwLock::new(Status::Connected(connected)),
is_nonblocking: AtomicBool::new(false),
pollee: Pollee::new(IoEvents::empty()),
}
}
fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen = match &*self.status.read() {
Status::Listen(listen) => listen.clone(),
Status::Init(_) | Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
}
};
let connected = listen.try_accept()?;
listen.update_io_events();
let peer_addr = connected.peer_addr();
VSOCK_GLOBAL
.get()
.unwrap()
.connected_sockets
.lock_irq_disabled()
.insert(connected.id(), connected.clone());
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.response(&connected.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
Ok((socket, peer_addr.into()))
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected = match &*self.status.read() {
Status::Connected(connected) => connected.clone(),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
};
let read_size = connected.try_recv(buf)?;
connected.update_io_events();
let peer_addr = self.peer_addr()?;
// If buffer is now empty and the peer requested shutdown, finish shutting down the
// connection.
// TODO: properly place the close request
if connected.should_close() {
if let Err(e) = self.shutdown(SockShutdownCmd::SHUT_RDWR) {
debug!("The error is {:?}", e);
}
}
Ok((read_size, peer_addr))
}
}
@ -56,11 +123,47 @@ impl FileLike for VsockStreamSocket {
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let inner = self.0.read();
match &*inner {
Status::Init(init) => init.poll(mask, poller),
Status::Listen(listen) => listen.poll(mask, poller),
Status::Connected(connect) => connect.poll(mask, poller),
self.pollee.poll(mask, poller)
}
fn status_flags(&self) -> StatusFlags {
if self.is_nonblocking() {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
if new_flags.contains(StatusFlags::O_NONBLOCK) {
self.set_nonblocking(true);
} else {
self.set_nonblocking(false);
}
Ok(())
}
fn register_observer(
&self,
observer: Weak<dyn crate::events::Observer<IoEvents>>,
mask: IoEvents,
) -> Result<()> {
match &*self.status.read() {
Status::Init(init) => init.register_observer(&self.pollee, observer, mask),
Status::Listen(listen) => listen.register_observer(&self.pollee, observer, mask),
Status::Connected(connected) => {
connected.register_observer(&self.pollee, observer, mask)
}
}
}
fn unregister_observer(
&self,
observer: &Weak<dyn crate::events::Observer<IoEvents>>,
) -> Result<Weak<dyn crate::events::Observer<IoEvents>>> {
match &*self.status.read() {
Status::Init(init) => init.unregister_observer(&self.pollee, observer),
Status::Listen(listen) => listen.unregister_observer(&self.pollee, observer),
Status::Connected(connected) => connected.unregister_observer(&self.pollee, observer),
}
}
}
@ -68,7 +171,7 @@ impl FileLike for VsockStreamSocket {
impl Socket for VsockStreamSocket {
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
let addr = VsockSocketAddr::try_from(sockaddr)?;
let inner = self.0.read();
let inner = self.status.read();
match &*inner {
Status::Init(init) => init.bind(addr),
Status::Listen(_) | Status::Connected(_) => {
@ -80,8 +183,10 @@ impl Socket for VsockStreamSocket {
}
}
// Since blocking mode is supported, there is no need to store the connecting status.
// TODO: Refactor when blocking mode is supported.
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let init = match &*self.0.read() {
let init = match &*self.status.read() {
Status::Init(init) => init.clone(),
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is listened");
@ -131,7 +236,7 @@ impl Socket for VsockStreamSocket {
.unwrap();
let connected = Arc::new(Connected::from_connecting(connecting));
*self.0.write() = Status::Connected(connected.clone());
*self.status.write() = Status::Connected(connected.clone());
// move connecting socket map to connected sockmap
vsockspace
.connected_sockets
@ -142,7 +247,7 @@ impl Socket for VsockStreamSocket {
}
fn listen(&self, backlog: usize) -> Result<()> {
let init = match &*self.0.read() {
let init = match &*self.status.read() {
Status::Init(init) => init.clone(),
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is already listened");
@ -156,7 +261,7 @@ impl Socket for VsockStreamSocket {
"The socket is not bound",
))?;
let listen = Arc::new(Listen::new(addr, backlog));
*self.0.write() = Status::Listen(listen.clone());
*self.status.write() = Status::Listen(listen.clone());
// push listen socket into vsockspace
VSOCK_GLOBAL
@ -170,76 +275,30 @@ impl Socket for VsockStreamSocket {
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen = match &*self.0.read() {
Status::Listen(listen) => listen.clone(),
Status::Init(_) | Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not listening");
}
};
let connected = listen.accept()?;
let peer_addr = connected.peer_addr();
VSOCK_GLOBAL
.get()
.unwrap()
.connected_sockets
.lock_irq_disabled()
.insert(connected.id(), connected.clone());
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.response(&connected.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send response packet"))?;
let socket = Arc::new(VsockStreamSocket::new_from_connected(connected));
Ok((socket, peer_addr.into()))
if self.is_nonblocking() {
self.try_accept()
} else {
wait_events(self, IoEvents::IN, || self.try_accept())
}
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
let result = connected.shutdown(cmd);
if result.is_ok() {
let vsockspace = VSOCK_GLOBAL.get().unwrap();
vsockspace
.used_ports
.lock_irq_disabled()
.remove(&connected.local_addr().port);
vsockspace
.connected_sockets
.lock_irq_disabled()
.remove(&connected.id());
match &*self.status.read() {
Status::Connected(connected) => connected.shutdown(cmd),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "The socket is not connected");
}
result
} else {
return_errno_with_message!(Errno::EINVAL, "The socket is not connected.");
}
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected = match &*self.0.read() {
Status::Connected(connected) => connected.clone(),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
};
let read_size = connected.recv(buf)?;
let peer_addr = self.peer_addr()?;
// If buffer is now empty and the peer requested shutdown, finish shutting down the
// connection.
if connected.should_close() {
VSOCK_GLOBAL
.get()
.unwrap()
.driver
.lock_irq_disabled()
.reset(&connected.get_info())
.map_err(|e| Error::with_message(Errno::EAGAIN, "can not send close packet"))?;
debug_assert!(flags.is_all_supported());
if self.is_nonblocking() {
self.try_recvfrom(buf, flags)
} else {
wait_events(self, IoEvents::IN, || self.try_recvfrom(buf, flags))
}
Ok((read_size, peer_addr))
}
fn sendto(
@ -252,7 +311,7 @@ impl Socket for VsockStreamSocket {
if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "vsock should not provide remote addr");
}
let inner = self.0.read();
let inner = self.status.read();
match &*inner {
Status::Connected(connected) => connected.send(buf, flags),
Status::Init(_) | Status::Listen(_) => {
@ -262,7 +321,7 @@ impl Socket for VsockStreamSocket {
}
fn addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
let inner = self.status.read();
let addr = match &*inner {
Status::Init(init) => init.bound_addr(),
Status::Listen(listen) => Some(listen.addr()),
@ -276,7 +335,7 @@ impl Socket for VsockStreamSocket {
}
fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
let inner = self.status.read();
if let Status::Connected(connected) = &*inner {
Ok(connected.peer_addr().into())
} else {
@ -285,8 +344,36 @@ impl Socket for VsockStreamSocket {
}
}
impl Default for VsockStreamSocket {
fn default() -> Self {
Self::new()
// TODO: Support timeout
fn wait_events<F, R>(socket: &VsockStreamSocket, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => {
if let Err(e) = result {
debug!("The result of cond() is Error: {:?}", e);
}
return result;
}
};
let events = match &*socket.status.read() {
Status::Init(init) => init.poll(mask, Some(&poller)),
Status::Listen(listen) => listen.poll(mask, Some(&poller)),
Status::Connected(connected) => connected.poll(mask, Some(&poller)),
};
debug!("events: {:?}", events);
if !events.is_empty() {
continue;
}
poller.wait()?;
debug!("pass the poller wait");
}
}