diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 9620ff02..1fe45fd6 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -256,8 +256,25 @@ pub enum ConnectState { Refused, } +#[derive(Debug, Clone, Copy)] +pub struct NeedIfacePoll(bool); + +impl NeedIfacePoll { + pub const FALSE: Self = Self(false); +} + +impl Deref for NeedIfacePoll { + type Target = bool; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl BoundTcpSocket { /// Connects to a remote endpoint. + /// + /// Polling the iface is _always_ required after this method succeeds. pub fn connect( &self, remote_endpoint: IpEndpoint, @@ -270,8 +287,7 @@ impl BoundTcpSocket { socket.connect(iface.context(), remote_endpoint, self.0.port)?; socket.has_connected = false; - self.0 - .update_next_poll_at_ms(socket.poll_at(iface.context())); + self.0.update_next_poll_at_ms(PollAt::Now); Ok(()) } @@ -290,6 +306,8 @@ impl BoundTcpSocket { } /// Listens at a specified endpoint. + /// + /// Polling the iface is _not_ required after this method succeeds. pub fn listen( &self, local_endpoint: IpEndpoint, @@ -299,30 +317,49 @@ impl BoundTcpSocket { socket.listen(local_endpoint) } - pub fn send(&self, f: F) -> Result + /// Sends some data. + /// + /// Polling the iface _may_ be required after this method succeeds. + pub fn send(&self, f: F) -> Result<(R, NeedIfacePoll), smoltcp::socket::tcp::SendError> where F: FnOnce(&mut [u8]) -> (usize, R), { + let common = self.iface().common(); + let mut iface = common.interface(); + let mut socket = self.0.socket.lock(); - let result = socket.send(f); - self.0.update_next_poll_at_ms(PollAt::Now); + let result = socket.send(f)?; + let need_poll = self + .0 + .update_next_poll_at_ms(socket.poll_at(iface.context())); - result + Ok((result, need_poll)) } - pub fn recv(&self, f: F) -> Result + /// Receives some data. + /// + /// Polling the iface _may_ be required after this method succeeds. + pub fn recv(&self, f: F) -> Result<(R, NeedIfacePoll), smoltcp::socket::tcp::RecvError> where F: FnOnce(&mut [u8]) -> (usize, R), { + let common = self.iface().common(); + let mut iface = common.interface(); + let mut socket = self.0.socket.lock(); - let result = socket.recv(f); - self.0.update_next_poll_at_ms(PollAt::Now); + let result = socket.recv(f)?; + let need_poll = self + .0 + .update_next_poll_at_ms(socket.poll_at(iface.context())); - result + Ok((result, need_poll)) } + /// Closes the connection. + /// + /// Polling the iface is _always_ required after this method succeeds. pub fn close(&self) { let mut socket = self.0.socket.lock(); @@ -345,12 +382,17 @@ impl BoundTcpSocket { impl BoundUdpSocket { /// Binds to a specified endpoint. + /// + /// Polling the iface is _not_ required after this method succeeds. pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> { let mut socket = self.0.socket.lock(); socket.bind(local_endpoint) } + /// Sends some data. + /// + /// Polling the iface is _always_ required after this method succeeds. pub fn send( &self, size: usize, @@ -381,6 +423,9 @@ impl BoundUdpSocket { Ok(result) } + /// Receives some data. + /// + /// Polling the iface is _not_ required after this method succeeds. pub fn recv(&self, f: F) -> Result where F: FnOnce(&[u8], UdpMetadata) -> R, @@ -389,7 +434,6 @@ impl BoundUdpSocket { let (data, meta) = socket.recv()?; let result = f(data, meta); - self.0.update_next_poll_at_ms(PollAt::Now); Ok(result) } @@ -448,13 +492,25 @@ impl BoundSocketInner { /// The update is typically needed after new network or user events have been handled, so this /// method also marks that there may be new events, so that the event observer provided by /// [`BoundSocket::set_observer`] can be notified later. - fn update_next_poll_at_ms(&self, poll_at: PollAt) { + fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll { match poll_at { - PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed), - PollAt::Time(instant) => self - .next_poll_at_ms - .store(instant.total_millis() as u64, Ordering::Relaxed), - PollAt::Ingress => self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed), + PollAt::Now => { + self.next_poll_at_ms.store(0, Ordering::Relaxed); + NeedIfacePoll(true) + } + PollAt::Time(instant) => { + let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed); + let new_total_millis = instant.total_millis() as u64; + + self.next_poll_at_ms + .store(new_total_millis, Ordering::Relaxed); + + NeedIfacePoll(new_total_millis < old_total_millis) + } + PollAt::Ingress => { + self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed); + NeedIfacePoll(false) + } } } } diff --git a/kernel/libs/aster-bigtcp/src/socket/mod.rs b/kernel/libs/aster-bigtcp/src/socket/mod.rs index 36feec33..4495afba 100644 --- a/kernel/libs/aster-bigtcp/src/socket/mod.rs +++ b/kernel/libs/aster-bigtcp/src/socket/mod.rs @@ -5,7 +5,7 @@ mod event; mod state; mod unbound; -pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState}; +pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState, NeedIfacePoll}; pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; pub use event::{SocketEventObserver, SocketEvents}; pub use state::TcpStateCheck; diff --git a/kernel/src/net/iface/init.rs b/kernel/src/net/iface/init.rs index e3bb2f98..49668fd2 100644 --- a/kernel/src/net/iface/init.rs +++ b/kernel/src/net/iface/init.rs @@ -6,7 +6,7 @@ use aster_bigtcp::device::WithDevice; use ostd::sync::LocalIrqDisabled; use spin::Once; -use super::{poll_ifaces, Iface}; +use super::{poll::poll_ifaces, Iface}; use crate::{ net::iface::ext::{IfaceEx, IfaceExt}, prelude::*, diff --git a/kernel/src/net/iface/mod.rs b/kernel/src/net/iface/mod.rs index d7a411d4..55438a2f 100644 --- a/kernel/src/net/iface/mod.rs +++ b/kernel/src/net/iface/mod.rs @@ -4,8 +4,9 @@ mod ext; mod init; mod poll; +pub use ext::IfaceEx; pub use init::{init, IFACES}; -pub use poll::{lazy_init, poll_ifaces}; +pub use poll::lazy_init; pub type Iface = dyn aster_bigtcp::iface::Iface; pub type BoundTcpSocket = aster_bigtcp::socket::BoundTcpSocket; diff --git a/kernel/src/net/iface/poll.rs b/kernel/src/net/iface/poll.rs index fb783f08..c2de2cb1 100644 --- a/kernel/src/net/iface/poll.rs +++ b/kernel/src/net/iface/poll.rs @@ -19,7 +19,7 @@ pub fn lazy_init() { } } -pub fn poll_ifaces() { +pub(super) fn poll_ifaces() { let ifaces = IFACES.get().unwrap(); for iface in ifaces.iter() { diff --git a/kernel/src/net/socket/ip/datagram/bound.rs b/kernel/src/net/socket/ip/datagram/bound.rs index 03b3d64d..8cb82b06 100644 --- a/kernel/src/net/socket/ip/datagram/bound.rs +++ b/kernel/src/net/socket/ip/datagram/bound.rs @@ -7,7 +7,10 @@ use aster_bigtcp::{ use crate::{ events::IoEvents, - net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags}, + net::{ + iface::{BoundUdpSocket, Iface}, + socket::util::send_recv_flags::SendRecvFlags, + }, prelude::*, util::{MultiRead, MultiWrite}, }; @@ -37,6 +40,10 @@ impl BoundDatagram { self.remote_endpoint = Some(*endpoint) } + pub fn iface(&self) -> &Arc { + self.bound_socket.iface() + } + pub fn try_recv( &self, writer: &mut dyn MultiWrite, diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 86538d3d..7618fcdc 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -18,7 +18,7 @@ use crate::{ }, match_sock_option_mut, net::{ - iface::poll_ifaces, + iface::IfaceEx, socket::{ options::{Error as SocketError, SocketOption}, util::{ @@ -157,14 +157,9 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); }; - let received = bound_datagram + bound_datagram .try_recv(writer, flags) - .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into())); - - drop(inner); - poll_ifaces(); - - received + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into())) } fn recv( @@ -191,12 +186,13 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound") }; - let sent_bytes = bound_datagram.try_send(reader, remote, flags); + let sent_bytes = bound_datagram.try_send(reader, remote, flags)?; + let iface_to_poll = bound_datagram.iface().clone(); drop(inner); - poll_ifaces(); + iface_to_poll.poll(); - sent_bytes + Ok(sent_bytes) } fn check_io_events(&self) -> IoEvents { diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index ca8a13c6..c6cd05fc 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -5,14 +5,14 @@ use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, - socket::{SocketEventObserver, TcpStateCheck}, + socket::{NeedIfacePoll, SocketEventObserver, TcpStateCheck}, wire::IpEndpoint, }; use crate::{ events::IoEvents, net::{ - iface::BoundTcpSocket, + iface::{BoundTcpSocket, Iface}, socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, }, prelude::*, @@ -79,7 +79,11 @@ impl ConnectedStream { Ok(()) } - pub fn try_recv(&self, writer: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result { + pub fn try_recv( + &self, + writer: &mut dyn MultiWrite, + _flags: SendRecvFlags, + ) -> Result<(usize, NeedIfacePoll)> { let result = self.bound_socket.recv(|socket_buffer| { match writer.write(&mut VmReader::from(&*socket_buffer)) { Ok(len) => (len, Ok(len)), @@ -88,18 +92,30 @@ impl ConnectedStream { }); match result { - Ok(Ok(0)) if self.is_receiving_closed.load(Ordering::Relaxed) => Ok(0), - Ok(Ok(0)) => return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"), - Ok(Ok(recv_bytes)) => Ok(recv_bytes), - Ok(Err(e)) => Err(e), - Err(RecvError::Finished) => Ok(0), + Ok((Ok(0), need_poll)) if self.is_receiving_closed.load(Ordering::Relaxed) => { + Ok((0, need_poll)) + } + Ok((Ok(0), need_poll)) => { + debug_assert!(!*need_poll); + return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty") + } + Ok((Ok(recv_bytes), need_poll)) => Ok((recv_bytes, need_poll)), + Ok((Err(e), need_poll)) => { + debug_assert!(!*need_poll); + Err(e) + } + Err(RecvError::Finished) => Ok((0, NeedIfacePoll::FALSE)), Err(RecvError::InvalidState) => { return_errno_with_message!(Errno::ECONNRESET, "the connection is reset") } } } - pub fn try_send(&self, reader: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result { + pub fn try_send( + &self, + reader: &mut dyn MultiRead, + _flags: SendRecvFlags, + ) -> Result<(usize, NeedIfacePoll)> { let result = self.bound_socket.send(|socket_buffer| { match reader.read(&mut VmWriter::from(socket_buffer)) { Ok(len) => (len, Ok(len)), @@ -108,9 +124,15 @@ impl ConnectedStream { }); match result { - Ok(Ok(0)) => return_errno_with_message!(Errno::EAGAIN, "the send buffer is full"), - Ok(Ok(sent_bytes)) => Ok(sent_bytes), - Ok(Err(e)) => Err(e), + Ok((Ok(0), need_poll)) => { + debug_assert!(!*need_poll); + return_errno_with_message!(Errno::EAGAIN, "the send buffer is full") + } + Ok((Ok(sent_bytes), need_poll)) => Ok((sent_bytes, need_poll)), + Ok((Err(e), need_poll)) => { + debug_assert!(!*need_poll); + Err(e) + } Err(SendError::InvalidState) => { // FIXME: `EPIPE` is another possibility, which means that the socket is shut down // for writing. In that case, we should also trigger a `SIGPIPE` if `MSG_NOSIGNAL` @@ -128,6 +150,10 @@ impl ConnectedStream { self.remote_endpoint } + pub fn iface(&self) -> &Arc { + self.bound_socket.iface() + } + pub fn check_new(&mut self) -> Result<()> { if !self.is_new_connection { return_errno_with_message!(Errno::EISCONN, "the socket is already connected"); diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index 4448dabf..c6b1eba7 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -3,7 +3,11 @@ use aster_bigtcp::{socket::ConnectState, wire::IpEndpoint}; use super::{connected::ConnectedStream, init::InitStream}; -use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*}; +use crate::{ + events::IoEvents, + net::iface::{BoundTcpSocket, Iface}, + prelude::*, +}; pub struct ConnectingStream { bound_socket: BoundTcpSocket, @@ -72,6 +76,10 @@ impl ConnectingStream { self.remote_endpoint } + pub fn iface(&self) -> &Arc { + self.bound_socket.iface() + } + pub(super) fn check_io_events(&self) -> IoEvents { IoEvents::empty() } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index a26f0e92..fba1163e 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -5,7 +5,11 @@ use aster_bigtcp::{ }; use super::connected::ConnectedStream; -use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*}; +use crate::{ + events::IoEvents, + net::iface::{BoundTcpSocket, Iface}, + prelude::*, +}; pub struct ListenStream { backlog: usize, @@ -80,6 +84,10 @@ impl ListenStream { self.bound_socket.local_endpoint().unwrap() } + pub fn iface(&self) -> &Arc { + self.bound_socket.iface() + } + pub(super) fn check_io_events(&self) -> IoEvents { let backlog_sockets = self.backlog_sockets.read(); diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 9c14d3c9..3a317767 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -24,7 +24,7 @@ use crate::{ }, match_sock_option_mut, match_sock_option_ref, net::{ - iface::poll_ifaces, + iface::IfaceEx, socket::{ options::{Error as SocketError, SocketOption}, util::{ @@ -190,32 +190,36 @@ impl StreamSocket { let is_nonblocking = self.is_nonblocking(); let mut state = self.write_updated_state(); - let result_or_block = state.borrow_result(|mut owned_state| { + let (result_or_block, iface_to_poll) = state.borrow_result(|mut owned_state| { let init_stream = match owned_state { State::Init(init_stream) => init_stream, State::Connecting(_) if is_nonblocking => { return ( owned_state, - Some(Err(Error::with_message( - Errno::EALREADY, - "the socket is connecting", - ))), + ( + Some(Err(Error::with_message( + Errno::EALREADY, + "the socket is connecting", + ))), + None, + ), ); } - State::Connecting(_) => { - return (owned_state, None); - } + State::Connecting(_) => return (owned_state, (None, None)), State::Connected(ref mut connected_stream) => { let err = connected_stream.check_new(); - return (owned_state, Some(err)); + return (owned_state, (Some(err), None)); } State::Listen(_) => { return ( owned_state, - Some(Err(Error::with_message( - Errno::EISCONN, - "the socket is listening", - ))), + ( + Some(Err(Error::with_message( + Errno::EISCONN, + "the socket is listening", + ))), + None, + ), ); } }; @@ -223,25 +227,30 @@ impl StreamSocket { let connecting_stream = match init_stream.connect(remote_endpoint) { Ok(connecting_stream) => connecting_stream, Err((err, init_stream)) => { - return (State::Init(init_stream), Some(Err(err))); + return (State::Init(init_stream), (Some(Err(err)), None)); } }; + let result_or_block = if is_nonblocking { + Some(Err(Error::with_message( + Errno::EINPROGRESS, + "the socket is connecting", + ))) + } else { + None + }; + let iface_to_poll = connecting_stream.iface().clone(); + ( State::Connecting(connecting_stream), - if is_nonblocking { - Some(Err(Error::with_message( - Errno::EINPROGRESS, - "the socket is connecting", - ))) - } else { - None - }, + (result_or_block, Some(iface_to_poll)), ) }); drop(state); - poll_ifaces(); + if let Some(iface) = iface_to_poll { + iface.poll(); + } result_or_block } @@ -274,9 +283,10 @@ impl StreamSocket { let accepted_socket = Self::new_connected(connected_stream); (accepted_socket as _, remote_endpoint.into()) }); + let iface_to_poll = listen_stream.iface().clone(); drop(state); - poll_ifaces(); + iface_to_poll.poll(); accepted } @@ -298,15 +308,16 @@ impl StreamSocket { } }; - let received = connected_stream.try_recv(writer, flags).map(|recv_bytes| { - let remote_endpoint = connected_stream.remote_endpoint(); - (recv_bytes, remote_endpoint.into()) - }); + let (recv_bytes, need_poll) = connected_stream.try_recv(writer, flags)?; + let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); + let remote_endpoint = connected_stream.remote_endpoint(); drop(state); - poll_ifaces(); + if let Some(iface) = iface_to_poll { + iface.poll(); + } - received + Ok((recv_bytes, remote_endpoint.into())) } fn recv( @@ -337,12 +348,15 @@ impl StreamSocket { } }; - let sent_bytes = connected_stream.try_send(reader, flags); + let (sent_bytes, need_poll) = connected_stream.try_send(reader, flags)?; + let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); drop(state); - poll_ifaces(); + if let Some(iface) = iface_to_poll { + iface.poll(); + } - sent_bytes + Ok(sent_bytes) } fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { @@ -498,16 +512,20 @@ impl Socket for StreamSocket { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { let state = self.read_updated_state(); - let res = match state.as_ref() { - State::Connected(connected_stream) => connected_stream.shutdown(cmd, &self.pollee), + + let (result, iface_to_poll) = match state.as_ref() { + State::Connected(connected_stream) => ( + connected_stream.shutdown(cmd, &self.pollee), + connected_stream.iface().clone(), + ), // TODO: shutdown listening stream _ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"), }; drop(state); - poll_ifaces(); + iface_to_poll.poll(); - res + result } fn addr(&self) -> Result { @@ -692,8 +710,18 @@ impl SocketEventObserver for StreamSocket { impl Drop for StreamSocket { fn drop(&mut self) { - self.state.write().take(); + let state = self.state.write().take(); - poll_ifaces(); + let iface_to_poll = match state { + State::Init(_) => None, + State::Connecting(ref connecting_stream) => Some(connecting_stream.iface().clone()), + State::Connected(ref connected_stream) => Some(connected_stream.iface().clone()), + State::Listen(ref listen_stream) => Some(listen_stream.iface().clone()), + }; + + drop(state); + if let Some(iface) = iface_to_poll { + iface.poll(); + } } }