Replace Poisoned state by takeable crate

This commit is contained in:
Ruihan Li
2024-02-20 23:10:06 +08:00
committed by Tate, Hongliang Tian
parent a10d04c5f9
commit 595c6ab288
4 changed files with 129 additions and 133 deletions

View File

@ -1,9 +1,8 @@
// SPDX-License-Identifier: MPL-2.0
use core::{
mem,
sync::atomic::{AtomicBool, Ordering},
};
use core::sync::atomic::{AtomicBool, Ordering};
use takeable::Takeable;
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
use super::common::get_ephemeral_endpoint;
@ -26,7 +25,7 @@ mod bound;
mod unbound;
pub struct DatagramSocket {
inner: RwLock<Inner>,
inner: RwLock<Takeable<Inner>>,
nonblocking: AtomicBool,
pollee: Pollee,
}
@ -34,7 +33,6 @@ pub struct DatagramSocket {
enum Inner {
Unbound(UnboundDatagram),
Bound(BoundDatagram),
Poisoned,
}
impl Inner {
@ -47,12 +45,6 @@ impl Inner {
Inner::Bound(bound_datagram),
));
}
Inner::Poisoned => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is poisoned"),
Inner::Poisoned,
));
}
};
let bound_datagram = match unbound_datagram.bind(endpoint) {
@ -81,7 +73,7 @@ impl DatagramSocket {
let unbound_datagram = UnboundDatagram::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
Self {
inner: RwLock::new(Inner::Unbound(unbound_datagram)),
inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))),
nonblocking: AtomicBool::new(nonblocking),
pollee,
}
@ -98,29 +90,27 @@ impl DatagramSocket {
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
// Fast path
if let Inner::Bound(_) = &*self.inner.read() {
if let Inner::Bound(_) = self.inner.read().as_ref() {
return Ok(());
}
// Slow path
let mut inner = self.inner.write();
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
*inner = err_inner;
return Err(err);
}
};
bound_datagram.init_pollee(&self.pollee);
*inner = Inner::Bound(bound_datagram);
Ok(())
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
return (err_inner, Err(err));
}
};
bound_datagram.init_pollee(&self.pollee);
(Inner::Bound(bound_datagram), Ok(()))
})
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?;
@ -135,7 +125,7 @@ impl DatagramSocket {
flags: SendRecvFlags,
) -> Result<usize> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?;
@ -167,7 +157,7 @@ impl DatagramSocket {
fn update_io_events(&self) {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return;
};
bound_datagram.update_io_events(&self.pollee);
@ -219,18 +209,16 @@ impl Socket for DatagramSocket {
let endpoint = socket_addr.try_into()?;
let mut inner = self.inner.write();
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
let bound_datagram = match owned_inner.bind(&endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
*inner = err_inner;
return Err(err);
}
};
bound_datagram.init_pollee(&self.pollee);
*inner = Inner::Bound(bound_datagram);
Ok(())
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(&endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
return (err_inner, Err(err));
}
};
bound_datagram.init_pollee(&self.pollee);
(Inner::Bound(bound_datagram), Ok(()))
})
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
@ -239,7 +227,7 @@ impl Socket for DatagramSocket {
self.try_bind_empheral(&endpoint)?;
let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = &mut *inner else {
let Inner::Bound(bound_datagram) = inner.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
};
bound_datagram.set_remote_endpoint(&endpoint);
@ -249,7 +237,7 @@ impl Socket for DatagramSocket {
fn addr(&self) -> Result<SocketAddr> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
bound_datagram.local_endpoint().try_into()
@ -257,7 +245,7 @@ impl Socket for DatagramSocket {
fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
bound_datagram.remote_endpoint()?.try_into()

View File

@ -1,9 +1,6 @@
// SPDX-License-Identifier: MPL-2.0
use core::{
mem,
sync::atomic::{AtomicBool, Ordering},
};
use core::sync::atomic::{AtomicBool, Ordering};
use connected::ConnectedStream;
use connecting::ConnectingStream;
@ -11,6 +8,7 @@ use init::InitStream;
use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
use smoltcp::wire::IpEndpoint;
use takeable::Takeable;
use util::{TcpOptionSet, DEFAULT_MAXSEG};
use crate::{
@ -20,7 +18,9 @@ use crate::{
net::{
poll_ifaces,
socket::{
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
options::{
Error as SocketError, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption,
},
util::{
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
send_recv_flags::SendRecvFlags,
@ -46,7 +46,7 @@ pub use self::util::CongestionControl;
pub struct StreamSocket {
options: RwLock<OptionSet>,
state: RwLock<State>,
state: RwLock<Takeable<State>>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
@ -60,8 +60,6 @@ enum State {
Connected(ConnectedStream),
// Final State 2
Listen(ListenStream),
// Poisoned state
Poisoned,
}
#[derive(Debug, Clone)]
@ -85,7 +83,7 @@ impl StreamSocket {
let pollee = Pollee::new(IoEvents::empty());
Self {
options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Init(init_stream)),
state: RwLock::new(Takeable::new(State::Init(init_stream))),
is_nonblocking: AtomicBool::new(nonblocking),
pollee,
}
@ -99,7 +97,7 @@ impl StreamSocket {
connected_stream.init_pollee(&pollee);
Self {
options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Connected(connected_stream)),
state: RwLock::new(Takeable::new(State::Connected(connected_stream))),
is_nonblocking: AtomicBool::new(false),
pollee,
}
@ -117,56 +115,60 @@ impl StreamSocket {
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot connect")
};
state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else {
return (
owned_state,
Err(Error::with_message(Errno::EINVAL, "cannot connect")),
);
};
let connecting_stream = match init_stream.connect(remote_endpoint) {
Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
connecting_stream.init_pollee(&self.pollee);
*state = State::Connecting(connecting_stream);
let connecting_stream = match init_stream.connect(remote_endpoint) {
Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err));
}
};
connecting_stream.init_pollee(&self.pollee);
Ok(())
(State::Connecting(connecting_stream), Ok(()))
})
}
fn finish_connect(&self) -> Result<()> {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Connecting(connecting_stream) = owned_state else {
*state = owned_state;
debug_assert!(false, "the socket unexpectedly left the connecting state");
return_errno_with_message!(Errno::EINVAL, "the socket is not connecting");
};
state.borrow_result(|owned_state| {
let State::Connecting(connecting_stream) = owned_state else {
debug_assert!(false, "the socket unexpectedly left the connecting state");
return (
owned_state,
Err(Error::with_message(
Errno::EINVAL,
"the socket is not connecting",
)),
);
};
let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream,
Err((err, NonConnectedStream::Init(init_stream))) => {
*state = State::Init(init_stream);
return Err(err);
}
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
*state = State::Connecting(connecting_stream);
return Err(err);
}
};
connected_stream.init_pollee(&self.pollee);
*state = State::Connected(connected_stream);
let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream,
Err((err, NonConnectedStream::Init(init_stream))) => {
return (State::Init(init_stream), Err(err));
}
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
return (State::Connecting(connecting_stream), Err(err));
}
};
connected_stream.init_pollee(&self.pollee);
Ok(())
(State::Connected(connected_stream), Ok(()))
})
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read();
let State::Listen(listen_stream) = &*state else {
let State::Listen(listen_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
};
@ -181,7 +183,7 @@ impl StreamSocket {
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let state = self.state.read();
let State::Connected(connected_stream) = &*state else {
let State::Connected(connected_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
};
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?;
@ -192,7 +194,7 @@ impl StreamSocket {
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let state = self.state.read();
let State::Connected(connected_stream) = &*state else {
let State::Connected(connected_stream) = state.as_ref() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
};
let sent_bytes = connected_stream.try_sendto(buf, flags)?;
@ -224,8 +226,8 @@ impl StreamSocket {
fn update_io_events(&self) {
let state = self.state.read();
match &*state {
State::Init(_) | State::Poisoned => (),
match state.as_ref() {
State::Init(_) => (),
State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee)
}
@ -281,22 +283,23 @@ impl Socket for StreamSocket {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot bind");
};
state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else {
return (
owned_state,
Err(Error::with_message(Errno::EINVAL, "cannot bind")),
);
};
let bound_socket = match init_stream.bind(&endpoint) {
Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
*state = State::Init(InitStream::new_bound(bound_socket));
let bound_socket = match init_stream.bind(&endpoint) {
Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err));
}
};
Ok(())
(State::Init(InitStream::new_bound(bound_socket)), Ok(()))
})
}
// TODO: Support nonblocking mode
@ -311,23 +314,24 @@ impl Socket for StreamSocket {
fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot listen");
};
state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else {
return (
owned_state,
Err(Error::with_message(Errno::EINVAL, "cannot listen")),
);
};
let listen_stream = match init_stream.listen(backlog) {
Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
listen_stream.init_pollee(&self.pollee);
*state = State::Listen(listen_stream);
let listen_stream = match init_stream.listen(backlog) {
Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err));
}
};
listen_stream.init_pollee(&self.pollee);
Ok(())
(State::Listen(listen_stream), Ok(()))
})
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
@ -341,7 +345,7 @@ impl Socket for StreamSocket {
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let state = self.state.read();
match &*state {
match state.as_ref() {
State::Connected(connected_stream) => connected_stream.shutdown(cmd),
// TDOD: shutdown listening stream
_ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"),
@ -350,19 +354,18 @@ impl Socket for StreamSocket {
fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let local_endpoint = match &*state {
let local_endpoint = match state.as_ref() {
State::Init(init_stream) => init_stream.local_endpoint()?,
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(),
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
};
local_endpoint.try_into()
}
fn peer_addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let remote_endpoint = match &*state {
let remote_endpoint = match state.as_ref() {
State::Init(init_stream) => {
return_errno_with_message!(Errno::EINVAL, "init socket does not have peer")
}
@ -371,7 +374,6 @@ impl Socket for StreamSocket {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
}
State::Connected(connected_stream) => connected_stream.remote_endpoint(),
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
};
remote_endpoint.try_into()
}
@ -412,7 +414,7 @@ impl Socket for StreamSocket {
let options = self.options.read();
match_sock_option_mut!(option, {
// Socket Options
socket_errors: Error => {
socket_errors: SocketError => {
let sock_errors = options.socket.sock_errors();
socket_errors.set(sock_errors);
},
@ -446,10 +448,8 @@ impl Socket for StreamSocket {
// and always return the actual current MSS for a connected one.
// FIXME: how to get the current MSS?
let maxseg = match &*self.state.read() {
State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => {
DEFAULT_MAXSEG
}
let maxseg = match self.state.read().as_ref() {
State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG,
State::Connected(_) => options.tcp.maxseg(),
};
tcp_maxseg.set(maxseg);