Replace Poisoned state by takeable crate

This commit is contained in:
Ruihan Li
2024-02-20 23:10:06 +08:00
committed by Tate, Hongliang Tian
parent a10d04c5f9
commit 595c6ab288
4 changed files with 129 additions and 133 deletions

7
Cargo.lock generated
View File

@ -226,6 +226,7 @@ dependencies = [
"smoltcp", "smoltcp",
"spin 0.9.8", "spin 0.9.8",
"static_assertions", "static_assertions",
"takeable",
"tdx-guest", "tdx-guest",
"time", "time",
"typeflags", "typeflags",
@ -1342,6 +1343,12 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "takeable"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7efe11d01772262041364ebf855e0c6c9e2bd0dd4f059083e540a5f1cda69477"
[[package]] [[package]]
name = "tap" name = "tap"
version = "1.0.1" version = "1.0.1"

View File

@ -68,6 +68,7 @@ getset = "0.1.2"
atomic = "0.6" atomic = "0.6"
bytemuck = "1.14.3" bytemuck = "1.14.3"
bytemuck_derive = "1.5.0" bytemuck_derive = "1.5.0"
takeable = "0.2.2"
[dependencies.lazy_static] [dependencies.lazy_static]
version = "1.0" version = "1.0"

View File

@ -1,9 +1,8 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::{ use core::sync::atomic::{AtomicBool, Ordering};
mem,
sync::atomic::{AtomicBool, Ordering}, use takeable::Takeable;
};
use self::{bound::BoundDatagram, unbound::UnboundDatagram}; use self::{bound::BoundDatagram, unbound::UnboundDatagram};
use super::common::get_ephemeral_endpoint; use super::common::get_ephemeral_endpoint;
@ -26,7 +25,7 @@ mod bound;
mod unbound; mod unbound;
pub struct DatagramSocket { pub struct DatagramSocket {
inner: RwLock<Inner>, inner: RwLock<Takeable<Inner>>,
nonblocking: AtomicBool, nonblocking: AtomicBool,
pollee: Pollee, pollee: Pollee,
} }
@ -34,7 +33,6 @@ pub struct DatagramSocket {
enum Inner { enum Inner {
Unbound(UnboundDatagram), Unbound(UnboundDatagram),
Bound(BoundDatagram), Bound(BoundDatagram),
Poisoned,
} }
impl Inner { impl Inner {
@ -47,12 +45,6 @@ impl Inner {
Inner::Bound(bound_datagram), Inner::Bound(bound_datagram),
)); ));
} }
Inner::Poisoned => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is poisoned"),
Inner::Poisoned,
));
}
}; };
let bound_datagram = match unbound_datagram.bind(endpoint) { let bound_datagram = match unbound_datagram.bind(endpoint) {
@ -81,7 +73,7 @@ impl DatagramSocket {
let unbound_datagram = UnboundDatagram::new(me.clone() as _); let unbound_datagram = UnboundDatagram::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty()); let pollee = Pollee::new(IoEvents::empty());
Self { Self {
inner: RwLock::new(Inner::Unbound(unbound_datagram)), inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))),
nonblocking: AtomicBool::new(nonblocking), nonblocking: AtomicBool::new(nonblocking),
pollee, pollee,
} }
@ -98,29 +90,27 @@ impl DatagramSocket {
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> { fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
// Fast path // Fast path
if let Inner::Bound(_) = &*self.inner.read() { if let Inner::Bound(_) = self.inner.read().as_ref() {
return Ok(()); return Ok(());
} }
// Slow path // Slow path
let mut inner = self.inner.write(); let mut inner = self.inner.write();
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned); inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { Ok(bound_datagram) => bound_datagram,
Ok(bound_datagram) => bound_datagram, Err((err, err_inner)) => {
Err((err, err_inner)) => { return (err_inner, Err(err));
*inner = err_inner; }
return Err(err); };
} bound_datagram.init_pollee(&self.pollee);
}; (Inner::Bound(bound_datagram), Ok(()))
bound_datagram.init_pollee(&self.pollee); })
*inner = Inner::Bound(bound_datagram);
Ok(())
} }
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let inner = self.inner.read(); let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else { let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
}; };
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?; let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?;
@ -135,7 +125,7 @@ impl DatagramSocket {
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<usize> { ) -> Result<usize> {
let inner = self.inner.read(); let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else { let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
}; };
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?; let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?;
@ -167,7 +157,7 @@ impl DatagramSocket {
fn update_io_events(&self) { fn update_io_events(&self) {
let inner = self.inner.read(); let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else { let Inner::Bound(bound_datagram) = inner.as_ref() else {
return; return;
}; };
bound_datagram.update_io_events(&self.pollee); bound_datagram.update_io_events(&self.pollee);
@ -219,18 +209,16 @@ impl Socket for DatagramSocket {
let endpoint = socket_addr.try_into()?; let endpoint = socket_addr.try_into()?;
let mut inner = self.inner.write(); let mut inner = self.inner.write();
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned); inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(&endpoint) {
let bound_datagram = match owned_inner.bind(&endpoint) { Ok(bound_datagram) => bound_datagram,
Ok(bound_datagram) => bound_datagram, Err((err, err_inner)) => {
Err((err, err_inner)) => { return (err_inner, Err(err));
*inner = err_inner; }
return Err(err); };
} bound_datagram.init_pollee(&self.pollee);
}; (Inner::Bound(bound_datagram), Ok(()))
bound_datagram.init_pollee(&self.pollee); })
*inner = Inner::Bound(bound_datagram);
Ok(())
} }
fn connect(&self, socket_addr: SocketAddr) -> Result<()> { fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
@ -239,7 +227,7 @@ impl Socket for DatagramSocket {
self.try_bind_empheral(&endpoint)?; self.try_bind_empheral(&endpoint)?;
let mut inner = self.inner.write(); let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = &mut *inner else { let Inner::Bound(bound_datagram) = inner.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound") return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
}; };
bound_datagram.set_remote_endpoint(&endpoint); bound_datagram.set_remote_endpoint(&endpoint);
@ -249,7 +237,7 @@ impl Socket for DatagramSocket {
fn addr(&self) -> Result<SocketAddr> { fn addr(&self) -> Result<SocketAddr> {
let inner = self.inner.read(); let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else { let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
}; };
bound_datagram.local_endpoint().try_into() bound_datagram.local_endpoint().try_into()
@ -257,7 +245,7 @@ impl Socket for DatagramSocket {
fn peer_addr(&self) -> Result<SocketAddr> { fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.inner.read(); let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else { let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
}; };
bound_datagram.remote_endpoint()?.try_into() bound_datagram.remote_endpoint()?.try_into()

View File

@ -1,9 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::{ use core::sync::atomic::{AtomicBool, Ordering};
mem,
sync::atomic::{AtomicBool, Ordering},
};
use connected::ConnectedStream; use connected::ConnectedStream;
use connecting::ConnectingStream; use connecting::ConnectingStream;
@ -11,6 +8,7 @@ use init::InitStream;
use listen::ListenStream; use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
use smoltcp::wire::IpEndpoint; use smoltcp::wire::IpEndpoint;
use takeable::Takeable;
use util::{TcpOptionSet, DEFAULT_MAXSEG}; use util::{TcpOptionSet, DEFAULT_MAXSEG};
use crate::{ use crate::{
@ -20,7 +18,9 @@ use crate::{
net::{ net::{
poll_ifaces, poll_ifaces,
socket::{ socket::{
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption}, options::{
Error as SocketError, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption,
},
util::{ util::{
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF}, options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
send_recv_flags::SendRecvFlags, send_recv_flags::SendRecvFlags,
@ -46,7 +46,7 @@ pub use self::util::CongestionControl;
pub struct StreamSocket { pub struct StreamSocket {
options: RwLock<OptionSet>, options: RwLock<OptionSet>,
state: RwLock<State>, state: RwLock<Takeable<State>>,
is_nonblocking: AtomicBool, is_nonblocking: AtomicBool,
pollee: Pollee, pollee: Pollee,
} }
@ -60,8 +60,6 @@ enum State {
Connected(ConnectedStream), Connected(ConnectedStream),
// Final State 2 // Final State 2
Listen(ListenStream), Listen(ListenStream),
// Poisoned state
Poisoned,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -85,7 +83,7 @@ impl StreamSocket {
let pollee = Pollee::new(IoEvents::empty()); let pollee = Pollee::new(IoEvents::empty());
Self { Self {
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Init(init_stream)), state: RwLock::new(Takeable::new(State::Init(init_stream))),
is_nonblocking: AtomicBool::new(nonblocking), is_nonblocking: AtomicBool::new(nonblocking),
pollee, pollee,
} }
@ -99,7 +97,7 @@ impl StreamSocket {
connected_stream.init_pollee(&pollee); connected_stream.init_pollee(&pollee);
Self { Self {
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Connected(connected_stream)), state: RwLock::new(Takeable::new(State::Connected(connected_stream))),
is_nonblocking: AtomicBool::new(false), is_nonblocking: AtomicBool::new(false),
pollee, pollee,
} }
@ -117,56 +115,60 @@ impl StreamSocket {
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> { fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
let mut state = self.state.write(); let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned); state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else { let State::Init(init_stream) = owned_state else {
*state = owned_state; return (
return_errno_with_message!(Errno::EINVAL, "cannot connect") owned_state,
}; Err(Error::with_message(Errno::EINVAL, "cannot connect")),
);
};
let connecting_stream = match init_stream.connect(remote_endpoint) { let connecting_stream = match init_stream.connect(remote_endpoint) {
Ok(connecting_stream) => connecting_stream, Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => { Err((err, init_stream)) => {
*state = State::Init(init_stream); return (State::Init(init_stream), Err(err));
return Err(err); }
} };
}; connecting_stream.init_pollee(&self.pollee);
connecting_stream.init_pollee(&self.pollee);
*state = State::Connecting(connecting_stream);
Ok(()) (State::Connecting(connecting_stream), Ok(()))
})
} }
fn finish_connect(&self) -> Result<()> { fn finish_connect(&self) -> Result<()> {
let mut state = self.state.write(); let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned); state.borrow_result(|owned_state| {
let State::Connecting(connecting_stream) = owned_state else { let State::Connecting(connecting_stream) = owned_state else {
*state = owned_state; debug_assert!(false, "the socket unexpectedly left the connecting state");
debug_assert!(false, "the socket unexpectedly left the connecting state"); return (
return_errno_with_message!(Errno::EINVAL, "the socket is not connecting"); owned_state,
}; Err(Error::with_message(
Errno::EINVAL,
"the socket is not connecting",
)),
);
};
let connected_stream = match connecting_stream.into_result() { let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream, Ok(connected_stream) => connected_stream,
Err((err, NonConnectedStream::Init(init_stream))) => { Err((err, NonConnectedStream::Init(init_stream))) => {
*state = State::Init(init_stream); return (State::Init(init_stream), Err(err));
return Err(err); }
} Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
Err((err, NonConnectedStream::Connecting(connecting_stream))) => { return (State::Connecting(connecting_stream), Err(err));
*state = State::Connecting(connecting_stream); }
return Err(err); };
} connected_stream.init_pollee(&self.pollee);
};
connected_stream.init_pollee(&self.pollee);
*state = State::Connected(connected_stream);
Ok(()) (State::Connected(connected_stream), Ok(()))
})
} }
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> { fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read(); let state = self.state.read();
let State::Listen(listen_stream) = &*state else { let State::Listen(listen_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
}; };
@ -181,7 +183,7 @@ impl StreamSocket {
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let state = self.state.read(); let state = self.state.read();
let State::Connected(connected_stream) = &*state else { let State::Connected(connected_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}; };
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?; let recv_bytes = connected_stream.try_recvfrom(buf, flags)?;
@ -192,7 +194,7 @@ impl StreamSocket {
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> { fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let state = self.state.read(); let state = self.state.read();
let State::Connected(connected_stream) = &*state else { let State::Connected(connected_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}; };
let sent_bytes = connected_stream.try_sendto(buf, flags)?; let sent_bytes = connected_stream.try_sendto(buf, flags)?;
@ -224,8 +226,8 @@ impl StreamSocket {
fn update_io_events(&self) { fn update_io_events(&self) {
let state = self.state.read(); let state = self.state.read();
match &*state { match state.as_ref() {
State::Init(_) | State::Poisoned => (), State::Init(_) => (),
State::Connecting(connecting_stream) => { State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee) connecting_stream.update_io_events(&self.pollee)
} }
@ -281,22 +283,23 @@ impl Socket for StreamSocket {
let mut state = self.state.write(); let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned); state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else { let State::Init(init_stream) = owned_state else {
*state = owned_state; return (
return_errno_with_message!(Errno::EINVAL, "cannot bind"); owned_state,
}; Err(Error::with_message(Errno::EINVAL, "cannot bind")),
);
};
let bound_socket = match init_stream.bind(&endpoint) { let bound_socket = match init_stream.bind(&endpoint) {
Ok(bound_socket) => bound_socket, Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => { Err((err, init_stream)) => {
*state = State::Init(init_stream); return (State::Init(init_stream), Err(err));
return Err(err); }
} };
};
*state = State::Init(InitStream::new_bound(bound_socket));
Ok(()) (State::Init(InitStream::new_bound(bound_socket)), Ok(()))
})
} }
// TODO: Support nonblocking mode // TODO: Support nonblocking mode
@ -311,23 +314,24 @@ impl Socket for StreamSocket {
fn listen(&self, backlog: usize) -> Result<()> { fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write(); let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned); state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else { let State::Init(init_stream) = owned_state else {
*state = owned_state; return (
return_errno_with_message!(Errno::EINVAL, "cannot listen"); owned_state,
}; Err(Error::with_message(Errno::EINVAL, "cannot listen")),
);
};
let listen_stream = match init_stream.listen(backlog) { let listen_stream = match init_stream.listen(backlog) {
Ok(listen_stream) => listen_stream, Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => { Err((err, init_stream)) => {
*state = State::Init(init_stream); return (State::Init(init_stream), Err(err));
return Err(err); }
} };
}; listen_stream.init_pollee(&self.pollee);
listen_stream.init_pollee(&self.pollee);
*state = State::Listen(listen_stream);
Ok(()) (State::Listen(listen_stream), Ok(()))
})
} }
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> { fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
@ -341,7 +345,7 @@ impl Socket for StreamSocket {
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let state = self.state.read(); let state = self.state.read();
match &*state { match state.as_ref() {
State::Connected(connected_stream) => connected_stream.shutdown(cmd), State::Connected(connected_stream) => connected_stream.shutdown(cmd),
// TDOD: shutdown listening stream // TDOD: shutdown listening stream
_ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"), _ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"),
@ -350,19 +354,18 @@ impl Socket for StreamSocket {
fn addr(&self) -> Result<SocketAddr> { fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read(); let state = self.state.read();
let local_endpoint = match &*state { let local_endpoint = match state.as_ref() {
State::Init(init_stream) => init_stream.local_endpoint()?, State::Init(init_stream) => init_stream.local_endpoint()?,
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(), State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(), State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(), State::Connected(connected_stream) => connected_stream.local_endpoint(),
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
}; };
local_endpoint.try_into() local_endpoint.try_into()
} }
fn peer_addr(&self) -> Result<SocketAddr> { fn peer_addr(&self) -> Result<SocketAddr> {
let state = self.state.read(); let state = self.state.read();
let remote_endpoint = match &*state { let remote_endpoint = match state.as_ref() {
State::Init(init_stream) => { State::Init(init_stream) => {
return_errno_with_message!(Errno::EINVAL, "init socket does not have peer") return_errno_with_message!(Errno::EINVAL, "init socket does not have peer")
} }
@ -371,7 +374,6 @@ impl Socket for StreamSocket {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer") return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
} }
State::Connected(connected_stream) => connected_stream.remote_endpoint(), State::Connected(connected_stream) => connected_stream.remote_endpoint(),
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
}; };
remote_endpoint.try_into() remote_endpoint.try_into()
} }
@ -412,7 +414,7 @@ impl Socket for StreamSocket {
let options = self.options.read(); let options = self.options.read();
match_sock_option_mut!(option, { match_sock_option_mut!(option, {
// Socket Options // Socket Options
socket_errors: Error => { socket_errors: SocketError => {
let sock_errors = options.socket.sock_errors(); let sock_errors = options.socket.sock_errors();
socket_errors.set(sock_errors); socket_errors.set(sock_errors);
}, },
@ -446,10 +448,8 @@ impl Socket for StreamSocket {
// and always return the actual current MSS for a connected one. // and always return the actual current MSS for a connected one.
// FIXME: how to get the current MSS? // FIXME: how to get the current MSS?
let maxseg = match &*self.state.read() { let maxseg = match self.state.read().as_ref() {
State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => { State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG,
DEFAULT_MAXSEG
}
State::Connected(_) => options.tcp.maxseg(), State::Connected(_) => options.tcp.maxseg(),
}; };
tcp_maxseg.set(maxseg); tcp_maxseg.set(maxseg);