Use Pollee as the socket observer

This commit is contained in:
Ruihan Li
2024-12-02 23:53:19 +08:00
committed by Tate, Hongliang Tian
parent fa76afb3a9
commit 1716f4f324
18 changed files with 242 additions and 188 deletions

View File

@ -1,6 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use crate::iface::ScheduleNextPoll; use crate::{iface::ScheduleNextPoll, socket::SocketEventObserver};
/// Extension to be implemented by users of this crate. /// Extension to be implemented by users of this crate.
/// ///
@ -13,4 +13,10 @@ use crate::iface::ScheduleNextPoll;
pub trait Ext { pub trait Ext {
/// The type for ifaces to schedule the next poll. /// The type for ifaces to schedule the next poll.
type ScheduleNextPoll: ScheduleNextPoll; type ScheduleNextPoll: ScheduleNextPoll;
/// The type for TCP sockets to observe events.
type TcpEventObserver: SocketEventObserver;
/// The type for UDP sockets to observe events.
type UdpEventObserver: SocketEventObserver;
} }

View File

@ -86,6 +86,7 @@ impl<E: Ext> IfaceCommon<E> {
&self, &self,
iface: Arc<dyn Iface<E>>, iface: Arc<dyn Iface<E>>,
socket: Box<UnboundTcpSocket>, socket: Box<UnboundTcpSocket>,
observer: E::TcpEventObserver,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> { ) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
let port = match self.bind_port(config) { let port = match self.bind_port(config) {
@ -93,7 +94,7 @@ impl<E: Ext> IfaceCommon<E> {
Err(err) => return Err((err, socket)), Err(err) => return Err((err, socket)),
}; };
let (raw_socket, observer) = socket.into_raw(); let raw_socket = socket.into_raw();
let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer); let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer);
let inserted = self let inserted = self
@ -109,6 +110,7 @@ impl<E: Ext> IfaceCommon<E> {
&self, &self,
iface: Arc<dyn Iface<E>>, iface: Arc<dyn Iface<E>>,
socket: Box<UnboundUdpSocket>, socket: Box<UnboundUdpSocket>,
observer: E::UdpEventObserver,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> { ) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
let port = match self.bind_port(config) { let port = match self.bind_port(config) {
@ -116,7 +118,7 @@ impl<E: Ext> IfaceCommon<E> {
Err(err) => return Err((err, socket)), Err(err) => return Err((err, socket)),
}; };
let (raw_socket, observer) = socket.into_raw(); let raw_socket = socket.into_raw();
let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer); let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer);
let inserted = self let inserted = self

View File

