From 595c6ab288ad42e95af4a3f8d1060ae39a63b174 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Tue, 20 Feb 2024 23:10:06 +0800 Subject: [PATCH] Replace `Poisoned` state by `takeable` crate --- Cargo.lock | 7 + kernel/aster-nix/Cargo.toml | 1 + .../src/net/socket/ip/datagram/mod.rs | 76 ++++---- .../aster-nix/src/net/socket/ip/stream/mod.rs | 178 +++++++++--------- 4 files changed, 129 insertions(+), 133 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b1c21d916..3ab3aece4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -226,6 +226,7 @@ dependencies = [ "smoltcp", "spin 0.9.8", "static_assertions", + "takeable", "tdx-guest", "time", "typeflags", @@ -1342,6 +1343,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "takeable" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7efe11d01772262041364ebf855e0c6c9e2bd0dd4f059083e540a5f1cda69477" + [[package]] name = "tap" version = "1.0.1" diff --git a/kernel/aster-nix/Cargo.toml b/kernel/aster-nix/Cargo.toml index 45f88f69b..8028f8fe9 100644 --- a/kernel/aster-nix/Cargo.toml +++ b/kernel/aster-nix/Cargo.toml @@ -68,6 +68,7 @@ getset = "0.1.2" atomic = "0.6" bytemuck = "1.14.3" bytemuck_derive = "1.5.0" +takeable = "0.2.2" [dependencies.lazy_static] version = "1.0" diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs index f73547309..c48535ee8 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs @@ -1,9 +1,8 @@ // SPDX-License-Identifier: MPL-2.0 -use core::{ - mem, - sync::atomic::{AtomicBool, Ordering}, -}; +use core::sync::atomic::{AtomicBool, Ordering}; + +use takeable::Takeable; use self::{bound::BoundDatagram, unbound::UnboundDatagram}; use super::common::get_ephemeral_endpoint; @@ -26,7 +25,7 @@ mod bound; mod unbound; pub struct DatagramSocket { - inner: RwLock, + inner: RwLock>, nonblocking: AtomicBool, pollee: Pollee, } @@ -34,7 +33,6 @@ pub struct DatagramSocket { enum Inner { Unbound(UnboundDatagram), Bound(BoundDatagram), - Poisoned, } impl Inner { @@ -47,12 +45,6 @@ impl Inner { Inner::Bound(bound_datagram), )); } - Inner::Poisoned => { - return Err(( - Error::with_message(Errno::EINVAL, "the socket is poisoned"), - Inner::Poisoned, - )); - } }; let bound_datagram = match unbound_datagram.bind(endpoint) { @@ -81,7 +73,7 @@ impl DatagramSocket { let unbound_datagram = UnboundDatagram::new(me.clone() as _); let pollee = Pollee::new(IoEvents::empty()); Self { - inner: RwLock::new(Inner::Unbound(unbound_datagram)), + inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), nonblocking: AtomicBool::new(nonblocking), pollee, } @@ -98,29 +90,27 @@ impl DatagramSocket { fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> { // Fast path - if let Inner::Bound(_) = &*self.inner.read() { + if let Inner::Bound(_) = self.inner.read().as_ref() { return Ok(()); } // Slow path let mut inner = self.inner.write(); - let owned_inner = mem::replace(&mut *inner, Inner::Poisoned); - - let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { - Ok(bound_datagram) => bound_datagram, - Err((err, err_inner)) => { - *inner = err_inner; - return Err(err); - } - }; - bound_datagram.init_pollee(&self.pollee); - *inner = Inner::Bound(bound_datagram); - Ok(()) + inner.borrow_result(|owned_inner| { + let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { + Ok(bound_datagram) => bound_datagram, + Err((err, err_inner)) => { + return (err_inner, Err(err)); + } + }; + bound_datagram.init_pollee(&self.pollee); + (Inner::Bound(bound_datagram), Ok(())) + }) } fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { let inner = self.inner.read(); - let Inner::Bound(bound_datagram) = &*inner else { + let Inner::Bound(bound_datagram) = inner.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); }; let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?; @@ -135,7 +125,7 @@ impl DatagramSocket { flags: SendRecvFlags, ) -> Result { let inner = self.inner.read(); - let Inner::Bound(bound_datagram) = &*inner else { + let Inner::Bound(bound_datagram) = inner.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); }; let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?; @@ -167,7 +157,7 @@ impl DatagramSocket { fn update_io_events(&self) { let inner = self.inner.read(); - let Inner::Bound(bound_datagram) = &*inner else { + let Inner::Bound(bound_datagram) = inner.as_ref() else { return; }; bound_datagram.update_io_events(&self.pollee); @@ -219,18 +209,16 @@ impl Socket for DatagramSocket { let endpoint = socket_addr.try_into()?; let mut inner = self.inner.write(); - let owned_inner = mem::replace(&mut *inner, Inner::Poisoned); - - let bound_datagram = match owned_inner.bind(&endpoint) { - Ok(bound_datagram) => bound_datagram, - Err((err, err_inner)) => { - *inner = err_inner; - return Err(err); - } - }; - bound_datagram.init_pollee(&self.pollee); - *inner = Inner::Bound(bound_datagram); - Ok(()) + inner.borrow_result(|owned_inner| { + let bound_datagram = match owned_inner.bind(&endpoint) { + Ok(bound_datagram) => bound_datagram, + Err((err, err_inner)) => { + return (err_inner, Err(err)); + } + }; + bound_datagram.init_pollee(&self.pollee); + (Inner::Bound(bound_datagram), Ok(())) + }) } fn connect(&self, socket_addr: SocketAddr) -> Result<()> { @@ -239,7 +227,7 @@ impl Socket for DatagramSocket { self.try_bind_empheral(&endpoint)?; let mut inner = self.inner.write(); - let Inner::Bound(bound_datagram) = &mut *inner else { + let Inner::Bound(bound_datagram) = inner.as_mut() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound") }; bound_datagram.set_remote_endpoint(&endpoint); @@ -249,7 +237,7 @@ impl Socket for DatagramSocket { fn addr(&self) -> Result { let inner = self.inner.read(); - let Inner::Bound(bound_datagram) = &*inner else { + let Inner::Bound(bound_datagram) = inner.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); }; bound_datagram.local_endpoint().try_into() @@ -257,7 +245,7 @@ impl Socket for DatagramSocket { fn peer_addr(&self) -> Result { let inner = self.inner.read(); - let Inner::Bound(bound_datagram) = &*inner else { + let Inner::Bound(bound_datagram) = inner.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); }; bound_datagram.remote_endpoint()?.try_into() diff --git a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs index df23c2899..32b95e52a 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs @@ -1,9 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use core::{ - mem, - sync::atomic::{AtomicBool, Ordering}, -}; +use core::sync::atomic::{AtomicBool, Ordering}; use connected::ConnectedStream; use connecting::ConnectingStream; @@ -11,6 +8,7 @@ use init::InitStream; use listen::ListenStream; use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; use smoltcp::wire::IpEndpoint; +use takeable::Takeable; use util::{TcpOptionSet, DEFAULT_MAXSEG}; use crate::{ @@ -20,7 +18,9 @@ use crate::{ net::{ poll_ifaces, socket::{ - options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption}, + options::{ + Error as SocketError, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption, + }, util::{ options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF}, send_recv_flags::SendRecvFlags, @@ -46,7 +46,7 @@ pub use self::util::CongestionControl; pub struct StreamSocket { options: RwLock, - state: RwLock, + state: RwLock>, is_nonblocking: AtomicBool, pollee: Pollee, } @@ -60,8 +60,6 @@ enum State { Connected(ConnectedStream), // Final State 2 Listen(ListenStream), - // Poisoned state - Poisoned, } #[derive(Debug, Clone)] @@ -85,7 +83,7 @@ impl StreamSocket { let pollee = Pollee::new(IoEvents::empty()); Self { options: RwLock::new(OptionSet::new()), - state: RwLock::new(State::Init(init_stream)), + state: RwLock::new(Takeable::new(State::Init(init_stream))), is_nonblocking: AtomicBool::new(nonblocking), pollee, } @@ -99,7 +97,7 @@ impl StreamSocket { connected_stream.init_pollee(&pollee); Self { options: RwLock::new(OptionSet::new()), - state: RwLock::new(State::Connected(connected_stream)), + state: RwLock::new(Takeable::new(State::Connected(connected_stream))), is_nonblocking: AtomicBool::new(false), pollee, } @@ -117,56 +115,60 @@ impl StreamSocket { fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> { let mut state = self.state.write(); - let owned_state = mem::replace(&mut *state, State::Poisoned); - let State::Init(init_stream) = owned_state else { - *state = owned_state; - return_errno_with_message!(Errno::EINVAL, "cannot connect") - }; + state.borrow_result(|owned_state| { + let State::Init(init_stream) = owned_state else { + return ( + owned_state, + Err(Error::with_message(Errno::EINVAL, "cannot connect")), + ); + }; - let connecting_stream = match init_stream.connect(remote_endpoint) { - Ok(connecting_stream) => connecting_stream, - Err((err, init_stream)) => { - *state = State::Init(init_stream); - return Err(err); - } - }; - connecting_stream.init_pollee(&self.pollee); - *state = State::Connecting(connecting_stream); + let connecting_stream = match init_stream.connect(remote_endpoint) { + Ok(connecting_stream) => connecting_stream, + Err((err, init_stream)) => { + return (State::Init(init_stream), Err(err)); + } + }; + connecting_stream.init_pollee(&self.pollee); - Ok(()) + (State::Connecting(connecting_stream), Ok(())) + }) } fn finish_connect(&self) -> Result<()> { let mut state = self.state.write(); - let owned_state = mem::replace(&mut *state, State::Poisoned); - let State::Connecting(connecting_stream) = owned_state else { - *state = owned_state; - debug_assert!(false, "the socket unexpectedly left the connecting state"); - return_errno_with_message!(Errno::EINVAL, "the socket is not connecting"); - }; + state.borrow_result(|owned_state| { + let State::Connecting(connecting_stream) = owned_state else { + debug_assert!(false, "the socket unexpectedly left the connecting state"); + return ( + owned_state, + Err(Error::with_message( + Errno::EINVAL, + "the socket is not connecting", + )), + ); + }; - let connected_stream = match connecting_stream.into_result() { - Ok(connected_stream) => connected_stream, - Err((err, NonConnectedStream::Init(init_stream))) => { - *state = State::Init(init_stream); - return Err(err); - } - Err((err, NonConnectedStream::Connecting(connecting_stream))) => { - *state = State::Connecting(connecting_stream); - return Err(err); - } - }; - connected_stream.init_pollee(&self.pollee); - *state = State::Connected(connected_stream); + let connected_stream = match connecting_stream.into_result() { + Ok(connected_stream) => connected_stream, + Err((err, NonConnectedStream::Init(init_stream))) => { + return (State::Init(init_stream), Err(err)); + } + Err((err, NonConnectedStream::Connecting(connecting_stream))) => { + return (State::Connecting(connecting_stream), Err(err)); + } + }; + connected_stream.init_pollee(&self.pollee); - Ok(()) + (State::Connected(connected_stream), Ok(())) + }) } fn try_accept(&self) -> Result<(Arc, SocketAddr)> { let state = self.state.read(); - let State::Listen(listen_stream) = &*state else { + let State::Listen(listen_stream) = state.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); }; @@ -181,7 +183,7 @@ impl StreamSocket { fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { let state = self.state.read(); - let State::Connected(connected_stream) = &*state else { + let State::Connected(connected_stream) = state.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); }; let recv_bytes = connected_stream.try_recvfrom(buf, flags)?; @@ -192,7 +194,7 @@ impl StreamSocket { fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { let state = self.state.read(); - let State::Connected(connected_stream) = &*state else { + let State::Connected(connected_stream) = state.as_ref() else { return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); }; let sent_bytes = connected_stream.try_sendto(buf, flags)?; @@ -224,8 +226,8 @@ impl StreamSocket { fn update_io_events(&self) { let state = self.state.read(); - match &*state { - State::Init(_) | State::Poisoned => (), + match state.as_ref() { + State::Init(_) => (), State::Connecting(connecting_stream) => { connecting_stream.update_io_events(&self.pollee) } @@ -281,22 +283,23 @@ impl Socket for StreamSocket { let mut state = self.state.write(); - let owned_state = mem::replace(&mut *state, State::Poisoned); - let State::Init(init_stream) = owned_state else { - *state = owned_state; - return_errno_with_message!(Errno::EINVAL, "cannot bind"); - }; + state.borrow_result(|owned_state| { + let State::Init(init_stream) = owned_state else { + return ( + owned_state, + Err(Error::with_message(Errno::EINVAL, "cannot bind")), + ); + }; - let bound_socket = match init_stream.bind(&endpoint) { - Ok(bound_socket) => bound_socket, - Err((err, init_stream)) => { - *state = State::Init(init_stream); - return Err(err); - } - }; - *state = State::Init(InitStream::new_bound(bound_socket)); + let bound_socket = match init_stream.bind(&endpoint) { + Ok(bound_socket) => bound_socket, + Err((err, init_stream)) => { + return (State::Init(init_stream), Err(err)); + } + }; - Ok(()) + (State::Init(InitStream::new_bound(bound_socket)), Ok(())) + }) } // TODO: Support nonblocking mode @@ -311,23 +314,24 @@ impl Socket for StreamSocket { fn listen(&self, backlog: usize) -> Result<()> { let mut state = self.state.write(); - let owned_state = mem::replace(&mut *state, State::Poisoned); - let State::Init(init_stream) = owned_state else { - *state = owned_state; - return_errno_with_message!(Errno::EINVAL, "cannot listen"); - }; + state.borrow_result(|owned_state| { + let State::Init(init_stream) = owned_state else { + return ( + owned_state, + Err(Error::with_message(Errno::EINVAL, "cannot listen")), + ); + }; - let listen_stream = match init_stream.listen(backlog) { - Ok(listen_stream) => listen_stream, - Err((err, init_stream)) => { - *state = State::Init(init_stream); - return Err(err); - } - }; - listen_stream.init_pollee(&self.pollee); - *state = State::Listen(listen_stream); + let listen_stream = match init_stream.listen(backlog) { + Ok(listen_stream) => listen_stream, + Err((err, init_stream)) => { + return (State::Init(init_stream), Err(err)); + } + }; + listen_stream.init_pollee(&self.pollee); - Ok(()) + (State::Listen(listen_stream), Ok(())) + }) } fn accept(&self) -> Result<(Arc, SocketAddr)> { @@ -341,7 +345,7 @@ impl Socket for StreamSocket { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { let state = self.state.read(); - match &*state { + match state.as_ref() { State::Connected(connected_stream) => connected_stream.shutdown(cmd), // TDOD: shutdown listening stream _ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"), @@ -350,19 +354,18 @@ impl Socket for StreamSocket { fn addr(&self) -> Result { let state = self.state.read(); - let local_endpoint = match &*state { + let local_endpoint = match state.as_ref() { State::Init(init_stream) => init_stream.local_endpoint()?, State::Connecting(connecting_stream) => connecting_stream.local_endpoint(), State::Listen(listen_stream) => listen_stream.local_endpoint(), State::Connected(connected_stream) => connected_stream.local_endpoint(), - State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"), }; local_endpoint.try_into() } fn peer_addr(&self) -> Result { let state = self.state.read(); - let remote_endpoint = match &*state { + let remote_endpoint = match state.as_ref() { State::Init(init_stream) => { return_errno_with_message!(Errno::EINVAL, "init socket does not have peer") } @@ -371,7 +374,6 @@ impl Socket for StreamSocket { return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer") } State::Connected(connected_stream) => connected_stream.remote_endpoint(), - State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"), }; remote_endpoint.try_into() } @@ -412,7 +414,7 @@ impl Socket for StreamSocket { let options = self.options.read(); match_sock_option_mut!(option, { // Socket Options - socket_errors: Error => { + socket_errors: SocketError => { let sock_errors = options.socket.sock_errors(); socket_errors.set(sock_errors); }, @@ -446,10 +448,8 @@ impl Socket for StreamSocket { // and always return the actual current MSS for a connected one. // FIXME: how to get the current MSS? - let maxseg = match &*self.state.read() { - State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => { - DEFAULT_MAXSEG - } + let maxseg = match self.state.read().as_ref() { + State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG, State::Connected(_) => options.tcp.maxseg(), }; tcp_maxseg.set(maxseg);