mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-27 03:13:23 +00:00
Replace Poisoned
state by takeable
crate
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
a10d04c5f9
commit
595c6ab288
7
Cargo.lock
generated
7
Cargo.lock
generated
@ -226,6 +226,7 @@ dependencies = [
|
||||
"smoltcp",
|
||||
"spin 0.9.8",
|
||||
"static_assertions",
|
||||
"takeable",
|
||||
"tdx-guest",
|
||||
"time",
|
||||
"typeflags",
|
||||
@ -1342,6 +1343,12 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "takeable"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7efe11d01772262041364ebf855e0c6c9e2bd0dd4f059083e540a5f1cda69477"
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
|
@ -68,6 +68,7 @@ getset = "0.1.2"
|
||||
atomic = "0.6"
|
||||
bytemuck = "1.14.3"
|
||||
bytemuck_derive = "1.5.0"
|
||||
takeable = "0.2.2"
|
||||
|
||||
[dependencies.lazy_static]
|
||||
version = "1.0"
|
||||
|
@ -1,9 +1,8 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::{
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use takeable::Takeable;
|
||||
|
||||
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
|
||||
use super::common::get_ephemeral_endpoint;
|
||||
@ -26,7 +25,7 @@ mod bound;
|
||||
mod unbound;
|
||||
|
||||
pub struct DatagramSocket {
|
||||
inner: RwLock<Inner>,
|
||||
inner: RwLock<Takeable<Inner>>,
|
||||
nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
@ -34,7 +33,6 @@ pub struct DatagramSocket {
|
||||
enum Inner {
|
||||
Unbound(UnboundDatagram),
|
||||
Bound(BoundDatagram),
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
@ -47,12 +45,6 @@ impl Inner {
|
||||
Inner::Bound(bound_datagram),
|
||||
));
|
||||
}
|
||||
Inner::Poisoned => {
|
||||
return Err((
|
||||
Error::with_message(Errno::EINVAL, "the socket is poisoned"),
|
||||
Inner::Poisoned,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let bound_datagram = match unbound_datagram.bind(endpoint) {
|
||||
@ -81,7 +73,7 @@ impl DatagramSocket {
|
||||
let unbound_datagram = UnboundDatagram::new(me.clone() as _);
|
||||
let pollee = Pollee::new(IoEvents::empty());
|
||||
Self {
|
||||
inner: RwLock::new(Inner::Unbound(unbound_datagram)),
|
||||
inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))),
|
||||
nonblocking: AtomicBool::new(nonblocking),
|
||||
pollee,
|
||||
}
|
||||
@ -98,29 +90,27 @@ impl DatagramSocket {
|
||||
|
||||
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
// Fast path
|
||||
if let Inner::Bound(_) = &*self.inner.read() {
|
||||
if let Inner::Bound(_) = self.inner.read().as_ref() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Slow path
|
||||
let mut inner = self.inner.write();
|
||||
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
|
||||
|
||||
inner.borrow_result(|owned_inner| {
|
||||
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
Err((err, err_inner)) => {
|
||||
*inner = err_inner;
|
||||
return Err(err);
|
||||
return (err_inner, Err(err));
|
||||
}
|
||||
};
|
||||
bound_datagram.init_pollee(&self.pollee);
|
||||
*inner = Inner::Bound(bound_datagram);
|
||||
Ok(())
|
||||
(Inner::Bound(bound_datagram), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
let Inner::Bound(bound_datagram) = inner.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?;
|
||||
@ -135,7 +125,7 @@ impl DatagramSocket {
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
let Inner::Bound(bound_datagram) = inner.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?;
|
||||
@ -167,7 +157,7 @@ impl DatagramSocket {
|
||||
|
||||
fn update_io_events(&self) {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
let Inner::Bound(bound_datagram) = inner.as_ref() else {
|
||||
return;
|
||||
};
|
||||
bound_datagram.update_io_events(&self.pollee);
|
||||
@ -219,18 +209,16 @@ impl Socket for DatagramSocket {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
|
||||
|
||||
inner.borrow_result(|owned_inner| {
|
||||
let bound_datagram = match owned_inner.bind(&endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
Err((err, err_inner)) => {
|
||||
*inner = err_inner;
|
||||
return Err(err);
|
||||
return (err_inner, Err(err));
|
||||
}
|
||||
};
|
||||
bound_datagram.init_pollee(&self.pollee);
|
||||
*inner = Inner::Bound(bound_datagram);
|
||||
Ok(())
|
||||
(Inner::Bound(bound_datagram), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
@ -239,7 +227,7 @@ impl Socket for DatagramSocket {
|
||||
self.try_bind_empheral(&endpoint)?;
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
let Inner::Bound(bound_datagram) = &mut *inner else {
|
||||
let Inner::Bound(bound_datagram) = inner.as_mut() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
|
||||
};
|
||||
bound_datagram.set_remote_endpoint(&endpoint);
|
||||
@ -249,7 +237,7 @@ impl Socket for DatagramSocket {
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
let Inner::Bound(bound_datagram) = inner.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
bound_datagram.local_endpoint().try_into()
|
||||
@ -257,7 +245,7 @@ impl Socket for DatagramSocket {
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
let Inner::Bound(bound_datagram) = inner.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
bound_datagram.remote_endpoint()?.try_into()
|
||||
|
@ -1,9 +1,6 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::{
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use connected::ConnectedStream;
|
||||
use connecting::ConnectingStream;
|
||||
@ -11,6 +8,7 @@ use init::InitStream;
|
||||
use listen::ListenStream;
|
||||
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
|
||||
use smoltcp::wire::IpEndpoint;
|
||||
use takeable::Takeable;
|
||||
use util::{TcpOptionSet, DEFAULT_MAXSEG};
|
||||
|
||||
use crate::{
|
||||
@ -20,7 +18,9 @@ use crate::{
|
||||
net::{
|
||||
poll_ifaces,
|
||||
socket::{
|
||||
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
|
||||
options::{
|
||||
Error as SocketError, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption,
|
||||
},
|
||||
util::{
|
||||
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
|
||||
send_recv_flags::SendRecvFlags,
|
||||
@ -46,7 +46,7 @@ pub use self::util::CongestionControl;
|
||||
|
||||
pub struct StreamSocket {
|
||||
options: RwLock<OptionSet>,
|
||||
state: RwLock<State>,
|
||||
state: RwLock<Takeable<State>>,
|
||||
is_nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
@ -60,8 +60,6 @@ enum State {
|
||||
Connected(ConnectedStream),
|
||||
// Final State 2
|
||||
Listen(ListenStream),
|
||||
// Poisoned state
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -85,7 +83,7 @@ impl StreamSocket {
|
||||
let pollee = Pollee::new(IoEvents::empty());
|
||||
Self {
|
||||
options: RwLock::new(OptionSet::new()),
|
||||
state: RwLock::new(State::Init(init_stream)),
|
||||
state: RwLock::new(Takeable::new(State::Init(init_stream))),
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
pollee,
|
||||
}
|
||||
@ -99,7 +97,7 @@ impl StreamSocket {
|
||||
connected_stream.init_pollee(&pollee);
|
||||
Self {
|
||||
options: RwLock::new(OptionSet::new()),
|
||||
state: RwLock::new(State::Connected(connected_stream)),
|
||||
state: RwLock::new(Takeable::new(State::Connected(connected_stream))),
|
||||
is_nonblocking: AtomicBool::new(false),
|
||||
pollee,
|
||||
}
|
||||
@ -117,56 +115,60 @@ impl StreamSocket {
|
||||
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot connect")
|
||||
return (
|
||||
owned_state,
|
||||
Err(Error::with_message(Errno::EINVAL, "cannot connect")),
|
||||
);
|
||||
};
|
||||
|
||||
let connecting_stream = match init_stream.connect(remote_endpoint) {
|
||||
Ok(connecting_stream) => connecting_stream,
|
||||
Err((err, init_stream)) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
return (State::Init(init_stream), Err(err));
|
||||
}
|
||||
};
|
||||
connecting_stream.init_pollee(&self.pollee);
|
||||
*state = State::Connecting(connecting_stream);
|
||||
|
||||
Ok(())
|
||||
(State::Connecting(connecting_stream), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn finish_connect(&self) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Connecting(connecting_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
debug_assert!(false, "the socket unexpectedly left the connecting state");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connecting");
|
||||
return (
|
||||
owned_state,
|
||||
Err(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket is not connecting",
|
||||
)),
|
||||
);
|
||||
};
|
||||
|
||||
let connected_stream = match connecting_stream.into_result() {
|
||||
Ok(connected_stream) => connected_stream,
|
||||
Err((err, NonConnectedStream::Init(init_stream))) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
return (State::Init(init_stream), Err(err));
|
||||
}
|
||||
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
|
||||
*state = State::Connecting(connecting_stream);
|
||||
return Err(err);
|
||||
return (State::Connecting(connecting_stream), Err(err));
|
||||
}
|
||||
};
|
||||
connected_stream.init_pollee(&self.pollee);
|
||||
*state = State::Connected(connected_stream);
|
||||
|
||||
Ok(())
|
||||
(State::Connected(connected_stream), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
|
||||
let State::Listen(listen_stream) = &*state else {
|
||||
let State::Listen(listen_stream) = state.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
|
||||
};
|
||||
|
||||
@ -181,7 +183,7 @@ impl StreamSocket {
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
|
||||
let State::Connected(connected_stream) = &*state else {
|
||||
let State::Connected(connected_stream) = state.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
};
|
||||
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?;
|
||||
@ -192,7 +194,7 @@ impl StreamSocket {
|
||||
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let state = self.state.read();
|
||||
|
||||
let State::Connected(connected_stream) = &*state else {
|
||||
let State::Connected(connected_stream) = state.as_ref() else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
};
|
||||
let sent_bytes = connected_stream.try_sendto(buf, flags)?;
|
||||
@ -224,8 +226,8 @@ impl StreamSocket {
|
||||
|
||||
fn update_io_events(&self) {
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
State::Init(_) | State::Poisoned => (),
|
||||
match state.as_ref() {
|
||||
State::Init(_) => (),
|
||||
State::Connecting(connecting_stream) => {
|
||||
connecting_stream.update_io_events(&self.pollee)
|
||||
}
|
||||
@ -281,22 +283,23 @@ impl Socket for StreamSocket {
|
||||
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot bind");
|
||||
return (
|
||||
owned_state,
|
||||
Err(Error::with_message(Errno::EINVAL, "cannot bind")),
|
||||
);
|
||||
};
|
||||
|
||||
let bound_socket = match init_stream.bind(&endpoint) {
|
||||
Ok(bound_socket) => bound_socket,
|
||||
Err((err, init_stream)) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
return (State::Init(init_stream), Err(err));
|
||||
}
|
||||
};
|
||||
*state = State::Init(InitStream::new_bound(bound_socket));
|
||||
|
||||
Ok(())
|
||||
(State::Init(InitStream::new_bound(bound_socket)), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: Support nonblocking mode
|
||||
@ -311,23 +314,24 @@ impl Socket for StreamSocket {
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen");
|
||||
return (
|
||||
owned_state,
|
||||
Err(Error::with_message(Errno::EINVAL, "cannot listen")),
|
||||
);
|
||||
};
|
||||
|
||||
let listen_stream = match init_stream.listen(backlog) {
|
||||
Ok(listen_stream) => listen_stream,
|
||||
Err((err, init_stream)) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
return (State::Init(init_stream), Err(err));
|
||||
}
|
||||
};
|
||||
listen_stream.init_pollee(&self.pollee);
|
||||
*state = State::Listen(listen_stream);
|
||||
|
||||
Ok(())
|
||||
(State::Listen(listen_stream), Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
@ -341,7 +345,7 @@ impl Socket for StreamSocket {
|
||||
|
||||
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
match state.as_ref() {
|
||||
State::Connected(connected_stream) => connected_stream.shutdown(cmd),
|
||||
// TDOD: shutdown listening stream
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"),
|
||||
@ -350,19 +354,18 @@ impl Socket for StreamSocket {
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let state = self.state.read();
|
||||
let local_endpoint = match &*state {
|
||||
let local_endpoint = match state.as_ref() {
|
||||
State::Init(init_stream) => init_stream.local_endpoint()?,
|
||||
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
|
||||
State::Listen(listen_stream) => listen_stream.local_endpoint(),
|
||||
State::Connected(connected_stream) => connected_stream.local_endpoint(),
|
||||
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
|
||||
};
|
||||
local_endpoint.try_into()
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let state = self.state.read();
|
||||
let remote_endpoint = match &*state {
|
||||
let remote_endpoint = match state.as_ref() {
|
||||
State::Init(init_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "init socket does not have peer")
|
||||
}
|
||||
@ -371,7 +374,6 @@ impl Socket for StreamSocket {
|
||||
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
|
||||
}
|
||||
State::Connected(connected_stream) => connected_stream.remote_endpoint(),
|
||||
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
|
||||
};
|
||||
remote_endpoint.try_into()
|
||||
}
|
||||
@ -412,7 +414,7 @@ impl Socket for StreamSocket {
|
||||
let options = self.options.read();
|
||||
match_sock_option_mut!(option, {
|
||||
// Socket Options
|
||||
socket_errors: Error => {
|
||||
socket_errors: SocketError => {
|
||||
let sock_errors = options.socket.sock_errors();
|
||||
socket_errors.set(sock_errors);
|
||||
},
|
||||
@ -446,10 +448,8 @@ impl Socket for StreamSocket {
|
||||
// and always return the actual current MSS for a connected one.
|
||||
|
||||
// FIXME: how to get the current MSS?
|
||||
let maxseg = match &*self.state.read() {
|
||||
State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => {
|
||||
DEFAULT_MAXSEG
|
||||
}
|
||||
let maxseg = match self.state.read().as_ref() {
|
||||
State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG,
|
||||
State::Connected(_) => options.tcp.maxseg(),
|
||||
};
|
||||
tcp_maxseg.set(maxseg);
|
||||
|
Reference in New Issue
Block a user