// SPDX-License-Identifier: MPL-2.0 use alloc::sync::Arc; use core::sync::atomic::{AtomicU64, AtomicU8, Ordering}; use smoltcp::{socket::PollAt, time::Instant, wire::IpEndpoint}; use spin::once::Once; use takeable::Takeable; use crate::{ define_boolean_value, ext::Ext, iface::{BoundPort, Iface}, socket::event::{SocketEventObserver, SocketEvents}, }; pub struct Socket, E: Ext>(pub(super) Takeable>>); /// [`TcpConnectionInner`], [`TcpListenerInner`], or [`UdpSocketInner`]. /// /// [`TcpConnectionInner`]: super::tcp_conn::TcpConnectionInner /// [`TcpListenerInner`]: super::tcp_listen::TcpListenerInner /// [`UdpSocketInner`]: super::udp::UdpSocketInner pub trait Inner { type Observer: SocketEventObserver; /// Called by [`Socket::drop`]. fn on_drop(this: &Arc>) where E: Ext, Self: Sized; } /// Common states shared by [`TcpConnectionBg`], [`TcpListenerBg`], and [`UdpSocketBg`]. /// /// In the type name, `Bg` means "background". Its meaning is described below: /// - A foreground socket (e.g., [`TcpConnection`]) handles system calls from the user program. /// - A background socket (e.g., [`TcpConnectionBg`]) handles packets from the network. /// /// [`TcpConnectionBg`]: super::tcp_conn::TcpConnectionBg /// [`TcpListenerBg`]: super::tcp_listen::TcpListenerBg /// [`UdpSocketBg`]: super::udp::UdpSocketBg /// [`TcpConnection`]: super::tcp_conn::TcpConnection pub struct SocketBg, E: Ext> { pub(super) bound: BoundPort, pub(super) inner: T, observer: Once, events: AtomicU8, next_poll_at_ms: AtomicU64, } impl, E: Ext> Drop for Socket { fn drop(&mut self) { if self.0.is_usable() { T::on_drop(&self.0); } } } impl, E: Ext> Socket { pub(crate) fn new(bound: BoundPort, inner: T) -> Self { Self(Takeable::new(Arc::new(SocketBg { bound, inner, observer: Once::new(), events: AtomicU8::new(0), next_poll_at_ms: AtomicU64::new(u64::MAX), }))) } pub(crate) fn inner(&self) -> &Arc> { &self.0 } } impl, E: Ext> Socket { /// Initializes the observer whose `on_events` will be called when certain iface events happen. /// /// The caller needs to be responsible for race conditions if network events can occur /// simultaneously. /// /// Calling this method on a socket whose observer has already been initialized will have no /// effect. pub fn init_observer(&self, new_observer: T::Observer) { self.0.observer.call_once(|| new_observer); } pub fn local_endpoint(&self) -> Option { self.0.bound.endpoint() } pub fn iface(&self) -> &Arc> { self.0.bound.iface() } } define_boolean_value!( /// Whether the iface needs to be polled NeedIfacePoll ); impl, E: Ext> SocketBg { pub(crate) fn has_events(&self) -> bool { self.events.load(Ordering::Relaxed) != 0 } pub(crate) fn on_events(&self) { // This method can only be called to process network events, so we assume we are holding the // poll lock and no race conditions can occur. let events = self.events.load(Ordering::Relaxed); self.events.store(0, Ordering::Relaxed); if let Some(observer) = self.observer.get() { observer.on_events(SocketEvents::from_bits_truncate(events)); } } pub(crate) fn on_dead_events(self: Arc) where T::Observer: Clone, { // This method can only be called to process network events, so we assume we are holding the // poll lock and no race conditions can occur. let events = self.events.load(Ordering::Relaxed); self.events.store(0, Ordering::Relaxed); let observer = self.observer.get().cloned(); drop(self); // Notify dead events after the `Arc` is dropped to ensure the observer sees this event // with the expected reference count. See `TcpConnection::connect_state` for an example. if let Some(ref observer) = observer { observer.on_events(SocketEvents::from_bits_truncate(events)); } } pub(super) fn add_events(&self, new_events: SocketEvents) { // This method can only be called to add network events, so we assume we are holding the // poll lock and no race conditions can occur. let events = self.events.load(Ordering::Relaxed); self.events .store(events | new_events.bits(), Ordering::Relaxed); } /// Returns the next polling time. /// /// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required /// before new network or user events. pub(crate) fn next_poll_at_ms(&self) -> u64 { self.next_poll_at_ms.load(Ordering::Relaxed) } /// Updates the next polling time according to `poll_at`. /// /// The update is typically needed after new network or user events have been handled, so this /// method also marks that there may be new events, so that the event observer provided by /// [`Socket::init_observer`] can be notified later. pub(super) fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll { match poll_at { PollAt::Now => { self.next_poll_at_ms.store(0, Ordering::Relaxed); NeedIfacePoll::TRUE } PollAt::Time(instant) => { let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed); let new_total_millis = instant.total_millis() as u64; self.next_poll_at_ms .store(new_total_millis, Ordering::Relaxed); NeedIfacePoll(new_total_millis < old_total_millis) } PollAt::Ingress => { self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed); NeedIfacePoll::FALSE } } } } impl, E: Ext> SocketBg { /// Returns whether an incoming packet _may_ be processed by the socket. /// /// The check is intended to be lock-free and fast, but may have false positives. pub(crate) fn can_process(&self, dst_port: u16) -> bool { self.bound.port() == dst_port } /// Returns whether the socket _may_ generate an outgoing packet. /// /// The check is intended to be lock-free and fast, but may have false positives. pub(crate) fn need_dispatch(&self, now: Instant) -> bool { now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed) } }