Make Pollee stateless

This commit is contained in:
Ruihan Li
2024-11-13 23:39:55 +08:00
committed by Tate, Hongliang Tian
parent 5450d0bd71
commit fab61f5f66
30 changed files with 514 additions and 430 deletions

View File

@ -220,13 +220,13 @@ impl<E> IfaceCommon<E> {
context.poll_egress(device, dispatch_phy);
tcp_sockets.iter().for_each(|socket| {
if socket.has_new_events() {
socket.on_iface_events();
if socket.has_events() {
socket.on_events();
}
});
udp_sockets.iter().for_each(|socket| {
if socket.has_new_events() {
socket.on_iface_events();
if socket.has_events() {
socket.on_events();
}
});

View File

@ -6,7 +6,7 @@ use alloc::{
};
use core::{
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, AtomicU64, Ordering},
sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
};
use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard};
@ -17,7 +17,10 @@ use smoltcp::{
wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr},
};
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
use super::{
event::{SocketEventObserver, SocketEvents},
RawTcpSocket, RawUdpSocket, TcpStateCheck,
};
use crate::iface::Iface;
pub struct BoundSocket<T: AnySocket, E>(Arc<BoundSocketInner<T, E>>);
@ -44,8 +47,8 @@ pub struct BoundSocketInner<T, E> {
port: u16,
socket: T,
observer: RwLock<Weak<dyn SocketEventObserver>>,
events: AtomicU8,
next_poll_at_ms: AtomicU64,
has_new_events: AtomicBool,
}
/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`].
@ -56,6 +59,7 @@ pub struct TcpSocket {
struct RawTcpSocketExt {
socket: Box<RawTcpSocket>,
has_connected: bool,
/// Whether the socket is in the background.
///
/// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means
@ -79,6 +83,22 @@ impl DerefMut for RawTcpSocketExt {
}
}
impl RawTcpSocketExt {
fn on_new_state(&mut self) -> SocketEvents {
if self.may_send() {
self.has_connected = true;
}
if self.is_peer_closed() {
SocketEvents::PEER_CLOSED
} else if self.is_closed() {
SocketEvents::CLOSED
} else {
SocketEvents::empty()
}
}
}
impl TcpSocket {
fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> {
self.socket.lock()
@ -123,6 +143,7 @@ impl AnySocket for TcpSocket {
fn new(socket: Box<Self::RawSocket>) -> Self {
let socket_ext = RawTcpSocketExt {
socket,
has_connected: false,
in_background: false,
};
@ -184,8 +205,8 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
port,
socket: T::new(socket),
observer: RwLock::new(observer),
events: AtomicU8::new(0),
next_poll_at_ms: AtomicU64::new(u64::MAX),
has_new_events: AtomicBool::new(false),
}))
}
@ -204,7 +225,7 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) {
*self.0.observer.write_irq_disabled() = new_observer;
self.0.on_iface_events();
self.0.on_events();
}
/// Returns the observer.
@ -229,6 +250,12 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
}
}
pub enum ConnectState {
Connecting,
Connected,
Refused,
}
impl<E> BoundTcpSocket<E> {
/// Connects to a remote endpoint.
pub fn connect(
@ -240,11 +267,26 @@ impl<E> BoundTcpSocket<E> {
let mut socket = self.0.socket.lock();
let result = socket.connect(iface.context(), remote_endpoint, self.0.port);
socket.connect(iface.context(), remote_endpoint, self.0.port)?;
socket.has_connected = false;
self.0
.update_next_poll_at_ms(socket.poll_at(iface.context()));
result
Ok(())
}
/// Returns the state of the connecting procedure.
pub fn connect_state(&self) -> ConnectState {
let socket = self.0.socket.lock();
if socket.state() == State::SynSent || socket.state() == State::SynReceived {
ConnectState::Connecting
} else if socket.has_connected {
ConnectState::Connected
} else {
ConnectState::Refused
}
}
/// Listens at a specified endpoint.
@ -366,22 +408,33 @@ impl<E> BoundUdpSocket<E> {
}
impl<T, E> BoundSocketInner<T, E> {
pub(crate) fn has_new_events(&self) -> bool {
self.has_new_events.load(Ordering::Relaxed)
pub(crate) fn has_events(&self) -> bool {
self.events.load(Ordering::Relaxed) != 0
}
pub(crate) fn on_iface_events(&self) {
self.has_new_events.store(false, Ordering::Relaxed);
pub(crate) fn on_events(&self) {
// This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed);
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
// get the read lock.
let observer = Weak::upgrade(&*self.observer.read());
if let Some(inner) = observer {
inner.on_events();
inner.on_events(SocketEvents::from_bits_truncate(events));
}
}
fn add_events(&self, new_events: SocketEvents) {
// This method can only be called to add network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = self.events.load(Ordering::Relaxed);
self.events
.store(events | new_events.bits(), Ordering::Relaxed);
}
/// Returns the next polling time.
///
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
@ -396,8 +449,6 @@ impl<T, E> BoundSocketInner<T, E> {
/// method also marks that there may be new events, so that the event observer provided by
/// [`BoundSocket::set_observer`] can be notified later.
fn update_next_poll_at_ms(&self, poll_at: PollAt) {
self.has_new_events.store(true, Ordering::Relaxed);
match poll_at {
PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed),
PollAt::Time(instant) => self
@ -484,11 +535,21 @@ impl<E> BoundTcpSocketInner<E> {
return TcpProcessResult::NotProcessed;
}
let old_state = socket.state();
// For TCP, receiving an ACK packet can free up space in the queue, allowing more packets
// to be queued.
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
let result = match socket.process(cx, ip_repr, tcp_repr) {
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
if socket.state() != old_state {
events |= socket.on_new_state();
}
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket);
@ -506,6 +567,9 @@ impl<E> BoundTcpSocketInner<E> {
{
let mut socket = self.socket.lock();
let old_state = socket.state();
let mut events = SocketEvents::empty();
let mut reply = None;
socket
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
@ -521,8 +585,14 @@ impl<E> BoundTcpSocketInner<E> {
break;
}
reply = socket.process(cx, ip_repr, tcp_repr);
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
}
if socket.state() != old_state {
events |= socket.on_new_state();
}
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket);
@ -552,6 +622,8 @@ impl<E> BoundUdpSocketInner<E> {
udp_repr,
udp_payload,
);
self.add_events(SocketEvents::CAN_RECV);
self.update_next_poll_at_ms(socket.poll_at(cx));
true
@ -570,6 +642,9 @@ impl<E> BoundUdpSocketInner<E> {
Ok::<(), ()>(())
})
.unwrap();
// For UDP, dequeuing a packet means that we can queue more packets.
self.add_events(SocketEvents::CAN_SEND);
self.update_next_poll_at_ms(socket.poll_at(cx));
}
}

View File

@ -3,9 +3,19 @@
/// A observer that will be invoked whenever events occur on the socket.
pub trait SocketEventObserver: Send + Sync {
/// Notifies that events occurred on the socket.
fn on_events(&self);
fn on_events(&self, events: SocketEvents);
}
impl SocketEventObserver for () {
fn on_events(&self) {}
fn on_events(&self, _events: SocketEvents) {}
}
bitflags::bitflags! {
/// Socket events caused by the _network_.
pub struct SocketEvents: u8 {
const CAN_RECV = 1;
const CAN_SEND = 2;
const PEER_CLOSED = 4;
const CLOSED = 8;
}
}

View File

@ -2,12 +2,13 @@
mod bound;
mod event;
mod state;
mod unbound;
pub use bound::{BoundTcpSocket, BoundUdpSocket};
pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState};
pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
pub use event::SocketEventObserver;
pub use smoltcp::socket::tcp::State as TcpState;
pub use event::{SocketEventObserver, SocketEvents};
pub use state::TcpStateCheck;
pub use unbound::{
UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
UDP_SEND_PAYLOAD_LEN,

View File

@ -0,0 +1,31 @@
// SPDX-License-Identifier: MPL-2.0
use smoltcp::socket::tcp::State as TcpState;
use super::RawTcpSocket;
pub trait TcpStateCheck {
/// Checks if the peer socket has closed its sending side.
///
/// If the sending side of this socket is also closed, this method will return `false`.
/// In such cases, you should verify using [`is_closed`].
fn is_peer_closed(&self) -> bool;
/// Checks if the socket is fully closed.
///
/// This function returns `true` if both this socket and the peer have closed their sending sides.
///
/// This TCP state corresponds to the `Normal Close Sequence` and `Simultaneous Close Sequence`
/// as outlined in RFC793 (https://datatracker.ietf.org/doc/html/rfc793#page-39).
fn is_closed(&self) -> bool;
}
impl TcpStateCheck for RawTcpSocket {
fn is_peer_closed(&self) -> bool {
self.state() == TcpState::CloseWait
}
fn is_closed(&self) -> bool {
!self.is_open() || self.state() == TcpState::Closing || self.state() == TcpState::LastAck
}
}