From 4e1d98a32392b4be01c2f6b3e0ce29e592c1fbad Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 21 Apr 2024 19:56:37 -0700 Subject: [PATCH] Poll interfaces at the right time --- .../src/net/socket/ip/datagram/mod.rs | 34 +++++++--- .../aster-nix/src/net/socket/ip/stream/mod.rs | 65 ++++++++++++------- 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs index 0e9270d82..384440a7f 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs @@ -125,9 +125,18 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); }; - let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?; - bound_datagram.update_io_events(&self.pollee); - Ok((recv_bytes, remote_endpoint.into())) + let received = + bound_datagram + .try_recvfrom(buf, flags) + .map(|(recv_bytes, remote_endpoint)| { + bound_datagram.update_io_events(&self.pollee); + (recv_bytes, remote_endpoint.into()) + }); + + drop(inner); + poll_ifaces(); + + received } fn try_sendto(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result { @@ -137,9 +146,17 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound") }; - let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?; - bound_datagram.update_io_events(&self.pollee); - Ok(sent_bytes) + let sent_bytes = bound_datagram + .try_sendto(buf, remote, flags) + .map(|sent_bytes| { + bound_datagram.update_io_events(&self.pollee); + sent_bytes + }); + + drop(inner); + poll_ifaces(); + + sent_bytes } // TODO: Support timeout @@ -278,7 +295,6 @@ impl Socket for DatagramSocket { fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { debug_assert!(flags.is_all_supported()); - poll_ifaces(); if self.is_nonblocking() { self.try_recvfrom(buf, flags) } else { @@ -309,9 +325,7 @@ impl Socket for DatagramSocket { }; // TODO: Block if the send buffer is full - let sent_bytes = self.try_sendto(buf, &remote_endpoint, flags)?; - poll_ifaces(); - Ok(sent_bytes) + self.try_sendto(buf, &remote_endpoint, flags) } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs index 5787b17a9..7f8129fd3 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs @@ -120,7 +120,7 @@ impl StreamSocket { let is_nonblocking = self.is_nonblocking(); let mut state = self.state.write(); - state.borrow_result(|mut owned_state| { + let result_or_block = state.borrow_result(|mut owned_state| { let init_stream = match owned_state { State::Init(init_stream) => init_stream, State::Connecting(_) if is_nonblocking => { @@ -169,7 +169,12 @@ impl StreamSocket { None }, ) - }) + }); + + drop(state); + poll_ifaces(); + + result_or_block } fn finish_connect(&self) -> Result<()> { @@ -228,12 +233,18 @@ impl StreamSocket { return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); }; - let connected_stream = listen_stream.try_accept()?; - listen_stream.update_io_events(&self.pollee); + let accepted = listen_stream.try_accept().map(|connected_stream| { + listen_stream.update_io_events(&self.pollee); - let remote_endpoint = connected_stream.remote_endpoint(); - let accepted_socket = Self::new_connected(connected_stream); - Ok((accepted_socket, remote_endpoint.into())) + let remote_endpoint = connected_stream.remote_endpoint(); + let accepted_socket = Self::new_connected(connected_stream); + (accepted_socket as _, remote_endpoint.into()) + }); + + drop(state); + poll_ifaces(); + + accepted } fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { @@ -249,9 +260,17 @@ impl StreamSocket { } }; - let recv_bytes = connected_stream.try_recvfrom(buf, flags)?; - connected_stream.update_io_events(&self.pollee); - Ok((recv_bytes, connected_stream.remote_endpoint().into())) + let received = connected_stream.try_recvfrom(buf, flags).map(|recv_bytes| { + connected_stream.update_io_events(&self.pollee); + + let remote_endpoint = connected_stream.remote_endpoint(); + (recv_bytes, remote_endpoint.into()) + }); + + drop(state); + poll_ifaces(); + + received } fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { @@ -270,9 +289,15 @@ impl StreamSocket { } }; - let sent_bytes = connected_stream.try_sendto(buf, flags)?; - connected_stream.update_io_events(&self.pollee); - Ok(sent_bytes) + let sent_bytes = connected_stream.try_sendto(buf, flags).map(|sent_bytes| { + connected_stream.update_io_events(&self.pollee); + sent_bytes + }); + + drop(state); + poll_ifaces(); + + sent_bytes } // TODO: Support timeout @@ -403,10 +428,8 @@ impl Socket for StreamSocket { let remote_endpoint = socket_addr.try_into()?; if let Some(result) = self.start_connect(&remote_endpoint) { - poll_ifaces(); return result; } - poll_ifaces(); self.wait_events(IoEvents::OUT, || self.check_connect()) } @@ -444,7 +467,6 @@ impl Socket for StreamSocket { } fn accept(&self) -> Result<(Arc, SocketAddr)> { - poll_ifaces(); if self.is_nonblocking() { self.try_accept() } else { @@ -489,7 +511,6 @@ impl Socket for StreamSocket { fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { debug_assert!(flags.is_all_supported()); - poll_ifaces(); if self.is_nonblocking() { self.try_recvfrom(buf, flags) } else { @@ -509,13 +530,11 @@ impl Socket for StreamSocket { // address is specified for a connection-mode socket. In practice, the destination address // is simply ignored. We follow the same behavior as the Linux implementation to ignore it. - let sent_bytes = if self.is_nonblocking() { - self.try_sendto(buf, flags)? + if self.is_nonblocking() { + self.try_sendto(buf, flags) } else { - self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))? - }; - poll_ifaces(); - Ok(sent_bytes) + self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags)) + } } fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {