Support nonblocking ip sockets

This commit is contained in:
Jianfeng Jiang 2023-06-15 16:22:10 +08:00 committed by Tate, Hongliang Tian
parent 809e477bdf
commit 445fb8eb76
6 changed files with 47 additions and 36 deletions

View File

@ -128,11 +128,11 @@ impl Inner {
} }
impl DatagramSocket { impl DatagramSocket {
pub fn new() -> Self { pub fn new(nonblocking: bool) -> Self {
let udp_socket = AnyUnboundSocket::new_udp(); let udp_socket = AnyUnboundSocket::new_udp();
Self { Self {
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))), inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))),
nonblocking: AtomicBool::new(false), nonblocking: AtomicBool::new(nonblocking),
} }
} }
@ -259,6 +259,9 @@ impl Socket for DatagramSocket {
} }
let events = self.inner.read().poll(IoEvents::IN, Some(&poller)); let events = self.inner.read().poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) { if !events.contains(IoEvents::IN) {
if self.nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to receive again");
}
poller.wait(); poller.wait();
} }
} }

View File

@ -19,12 +19,12 @@ pub struct ConnectedStream {
impl ConnectedStream { impl ConnectedStream {
pub fn new( pub fn new(
nonblocking: bool, is_nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
) -> Self { ) -> Self {
Self { Self {
nonblocking: AtomicBool::new(nonblocking), nonblocking: AtomicBool::new(is_nonblocking),
bound_socket, bound_socket,
remote_endpoint, remote_endpoint,
} }
@ -54,6 +54,9 @@ impl ConnectedStream {
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails"); return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
} }
if !events.contains(IoEvents::IN) { if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to recv again");
}
poller.wait(); poller.wait();
} }
} }
@ -101,11 +104,11 @@ impl ConnectedStream {
self.bound_socket.poll(mask, poller) self.bound_socket.poll(mask, poller)
} }
pub fn nonblocking(&self) -> bool { pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::SeqCst) self.nonblocking.load(Ordering::Relaxed)
} }
pub fn set_nonblocking(&self, nonblocking: bool) { pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::SeqCst); self.nonblocking.store(nonblocking, Ordering::Relaxed);
} }
} }

View File

@ -11,8 +11,7 @@ use crate::prelude::*;
pub struct InitStream { pub struct InitStream {
inner: RwLock<Inner>, inner: RwLock<Inner>,
// TODO: deal with nonblocking is_nonblocking: AtomicBool,
nonblocking: AtomicBool,
} }
enum Inner { enum Inner {
@ -114,11 +113,11 @@ impl Inner {
} }
impl InitStream { impl InitStream {
pub fn new() -> Self { pub fn new(nonblocking: bool) -> Self {
let socket = AnyUnboundSocket::new_tcp(); let socket = AnyUnboundSocket::new_tcp();
let inner = Inner::Unbound(AlwaysSome::new(socket)); let inner = Inner::Unbound(AlwaysSome::new(socket));
Self { Self {
nonblocking: AtomicBool::new(false), is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner), inner: RwLock::new(inner),
} }
} }
@ -149,7 +148,9 @@ impl InitStream {
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) { if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
return Ok(()); return Ok(());
} else if !events.is_empty() { } else if !events.is_empty() {
return_errno_with_message!(Errno::ECONNREFUSED, "connect refused") return_errno_with_message!(Errno::ECONNREFUSED, "connect refused");
} else if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try connect again");
} else { } else {
poller.wait(); poller.wait();
} }
@ -178,11 +179,11 @@ impl InitStream {
self.inner.read().bound_socket().map(Clone::clone) self.inner.read().bound_socket().map(Clone::clone)
} }
pub fn nonblocking(&self) -> bool { pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::SeqCst) self.is_nonblocking.load(Ordering::Relaxed)
} }
pub fn set_nonblocking(&self, nonblocking: bool) { pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::SeqCst); self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
} }
} }

View File

@ -9,7 +9,7 @@ use crate::{net::poll_ifaces, prelude::*};
use super::connected::ConnectedStream; use super::connected::ConnectedStream;
pub struct ListenStream { pub struct ListenStream {
nonblocking: AtomicBool, is_nonblocking: AtomicBool,
backlog: usize, backlog: usize,
/// Sockets also listening at LocalEndPoint when called `listen` /// Sockets also listening at LocalEndPoint when called `listen`
backlog_sockets: RwLock<Vec<BacklogSocket>>, backlog_sockets: RwLock<Vec<BacklogSocket>>,
@ -24,7 +24,7 @@ impl ListenStream {
debug_assert!(backlog >= 1); debug_assert!(backlog >= 1);
let backlog_socket = BacklogSocket::new(&bound_socket)?; let backlog_socket = BacklogSocket::new(&bound_socket)?;
let listen_stream = Self { let listen_stream = Self {
nonblocking: AtomicBool::new(nonblocking), is_nonblocking: AtomicBool::new(nonblocking),
backlog, backlog,
backlog_sockets: RwLock::new(vec![backlog_socket]), backlog_sockets: RwLock::new(vec![backlog_socket]),
}; };
@ -42,6 +42,9 @@ impl ListenStream {
} else { } else {
let events = self.poll(IoEvents::IN | IoEvents::OUT, Some(&poller)); let events = self.poll(IoEvents::IN | IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::IN) && !events.contains(IoEvents::OUT) { if !events.contains(IoEvents::IN) && !events.contains(IoEvents::OUT) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try accept again");
}
poller.wait(); poller.wait();
} }
continue; continue;
@ -51,8 +54,7 @@ impl ListenStream {
let BacklogSocket { let BacklogSocket {
bound_socket: backlog_socket, bound_socket: backlog_socket,
} = accepted_socket; } = accepted_socket;
let nonblocking = self.nonblocking(); ConnectedStream::new(false, backlog_socket, remote_endpoint)
ConnectedStream::new(nonblocking, backlog_socket, remote_endpoint)
}; };
return Ok((connected_stream, remote_endpoint)); return Ok((connected_stream, remote_endpoint));
} }
@ -110,12 +112,12 @@ impl ListenStream {
self.backlog_sockets.read()[0].bound_socket.clone() self.backlog_sockets.read()[0].bound_socket.clone()
} }
pub fn nonblocking(&self) -> bool { pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::SeqCst) self.is_nonblocking.load(Ordering::Relaxed)
} }
pub fn set_nonblocking(&self, nonblocking: bool) { pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::SeqCst); self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
} }
} }

