diff --git a/services/libs/jinux-std/src/net/socket/ip/datagram.rs b/services/libs/jinux-std/src/net/socket/ip/datagram.rs index 902715884..c790af97b 100644 --- a/services/libs/jinux-std/src/net/socket/ip/datagram.rs +++ b/services/libs/jinux-std/src/net/socket/ip/datagram.rs @@ -128,11 +128,11 @@ impl Inner { } impl DatagramSocket { - pub fn new() -> Self { + pub fn new(nonblocking: bool) -> Self { let udp_socket = AnyUnboundSocket::new_udp(); Self { inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))), - nonblocking: AtomicBool::new(false), + nonblocking: AtomicBool::new(nonblocking), } } @@ -259,6 +259,9 @@ impl Socket for DatagramSocket { } let events = self.inner.read().poll(IoEvents::IN, Some(&poller)); if !events.contains(IoEvents::IN) { + if self.nonblocking() { + return_errno_with_message!(Errno::EAGAIN, "try to receive again"); + } poller.wait(); } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs index 125b9e19e..858796cbe 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs @@ -19,12 +19,12 @@ pub struct ConnectedStream { impl ConnectedStream { pub fn new( - nonblocking: bool, + is_nonblocking: bool, bound_socket: Arc, remote_endpoint: IpEndpoint, ) -> Self { Self { - nonblocking: AtomicBool::new(nonblocking), + nonblocking: AtomicBool::new(is_nonblocking), bound_socket, remote_endpoint, } @@ -54,6 +54,9 @@ impl ConnectedStream { return_errno_with_message!(Errno::ENOTCONN, "recv packet fails"); } if !events.contains(IoEvents::IN) { + if self.is_nonblocking() { + return_errno_with_message!(Errno::EAGAIN, "try to recv again"); + } poller.wait(); } } @@ -101,11 +104,11 @@ impl ConnectedStream { self.bound_socket.poll(mask, poller) } - pub fn nonblocking(&self) -> bool { - self.nonblocking.load(Ordering::SeqCst) + pub fn is_nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::Relaxed) } pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblocking.store(nonblocking, Ordering::SeqCst); + self.nonblocking.store(nonblocking, Ordering::Relaxed); } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs index 279f25347..92578dfa3 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/init.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/init.rs @@ -11,8 +11,7 @@ use crate::prelude::*; pub struct InitStream { inner: RwLock, - // TODO: deal with nonblocking - nonblocking: AtomicBool, + is_nonblocking: AtomicBool, } enum Inner { @@ -114,11 +113,11 @@ impl Inner { } impl InitStream { - pub fn new() -> Self { + pub fn new(nonblocking: bool) -> Self { let socket = AnyUnboundSocket::new_tcp(); let inner = Inner::Unbound(AlwaysSome::new(socket)); Self { - nonblocking: AtomicBool::new(false), + is_nonblocking: AtomicBool::new(nonblocking), inner: RwLock::new(inner), } } @@ -149,7 +148,9 @@ impl InitStream { if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { return Ok(()); } else if !events.is_empty() { - return_errno_with_message!(Errno::ECONNREFUSED, "connect refused") + return_errno_with_message!(Errno::ECONNREFUSED, "connect refused"); + } else if self.is_nonblocking() { + return_errno_with_message!(Errno::EAGAIN, "try connect again"); } else { poller.wait(); } @@ -178,11 +179,11 @@ impl InitStream { self.inner.read().bound_socket().map(Clone::clone) } - pub fn nonblocking(&self) -> bool { - self.nonblocking.load(Ordering::SeqCst) + pub fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) } pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblocking.store(nonblocking, Ordering::SeqCst); + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs b/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs index 662f7878b..9bbf8ae19 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/listen.rs @@ -9,7 +9,7 @@ use crate::{net::poll_ifaces, prelude::*}; use super::connected::ConnectedStream; pub struct ListenStream { - nonblocking: AtomicBool, + is_nonblocking: AtomicBool, backlog: usize, /// Sockets also listening at LocalEndPoint when called `listen` backlog_sockets: RwLock>, @@ -24,7 +24,7 @@ impl ListenStream { debug_assert!(backlog >= 1); let backlog_socket = BacklogSocket::new(&bound_socket)?; let listen_stream = Self { - nonblocking: AtomicBool::new(nonblocking), + is_nonblocking: AtomicBool::new(nonblocking), backlog, backlog_sockets: RwLock::new(vec![backlog_socket]), }; @@ -42,6 +42,9 @@ impl ListenStream { } else { let events = self.poll(IoEvents::IN | IoEvents::OUT, Some(&poller)); if !events.contains(IoEvents::IN) && !events.contains(IoEvents::OUT) { + if self.is_nonblocking() { + return_errno_with_message!(Errno::EAGAIN, "try accept again"); + } poller.wait(); } continue; @@ -51,8 +54,7 @@ impl ListenStream { let BacklogSocket { bound_socket: backlog_socket, } = accepted_socket; - let nonblocking = self.nonblocking(); - ConnectedStream::new(nonblocking, backlog_socket, remote_endpoint) + ConnectedStream::new(false, backlog_socket, remote_endpoint) }; return Ok((connected_stream, remote_endpoint)); } @@ -110,12 +112,12 @@ impl ListenStream { self.backlog_sockets.read()[0].bound_socket.clone() } - pub fn nonblocking(&self) -> bool { - self.nonblocking.load(Ordering::SeqCst) + pub fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) } pub fn set_nonblocking(&self, nonblocking: bool) { - self.nonblocking.store(nonblocking, Ordering::SeqCst); + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } } diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index ac659c013..0ad89de04 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -31,18 +31,18 @@ enum State { } impl StreamSocket { - pub fn new() -> Self { - let state = State::Init(Arc::new(InitStream::new())); + pub fn new(nonblocking: bool) -> Self { + let state = State::Init(Arc::new(InitStream::new(nonblocking))); Self { state: RwLock::new(state), } } - fn nonblocking(&self) -> bool { + fn is_nonblocking(&self) -> bool { match &*self.state.read() { - State::Init(init) => init.nonblocking(), - State::Connected(connected) => connected.nonblocking(), - State::Listen(listen) => listen.nonblocking(), + State::Init(init) => init.is_nonblocking(), + State::Connected(connected) => connected.is_nonblocking(), + State::Listen(listen) => listen.is_nonblocking(), } } @@ -79,7 +79,7 @@ impl FileLike for StreamSocket { } fn status_flags(&self) -> StatusFlags { - if self.nonblocking() { + if self.is_nonblocking() { StatusFlags::O_NONBLOCK } else { StatusFlags::empty() @@ -112,19 +112,20 @@ impl Socket for StreamSocket { fn connect(&self, sockaddr: SocketAddr) -> Result<()> { let remote_endpoint = sockaddr.try_into()?; - let mut state = self.state.write(); - // FIXME: The rwlock is held when trying to connect, which may cause dead lock. + let state = self.state.read(); match &*state { State::Init(init_stream) => { + let init_stream = init_stream.clone(); + drop(state); init_stream.connect(&remote_endpoint)?; - let nonblocking = init_stream.nonblocking(); + let nonblocking = init_stream.is_nonblocking(); let bound_socket = init_stream.bound_socket().unwrap(); let connected_stream = Arc::new(ConnectedStream::new( nonblocking, bound_socket, remote_endpoint, )); - *state = State::Connected(connected_stream); + *self.state.write() = State::Connected(connected_stream); Ok(()) } _ => return_errno_with_message!(Errno::EINVAL, "cannot connect"), @@ -138,7 +139,7 @@ impl Socket for StreamSocket { if !init_stream.is_bound() { return_errno_with_message!(Errno::EINVAL, "cannot listen without bound"); } - let nonblocking = init_stream.nonblocking(); + let nonblocking = init_stream.is_nonblocking(); let bound_socket = init_stream.bound_socket().unwrap(); let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?); *state = State::Listen(listener); diff --git a/services/libs/jinux-std/src/syscall/socket.rs b/services/libs/jinux-std/src/syscall/socket.rs index 38b3b64ab..c70df3156 100644 --- a/services/libs/jinux-std/src/syscall/socket.rs +++ b/services/libs/jinux-std/src/syscall/socket.rs @@ -17,14 +17,15 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result Arc::new(StreamSocket::new()) as Arc, + ) => Arc::new(StreamSocket::new(nonblocking)) as Arc, (SaFamily::AF_INET, SockType::SOCK_DGRAM, Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP) => { - Arc::new(DatagramSocket::new()) as Arc + Arc::new(DatagramSocket::new(nonblocking)) as Arc } _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"), };