Notify socket events directly

This commit is contained in:
Ruihan Li
2025-03-18 20:55:17 +08:00
committed by Tate, Hongliang Tian
parent 2f66f5d234
commit d9f3a7761a
5 changed files with 12 additions and 65 deletions

View File

@ -202,23 +202,6 @@ impl<E: Ext> IfaceCommon<E> {
} }
} }
// Notify all socket events.
for socket in sockets.tcp_listener_iter() {
if socket.has_events() {
socket.on_events();
}
}
for socket in sockets.tcp_conn_iter() {
if socket.has_events() {
socket.on_events();
}
}
for socket in sockets.udp_socket_iter() {
if socket.has_events() {
socket.on_events();
}
}
// Note that only TCP connections can have timers set, so as far as the time to poll is // Note that only TCP connections can have timers set, so as far as the time to poll is
// concerned, we only need to consider TCP connections. // concerned, we only need to consider TCP connections.
interface.next_poll_at_ms() interface.next_poll_at_ms()

View File

@ -1,7 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::sync::{Arc, Weak}; use alloc::sync::{Arc, Weak};
use core::sync::atomic::{AtomicU8, Ordering};
use smoltcp::wire::IpEndpoint; use smoltcp::wire::IpEndpoint;
use spin::once::Once; use spin::once::Once;
@ -45,7 +44,6 @@ pub struct SocketBg<T: Inner<E>, E: Ext> {
pub(super) bound: BoundPort<E>, pub(super) bound: BoundPort<E>,
pub(super) inner: T, pub(super) inner: T,
observer: Once<T::Observer>, observer: Once<T::Observer>,
events: AtomicU8,
} }
impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> { impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
@ -62,7 +60,6 @@ impl<T: Inner<E>, E: Ext> Socket<T, E> {
bound, bound,
inner, inner,
observer: Once::new(), observer: Once::new(),
events: AtomicU8::new(0),
}))) })))
} }
@ -74,7 +71,6 @@ impl<T: Inner<E>, E: Ext> Socket<T, E> {
bound, bound,
inner: inner_fn(weak), inner: inner_fn(weak),
observer: Once::new(), observer: Once::new(),
events: AtomicU8::new(0),
}))) })))
} }
@ -110,44 +106,24 @@ define_boolean_value!(
); );
impl<T: Inner<E>, E: Ext> SocketBg<T, E> { impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
pub(crate) fn has_events(&self) -> bool { pub(crate) fn notify_dead_events(self: Arc<Self>)
self.events.load(Ordering::Relaxed) != 0
}
pub(crate) fn on_events(&self) {
// This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed);
if let Some(observer) = self.observer.get() {
observer.on_events(SocketEvents::from_bits_truncate(events));
}
}
pub(crate) fn on_dead_events(self: Arc<Self>)
where where
T::Observer: Clone, T::Observer: Clone,
{ {
// There is no need to clear the events because the socket is dead.
let events = self.events.load(Ordering::Relaxed);
let observer = self.observer.get().cloned(); let observer = self.observer.get().cloned();
drop(self); drop(self);
// Notify dead events after the `Arc` is dropped to ensure the observer sees this event // Notify dead events after the `Arc` is dropped to ensure the observer sees this event
// with the expected reference count. See `TcpConnection::connect_state` for an example. // with the expected reference count. See `TcpConnection::connect_state` for an example.
if let Some(ref observer) = observer { if let Some(ref observer) = observer {
observer.on_events(SocketEvents::from_bits_truncate(events)); observer.on_events(SocketEvents::CLOSED_SEND | SocketEvents::CLOSED_RECV);
} }
} }
pub(super) fn add_events(&self, new_events: SocketEvents) { pub(super) fn notify_events(&self, new_events: SocketEvents) {
// This method can only be called to add network events, so we assume we are holding the if let Some(observer) = self.observer.get() {
// poll lock and no race conditions can occur. observer.on_events(new_events);
let events = self.events.load(Ordering::Relaxed); }
self.events
.store(events | new_events.bits(), Ordering::Relaxed);
} }
} }

View File

@ -165,7 +165,7 @@ impl<E: Ext> RawTcpSocketExt<E> {
if let Some(value) = backlog.connecting.remove(this.connection_key()) { if let Some(value) = backlog.connecting.remove(this.connection_key()) {
backlog.connected.push(value); backlog.connected.push(value);
} }
listener.add_events(SocketEvents::CAN_RECV); listener.notify_events(SocketEvents::CAN_RECV);
} }
} }
@ -621,7 +621,7 @@ impl<E: Ext> TcpConnectionBg<E> {
socket.check_state(self, old_state, old_recv_queue, is_rst); socket.check_state(self, old_state, old_recv_queue, is_rst);
events |= state_events; events |= state_events;
self.add_events(events); self.notify_events(events);
let poll_at = socket.poll_at(iface.context_mut()); let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(self, poll_at); iface.update_next_poll_at_ms(self, poll_at);
@ -669,7 +669,7 @@ impl<E: Ext> TcpConnectionBg<E> {
socket.check_state(self, old_state, old_recv_queue, is_rst); socket.check_state(self, old_state, old_recv_queue, is_rst);
events |= state_events; events |= state_events;
self.add_events(events); self.notify_events(events);
let poll_at = socket.poll_at(iface.context_mut()); let poll_at = socket.poll_at(iface.context_mut());
iface.update_next_poll_at_ms(self, poll_at); iface.update_next_poll_at_ms(self, poll_at);

View File

@ -62,7 +62,7 @@ impl<E: Ext> UdpSocketBg<E> {
udp_payload, udp_payload,
); );
self.add_events(SocketEvents::CAN_RECV); self.notify_events(SocketEvents::CAN_RECV);
true true
} }
@ -82,7 +82,7 @@ impl<E: Ext> UdpSocketBg<E> {
.unwrap(); .unwrap();
// For UDP, dequeuing a packet means that we can queue more packets. // For UDP, dequeuing a packet means that we can queue more packets.
self.add_events(SocketEvents::CAN_SEND); self.notify_events(SocketEvents::CAN_SEND);
self.inner self.inner
.need_dispatch .need_dispatch

View File

@ -295,7 +295,7 @@ impl<E: Ext> SocketTable<E> {
"there should be no need to poll a dead TCP connection", "there should be no need to poll a dead TCP connection",
); );
connection.on_dead_events(); connection.notify_dead_events();
} }
pub(crate) fn remove_udp_socket( pub(crate) fn remove_udp_socket(
@ -309,18 +309,6 @@ impl<E: Ext> SocketTable<E> {
Some(self.udp_sockets.swap_remove(index)) Some(self.udp_sockets.swap_remove(index))
} }
pub(crate) fn tcp_listener_iter(&self) -> impl Iterator<Item = &Arc<TcpListenerBg<E>>> {
self.listener_buckets
.iter()
.flat_map(|bucket| bucket.listeners.iter())
}
pub(crate) fn tcp_conn_iter(&self) -> impl Iterator<Item = &Arc<TcpConnectionBg<E>>> {
self.connection_buckets
.iter()
.flat_map(|bucket| bucket.connections.iter())
}
pub(crate) fn udp_socket_iter(&self) -> impl Iterator<Item = &Arc<UdpSocketBg<E>>> { pub(crate) fn udp_socket_iter(&self) -> impl Iterator<Item = &Arc<UdpSocketBg<E>>> {
self.udp_sockets.iter() self.udp_sockets.iter()
} }