@ -37,19 +37,21 @@ impl<E: Ext> dyn Iface<E> {
pub fn bind_tcp( pub fn bind_tcp(
self: &Arc<Self>, self: &Arc<Self>,
socket: Box<UnboundTcpSocket>, socket: Box<UnboundTcpSocket>,
observer: E::TcpEventObserver,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> { ) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
let common = self.common(); let common = self.common();
common.bind_tcp(self.clone(), socket, config) common.bind_tcp(self.clone(), socket, observer, config)
} }
pub fn bind_udp( pub fn bind_udp(
self: &Arc<Self>, self: &Arc<Self>,
socket: Box<UnboundUdpSocket>, socket: Box<UnboundUdpSocket>,
observer: E::UdpEventObserver,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> { ) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
let common = self.common(); let common = self.common();
common.bind_udp(self.clone(), socket, config) common.bind_udp(self.clone(), socket, observer, config)
} }
/// Gets the name of the iface. /// Gets the name of the iface.

View File

@ -16,15 +16,18 @@ use smoltcp::{
}, },
}; };
use crate::socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; use crate::{
ext::Ext,
socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult},
};
pub(super) struct PollContext<'a, E> { pub(super) struct PollContext<'a, E: Ext> {
iface_cx: &'a mut Context, iface_cx: &'a mut Context,
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
} }
impl<'a, E> PollContext<'a, E> { impl<'a, E: Ext> PollContext<'a, E> {
#[allow(clippy::mutable_key_type)] #[allow(clippy::mutable_key_type)]
pub(super) fn new( pub(super) fn new(
iface_cx: &'a mut Context, iface_cx: &'a mut Context,
@ -44,7 +47,7 @@ impl<'a, E> PollContext<'a, E> {
pub(super) trait FnHelper<A, B, C, O>: FnMut(A, B, C) -> O {} pub(super) trait FnHelper<A, B, C, O>: FnMut(A, B, C) -> O {}
impl<A, B, C, O, F> FnHelper<A, B, C, O> for F where F: FnMut(A, B, C) -> O {} impl<A, B, C, O, F> FnHelper<A, B, C, O> for F where F: FnMut(A, B, C) -> O {}
impl<E> PollContext<'_, E> { impl<E: Ext> PollContext<'_, E> {
pub(super) fn poll_ingress<D, P, Q>( pub(super) fn poll_ingress<D, P, Q>(
&mut self, &mut self,
device: &mut D, device: &mut D,
@ -280,7 +283,7 @@ impl<E> PollContext<'_, E> {
} }
} }
impl<E> PollContext<'_, E> { impl<E: Ext> PollContext<'_, E> {
pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, mut dispatch_phy: Q) pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, mut dispatch_phy: Q)
where where
D: Device + ?Sized, D: Device + ?Sized,

View File

@ -1,9 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::{ use alloc::{boxed::Box, sync::Arc};
boxed::Box,
sync::{Arc, Weak},
};
use core::{ use core::{
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
@ -23,18 +20,20 @@ use super::{
}; };
use crate::{ext::Ext, iface::Iface}; use crate::{ext::Ext, iface::Iface};
pub struct BoundSocket<T: AnySocket, E: Ext>(Arc<BoundSocketInner<T, E>>); pub struct BoundSocket<T: AnySocket<E>, E: Ext>(Arc<BoundSocketInner<T, E>>);
/// [`TcpSocket`] or [`UdpSocket`]. /// [`TcpSocket`] or [`UdpSocket`].
pub trait AnySocket { pub trait AnySocket<E> {
type RawSocket; type RawSocket;
type Observer: SocketEventObserver;
/// Called by [`BoundSocket::new`]. /// Called by [`BoundSocket::new`].
fn new(socket: Box<Self::RawSocket>) -> Self; fn new(socket: Box<Self::RawSocket>) -> Self;
/// Called by [`BoundSocket::drop`]. /// Called by [`BoundSocket::drop`].
fn on_drop<E: Ext>(this: &Arc<BoundSocketInner<Self, E>>) fn on_drop(this: &Arc<BoundSocketInner<Self, E>>)
where where
E: Ext,
Self: Sized; Self: Sized;
} }
@ -42,11 +41,11 @@ pub type BoundTcpSocket<E> = BoundSocket<TcpSocket, E>;
pub type BoundUdpSocket<E> = BoundSocket<UdpSocket, E>; pub type BoundUdpSocket<E> = BoundSocket<UdpSocket, E>;
/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`]. /// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`].
pub struct BoundSocketInner<T, E> { pub struct BoundSocketInner<T: AnySocket<E>, E> {
iface: Arc<dyn Iface<E>>, iface: Arc<dyn Iface<E>>,
port: u16, port: u16,
socket: T, socket: T,
observer: RwLock<Weak<dyn SocketEventObserver>, WriteIrqDisabled>, observer: RwLock<T::Observer, WriteIrqDisabled>,
events: AtomicU8, events: AtomicU8,
next_poll_at_ms: AtomicU64, next_poll_at_ms: AtomicU64,
} }
@ -137,8 +136,9 @@ impl TcpSocket {
} }
} }
impl AnySocket for TcpSocket { impl<E: Ext> AnySocket<E> for TcpSocket {
type RawSocket = RawTcpSocket; type RawSocket = RawTcpSocket;
type Observer = E::TcpEventObserver;
fn new(socket: Box<Self::RawSocket>) -> Self { fn new(socket: Box<Self::RawSocket>) -> Self {
let socket_ext = RawTcpSocketExt { let socket_ext = RawTcpSocketExt {
@ -153,7 +153,7 @@ impl AnySocket for TcpSocket {
} }
} }
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) { fn on_drop(this: &Arc<BoundSocketInner<Self, E>>) {
let mut socket = this.socket.lock(); let mut socket = this.socket.lock();
socket.in_background = true; socket.in_background = true;
@ -169,14 +169,18 @@ impl AnySocket for TcpSocket {
/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`]. /// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`].
type UdpSocket = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>; type UdpSocket = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
impl AnySocket for UdpSocket { impl<E: Ext> AnySocket<E> for UdpSocket {
type RawSocket = RawUdpSocket; type RawSocket = RawUdpSocket;
type Observer = E::UdpEventObserver;
fn new(socket: Box<Self::RawSocket>) -> Self { fn new(socket: Box<Self::RawSocket>) -> Self {
Self::new(socket) Self::new(socket)
} }
fn on_drop<E: Ext>(this: &Arc<BoundSocketInner<Self, E>>) { fn on_drop(this: &Arc<BoundSocketInner<Self, E>>)
where
E: Ext,
{
this.socket.lock().close(); this.socket.lock().close();
// A UDP socket can be removed immediately. // A UDP socket can be removed immediately.
@ -184,7 +188,7 @@ impl AnySocket for UdpSocket {
} }
} }
impl<T: AnySocket, E: Ext> Drop for BoundSocket<T, E> { impl<T: AnySocket<E>, E: Ext> Drop for BoundSocket<T, E> {
fn drop(&mut self) { fn drop(&mut self) {
T::on_drop(&self.0); T::on_drop(&self.0);
} }
@ -193,12 +197,12 @@ impl<T: AnySocket, E: Ext> Drop for BoundSocket<T, E> {
pub(crate) type BoundTcpSocketInner<E> = BoundSocketInner<TcpSocket, E>; pub(crate) type BoundTcpSocketInner<E> = BoundSocketInner<TcpSocket, E>;
pub(crate) type BoundUdpSocketInner<E> = BoundSocketInner<UdpSocket, E>; pub(crate) type BoundUdpSocketInner<E> = BoundSocketInner<UdpSocket, E>;
impl<T: AnySocket, E: Ext> BoundSocket<T, E> { impl<T: AnySocket<E>, E: Ext> BoundSocket<T, E> {
pub(crate) fn new( pub(crate) fn new(
iface: Arc<dyn Iface<E>>, iface: Arc<dyn Iface<E>>,
port: u16, port: u16,
socket: Box<T::RawSocket>, socket: Box<T::RawSocket>,
observer: Weak<dyn SocketEventObserver>, observer: T::Observer,
) -> Self { ) -> Self {
Self(Arc::new(BoundSocketInner { Self(Arc::new(BoundSocketInner {
iface, iface,
@ -215,24 +219,13 @@ impl<T: AnySocket, E: Ext> BoundSocket<T, E> {
} }
} }
impl<T: AnySocket, E: Ext> BoundSocket<T, E> { impl<T: AnySocket<E>, E: Ext> BoundSocket<T, E> {
/// Sets the observer whose `on_events` will be called when certain iface events happen. After /// Sets the observer whose `on_events` will be called when certain iface events happen.
/// setting, the new observer will fire once immediately to avoid missing any events.
/// ///
/// If there is an existing observer, due to race conditions, this function does not guarantee /// The caller needs to be responsible for race conditions if network events can occur
/// that the old observer will never be called after the setting. Users should be aware of this /// simultaneously.
/// and proactively handle the race conditions if necessary. pub fn set_observer(&self, new_observer: T::Observer) {
pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) {
*self.0.observer.write() = new_observer; *self.0.observer.write() = new_observer;
self.0.on_events();
}
/// Returns the observer.
///
/// See also [`Self::set_observer`].
pub fn observer(&self) -> Weak<dyn SocketEventObserver> {
self.0.observer.read().clone()
} }
pub fn local_endpoint(&self) -> Option<IpEndpoint> { pub fn local_endpoint(&self) -> Option<IpEndpoint> {
@ -449,7 +442,7 @@ impl<E: Ext> BoundUdpSocket<E> {
} }
} }
impl<T, E> BoundSocketInner<T, E> { impl<T: AnySocket<E>, E> BoundSocketInner<T, E> {
pub(crate) fn has_events(&self) -> bool { pub(crate) fn has_events(&self) -> bool {
self.events.load(Ordering::Relaxed) != 0 self.events.load(Ordering::Relaxed) != 0
} }
@ -460,13 +453,8 @@ impl<T, E> BoundSocketInner<T, E> {
let events = self.events.load(Ordering::Relaxed); let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed); self.events.store(0, Ordering::Relaxed);
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we let observer = self.observer.read();
// get the read lock. observer.on_events(SocketEvents::from_bits_truncate(events));
let observer = Weak::upgrade(&*self.observer.read());
if let Some(inner) = observer {
inner.on_events(SocketEvents::from_bits_truncate(events));
}
} }
fn add_events(&self, new_events: SocketEvents) { fn add_events(&self, new_events: SocketEvents) {
@ -513,13 +501,13 @@ impl<T, E> BoundSocketInner<T, E> {
} }
} }
impl<T, E> BoundSocketInner<T, E> { impl<T: AnySocket<E>, E> BoundSocketInner<T, E> {
pub(crate) fn port(&self) -> u16 { pub(crate) fn port(&self) -> u16 {
self.port self.port
} }
} }
impl<E> BoundTcpSocketInner<E> { impl<E: Ext> BoundTcpSocketInner<E> {
/// Returns whether the TCP socket is dead. /// Returns whether the TCP socket is dead.
/// ///
/// A TCP socket is considered dead if and only if the following two conditions are met: /// A TCP socket is considered dead if and only if the following two conditions are met:
@ -531,7 +519,7 @@ impl<E> BoundTcpSocketInner<E> {
} }
} }
impl<T, E> BoundSocketInner<T, E> { impl<T: AnySocket<E>, E> BoundSocketInner<T, E> {
/// Returns whether an incoming packet _may_ be processed by the socket. /// Returns whether an incoming packet _may_ be processed by the socket.
/// ///
/// The check is intended to be lock-free and fast, but may have false positives. /// The check is intended to be lock-free and fast, but may have false positives.
@ -554,7 +542,7 @@ pub(crate) enum TcpProcessResult {
ProcessedWithReply(IpRepr, TcpRepr<'static>), ProcessedWithReply(IpRepr, TcpRepr<'static>),
} }
impl<E> BoundTcpSocketInner<E> { impl<E: Ext> BoundTcpSocketInner<E> {
/// Tries to process an incoming packet and returns whether the packet is processed. /// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process( pub(crate) fn process(
&self, &self,
@ -654,7 +642,7 @@ impl<E> BoundTcpSocketInner<E> {
} }
} }
impl<E> BoundUdpSocketInner<E> { impl<E: Ext> BoundUdpSocketInner<E> {
/// Tries to process an incoming packet and returns whether the packet is processed. /// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process( pub(crate) fn process(
&self, &self,

View File

@ -1,19 +1,18 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Weak, vec}; use alloc::{boxed::Box, vec};
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; use super::{RawTcpSocket, RawUdpSocket};
pub struct UnboundSocket<T> { pub struct UnboundSocket<T> {
socket: Box<T>, socket: Box<T>,
observer: Weak<dyn SocketEventObserver>,
} }
pub type UnboundTcpSocket = UnboundSocket<RawTcpSocket>; pub type UnboundTcpSocket = UnboundSocket<RawTcpSocket>;
pub type UnboundUdpSocket = UnboundSocket<RawUdpSocket>; pub type UnboundUdpSocket = UnboundSocket<RawUdpSocket>;
impl UnboundTcpSocket { impl UnboundTcpSocket {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self { pub fn new() -> Self {
let raw_tcp_socket = { let raw_tcp_socket = {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]); let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]);
@ -21,13 +20,18 @@ impl UnboundTcpSocket {
}; };
Self { Self {
socket: Box::new(raw_tcp_socket), socket: Box::new(raw_tcp_socket),
observer,
} }
} }
} }
impl Default for UnboundTcpSocket {
fn default() -> Self {
Self::new()
}
}
impl UnboundUdpSocket { impl UnboundUdpSocket {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self { pub fn new() -> Self {
let raw_udp_socket = { let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
@ -42,14 +46,19 @@ impl UnboundUdpSocket {
}; };
Self { Self {
socket: Box::new(raw_udp_socket), socket: Box::new(raw_udp_socket),
observer,
} }
} }
} }
impl Default for UnboundUdpSocket {
fn default() -> Self {
Self::new()
}
}
impl<T> UnboundSocket<T> { impl<T> UnboundSocket<T> {
pub(crate) fn into_raw(self) -> (Box<T>, Weak<dyn SocketEventObserver>) { pub(crate) fn into_raw(self) -> Box<T> {
(self.socket, self.observer) self.socket
} }
} }

View File

@ -1,9 +1,13 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use super::sched::PollScheduler; use super::sched::PollScheduler;
use crate::net::socket::ip::{datagram::DatagramObserver, stream::StreamObserver};
pub struct BigtcpExt; pub struct BigtcpExt;
impl aster_bigtcp::ext::Ext for BigtcpExt { impl aster_bigtcp::ext::Ext for BigtcpExt {
type ScheduleNextPoll = PollScheduler; type ScheduleNextPoll = PollScheduler;
type TcpEventObserver = StreamObserver;
type UdpEventObserver = DatagramObserver;
} }

View File

@ -2,10 +2,7 @@
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{ use aster_bigtcp::wire::IpEndpoint;
socket::{SocketEventObserver, SocketEvents},
wire::IpEndpoint,
};
use ostd::sync::WriteIrqDisabled; use ostd::sync::WriteIrqDisabled;
use takeable::Takeable; use takeable::Takeable;
@ -32,8 +29,11 @@ use crate::{
}; };
mod bound; mod bound;
mod observer;
mod unbound; mod unbound;
pub(in crate::net) use self::observer::DatagramObserver;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct OptionSet { struct OptionSet {
socket: SocketOptionSet, socket: SocketOptionSet,
@ -64,6 +64,7 @@ impl Inner {
self, self,
endpoint: &IpEndpoint, endpoint: &IpEndpoint,
can_reuse: bool, can_reuse: bool,
observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> { ) -> core::result::Result<BoundDatagram, (Error, Self)> {
let unbound_datagram = match self { let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram, Inner::Unbound(unbound_datagram) => unbound_datagram,
@ -75,7 +76,7 @@ impl Inner {
} }
}; };
let bound_datagram = match unbound_datagram.bind(endpoint, can_reuse) { let bound_datagram = match unbound_datagram.bind(endpoint, can_reuse, observer) {
Ok(bound_datagram) => bound_datagram, Ok(bound_datagram) => bound_datagram,
Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))), Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))),
}; };
@ -85,26 +86,25 @@ impl Inner {
fn bind_to_ephemeral_endpoint( fn bind_to_ephemeral_endpoint(
self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> { ) -> core::result::Result<BoundDatagram, (Error, Self)> {
if let Inner::Bound(bound_datagram) = self { if let Inner::Bound(bound_datagram) = self {
return Ok(bound_datagram); return Ok(bound_datagram);
} }
let endpoint = get_ephemeral_endpoint(remote_endpoint); let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, false) self.bind(&endpoint, false, observer)
} }
} }
impl DatagramSocket { impl DatagramSocket {
pub fn new(nonblocking: bool) -> Arc<Self> { pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| { let unbound_datagram = UnboundDatagram::new();
let unbound_datagram = UnboundDatagram::new(me.clone() as _); Arc::new(Self {
Self {
inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))),
nonblocking: AtomicBool::new(nonblocking), nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(), pollee: Pollee::new(),
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
}
}) })
} }
@ -134,7 +134,10 @@ impl DatagramSocket {
// Slow path // Slow path
let mut inner = self.inner.write(); let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| { inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) { let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(
remote_endpoint,
DatagramObserver::new(self.pollee.clone()),
) {
Ok(bound_datagram) => bound_datagram, Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => { Err((err, err_inner)) => {
return (err_inner, Err(err)); return (err_inner, Err(err));
@ -277,7 +280,11 @@ impl Socket for DatagramSocket {
let can_reuse = self.options.read().socket.reuse_addr(); let can_reuse = self.options.read().socket.reuse_addr();
let mut inner = self.inner.write(); let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| { inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(&endpoint, can_reuse) { let bound_datagram = match owned_inner.bind(
&endpoint,
can_reuse,
DatagramObserver::new(self.pollee.clone()),
) {
Ok(bound_datagram) => bound_datagram, Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => { Err((err, err_inner)) => {
return (err_inner, Err(err)); return (err_inner, Err(err));
@ -389,19 +396,3 @@ impl Socket for DatagramSocket {
self.options.write().socket.set_option(option) self.options.write().socket.set_option(option)
} }
} }
impl SocketEventObserver for DatagramSocket {
fn on_events(&self, events: SocketEvents) {
let mut io_events = IoEvents::empty();
if events.contains(SocketEvents::CAN_RECV) {
io_events |= IoEvents::IN;
}
if events.contains(SocketEvents::CAN_SEND) {
io_events |= IoEvents::OUT;
}
self.pollee.notify(io_events);
}
}

View File

@ -0,0 +1,29 @@
// SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::socket::{SocketEventObserver, SocketEvents};
use crate::{events::IoEvents, process::signal::Pollee};
pub struct DatagramObserver(Pollee);
impl DatagramObserver {
pub(super) fn new(pollee: Pollee) -> Self {
Self(pollee)
}
}
impl SocketEventObserver for DatagramObserver {
fn on_events(&self, events: SocketEvents) {
let mut io_events = IoEvents::empty();
if events.contains(SocketEvents::CAN_RECV) {
io_events |= IoEvents::IN;
}
if events.contains(SocketEvents::CAN_SEND) {
io_events |= IoEvents::OUT;
}
self.0.notify(io_events);
}
}

View File

@ -1,13 +1,8 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::sync::Weak; use aster_bigtcp::{socket::UnboundUdpSocket, wire::IpEndpoint};
use aster_bigtcp::{ use super::{bound::BoundDatagram, DatagramObserver};
socket::{SocketEventObserver, UnboundUdpSocket},
wire::IpEndpoint,
};
use super::bound::BoundDatagram;
use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*}; use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*};
pub struct UnboundDatagram { pub struct UnboundDatagram {
@ -15,9 +10,9 @@ pub struct UnboundDatagram {
} }
impl UnboundDatagram { impl UnboundDatagram {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self { pub fn new() -> Self {
Self { Self {
unbound_socket: Box::new(UnboundUdpSocket::new(observer)), unbound_socket: Box::new(UnboundUdpSocket::new()),
} }
} }
@ -25,12 +20,13 @@ impl UnboundDatagram {
self, self,
endpoint: &IpEndpoint, endpoint: &IpEndpoint,
can_reuse: bool, can_reuse: bool,
observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> { ) -> core::result::Result<BoundDatagram, (Error, Self)> {
let bound_socket = match bind_socket( let bound_socket = match bind_socket(
self.unbound_socket, self.unbound_socket,
endpoint, endpoint,
can_reuse, can_reuse,
|iface, socket, config| iface.bind_udp(socket, config), |iface, socket, config| iface.bind_udp(socket, observer, config),
) { ) {
Ok(bound_socket) => bound_socket, Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })), Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),

View File

@ -2,9 +2,7 @@
mod addr; mod addr;
mod common; mod common;
mod datagram; pub mod datagram;
pub mod stream; pub mod stream;
use addr::UNSPECIFIED_LOCAL_ENDPOINT; use addr::UNSPECIFIED_LOCAL_ENDPOINT;
pub use datagram::DatagramSocket;
pub use stream::StreamSocket;

View File

@ -1,14 +1,14 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::sync::Weak;
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{ use aster_bigtcp::{
errors::tcp::{RecvError, SendError}, errors::tcp::{RecvError, SendError},
socket::{NeedIfacePoll, SocketEventObserver, TcpStateCheck}, socket::{NeedIfacePoll, TcpStateCheck},
wire::IpEndpoint, wire::IpEndpoint,
}; };
use super::StreamObserver;
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::{ net::{
@ -202,7 +202,7 @@ impl ConnectedStream {
}) })
} }
pub(super) fn set_observer(&self, observer: Weak<dyn SocketEventObserver>) { pub(super) fn set_observer(&self, observer: StreamObserver) {
self.bound_socket.set_observer(observer) self.bound_socket.set_observer(observer)
} }
} }

View File

@ -1,13 +1,8 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::sync::Weak; use aster_bigtcp::{socket::UnboundTcpSocket, wire::IpEndpoint};
use aster_bigtcp::{ use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver};
socket::{SocketEventObserver, UnboundTcpSocket},
wire::IpEndpoint,
};
use super::{connecting::ConnectingStream, listen::ListenStream};
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::{ net::{
@ -15,6 +10,7 @@ use crate::{
socket::ip::common::{bind_socket, get_ephemeral_endpoint}, socket::ip::common::{bind_socket, get_ephemeral_endpoint},
}, },
prelude::*, prelude::*,
process::signal::Pollee,
}; };
pub enum InitStream { pub enum InitStream {
@ -23,8 +19,8 @@ pub enum InitStream {
} }
impl InitStream { impl InitStream {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self { pub fn new() -> Self {
InitStream::Unbound(Box::new(UnboundTcpSocket::new(observer))) InitStream::Unbound(Box::new(UnboundTcpSocket::new()))
} }
pub fn new_bound(bound_socket: BoundTcpSocket) -> Self { pub fn new_bound(bound_socket: BoundTcpSocket) -> Self {
@ -35,6 +31,7 @@ impl InitStream {
self, self,
endpoint: &IpEndpoint, endpoint: &IpEndpoint,
can_reuse: bool, can_reuse: bool,
observer: StreamObserver,
) -> core::result::Result<BoundTcpSocket, (Error, Self)> { ) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
let unbound_socket = match self { let unbound_socket = match self {
InitStream::Unbound(unbound_socket) => unbound_socket, InitStream::Unbound(unbound_socket) => unbound_socket,
@ -49,7 +46,7 @@ impl InitStream {
unbound_socket, unbound_socket,
endpoint, endpoint,
can_reuse, can_reuse,
|iface, socket, config| iface.bind_tcp(socket, config), |iface, socket, config| iface.bind_tcp(socket, observer, config),
) { ) {
Ok(bound_socket) => bound_socket, Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))), Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
@ -60,25 +57,32 @@ impl InitStream {
fn bind_to_ephemeral_endpoint( fn bind_to_ephemeral_endpoint(
self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
observer: StreamObserver,
) -> core::result::Result<BoundTcpSocket, (Error, Self)> { ) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint); let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, false) self.bind(&endpoint, false, observer)
} }
pub fn connect( pub fn connect(
self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
pollee: &Pollee,
) -> core::result::Result<ConnectingStream, (Error, Self)> { ) -> core::result::Result<ConnectingStream, (Error, Self)> {
let bound_socket = match self { let bound_socket = match self {
InitStream::Bound(bound_socket) => bound_socket, InitStream::Bound(bound_socket) => bound_socket,
InitStream::Unbound(_) => self.bind_to_ephemeral_endpoint(remote_endpoint)?, InitStream::Unbound(_) => self
.bind_to_ephemeral_endpoint(remote_endpoint, StreamObserver::new(pollee.clone()))?,
}; };
ConnectingStream::new(bound_socket, *remote_endpoint) ConnectingStream::new(bound_socket, *remote_endpoint)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
} }
pub fn listen(self, backlog: usize) -> core::result::Result<ListenStream, (Error, Self)> { pub fn listen(
self,
backlog: usize,
pollee: &Pollee,
) -> core::result::Result<ListenStream, (Error, Self)> {
let InitStream::Bound(bound_socket) = self else { let InitStream::Bound(bound_socket) = self else {
// FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral // FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral
// port. However, INADDR_ANY is not yet supported, so we need to return an error first. // port. However, INADDR_ANY is not yet supported, so we need to return an error first.
@ -89,7 +93,7 @@ impl InitStream {
)); ));
}; };
ListenStream::new(bound_socket, backlog) ListenStream::new(bound_socket, backlog, pollee)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) .map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
} }

View File

@ -5,11 +5,12 @@ use aster_bigtcp::{
}; };
use ostd::sync::WriteIrqDisabled; use ostd::sync::WriteIrqDisabled;
use super::connected::ConnectedStream; use super::{connected::ConnectedStream, StreamObserver};
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::iface::{BoundTcpSocket, Iface}, net::iface::{BoundTcpSocket, Iface},
prelude::*, prelude::*,
process::signal::Pollee,
}; };
pub struct ListenStream { pub struct ListenStream {
@ -24,6 +25,7 @@ impl ListenStream {
pub fn new( pub fn new(
bound_socket: BoundTcpSocket, bound_socket: BoundTcpSocket,
backlog: usize, backlog: usize,
pollee: &Pollee,
) -> core::result::Result<Self, (Error, BoundTcpSocket)> { ) -> core::result::Result<Self, (Error, BoundTcpSocket)> {
const SOMAXCONN: usize = 4096; const SOMAXCONN: usize = 4096;
let somaxconn = SOMAXCONN.min(backlog); let somaxconn = SOMAXCONN.min(backlog);
@ -33,14 +35,14 @@ impl ListenStream {
bound_socket, bound_socket,
backlog_sockets: RwLock::new(Vec::new()), backlog_sockets: RwLock::new(Vec::new()),
}; };
if let Err(err) = listen_stream.fill_backlog_sockets() { if let Err(err) = listen_stream.fill_backlog_sockets(pollee) {
return Err((err, listen_stream.bound_socket)); return Err((err, listen_stream.bound_socket));
} }
Ok(listen_stream) Ok(listen_stream)
} }
/// Append sockets listening at LocalEndPoint to support backlog /// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> { fn fill_backlog_sockets(&self, pollee: &Pollee) -> Result<()> {
let mut backlog_sockets = self.backlog_sockets.write(); let mut backlog_sockets = self.backlog_sockets.write();
let backlog = self.backlog; let backlog = self.backlog;
@ -51,14 +53,14 @@ impl ListenStream {
} }
for _ in current_backlog_len..backlog { for _ in current_backlog_len..backlog {
let backlog_socket = BacklogSocket::new(&self.bound_socket)?; let backlog_socket = BacklogSocket::new(&self.bound_socket, pollee)?;
backlog_sockets.push(backlog_socket); backlog_sockets.push(backlog_socket);
} }
Ok(()) Ok(())
} }
pub fn try_accept(&self) -> Result<ConnectedStream> { pub fn try_accept(&self, pollee: &Pollee) -> Result<ConnectedStream> {
let mut backlog_sockets = self.backlog_sockets.write(); let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets let index = backlog_sockets
@ -69,7 +71,7 @@ impl ListenStream {
})?; })?;
let active_backlog_socket = backlog_sockets.remove(index); let active_backlog_socket = backlog_sockets.remove(index);
if let Ok(backlog_socket) = BacklogSocket::new(&self.bound_socket) { if let Ok(backlog_socket) = BacklogSocket::new(&self.bound_socket, pollee) {
backlog_sockets.push(backlog_socket); backlog_sockets.push(backlog_socket);
} }
@ -111,18 +113,22 @@ struct BacklogSocket {
impl BacklogSocket { impl BacklogSocket {
// FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason // FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason
// why the error may occur. Perhaps it is better to call `unwrap()` directly? // why the error may occur. Perhaps it is better to call `unwrap()` directly?
fn new(bound_socket: &BoundTcpSocket) -> Result<Self> { fn new(bound_socket: &BoundTcpSocket, pollee: &Pollee) -> Result<Self> {
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message( let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
Errno::EINVAL, Errno::EINVAL,
"the socket is not bound", "the socket is not bound",
))?; ))?;
let unbound_socket = Box::new(UnboundTcpSocket::new(bound_socket.observer())); let unbound_socket = Box::new(UnboundTcpSocket::new());
let bound_socket = { let bound_socket = {
let iface = bound_socket.iface(); let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true); let bind_port_config = BindPortConfig::new(local_endpoint.port, true);
iface iface
.bind_tcp(unbound_socket, bind_port_config) .bind_tcp(
unbound_socket,
StreamObserver::new(pollee.clone()),
bind_port_config,
)
.map_err(|(err, _)| err)? .map_err(|(err, _)| err)?
}; };

View File

@ -2,10 +2,7 @@
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{ use aster_bigtcp::wire::IpEndpoint;
socket::{SocketEventObserver, SocketEvents},
wire::IpEndpoint,
};
use connected::ConnectedStream; use connected::ConnectedStream;
use connecting::{ConnResult, ConnectingStream}; use connecting::{ConnResult, ConnectingStream};
use init::InitStream; use init::InitStream;
@ -40,9 +37,11 @@ mod connected;
mod connecting; mod connecting;
mod init; mod init;
mod listen; mod listen;
mod observer;
pub mod options; pub mod options;
mod util; mod util;
pub(in crate::net) use self::observer::StreamObserver;
pub use self::util::CongestionControl; pub use self::util::CongestionControl;
pub struct StreamSocket { pub struct StreamSocket {
@ -79,26 +78,23 @@ impl OptionSet {
impl StreamSocket { impl StreamSocket {
pub fn new(nonblocking: bool) -> Arc<Self> { pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| { let init_stream = InitStream::new();
let init_stream = InitStream::new(me.clone() as _); Arc::new(Self {
Self {
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
state: RwLock::new(Takeable::new(State::Init(init_stream))), state: RwLock::new(Takeable::new(State::Init(init_stream))),
is_nonblocking: AtomicBool::new(nonblocking), is_nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(), pollee: Pollee::new(),
}
}) })
} }
fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> { fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> {
Arc::new_cyclic(move |me| { let pollee = Pollee::new();
connected_stream.set_observer(me.clone() as _); connected_stream.set_observer(StreamObserver::new(pollee.clone()));
Self { Arc::new(Self {
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
state: RwLock::new(Takeable::new(State::Connected(connected_stream))), state: RwLock::new(Takeable::new(State::Connected(connected_stream))),
is_nonblocking: AtomicBool::new(false), is_nonblocking: AtomicBool::new(false),
pollee: Pollee::new(), pollee,
}
}) })
} }
@ -221,7 +217,7 @@ impl StreamSocket {
} }
}; };
let connecting_stream = match init_stream.connect(remote_endpoint) { let connecting_stream = match init_stream.connect(remote_endpoint, &self.pollee) {
Ok(connecting_stream) => connecting_stream, Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => { Err((err, init_stream)) => {
return (State::Init(init_stream), (Some(Err(err)), None)); return (State::Init(init_stream), (Some(Err(err)), None));
@ -276,7 +272,9 @@ impl StreamSocket {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
}; };
let accepted = listen_stream.try_accept().map(|connected_stream| { let accepted = listen_stream
.try_accept(&self.pollee)
.map(|connected_stream| {
let remote_endpoint = connected_stream.remote_endpoint(); let remote_endpoint = connected_stream.remote_endpoint();
let accepted_socket = Self::new_connected(connected_stream); let accepted_socket = Self::new_connected(connected_stream);
(accepted_socket as _, remote_endpoint.into()) (accepted_socket as _, remote_endpoint.into())
@ -451,7 +449,11 @@ impl Socket for StreamSocket {
); );
}; };
let bound_socket = match init_stream.bind(&endpoint, can_reuse) { let bound_socket = match init_stream.bind(
&endpoint,
can_reuse,
StreamObserver::new(self.pollee.clone()),
) {
Ok(bound_socket) => bound_socket, Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => { Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err)); return (State::Init(init_stream), Err(err));
@ -492,7 +494,7 @@ impl Socket for StreamSocket {
} }
}; };
let listen_stream = match init_stream.listen(backlog) { let listen_stream = match init_stream.listen(backlog, &self.pollee) {
Ok(listen_stream) => listen_stream, Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => { Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err)); return (State::Init(init_stream), Err(err));
@ -692,30 +694,6 @@ impl Socket for StreamSocket {
} }
} }
impl SocketEventObserver for StreamSocket {
fn on_events(&self, events: SocketEvents) {
let mut io_events = IoEvents::empty();
if events.contains(SocketEvents::CAN_RECV) {
io_events |= IoEvents::IN;
}
if events.contains(SocketEvents::CAN_SEND) {
io_events |= IoEvents::OUT;
}
if events.contains(SocketEvents::PEER_CLOSED) {
io_events |= IoEvents::IN | IoEvents::RDHUP;
}
if events.contains(SocketEvents::CLOSED) {
io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP;
}
self.pollee.notify(io_events);
}
}
impl Drop for StreamSocket { impl Drop for StreamSocket {
fn drop(&mut self) { fn drop(&mut self) {
let state = self.state.write().take(); let state = self.state.write().take();

View File

@ -0,0 +1,37 @@
// SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::socket::{SocketEventObserver, SocketEvents};
use crate::{events::IoEvents, process::signal::Pollee};
pub struct StreamObserver(Pollee);
impl StreamObserver {
pub(super) fn new(pollee: Pollee) -> Self {
Self(pollee)
}
}
impl SocketEventObserver for StreamObserver {
fn on_events(&self, events: SocketEvents) {
let mut io_events = IoEvents::empty();
if events.contains(SocketEvents::CAN_RECV) {
io_events |= IoEvents::IN;
}
if events.contains(SocketEvents::CAN_SEND) {
io_events |= IoEvents::OUT;
}
if events.contains(SocketEvents::PEER_CLOSED) {
io_events |= IoEvents::IN | IoEvents::RDHUP;
}
if events.contains(SocketEvents::CLOSED) {
io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP;
}
self.0.notify(io_events);
}
}

View File

@ -28,6 +28,7 @@ use crate::{
/// ///
/// Then, [`Pollee::poll_with`] can allow you to register a [`Poller`] to wait for certain events, /// Then, [`Pollee::poll_with`] can allow you to register a [`Poller`] to wait for certain events,
/// or register a [`PollAdaptor`] to be notified when certain events occur. /// or register a [`PollAdaptor`] to be notified when certain events occur.
#[derive(Clone)]
pub struct Pollee { pub struct Pollee {
inner: Arc<PolleeInner>, inner: Arc<PolleeInner>,
} }

View File

@ -4,7 +4,7 @@ use super::SyscallReturn;
use crate::{ use crate::{
fs::{file_handle::FileLike, file_table::FdFlags}, fs::{file_handle::FileLike, file_table::FdFlags},
net::socket::{ net::socket::{
ip::{DatagramSocket, StreamSocket}, ip::{datagram::DatagramSocket, stream::StreamSocket},
unix::UnixStreamSocket, unix::UnixStreamSocket,
vsock::VsockStreamSocket, vsock::VsockStreamSocket,
}, },