From 6b903d0c101f72bbf5f5fc79a7db64f3eb17313d Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Thu, 30 Nov 2023 01:39:04 +0800 Subject: [PATCH] Seperate `ConnectingStream` from `InitStream` For TCP streams we used to have three states, e.g. `InitStream`, `ListenStream`, `ConnectedStream`. If the socket is not bound, it is in the `InitStream` state. If the socket is bound, it is still in that state. Most seriously, if the socket is connecting to the remote peer, but the connection has not been established, it is also in the `InitStream` state. While the socket is trying to connect to its peer, it needs to handle iface events to update its internal state. So it is expected to implement the trait that observes such events for this later. However, the reality that sockets connecting to their peers are mixed in with other unbound and bound sockets in the `InitStream` will complicate things. In fact, the connecting socket should belong to an independent state. It does not share too much logic with unbound and bound sockets in the `InitStream`. So in this commit we will decouple that and create a new `ConnectingStream` state. --- .../src/net/socket/ip/stream/connecting.rs | 83 ++++++++++++++ .../src/net/socket/ip/stream/init.rs | 108 +++++------------- .../jinux-std/src/net/socket/ip/stream/mod.rs | 85 ++++++++------ 3 files changed, 165 insertions(+), 111 deletions(-) create mode 100644 services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs new file mode 100644 index 000000000..fc98515a4 --- /dev/null +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs @@ -0,0 +1,83 @@ +use core::sync::atomic::{AtomicBool, Ordering}; + +use alloc::sync::Arc; + +use crate::events::IoEvents; +use crate::net::poll_ifaces; +use crate::prelude::*; + +use crate::net::iface::{AnyBoundSocket, IpEndpoint}; +use crate::process::signal::Poller; + +use super::connected::ConnectedStream; +use super::init::InitStream; + +pub struct ConnectingStream { + nonblocking: AtomicBool, + bound_socket: Arc, + remote_endpoint: IpEndpoint, +} + +impl ConnectingStream { + pub fn new( + nonblocking: bool, + bound_socket: Arc, + remote_endpoint: IpEndpoint, + ) -> Result { + bound_socket.do_connect(remote_endpoint)?; + + Ok(Self { + nonblocking: AtomicBool::new(nonblocking), + bound_socket, + remote_endpoint, + }) + } + + pub fn wait_conn(&self) -> core::result::Result { + debug_assert!(!self.is_nonblocking()); + + let poller = Poller::new(); + loop { + poll_ifaces(); + + let events = self.poll(IoEvents::OUT | IoEvents::IN, Some(&poller)); + if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { + return Ok(ConnectedStream::new( + self.is_nonblocking(), + self.bound_socket.clone(), + self.remote_endpoint, + )); + } else if !events.is_empty() { + return Err(( + Error::with_message(Errno::ECONNREFUSED, "connection refused"), + InitStream::new_bound(self.is_nonblocking(), self.bound_socket.clone()), + )); + } else { + // FIXME: deal with nonblocking mode & connecting timeout + poller.wait().expect("async connect() not implemented"); + } + } + } + + pub fn local_endpoint(&self) -> Result { + self.bound_socket + .local_endpoint() + .ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint")) + } + + pub fn remote_endpoint(&self) -> Result { + Ok(self.remote_endpoint) + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.bound_socket.poll(mask, poller) + } + + pub fn is_nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::Relaxed) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblocking.store(nonblocking, Ordering::Relaxed); + } +} diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs index 2da70549b..d82c16982 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs @@ -4,12 +4,14 @@ use crate::events::IoEvents; use crate::net::iface::Iface; use crate::net::iface::IpEndpoint; use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket}; -use crate::net::poll_ifaces; use crate::net::socket::ip::always_some::AlwaysSome; use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint}; use crate::prelude::*; use crate::process::signal::Poller; +use super::connecting::ConnectingStream; +use super::listen::ListenStream; + pub struct InitStream { inner: RwLock, is_nonblocking: AtomicBool, @@ -18,17 +20,13 @@ pub struct InitStream { enum Inner { Unbound(AlwaysSome>), Bound(AlwaysSome>), - Connecting { - bound_socket: Arc, - remote_endpoint: IpEndpoint, - }, } impl Inner { fn is_bound(&self) -> bool { match self { Self::Unbound(_) => false, - Self::Bound(..) | Self::Connecting { .. } => true, + Self::Bound(_) => true, } } @@ -50,39 +48,16 @@ impl Inner { self.bind(endpoint) } - fn do_connect(&mut self, new_remote_endpoint: IpEndpoint) -> Result<()> { - match self { - Inner::Unbound(_) => return_errno_with_message!(Errno::EINVAL, "the socket is invalid"), - Inner::Connecting { - bound_socket, - remote_endpoint, - } => { - *remote_endpoint = new_remote_endpoint; - bound_socket.do_connect(new_remote_endpoint)?; - } - Inner::Bound(bound_socket) => { - bound_socket.do_connect(new_remote_endpoint)?; - *self = Inner::Connecting { - bound_socket: bound_socket.take(), - remote_endpoint: new_remote_endpoint, - }; - } - } - Ok(()) - } - fn bound_socket(&self) -> Option<&Arc> { match self { Inner::Bound(bound_socket) => Some(bound_socket), - Inner::Connecting { bound_socket, .. } => Some(bound_socket), - _ => None, + Inner::Unbound(_) => None, } } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { match self { Inner::Bound(bound_socket) => bound_socket.poll(mask, poller), - Inner::Connecting { bound_socket, .. } => bound_socket.poll(mask, poller), Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller), } } @@ -90,8 +65,7 @@ impl Inner { fn iface(&self) -> Option> { match self { Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()), - Inner::Connecting { bound_socket, .. } => Some(bound_socket.iface().clone()), - _ => None, + Inner::Unbound(_) => None, } } @@ -99,17 +73,6 @@ impl Inner { self.bound_socket() .and_then(|socket| socket.local_endpoint()) } - - fn remote_endpoint(&self) -> Option { - if let Inner::Connecting { - remote_endpoint, .. - } = self - { - Some(*remote_endpoint) - } else { - None - } - } } impl InitStream { @@ -122,40 +85,38 @@ impl InitStream { } } - pub fn is_bound(&self) -> bool { - self.inner.read().is_bound() + pub fn new_bound(nonblocking: bool, bound_socket: Arc) -> Self { + let inner = Inner::Bound(AlwaysSome::new(bound_socket)); + Self { + is_nonblocking: AtomicBool::new(nonblocking), + inner: RwLock::new(inner), + } } pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> { self.inner.write().bind(endpoint) } - pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> { - if !self.is_bound() { + pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result { + if !self.inner.read().is_bound() { self.inner .write() .bind_to_ephemeral_endpoint(remote_endpoint)? } - self.inner.write().do_connect(*remote_endpoint)?; - // Wait until building connection - let poller = Poller::new(); - loop { - poll_ifaces(); - let events = self - .inner - .read() - .poll(IoEvents::OUT | IoEvents::IN, Some(&poller)); - if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { - return Ok(()); - } else if !events.is_empty() { - return_errno_with_message!(Errno::ECONNREFUSED, "connect refused"); - } else if self.is_nonblocking() { - return_errno_with_message!(Errno::EAGAIN, "try connect again"); - } else { - // FIXME: deal with connecting timeout - poller.wait()?; - } - } + ConnectingStream::new( + self.is_nonblocking(), + self.inner.read().bound_socket().unwrap().clone(), + *remote_endpoint, + ) + } + + pub fn listen(&self, backlog: usize) -> Result { + let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() { + bound_socket.clone() + } else { + return_errno_with_message!(Errno::EINVAL, "cannot listen without bound") + }; + ListenStream::new(self.is_nonblocking(), bound_socket, backlog) } pub fn local_endpoint(&self) -> Result { @@ -165,21 +126,10 @@ impl InitStream { .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint")) } - pub fn remote_endpoint(&self) -> Result { - self.inner - .read() - .remote_endpoint() - .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint")) - } - - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { self.inner.read().poll(mask, poller) } - pub fn bound_socket(&self) -> Option> { - self.inner.read().bound_socket().map(Clone::clone) - } - pub fn is_nonblocking(&self) -> bool { self.is_nonblocking.load(Ordering::Relaxed) } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index b34af7488..c364449ec 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -1,5 +1,6 @@ use crate::events::IoEvents; use crate::fs::{file_handle::FileLike, utils::StatusFlags}; +use crate::net::iface::IpEndpoint; use crate::net::socket::{ util::{ send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd, @@ -10,9 +11,13 @@ use crate::net::socket::{ use crate::prelude::*; use crate::process::signal::Poller; -use self::{connected::ConnectedStream, init::InitStream, listen::ListenStream}; +use self::{ + connected::ConnectedStream, connecting::ConnectingStream, init::InitStream, + listen::ListenStream, +}; mod connected; +mod connecting; mod init; mod listen; @@ -23,6 +28,8 @@ pub struct StreamSocket { enum State { // Start state Init(Arc), + // Intermediate state + Connecting(Arc), // Final State 1 Connected(Arc), // Final State 2 @@ -40,6 +47,7 @@ impl StreamSocket { fn is_nonblocking(&self) -> bool { match &*self.state.read() { State::Init(init) => init.is_nonblocking(), + State::Connecting(connecting) => connecting.is_nonblocking(), State::Connected(connected) => connected.is_nonblocking(), State::Listen(listen) => listen.is_nonblocking(), } @@ -48,10 +56,25 @@ impl StreamSocket { fn set_nonblocking(&self, nonblocking: bool) { match &*self.state.read() { State::Init(init) => init.set_nonblocking(nonblocking), + State::Connecting(connecting) => connecting.set_nonblocking(nonblocking), State::Connected(connected) => connected.set_nonblocking(nonblocking), State::Listen(listen) => listen.set_nonblocking(nonblocking), } } + + fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result> { + let mut state = self.state.write(); + let init_stream = match &*state { + State::Init(init_stream) => init_stream, + State::Listen(_) | State::Connecting(_) | State::Connected(_) => { + return_errno_with_message!(Errno::EINVAL, "cannot connect") + } + }; + + let connecting = Arc::new(init_stream.connect(remote_endpoint)?); + *state = State::Connecting(connecting.clone()); + Ok(connecting) + } } impl FileLike for StreamSocket { @@ -72,6 +95,7 @@ impl FileLike for StreamSocket { let state = self.state.read(); match &*state { State::Init(init) => init.poll(mask, poller), + State::Connecting(connecting) => connecting.poll(mask, poller), State::Connected(connected) => connected.poll(mask, poller), State::Listen(listen) => listen.poll(mask, poller), } @@ -112,44 +136,37 @@ impl Socket for StreamSocket { fn connect(&self, sockaddr: SocketAddr) -> Result<()> { let remote_endpoint = sockaddr.try_into()?; - let init_stream = match &*self.state.read() { - State::Init(init_stream) => init_stream.clone(), - _ => return_errno_with_message!(Errno::EINVAL, "cannot connect"), - }; - - init_stream.connect(&remote_endpoint)?; - - let connected_stream = { - let nonblocking = init_stream.is_nonblocking(); - let bound_socket = init_stream.bound_socket().unwrap(); - Arc::new(ConnectedStream::new( - nonblocking, - bound_socket, - remote_endpoint, - )) - }; - *self.state.write() = State::Connected(connected_stream); - Ok(()) + let connecting_stream = self.do_connect(&remote_endpoint)?; + match connecting_stream.wait_conn() { + Ok(connected_stream) => { + let connected_stream = Arc::new(connected_stream); + *self.state.write() = State::Connected(connected_stream); + Ok(()) + } + Err((err, init_stream)) => { + let init_stream = Arc::new(init_stream); + *self.state.write() = State::Init(init_stream); + Err(err) + } + } } fn listen(&self, backlog: usize) -> Result<()> { let mut state = self.state.write(); - match &*state { - State::Init(init_stream) => { - if !init_stream.is_bound() { - return_errno_with_message!(Errno::EINVAL, "cannot listen without bound"); - } - let nonblocking = init_stream.is_nonblocking(); - let bound_socket = init_stream.bound_socket().unwrap(); - let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?); - *state = State::Listen(listener); - Ok(()) + let init_stream = match &*state { + State::Init(init_stream) => init_stream, + State::Connecting(connecting_stream) => { + return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream") } State::Listen(listen_stream) => { return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream") } - _ => return_errno_with_message!(Errno::EINVAL, "cannot listen"), - } + State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"), + }; + + let listener = Arc::new(init_stream.listen(backlog)?); + *state = State::Listen(listener); + Ok(()) } fn accept(&self) -> Result<(Arc, SocketAddr)> { @@ -185,6 +202,7 @@ impl Socket for StreamSocket { let state = self.state.read(); let local_endpoint = match &*state { State::Init(init_stream) => init_stream.local_endpoint(), + State::Connecting(connecting_stream) => connecting_stream.local_endpoint(), State::Listen(listen_stream) => listen_stream.local_endpoint(), State::Connected(connected_stream) => connected_stream.local_endpoint(), }?; @@ -194,7 +212,10 @@ impl Socket for StreamSocket { fn peer_addr(&self) -> Result { let state = self.state.read(); let remote_endpoint = match &*state { - State::Init(init_stream) => init_stream.remote_endpoint(), + State::Init(init_stream) => { + return_errno_with_message!(Errno::EINVAL, "init socket does not have peer") + } + State::Connecting(connecting_stream) => connecting_stream.remote_endpoint(), State::Listen(listen_stream) => { return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer") }