View File

@ -31,18 +31,18 @@ enum State {
} }
impl StreamSocket { impl StreamSocket {
pub fn new() -> Self { pub fn new(nonblocking: bool) -> Self {
let state = State::Init(Arc::new(InitStream::new())); let state = State::Init(Arc::new(InitStream::new(nonblocking)));
Self { Self {
state: RwLock::new(state), state: RwLock::new(state),
} }
} }
fn nonblocking(&self) -> bool { fn is_nonblocking(&self) -> bool {
match &*self.state.read() { match &*self.state.read() {
State::Init(init) => init.nonblocking(), State::Init(init) => init.is_nonblocking(),
State::Connected(connected) => connected.nonblocking(), State::Connected(connected) => connected.is_nonblocking(),
State::Listen(listen) => listen.nonblocking(), State::Listen(listen) => listen.is_nonblocking(),
} }
} }
@ -79,7 +79,7 @@ impl FileLike for StreamSocket {
} }
fn status_flags(&self) -> StatusFlags { fn status_flags(&self) -> StatusFlags {
if self.nonblocking() { if self.is_nonblocking() {
StatusFlags::O_NONBLOCK StatusFlags::O_NONBLOCK
} else { } else {
StatusFlags::empty() StatusFlags::empty()
@ -112,19 +112,20 @@ impl Socket for StreamSocket {
fn connect(&self, sockaddr: SocketAddr) -> Result<()> { fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let remote_endpoint = sockaddr.try_into()?; let remote_endpoint = sockaddr.try_into()?;
let mut state = self.state.write(); let state = self.state.read();
// FIXME: The rwlock is held when trying to connect, which may cause dead lock.
match &*state { match &*state {
State::Init(init_stream) => { State::Init(init_stream) => {
let init_stream = init_stream.clone();
drop(state);
init_stream.connect(&remote_endpoint)?; init_stream.connect(&remote_endpoint)?;
let nonblocking = init_stream.nonblocking(); let nonblocking = init_stream.is_nonblocking();
let bound_socket = init_stream.bound_socket().unwrap(); let bound_socket = init_stream.bound_socket().unwrap();
let connected_stream = Arc::new(ConnectedStream::new( let connected_stream = Arc::new(ConnectedStream::new(
nonblocking, nonblocking,
bound_socket, bound_socket,
remote_endpoint, remote_endpoint,
)); ));
*state = State::Connected(connected_stream); *self.state.write() = State::Connected(connected_stream);
Ok(()) Ok(())
} }
_ => return_errno_with_message!(Errno::EINVAL, "cannot connect"), _ => return_errno_with_message!(Errno::EINVAL, "cannot connect"),
@ -138,7 +139,7 @@ impl Socket for StreamSocket {
if !init_stream.is_bound() { if !init_stream.is_bound() {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound"); return_errno_with_message!(Errno::EINVAL, "cannot listen without bound");
} }
let nonblocking = init_stream.nonblocking(); let nonblocking = init_stream.is_nonblocking();
let bound_socket = init_stream.bound_socket().unwrap(); let bound_socket = init_stream.bound_socket().unwrap();
let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?); let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?);
*state = State::Listen(listener); *state = State::Listen(listener);

View File

@ -17,14 +17,15 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result<SyscallRetur
"domain = {:?}, sock_type = {:?}, sock_flags = {:?}, protocol = {:?}", "domain = {:?}, sock_type = {:?}, sock_flags = {:?}, protocol = {:?}",
domain, sock_type, sock_flags, protocol domain, sock_type, sock_flags, protocol
); );
let nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK);
let file_like = match (domain, sock_type, protocol) { let file_like = match (domain, sock_type, protocol) {
( (
SaFamily::AF_INET, SaFamily::AF_INET,
SockType::SOCK_STREAM, SockType::SOCK_STREAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP, Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP,
) => Arc::new(StreamSocket::new()) as Arc<dyn FileLike>, ) => Arc::new(StreamSocket::new(nonblocking)) as Arc<dyn FileLike>,
(SaFamily::AF_INET, SockType::SOCK_DGRAM, Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP) => { (SaFamily::AF_INET, SockType::SOCK_DGRAM, Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP) => {
Arc::new(DatagramSocket::new()) as Arc<dyn FileLike> Arc::new(DatagramSocket::new(nonblocking)) as Arc<dyn FileLike>
} }
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"), _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
}; };