diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs new file mode 100644 index 000000000..fc98515a4 --- /dev/null +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connecting.rs @@ -0,0 +1,83 @@ +use core::sync::atomic::{AtomicBool, Ordering}; + +use alloc::sync::Arc; + +use crate::events::IoEvents; +use crate::net::poll_ifaces; +use crate::prelude::*; + +use crate::net::iface::{AnyBoundSocket, IpEndpoint}; +use crate::process::signal::Poller; + +use super::connected::ConnectedStream; +use super::init::InitStream; + +pub struct ConnectingStream { + nonblocking: AtomicBool, + bound_socket: Arc, + remote_endpoint: IpEndpoint, +} + +impl ConnectingStream { + pub fn new( + nonblocking: bool, + bound_socket: Arc, + remote_endpoint: IpEndpoint, + ) -> Result { + bound_socket.do_connect(remote_endpoint)?; + + Ok(Self { + nonblocking: AtomicBool::new(nonblocking), + bound_socket, + remote_endpoint, + }) + } + + pub fn wait_conn(&self) -> core::result::Result { + debug_assert!(!self.is_nonblocking()); + + let poller = Poller::new(); + loop { + poll_ifaces(); + + let events = self.poll(IoEvents::OUT | IoEvents::IN, Some(&poller)); + if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { + return Ok(ConnectedStream::new( + self.is_nonblocking(), + self.bound_socket.clone(), + self.remote_endpoint, + )); + } else if !events.is_empty() { + return Err(( + Error::with_message(Errno::ECONNREFUSED, "connection refused"), + InitStream::new_bound(self.is_nonblocking(), self.bound_socket.clone()), + )); + } else { + // FIXME: deal with nonblocking mode & connecting timeout + poller.wait().expect("async connect() not implemented"); + } + } + } + + pub fn local_endpoint(&self) -> Result { + self.bound_socket + .local_endpoint() + .ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint")) + } + + pub fn remote_endpoint(&self) -> Result { + Ok(self.remote_endpoint) + } + + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.bound_socket.poll(mask, poller) + } + + pub fn is_nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::Relaxed) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblocking.store(nonblocking, Ordering::Relaxed); + } +} diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs index 2da70549b..d82c16982 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs @@ -4,12 +4,14 @@ use crate::events::IoEvents; use crate::net::iface::Iface; use crate::net::iface::IpEndpoint; use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket}; -use crate::net::poll_ifaces; use crate::net::socket::ip::always_some::AlwaysSome; use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint}; use crate::prelude::*; use crate::process::signal::Poller; +use super::connecting::ConnectingStream; +use super::listen::ListenStream; + pub struct InitStream { inner: RwLock, is_nonblocking: AtomicBool, @@ -18,17 +20,13 @@ pub struct InitStream { enum Inner { Unbound(AlwaysSome>), Bound(AlwaysSome>), - Connecting { - bound_socket: Arc, - remote_endpoint: IpEndpoint, - }, } impl Inner { fn is_bound(&self) -> bool { match self { Self::Unbound(_) => false, - Self::Bound(..) | Self::Connecting { .. } => true, + Self::Bound(_) => true, } } @@ -50,39 +48,16 @@ impl Inner { self.bind(endpoint) } - fn do_connect(&mut self, new_remote_endpoint: IpEndpoint) -> Result<()> { - match self { - Inner::Unbound(_) => return_errno_with_message!(Errno::EINVAL, "the socket is invalid"), - Inner::Connecting { - bound_socket, - remote_endpoint, - } => { - *remote_endpoint = new_remote_endpoint; - bound_socket.do_connect(new_remote_endpoint)?; - } - Inner::Bound(bound_socket) => { - bound_socket.do_connect(new_remote_endpoint)?; - *self = Inner::Connecting { - bound_socket: bound_socket.take(), - remote_endpoint: new_remote_endpoint, - }; - } - } - Ok(()) - } - fn bound_socket(&self) -> Option<&Arc> { match self { Inner::Bound(bound_socket) => Some(bound_socket), - Inner::Connecting { bound_socket, .. } => Some(bound_socket), - _ => None, + Inner::Unbound(_) => None, } } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { match self { Inner::Bound(bound_socket) => bound_socket.poll(mask, poller), - Inner::Connecting { bound_socket, .. } => bound_socket.poll(mask, poller), Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller), } } @@ -90,8 +65,7 @@ impl Inner { fn iface(&self) -> Option> { match self { Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()), - Inner::Connecting { bound_socket, .. } => Some(bound_socket.iface().clone()), - _ => None, + Inner::Unbound(_) => None, } } @@ -99,17 +73,6 @@ impl Inner { self.bound_socket() .and_then(|socket| socket.local_endpoint()) } - - fn remote_endpoint(&self) -> Option { - if let Inner::Connecting { - remote_endpoint, .. - } = self - { - Some(*remote_endpoint) - } else { - None - } - } } impl InitStream { @@ -122,40 +85,38 @@ impl InitStream { } } - pub fn is_bound(&self) -> bool { - self.inner.read().is_bound() + pub fn new_bound(nonblocking: bool, bound_socket: Arc) -> Self { + let inner = Inner::Bound(AlwaysSome::new(bound_socket)); + Self { + is_nonblocking: AtomicBool::new(nonblocking), + inner: RwLock::new(inner), + } } pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> { self.inner.write().bind(endpoint) } - pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> { - if !self.is_bound() { + pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result { + if !self.inner.read().is_bound() { self.inner .write() .bind_to_ephemeral_endpoint(remote_endpoint)? } - self.inner.write().do_connect(*remote_endpoint)?; - // Wait until building connection - let poller = Poller::new(); - loop { - poll_ifaces(); - let events = self - .inner - .read() - .poll(IoEvents::OUT | IoEvents::IN, Some(&poller)); - if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { - return Ok(()); - } else if !events.is_empty() { - return_errno_with_message!(Errno::ECONNREFUSED, "connect refused"); - } else if self.is_nonblocking() { - return_errno_with_message!(Errno::EAGAIN, "try connect again"); - } else { - // FIXME: deal with connecting timeout - poller.wait()?; - } - } + ConnectingStream::new( + self.is_nonblocking(), + self.inner.read().bound_socket().unwrap().clone(), + *remote_endpoint, + ) + } + + pub fn listen(&self, backlog: usize) -> Result { + let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() { + bound_socket.clone() + } else { + return_errno_with_message!(Errno::EINVAL, "cannot listen without bound") + }; + ListenStream::new(self.is_nonblocking(), bound_socket, backlog) } pub fn local_endpoint(&self) -> Result { @@ -165,21 +126,10 @@ impl InitStream { .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint")) } - pub fn remote_endpoint(&self) -> Result { - self.inner - .read() - .remote_endpoint() - .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint")) - } - - pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { self.inner.read().poll(mask, poller) } - pub fn bound_socket(&self) -> Option> { - self.inner.read().bound_socket().map(Clone::clone) - } - pub fn is_nonblocking(&self) -> bool { self.is_nonblocking.load(Ordering::Relaxed) } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index b34af7488..c364449ec 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -1,5 +1,6 @@ use crate::events::IoEvents; use crate::fs::{file_handle::FileLike, utils::StatusFlags}; +use crate::net::iface::IpEndpoint; use crate::net::socket::{ util::{ send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd, @@ -10,9 +11,13 @@ use crate::net::socket::{ use crate::prelude::*; use crate::process::signal::Poller; -use self::{connected::ConnectedStream, init::InitStream, listen::ListenStream}; +use self::{ + connected::ConnectedStream, connecting::ConnectingStream, init::InitStream, + listen::ListenStream, +}; mod connected; +mod connecting; mod init; mod listen; @@ -23,6 +28,8 @@ pub struct StreamSocket { enum State { // Start state Init(Arc), + // Intermediate state + Connecting(Arc), // Final State 1 Connected(Arc), // Final State 2 @@ -40,6 +47,7 @@ impl StreamSocket { fn is_nonblocking(&self) -> bool { match &*self.state.read() { State::Init(init) => init.is_nonblocking(), + State::Connecting(connecting) => connecting.is_nonblocking(), State::Connected(connected) => connected.is_nonblocking(), State::Listen(listen) => listen.is_nonblocking(), } @@ -48,10 +56,25 @@ impl StreamSocket { fn set_nonblocking(&self, nonblocking: bool) { match &*self.state.read() { State::Init(init) => init.set_nonblocking(nonblocking), + State::Connecting(connecting) => connecting.set_nonblocking(nonblocking), State::Connected(connected) => connected.set_nonblocking(nonblocking), State::Listen(listen) => listen.set_nonblocking(nonblocking), } } + + fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result> { + let mut state = self.state.write(); + let init_stream = match &*state { + State::Init(init_stream) => init_stream, + State::Listen(_) | State::Connecting(_) | State::Connected(_) => { + return_errno_with_message!(Errno::EINVAL, "cannot connect") + } + }; + + let connecting = Arc::new(init_stream.connect(remote_endpoint)?); + *state = State::Connecting(connecting.clone()); + Ok(connecting) + } } impl FileLike for StreamSocket { @@ -72,6 +95,7 @@ impl FileLike for StreamSocket { let state = self.state.read(); match &*state { State::Init(init) => init.poll(mask, poller), + State::Connecting(connecting) => connecting.poll(mask, poller), State::Connected(connected) => connected.poll(mask, poller), State::Listen(listen) => listen.poll(mask, poller), } @@ -112,44 +136,37 @@ impl Socket for StreamSocket { fn connect(&self, sockaddr: SocketAddr) -> Result<()> { let remote_endpoint = sockaddr.try_into()?; - let init_stream = match &*self.state.read() { - State::Init(init_stream) => init_stream.clone(), - _ => return_errno_with_message!(Errno::EINVAL, "cannot connect"), - }; - - init_stream.connect(&remote_endpoint)?; - - let connected_stream = { - let nonblocking = init_stream.is_nonblocking(); - let bound_socket = init_stream.bound_socket().unwrap(); - Arc::new(ConnectedStream::new( - nonblocking, - bound_socket, - remote_endpoint, - )) - }; - *self.state.write() = State::Connected(connected_stream); - Ok(()) + let connecting_stream = self.do_connect(&remote_endpoint)?; + match connecting_stream.wait_conn() { + Ok(connected_stream) => { + let connected_stream = Arc::new(connected_stream); + *self.state.write() = State::Connected(connected_stream); + Ok(()) + } + Err((err, init_stream)) => { + let init_stream = Arc::new(init_stream); + *self.state.write() = State::Init(init_stream); + Err(err) + } + } } fn listen(&self, backlog: usize) -> Result<()> { let mut state = self.state.write(); - match &*state { - State::Init(init_stream) => { - if !init_stream.is_bound() { - return_errno_with_message!(Errno::EINVAL, "cannot listen without bound"); - } - let nonblocking = init_stream.is_nonblocking(); - let bound_socket = init_stream.bound_socket().unwrap(); - let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?); - *state = State::Listen(listener); - Ok(()) + let init_stream = match &*state { + State::Init(init_stream) => init_stream, + State::Connecting(connecting_stream) => { + return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream") } State::Listen(listen_stream) => { return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream") } - _ => return_errno_with_message!(Errno::EINVAL, "cannot listen"), - } + State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"), + }; + + let listener = Arc::new(init_stream.listen(backlog)?); + *state = State::Listen(listener); + Ok(()) } fn accept(&self) -> Result<(Arc, SocketAddr)> { @@ -185,6 +202,7 @@ impl Socket for StreamSocket { let state = self.state.read(); let local_endpoint = match &*state { State::Init(init_stream) => init_stream.local_endpoint(), + State::Connecting(connecting_stream) => connecting_stream.local_endpoint(), State::Listen(listen_stream) => listen_stream.local_endpoint(), State::Connected(connected_stream) => connected_stream.local_endpoint(), }?; @@ -194,7 +212,10 @@ impl Socket for StreamSocket { fn peer_addr(&self) -> Result { let state = self.state.read(); let remote_endpoint = match &*state { - State::Init(init_stream) => init_stream.remote_endpoint(), + State::Init(init_stream) => { + return_errno_with_message!(Errno::EINVAL, "init socket does not have peer") + } + State::Connecting(connecting_stream) => connecting_stream.remote_endpoint(), State::Listen(listen_stream) => { return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer") }