mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-30 10:53:58 +00:00
Avoid O(n)
iteration when sending TCP packets
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
58ad43b0a9
commit
a7e718e812
@ -1,9 +1,9 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use core::sync::atomic::{AtomicU64, AtomicU8, Ordering};
|
||||
use alloc::sync::{Arc, Weak};
|
||||
use core::sync::atomic::{AtomicU8, Ordering};
|
||||
|
||||
use smoltcp::{socket::PollAt, time::Instant, wire::IpEndpoint};
|
||||
use smoltcp::wire::IpEndpoint;
|
||||
use spin::once::Once;
|
||||
use takeable::Takeable;
|
||||
|
||||
@ -46,7 +46,6 @@ pub struct SocketBg<T: Inner<E>, E: Ext> {
|
||||
pub(super) inner: T,
|
||||
observer: Once<T::Observer>,
|
||||
events: AtomicU8,
|
||||
next_poll_at_ms: AtomicU64,
|
||||
}
|
||||
|
||||
impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
|
||||
@ -58,13 +57,24 @@ impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
|
||||
}
|
||||
|
||||
impl<T: Inner<E>, E: Ext> Socket<T, E> {
|
||||
pub(crate) fn new(bound: BoundPort<E>, inner: T) -> Self {
|
||||
pub(super) fn new(bound: BoundPort<E>, inner: T) -> Self {
|
||||
Self(Takeable::new(Arc::new(SocketBg {
|
||||
bound,
|
||||
inner,
|
||||
observer: Once::new(),
|
||||
events: AtomicU8::new(0),
|
||||
next_poll_at_ms: AtomicU64::new(u64::MAX),
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) fn new_cyclic<F>(bound: BoundPort<E>, inner_fn: F) -> Self
|
||||
where
|
||||
F: FnOnce(&Weak<SocketBg<T, E>>) -> T,
|
||||
{
|
||||
Self(Takeable::new(Arc::new_cyclic(|weak| SocketBg {
|
||||
bound,
|
||||
inner: inner_fn(weak),
|
||||
observer: Once::new(),
|
||||
events: AtomicU8::new(0),
|
||||
})))
|
||||
}
|
||||
|
||||
@ -119,10 +129,8 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
where
|
||||
T::Observer: Clone,
|
||||
{
|
||||
// This method can only be called to process network events, so we assume we are holding the
|
||||
// poll lock and no race conditions can occur.
|
||||
// There is no need to clear the events because the socket is dead.
|
||||
let events = self.events.load(Ordering::Relaxed);
|
||||
self.events.store(0, Ordering::Relaxed);
|
||||
|
||||
let observer = self.observer.get().cloned();
|
||||
drop(self);
|
||||
@ -141,41 +149,6 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
self.events
|
||||
.store(events | new_events.bits(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Returns the next polling time.
|
||||
///
|
||||
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
|
||||
/// before new network or user events.
|
||||
pub(crate) fn next_poll_at_ms(&self) -> u64 {
|
||||
self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Updates the next polling time according to `poll_at`.
|
||||
///
|
||||
/// The update is typically needed after new network or user events have been handled, so this
|
||||
/// method also marks that there may be new events, so that the event observer provided by
|
||||
/// [`Socket::init_observer`] can be notified later.
|
||||
pub(super) fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll {
|
||||
match poll_at {
|
||||
PollAt::Now => {
|
||||
self.next_poll_at_ms.store(0, Ordering::Relaxed);
|
||||
NeedIfacePoll::TRUE
|
||||
}
|
||||
PollAt::Time(instant) => {
|
||||
let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed);
|
||||
let new_total_millis = instant.total_millis() as u64;
|
||||
|
||||
self.next_poll_at_ms
|
||||
.store(new_total_millis, Ordering::Relaxed);
|
||||
|
||||
NeedIfacePoll(new_total_millis < old_total_millis)
|
||||
}
|
||||
PollAt::Ingress => {
|
||||
self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed);
|
||||
NeedIfacePoll::FALSE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
@ -185,11 +158,4 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
|
||||
self.bound.port() == dst_port
|
||||
}
|
||||
|
||||
/// Returns whether the socket _may_ generate an outgoing packet.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
|
||||
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,13 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
iface::Context,
|
||||
socket::{tcp::State, PollAt},
|
||||
time::Duration,
|
||||
wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr},
|
||||
@ -19,7 +21,7 @@ use crate::{
|
||||
define_boolean_value,
|
||||
errors::tcp::{ConnectError, RecvError, SendError},
|
||||
ext::Ext,
|
||||
iface::BoundPort,
|
||||
iface::{BoundPort, PollKey, PollableIfaceMut},
|
||||
socket::{
|
||||
event::SocketEvents,
|
||||
option::{RawTcpOption, RawTcpSetOption},
|
||||
@ -33,6 +35,7 @@ pub type TcpConnection<E> = Socket<TcpConnectionInner<E>, E>;
|
||||
/// States needed by [`TcpConnectionBg`].
|
||||
pub struct TcpConnectionInner<E: Ext> {
|
||||
socket: SpinLock<RawTcpSocketExt<E>, LocalIrqDisabled>,
|
||||
poll_key: PollKey,
|
||||
connection_key: ConnectionKey,
|
||||
}
|
||||
|
||||
@ -216,7 +219,11 @@ impl<E: Ext> RawTcpSocketExt<E> {
|
||||
}
|
||||
|
||||
impl<E: Ext> TcpConnectionInner<E> {
|
||||
pub(super) fn new(socket: Box<RawTcpSocket>, listener: Option<Arc<TcpListenerBg<E>>>) -> Self {
|
||||
pub(super) fn new(
|
||||
socket: Box<RawTcpSocket>,
|
||||
listener: Option<Arc<TcpListenerBg<E>>>,
|
||||
weak_self: &Weak<TcpConnectionBg<E>>,
|
||||
) -> Self {
|
||||
let connection_key = {
|
||||
// Since the socket is connected, the following unwrap can never fail
|
||||
let local_endpoint = socket.local_endpoint().unwrap();
|
||||
@ -224,6 +231,8 @@ impl<E: Ext> TcpConnectionInner<E> {
|
||||
ConnectionKey::from((local_endpoint, remote_endpoint))
|
||||
};
|
||||
|
||||
let poll_key = PollKey::new(Weak::as_ptr(weak_self).addr());
|
||||
|
||||
let socket_ext = RawTcpSocketExt {
|
||||
socket,
|
||||
listener,
|
||||
@ -234,6 +243,7 @@ impl<E: Ext> TcpConnectionInner<E> {
|
||||
|
||||
TcpConnectionInner {
|
||||
socket: SpinLock::new(socket_ext),
|
||||
poll_key,
|
||||
connection_key,
|
||||
}
|
||||
}
|
||||
@ -293,7 +303,7 @@ impl<E: Ext> TcpConnection<E> {
|
||||
};
|
||||
|
||||
let iface = bound.iface().clone();
|
||||
// We have to lock interface before locking interface
|
||||
// We have to lock `interface` before locking `sockets`
|
||||
// to avoid dead lock due to inconsistent lock orders.
|
||||
let mut interface = iface.common().interface();
|
||||
let mut sockets = iface.common().sockets();
|
||||
@ -309,18 +319,19 @@ impl<E: Ext> TcpConnection<E> {
|
||||
|
||||
option.apply(&mut socket);
|
||||
|
||||
if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) {
|
||||
if let Err(err) = socket.connect(interface.context_mut(), remote_endpoint, bound.port())
|
||||
{
|
||||
return Err((bound, err.into()));
|
||||
}
|
||||
|
||||
socket
|
||||
};
|
||||
|
||||
let inner = TcpConnectionInner::new(socket, None);
|
||||
|
||||
let connection = Self::new(bound, inner);
|
||||
connection.0.update_next_poll_at_ms(PollAt::Now);
|
||||
let connection =
|
||||
Self::new_cyclic(bound, |weak| TcpConnectionInner::new(socket, None, weak));
|
||||
interface.update_next_poll_at_ms(&connection.0, PollAt::Now);
|
||||
connection.init_observer(observer);
|
||||
|
||||
let res = sockets.insert_connection(connection.inner().clone());
|
||||
debug_assert!(res.is_ok());
|
||||
|
||||
@ -378,9 +389,8 @@ impl<E: Ext> TcpConnection<E> {
|
||||
}
|
||||
let result = socket.send(f)?;
|
||||
|
||||
let need_poll = self
|
||||
.0
|
||||
.update_next_poll_at_ms(socket.poll_at(iface.context()));
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
|
||||
Ok((result, need_poll))
|
||||
}
|
||||
@ -408,9 +418,8 @@ impl<E: Ext> TcpConnection<E> {
|
||||
res => res,
|
||||
}?;
|
||||
|
||||
let need_poll = self
|
||||
.0
|
||||
.update_next_poll_at_ms(socket.poll_at(iface.context()));
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
|
||||
Ok((result, need_poll))
|
||||
}
|
||||
@ -433,6 +442,7 @@ impl<E: Ext> TcpConnection<E> {
|
||||
///
|
||||
/// Polling the iface is _always_ required after this method succeeds.
|
||||
pub fn shut_send(&self) -> bool {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
if matches!(socket.state(), State::Closed | State::TimeWait) {
|
||||
@ -440,7 +450,9 @@ impl<E: Ext> TcpConnection<E> {
|
||||
}
|
||||
|
||||
socket.close();
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
|
||||
true
|
||||
}
|
||||
@ -469,6 +481,7 @@ impl<E: Ext> TcpConnection<E> {
|
||||
/// Note that either this method or [`Self::reset`] must be called before dropping the TCP
|
||||
/// connection to avoid resource leakage.
|
||||
pub fn close(&self) {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
socket.is_recv_shut = true;
|
||||
@ -479,7 +492,9 @@ impl<E: Ext> TcpConnection<E> {
|
||||
} else {
|
||||
socket.close();
|
||||
}
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
}
|
||||
|
||||
/// Resets the connection.
|
||||
@ -489,10 +504,13 @@ impl<E: Ext> TcpConnection<E> {
|
||||
/// Note that either this method or [`Self::close`] must be called before dropping the TCP
|
||||
/// connection to avoid resource leakage.
|
||||
pub fn reset(&self) {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
socket.abort();
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
}
|
||||
|
||||
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
|
||||
@ -510,15 +528,13 @@ impl<E: Ext> TcpConnection<E> {
|
||||
|
||||
impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
|
||||
fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
socket.set_keep_alive(interval);
|
||||
|
||||
if interval.is_some() {
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
NeedIfacePoll::TRUE
|
||||
} else {
|
||||
NeedIfacePoll::FALSE
|
||||
}
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at)
|
||||
}
|
||||
|
||||
fn set_nagle_enabled(&self, enabled: bool) {
|
||||
@ -528,6 +544,10 @@ impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
|
||||
}
|
||||
|
||||
impl<E: Ext> TcpConnectionBg<E> {
|
||||
pub(crate) const fn poll_key(&self) -> &PollKey {
|
||||
&self.inner.poll_key
|
||||
}
|
||||
|
||||
pub(crate) const fn connection_key(&self) -> &ConnectionKey {
|
||||
&self.inner.connection_key
|
||||
}
|
||||
@ -544,13 +564,13 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
self: &Arc<Self>,
|
||||
cx: &mut Context,
|
||||
iface: &mut PollableIfaceMut<E>,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> (TcpProcessResult, TcpConnBecameDead) {
|
||||
let mut socket = self.inner.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
|
||||
return (TcpProcessResult::NotProcessed, TcpConnBecameDead::FALSE);
|
||||
}
|
||||
|
||||
@ -581,7 +601,7 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
// to be queued.
|
||||
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
|
||||
|
||||
let result = match socket.process(cx, ip_repr, tcp_repr) {
|
||||
let result = match socket.process(iface.context_mut(), ip_repr, tcp_repr) {
|
||||
None => TcpProcessResult::Processed,
|
||||
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
|
||||
};
|
||||
@ -591,7 +611,9 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
events |= state_events;
|
||||
|
||||
self.add_events(events);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(self, poll_at);
|
||||
|
||||
(result, became_dead)
|
||||
}
|
||||
@ -599,11 +621,11 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
/// Tries to generate an outgoing packet and dispatches the generated packet.
|
||||
pub(crate) fn dispatch<D>(
|
||||
self: &Arc<Self>,
|
||||
cx: &mut Context,
|
||||
iface: &mut PollableIfaceMut<E>,
|
||||
dispatch: D,
|
||||
) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead)
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
|
||||
D: FnOnce(PollableIfaceMut<E>, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
|
||||
{
|
||||
let mut socket = self.inner.lock();
|
||||
|
||||
@ -613,9 +635,10 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
let mut events = SocketEvents::empty();
|
||||
|
||||
let mut reply = None;
|
||||
let (cx, pending) = iface.inner_mut();
|
||||
socket
|
||||
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
|
||||
reply = dispatch(cx, &ip_repr, &tcp_repr);
|
||||
reply = dispatch(PollableIfaceMut::new(cx, pending), &ip_repr, &tcp_repr);
|
||||
Ok::<(), ()>(())
|
||||
})
|
||||
.unwrap();
|
||||
@ -623,12 +646,12 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
// `dispatch` can return a packet in response to the generated packet. If the socket
|
||||
// accepts the packet, we can process it directly.
|
||||
while let Some((ref ip_repr, ref tcp_repr)) = reply {
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
|
||||
break;
|
||||
}
|
||||
is_rst |= tcp_repr.control == TcpControl::Rst;
|
||||
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
|
||||
reply = socket.process(cx, ip_repr, tcp_repr);
|
||||
reply = socket.process(iface.context_mut(), ip_repr, tcp_repr);
|
||||
}
|
||||
|
||||
let (state_events, became_dead) =
|
||||
@ -636,7 +659,9 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
events |= state_events;
|
||||
|
||||
self.add_events(events);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(self, poll_at);
|
||||
|
||||
(reply, became_dead)
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock};
|
||||
use smoltcp::{
|
||||
iface::Context,
|
||||
socket::PollAt,
|
||||
time::Duration,
|
||||
wire::{IpEndpoint, IpRepr, TcpRepr},
|
||||
@ -17,7 +16,7 @@ use super::{
|
||||
use crate::{
|
||||
errors::tcp::ListenError,
|
||||
ext::Ext,
|
||||
iface::{BindPortConfig, BoundPort},
|
||||
iface::{BindPortConfig, BoundPort, PollableIfaceMut},
|
||||
socket::{
|
||||
option::{RawTcpOption, RawTcpSetOption},
|
||||
unbound::{new_tcp_socket, RawTcpSocket},
|
||||
@ -194,13 +193,16 @@ impl<E: Ext> TcpListenerBg<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
self: &Arc<Self>,
|
||||
cx: &mut Context,
|
||||
iface: &mut PollableIfaceMut<E>,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> (TcpProcessResult, Option<Arc<TcpConnectionBg<E>>>) {
|
||||
let mut backlog = self.inner.backlog.lock();
|
||||
|
||||
if !backlog.socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
if !backlog
|
||||
.socket
|
||||
.accepts(iface.context_mut(), ip_repr, tcp_repr)
|
||||
{
|
||||
return (TcpProcessResult::NotProcessed, None);
|
||||
}
|
||||
|
||||
@ -211,7 +213,10 @@ impl<E: Ext> TcpListenerBg<E> {
|
||||
return (TcpProcessResult::Processed, None);
|
||||
}
|
||||
|
||||
let result = match backlog.socket.process(cx, ip_repr, tcp_repr) {
|
||||
let result = match backlog
|
||||
.socket
|
||||
.process(iface.context_mut(), ip_repr, tcp_repr)
|
||||
{
|
||||
None => TcpProcessResult::Processed,
|
||||
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
|
||||
};
|
||||
@ -227,23 +232,25 @@ impl<E: Ext> TcpListenerBg<E> {
|
||||
socket
|
||||
};
|
||||
|
||||
let inner = TcpConnectionInner::new(
|
||||
core::mem::replace(&mut backlog.socket, new_socket),
|
||||
Some(self.clone()),
|
||||
);
|
||||
let conn = TcpConnection::new(
|
||||
let conn = TcpConnection::new_cyclic(
|
||||
self.bound
|
||||
.iface()
|
||||
.bind(BindPortConfig::CanReuse(self.bound.port()))
|
||||
.unwrap(),
|
||||
inner,
|
||||
|weak| {
|
||||
TcpConnectionInner::new(
|
||||
core::mem::replace(&mut backlog.socket, new_socket),
|
||||
Some(self.clone()),
|
||||
weak,
|
||||
)
|
||||
},
|
||||
);
|
||||
let conn_bg = conn.inner().clone();
|
||||
|
||||
let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn);
|
||||
debug_assert!(old_conn.is_none());
|
||||
|
||||
conn_bg.update_next_poll_at_ms(PollAt::Now);
|
||||
iface.update_next_poll_at_ms(&conn_bg, PollAt::Now);
|
||||
|
||||
(result, Some(conn_bg))
|
||||
}
|
||||
|
@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock};
|
||||
use smoltcp::{
|
||||
iface::Context,
|
||||
socket::{udp::UdpMetadata, PollAt},
|
||||
socket::udp::UdpMetadata,
|
||||
wire::{IpRepr, UdpRepr},
|
||||
};
|
||||
|
||||
@ -20,13 +21,16 @@ use crate::{
|
||||
pub type UdpSocket<E> = Socket<UdpSocketInner, E>;
|
||||
|
||||
/// States needed by [`UdpSocketBg`].
|
||||
type UdpSocketInner = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
|
||||
pub struct UdpSocketInner {
|
||||
socket: SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>,
|
||||
need_dispatch: AtomicBool,
|
||||
}
|
||||
|
||||
impl<E: Ext> Inner<E> for UdpSocketInner {
|
||||
type Observer = E::UdpEventObserver;
|
||||
|
||||
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
|
||||
this.inner.lock().close();
|
||||
this.inner.socket.lock().close();
|
||||
|
||||
// A UDP socket can be removed immediately.
|
||||
this.bound.iface().common().remove_udp_socket(this);
|
||||
@ -44,7 +48,7 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
udp_repr: &UdpRepr,
|
||||
udp_payload: &[u8],
|
||||
) -> bool {
|
||||
let mut socket = self.inner.lock();
|
||||
let mut socket = self.inner.socket.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, udp_repr) {
|
||||
return false;
|
||||
@ -59,7 +63,6 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
);
|
||||
|
||||
self.add_events(SocketEvents::CAN_RECV);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
true
|
||||
}
|
||||
@ -69,7 +72,7 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
|
||||
{
|
||||
let mut socket = self.inner.lock();
|
||||
let mut socket = self.inner.socket.lock();
|
||||
|
||||
socket
|
||||
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
|
||||
@ -80,7 +83,17 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
|
||||
// For UDP, dequeuing a packet means that we can queue more packets.
|
||||
self.add_events(SocketEvents::CAN_SEND);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
self.inner
|
||||
.need_dispatch
|
||||
.store(socket.send_queue() > 0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Returns whether the socket _may_ generate an outgoing packet.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn need_dispatch(&self) -> bool {
|
||||
self.inner.need_dispatch.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,7 +119,10 @@ impl<E: Ext> UdpSocket<E> {
|
||||
socket
|
||||
};
|
||||
|
||||
let inner = UdpSocketInner::new(socket);
|
||||
let inner = UdpSocketInner {
|
||||
socket: SpinLock::new(socket),
|
||||
need_dispatch: AtomicBool::new(false),
|
||||
};
|
||||
|
||||
let socket = Self::new(bound, inner);
|
||||
socket.init_observer(observer);
|
||||
@ -130,7 +146,7 @@ impl<E: Ext> UdpSocket<E> {
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> R,
|
||||
{
|
||||
let mut socket = self.0.inner.lock();
|
||||
let mut socket = self.0.inner.socket.lock();
|
||||
|
||||
if size > socket.packet_send_capacity() {
|
||||
return Err(SendError::TooLarge);
|
||||
@ -141,7 +157,11 @@ impl<E: Ext> UdpSocket<E> {
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
let result = f(buffer);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
self.0
|
||||
.inner
|
||||
.need_dispatch
|
||||
.store(socket.send_queue() > 0, Ordering::Relaxed);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
@ -153,7 +173,7 @@ impl<E: Ext> UdpSocket<E> {
|
||||
where
|
||||
F: FnOnce(&[u8], UdpMetadata) -> R,
|
||||
{
|
||||
let mut socket = self.0.inner.lock();
|
||||
let mut socket = self.0.inner.socket.lock();
|
||||
|
||||
let (data, meta) = socket.recv()?;
|
||||
let result = f(data, meta);
|
||||
@ -169,7 +189,7 @@ impl<E: Ext> UdpSocket<E> {
|
||||
where
|
||||
F: FnOnce(&RawUdpSocket) -> R,
|
||||
{
|
||||
let socket = self.0.inner.lock();
|
||||
let socket = self.0.inner.socket.lock();
|
||||
f(&socket)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user