Poll interfaces at the right time

This commit is contained in:
Ruihan Li
2024-04-21 19:56:37 -07:00
committed by Tate, Hongliang Tian
parent dbc234ada6
commit 4e1d98a323
2 changed files with 66 additions and 33 deletions

View File

@ -125,9 +125,18 @@ impl DatagramSocket {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); return_errno_with_message!(Errno::EAGAIN, "the socket is not bound");
}; };
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?; let received =
bound_datagram.update_io_events(&self.pollee); bound_datagram
Ok((recv_bytes, remote_endpoint.into())) .try_recvfrom(buf, flags)
.map(|(recv_bytes, remote_endpoint)| {
bound_datagram.update_io_events(&self.pollee);
(recv_bytes, remote_endpoint.into())
});
drop(inner);
poll_ifaces();
received
} }
fn try_sendto(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result<usize> { fn try_sendto(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result<usize> {
@ -137,9 +146,17 @@ impl DatagramSocket {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound") return_errno_with_message!(Errno::EAGAIN, "the socket is not bound")
}; };
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?; let sent_bytes = bound_datagram
bound_datagram.update_io_events(&self.pollee); .try_sendto(buf, remote, flags)
Ok(sent_bytes) .map(|sent_bytes| {
bound_datagram.update_io_events(&self.pollee);
sent_bytes
});
drop(inner);
poll_ifaces();
sent_bytes
} }
// TODO: Support timeout // TODO: Support timeout
@ -278,7 +295,6 @@ impl Socket for DatagramSocket {
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
debug_assert!(flags.is_all_supported()); debug_assert!(flags.is_all_supported());
poll_ifaces();
if self.is_nonblocking() { if self.is_nonblocking() {
self.try_recvfrom(buf, flags) self.try_recvfrom(buf, flags)
} else { } else {
@ -309,9 +325,7 @@ impl Socket for DatagramSocket {
}; };
// TODO: Block if the send buffer is full // TODO: Block if the send buffer is full
let sent_bytes = self.try_sendto(buf, &remote_endpoint, flags)?; self.try_sendto(buf, &remote_endpoint, flags)
poll_ifaces();
Ok(sent_bytes)
} }
} }

View File

@ -120,7 +120,7 @@ impl StreamSocket {
let is_nonblocking = self.is_nonblocking(); let is_nonblocking = self.is_nonblocking();
let mut state = self.state.write(); let mut state = self.state.write();
state.borrow_result(|mut owned_state| { let result_or_block = state.borrow_result(|mut owned_state| {
let init_stream = match owned_state { let init_stream = match owned_state {
State::Init(init_stream) => init_stream, State::Init(init_stream) => init_stream,
State::Connecting(_) if is_nonblocking => { State::Connecting(_) if is_nonblocking => {
@ -169,7 +169,12 @@ impl StreamSocket {
None None
}, },
) )
}) });
drop(state);
poll_ifaces();
result_or_block
} }
fn finish_connect(&self) -> Result<()> { fn finish_connect(&self) -> Result<()> {
@ -228,12 +233,18 @@ impl StreamSocket {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
}; };
let connected_stream = listen_stream.try_accept()?; let accepted = listen_stream.try_accept().map(|connected_stream| {
listen_stream.update_io_events(&self.pollee); listen_stream.update_io_events(&self.pollee);
let remote_endpoint = connected_stream.remote_endpoint(); let remote_endpoint = connected_stream.remote_endpoint();
let accepted_socket = Self::new_connected(connected_stream); let accepted_socket = Self::new_connected(connected_stream);
Ok((accepted_socket, remote_endpoint.into())) (accepted_socket as _, remote_endpoint.into())
});
drop(state);
poll_ifaces();
accepted
} }
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
@ -249,9 +260,17 @@ impl StreamSocket {
} }
}; };
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?; let received = connected_stream.try_recvfrom(buf, flags).map(|recv_bytes| {
connected_stream.update_io_events(&self.pollee); connected_stream.update_io_events(&self.pollee);
Ok((recv_bytes, connected_stream.remote_endpoint().into()))
let remote_endpoint = connected_stream.remote_endpoint();
(recv_bytes, remote_endpoint.into())
});
drop(state);
poll_ifaces();
received
} }
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> { fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
@ -270,9 +289,15 @@ impl StreamSocket {
} }
}; };
let sent_bytes = connected_stream.try_sendto(buf, flags)?; let sent_bytes = connected_stream.try_sendto(buf, flags).map(|sent_bytes| {
connected_stream.update_io_events(&self.pollee); connected_stream.update_io_events(&self.pollee);
Ok(sent_bytes) sent_bytes
});
drop(state);
poll_ifaces();
sent_bytes
} }
// TODO: Support timeout // TODO: Support timeout
@ -403,10 +428,8 @@ impl Socket for StreamSocket {
let remote_endpoint = socket_addr.try_into()?; let remote_endpoint = socket_addr.try_into()?;
if let Some(result) = self.start_connect(&remote_endpoint) { if let Some(result) = self.start_connect(&remote_endpoint) {
poll_ifaces();
return result; return result;
} }
poll_ifaces();
self.wait_events(IoEvents::OUT, || self.check_connect()) self.wait_events(IoEvents::OUT, || self.check_connect())
} }
@ -444,7 +467,6 @@ impl Socket for StreamSocket {
} }
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> { fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
poll_ifaces();
if self.is_nonblocking() { if self.is_nonblocking() {
self.try_accept() self.try_accept()
} else { } else {
@ -489,7 +511,6 @@ impl Socket for StreamSocket {
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
debug_assert!(flags.is_all_supported()); debug_assert!(flags.is_all_supported());
poll_ifaces();
if self.is_nonblocking() { if self.is_nonblocking() {
self.try_recvfrom(buf, flags) self.try_recvfrom(buf, flags)
} else { } else {
@ -509,13 +530,11 @@ impl Socket for StreamSocket {
// address is specified for a connection-mode socket. In practice, the destination address // address is specified for a connection-mode socket. In practice, the destination address
// is simply ignored. We follow the same behavior as the Linux implementation to ignore it. // is simply ignored. We follow the same behavior as the Linux implementation to ignore it.
let sent_bytes = if self.is_nonblocking() { if self.is_nonblocking() {
self.try_sendto(buf, flags)? self.try_sendto(buf, flags)
} else { } else {
self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))? self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))
}; }
poll_ifaces();
Ok(sent_bytes)
} }
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {