diff --git a/services/libs/jinux-std/src/net/socket/ip/datagram.rs b/services/libs/jinux-std/src/net/socket/ip/datagram.rs index 8382dd6e2..902715884 100644 --- a/services/libs/jinux-std/src/net/socket/ip/datagram.rs +++ b/services/libs/jinux-std/src/net/socket/ip/datagram.rs @@ -1,3 +1,6 @@ +use core::sync::atomic::{AtomicBool, Ordering}; + +use crate::fs::utils::StatusFlags; use crate::net::iface::IpEndpoint; use crate::{ @@ -20,6 +23,7 @@ use super::always_some::AlwaysSome; use super::common::{bind_socket, get_ephemeral_endpoint}; pub struct DatagramSocket { + nonblocking: AtomicBool, inner: RwLock, } @@ -128,6 +132,7 @@ impl DatagramSocket { let udp_socket = AnyUnboundSocket::new_udp(); Self { inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))), + nonblocking: AtomicBool::new(false), } } @@ -156,6 +161,14 @@ impl DatagramSocket { "udp should provide remote addr", )) } + + pub fn nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::SeqCst) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblocking.store(nonblocking, Ordering::SeqCst); + } } impl FileLike for DatagramSocket { @@ -179,6 +192,23 @@ impl FileLike for DatagramSocket { fn as_socket(&self) -> Option<&dyn Socket> { Some(self) } + + fn status_flags(&self) -> StatusFlags { + if self.nonblocking() { + StatusFlags::O_NONBLOCK + } else { + StatusFlags::empty() + } + } + + fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { + if new_flags.contains(StatusFlags::O_NONBLOCK) { + self.set_nonblocking(true); + } else { + self.set_nonblocking(false); + } + Ok(()) + } } impl Socket for DatagramSocket { diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs index 5006bff9e..125b9e19e 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs @@ -1,3 +1,5 @@ +use core::sync::atomic::{AtomicBool, Ordering}; + use crate::net::iface::IpEndpoint; use crate::{ fs::utils::{IoEvents, Poller}, @@ -10,13 +12,19 @@ use crate::{ }; pub struct ConnectedStream { + nonblocking: AtomicBool, bound_socket: Arc, remote_endpoint: IpEndpoint, } impl ConnectedStream { - pub fn new(bound_socket: Arc, remote_endpoint: IpEndpoint) -> Self { + pub fn new( + nonblocking: bool, + bound_socket: Arc, + remote_endpoint: IpEndpoint, + ) -> Self { Self { + nonblocking: AtomicBool::new(nonblocking), bound_socket, remote_endpoint, } @@ -92,4 +100,12 @@ impl ConnectedStream { pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { self.bound_socket.poll(mask, poller) } + + pub fn nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::SeqCst) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblocking.store(nonblocking, Ordering::SeqCst); + } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs index bc4511c39..279f25347 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs @@ -1,3 +1,5 @@ +use core::sync::atomic::{AtomicBool, Ordering}; + use crate::fs::utils::{IoEvents, Poller}; use crate::net::iface::Iface; use crate::net::iface::IpEndpoint; @@ -9,6 +11,8 @@ use crate::prelude::*; pub struct InitStream { inner: RwLock, + // TODO: deal with nonblocking + nonblocking: AtomicBool, } enum Inner { @@ -114,6 +118,7 @@ impl InitStream { let socket = AnyUnboundSocket::new_tcp(); let inner = Inner::Unbound(AlwaysSome::new(socket)); Self { + nonblocking: AtomicBool::new(false), inner: RwLock::new(inner), } } @@ -172,4 +177,12 @@ impl InitStream { pub fn bound_socket(&self) -> Option> { self.inner.read().bound_socket().map(Clone::clone) } + + pub fn nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::SeqCst) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblocking.store(nonblocking, Ordering::SeqCst); + } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs b/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs index 550fd4047..662f7878b 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs @@ -1,3 +1,5 @@ +use core::sync::atomic::{AtomicBool, Ordering}; + use crate::net::iface::{AnyUnboundSocket, BindPortConfig, IpEndpoint}; use crate::fs::utils::{IoEvents, Poller}; @@ -7,16 +9,22 @@ use crate::{net::poll_ifaces, prelude::*}; use super::connected::ConnectedStream; pub struct ListenStream { + nonblocking: AtomicBool, backlog: usize, /// Sockets also listening at LocalEndPoint when called `listen` backlog_sockets: RwLock>, } impl ListenStream { - pub fn new(bound_socket: Arc, backlog: usize) -> Result { + pub fn new( + nonblocking: bool, + bound_socket: Arc, + backlog: usize, + ) -> Result { debug_assert!(backlog >= 1); let backlog_socket = BacklogSocket::new(&bound_socket)?; let listen_stream = Self { + nonblocking: AtomicBool::new(nonblocking), backlog, backlog_sockets: RwLock::new(vec![backlog_socket]), }; @@ -43,7 +51,8 @@ impl ListenStream { let BacklogSocket { bound_socket: backlog_socket, } = accepted_socket; - ConnectedStream::new(backlog_socket, remote_endpoint) + let nonblocking = self.nonblocking(); + ConnectedStream::new(nonblocking, backlog_socket, remote_endpoint) }; return Ok((connected_stream, remote_endpoint)); } @@ -100,6 +109,14 @@ impl ListenStream { fn bound_socket(&self) -> Arc { self.backlog_sockets.read()[0].bound_socket.clone() } + + pub fn nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::SeqCst) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.nonblocking.store(nonblocking, Ordering::SeqCst); + } } struct BacklogSocket { diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index 8b714eb9f..ac659c013 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -1,6 +1,6 @@ use crate::fs::{ file_handle::FileLike, - utils::{IoEvents, Poller}, + utils::{IoEvents, Poller, StatusFlags}, }; use crate::net::socket::{ util::{ @@ -37,6 +37,22 @@ impl StreamSocket { state: RwLock::new(state), } } + + fn nonblocking(&self) -> bool { + match &*self.state.read() { + State::Init(init) => init.nonblocking(), + State::Connected(connected) => connected.nonblocking(), + State::Listen(listen) => listen.nonblocking(), + } + } + + fn set_nonblocking(&self, nonblocking: bool) { + match &*self.state.read() { + State::Init(init) => init.set_nonblocking(nonblocking), + State::Connected(connected) => connected.set_nonblocking(nonblocking), + State::Listen(listen) => listen.set_nonblocking(nonblocking), + } + } } impl FileLike for StreamSocket { @@ -62,6 +78,23 @@ impl FileLike for StreamSocket { } } + fn status_flags(&self) -> StatusFlags { + if self.nonblocking() { + StatusFlags::O_NONBLOCK + } else { + StatusFlags::empty() + } + } + + fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { + if new_flags.contains(StatusFlags::O_NONBLOCK) { + self.set_nonblocking(true); + } else { + self.set_nonblocking(false); + } + Ok(()) + } + fn as_socket(&self) -> Option<&dyn Socket> { Some(self) } @@ -84,9 +117,13 @@ impl Socket for StreamSocket { match &*state { State::Init(init_stream) => { init_stream.connect(&remote_endpoint)?; + let nonblocking = init_stream.nonblocking(); let bound_socket = init_stream.bound_socket().unwrap(); - let connected_stream = - Arc::new(ConnectedStream::new(bound_socket, remote_endpoint)); + let connected_stream = Arc::new(ConnectedStream::new( + nonblocking, + bound_socket, + remote_endpoint, + )); *state = State::Connected(connected_stream); Ok(()) } @@ -101,8 +138,9 @@ impl Socket for StreamSocket { if !init_stream.is_bound() { return_errno_with_message!(Errno::EINVAL, "cannot listen without bound"); } + let nonblocking = init_stream.nonblocking(); let bound_socket = init_stream.bound_socket().unwrap(); - let listener = Arc::new(ListenStream::new(bound_socket, backlog)?); + let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?); *state = State::Listen(listener); Ok(()) }