mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-08 21:06:48 +00:00
Avoid O(n)
iteration when sending TCP packets
This commit is contained in:
parent
58ad43b0a9
commit
a7e718e812
@ -16,6 +16,7 @@ use smoltcp::{
|
||||
|
||||
use super::{
|
||||
poll::{FnHelper, PollContext},
|
||||
poll_iface::PollableIface,
|
||||
port::BindPortConfig,
|
||||
time::get_network_timestamp,
|
||||
Iface,
|
||||
@ -29,7 +30,7 @@ use crate::{
|
||||
|
||||
pub struct IfaceCommon<E: Ext> {
|
||||
name: String,
|
||||
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
|
||||
interface: SpinLock<PollableIface<E>, LocalIrqDisabled>,
|
||||
used_ports: SpinLock<BTreeMap<u16, usize>, LocalIrqDisabled>,
|
||||
sockets: SpinLock<SocketTable<E>, LocalIrqDisabled>,
|
||||
sched_poll: E::ScheduleNextPoll,
|
||||
@ -41,13 +42,11 @@ impl<E: Ext> IfaceCommon<E> {
|
||||
interface: smoltcp::iface::Interface,
|
||||
sched_poll: E::ScheduleNextPoll,
|
||||
) -> Self {
|
||||
let sockets = SocketTable::new();
|
||||
|
||||
Self {
|
||||
name,
|
||||
interface: SpinLock::new(interface),
|
||||
interface: SpinLock::new(PollableIface::new(interface)),
|
||||
used_ports: SpinLock::new(BTreeMap::new()),
|
||||
sockets: SpinLock::new(sockets),
|
||||
sockets: SpinLock::new(SocketTable::new()),
|
||||
sched_poll,
|
||||
}
|
||||
}
|
||||
@ -65,10 +64,10 @@ impl<E: Ext> IfaceCommon<E> {
|
||||
}
|
||||
}
|
||||
|
||||
// Lock order: interface -> sockets
|
||||
// Lock order: `interface` -> `sockets`
|
||||
impl<E: Ext> IfaceCommon<E> {
|
||||
/// Acquires the lock to the interface.
|
||||
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
|
||||
pub(crate) fn interface(&self) -> SpinLockGuard<'_, PollableIface<E>, LocalIrqDisabled> {
|
||||
self.interface.lock()
|
||||
}
|
||||
|
||||
@ -181,51 +180,42 @@ impl<E: Ext> IfaceCommon<E> {
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
let mut interface = self.interface();
|
||||
interface.context().now = get_network_timestamp();
|
||||
interface.context_mut().now = get_network_timestamp();
|
||||
|
||||
let mut sockets = self.sockets.lock();
|
||||
let mut dead_tcp_conns = Vec::new();
|
||||
|
||||
loop {
|
||||
let mut new_tcp_conns = Vec::new();
|
||||
let mut new_tcp_conns = Vec::new();
|
||||
|
||||
let mut context = PollContext::new(
|
||||
interface.context(),
|
||||
&sockets,
|
||||
&mut new_tcp_conns,
|
||||
&mut dead_tcp_conns,
|
||||
);
|
||||
context.poll_ingress(device, &mut process_phy, &mut dispatch_phy);
|
||||
context.poll_egress(device, &mut dispatch_phy);
|
||||
let mut context = PollContext::new(
|
||||
interface.as_mut(),
|
||||
&sockets,
|
||||
&mut new_tcp_conns,
|
||||
&mut dead_tcp_conns,
|
||||
);
|
||||
context.poll_ingress(device, &mut process_phy, &mut dispatch_phy);
|
||||
context.poll_egress(device, &mut dispatch_phy);
|
||||
|
||||
// New packets sent by new connections are not handled. So if there are new
|
||||
// connections, try again.
|
||||
if new_tcp_conns.is_empty() {
|
||||
break;
|
||||
} else {
|
||||
new_tcp_conns.into_iter().for_each(|tcp_conn| {
|
||||
let res = sockets.insert_connection(tcp_conn);
|
||||
debug_assert!(res.is_ok());
|
||||
});
|
||||
}
|
||||
// Insert new connections and remove dead connections.
|
||||
for new_tcp_conn in new_tcp_conns.into_iter() {
|
||||
let res = sockets.insert_connection(new_tcp_conn);
|
||||
debug_assert!(res.is_ok());
|
||||
}
|
||||
|
||||
for dead_conn_key in dead_tcp_conns.into_iter() {
|
||||
sockets.remove_dead_tcp_connection(&dead_conn_key);
|
||||
}
|
||||
|
||||
// Notify all socket events.
|
||||
for socket in sockets.tcp_listener_iter() {
|
||||
if socket.has_events() {
|
||||
socket.on_events();
|
||||
}
|
||||
}
|
||||
|
||||
for socket in sockets.tcp_conn_iter() {
|
||||
if socket.has_events() {
|
||||
socket.on_events();
|
||||
}
|
||||
}
|
||||
|
||||
for socket in sockets.udp_socket_iter() {
|
||||
if socket.has_events() {
|
||||
socket.on_events();
|
||||
@ -234,10 +224,7 @@ impl<E: Ext> IfaceCommon<E> {
|
||||
|
||||
// Note that only TCP connections can have timers set, so as far as the time to poll is
|
||||
// concerned, we only need to consider TCP connections.
|
||||
sockets
|
||||
.tcp_conn_iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min()
|
||||
interface.next_poll_at_ms()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@ mod common;
|
||||
mod iface;
|
||||
mod phy;
|
||||
mod poll;
|
||||
mod poll_iface;
|
||||
mod port;
|
||||
mod sched;
|
||||
mod time;
|
||||
@ -12,5 +13,6 @@ mod time;
|
||||
pub use common::BoundPort;
|
||||
pub use iface::Iface;
|
||||
pub use phy::{EtherIface, IpIface};
|
||||
pub(crate) use poll_iface::{PollKey, PollableIfaceMut};
|
||||
pub use port::BindPortConfig;
|
||||
pub use sched::ScheduleNextPoll;
|
||||
|
@ -15,6 +15,7 @@ use smoltcp::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::poll_iface::PollableIfaceMut;
|
||||
use crate::{
|
||||
ext::Ext,
|
||||
socket::{TcpConnectionBg, TcpProcessResult},
|
||||
@ -22,7 +23,7 @@ use crate::{
|
||||
};
|
||||
|
||||
pub(super) struct PollContext<'a, E: Ext> {
|
||||
iface_cx: &'a mut Context,
|
||||
iface: PollableIfaceMut<'a, E>,
|
||||
sockets: &'a SocketTable<E>,
|
||||
new_tcp_conns: &'a mut Vec<Arc<TcpConnectionBg<E>>>,
|
||||
dead_tcp_conns: &'a mut Vec<ConnectionKey>,
|
||||
@ -30,13 +31,13 @@ pub(super) struct PollContext<'a, E: Ext> {
|
||||
|
||||
impl<'a, E: Ext> PollContext<'a, E> {
|
||||
pub(super) fn new(
|
||||
iface_cx: &'a mut Context,
|
||||
iface: PollableIfaceMut<'a, E>,
|
||||
sockets: &'a SocketTable<E>,
|
||||
new_tcp_conns: &'a mut Vec<Arc<TcpConnectionBg<E>>>,
|
||||
dead_tcp_conns: &'a mut Vec<ConnectionKey>,
|
||||
) -> Self {
|
||||
Self {
|
||||
iface_cx,
|
||||
iface,
|
||||
sockets,
|
||||
new_tcp_conns,
|
||||
dead_tcp_conns,
|
||||
@ -65,9 +66,10 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
>,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) {
|
||||
while let Some((rx_token, tx_token)) = device.receive(self.iface.context().now()) {
|
||||
rx_token.consume(|data| {
|
||||
let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else {
|
||||
let Some((pkt, tx_token)) = process_phy(data, self.iface.context_mut(), tx_token)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
@ -75,7 +77,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
return;
|
||||
};
|
||||
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token);
|
||||
dispatch_phy(&reply, self.iface.context_mut(), tx_token);
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -85,7 +87,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
pkt: Ipv4Packet<&'pkt [u8]>,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// Parse the IP header. Ignore the packet if the header is ill-formed.
|
||||
let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?;
|
||||
let repr = Ipv4Repr::parse(&pkt, &self.iface.context().checksum_caps()).ok()?;
|
||||
|
||||
if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) {
|
||||
return self.generate_icmp_unreachable(
|
||||
@ -95,17 +97,14 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
);
|
||||
}
|
||||
|
||||
let checksum_caps = self.iface.context().checksum_caps();
|
||||
match repr.next_header {
|
||||
IpProtocol::Tcp => self.parse_and_process_tcp(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
&self.iface_cx.checksum_caps(),
|
||||
),
|
||||
IpProtocol::Udp => self.parse_and_process_udp(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
&self.iface_cx.checksum_caps(),
|
||||
),
|
||||
IpProtocol::Tcp => {
|
||||
self.parse_and_process_tcp(&IpRepr::Ipv4(repr), pkt.payload(), &checksum_caps)
|
||||
}
|
||||
IpProtocol::Udp => {
|
||||
self.parse_and_process_udp(&IpRepr::Ipv4(repr), pkt.payload(), &checksum_caps)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -164,7 +163,8 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() {
|
||||
let listener_key = ListenerKey::new(ip_repr.dst_addr(), tcp_repr.dst_port);
|
||||
if let Some(listener) = self.sockets.lookup_listener(&listener_key) {
|
||||
let (processed, new_tcp_conn) = listener.process(self.iface_cx, ip_repr, tcp_repr);
|
||||
let (processed, new_tcp_conn) =
|
||||
listener.process(&mut self.iface, ip_repr, tcp_repr);
|
||||
|
||||
if let Some(tcp_conn) = new_tcp_conn {
|
||||
self.new_tcp_conns.push(tcp_conn);
|
||||
@ -197,7 +197,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
|
||||
if let Some(connection) = connection {
|
||||
let (process_result, became_dead) =
|
||||
connection.process(self.iface_cx, ip_repr, tcp_repr);
|
||||
connection.process(&mut self.iface, ip_repr, tcp_repr);
|
||||
if *became_dead {
|
||||
self.dead_tcp_conns.push(*connection.connection_key());
|
||||
}
|
||||
@ -254,7 +254,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
continue;
|
||||
}
|
||||
|
||||
processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload);
|
||||
processed |= socket.process(self.iface.context_mut(), ip_repr, udp_repr, udp_payload);
|
||||
if processed && ip_repr.dst_addr().is_unicast() {
|
||||
break;
|
||||
}
|
||||
@ -295,7 +295,8 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
Some(Packet::new_ipv4(
|
||||
Ipv4Repr {
|
||||
src_addr: self
|
||||
.iface_cx
|
||||
.iface
|
||||
.context()
|
||||
.ipv4_addr()
|
||||
.unwrap_or(Ipv4Address::UNSPECIFIED),
|
||||
dst_addr: ipv4_repr.src_addr,
|
||||
@ -314,7 +315,8 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
fn is_unicast_local(&self, dst_addr: IpAddress) -> bool {
|
||||
match dst_addr {
|
||||
IpAddress::Ipv4(dst_addr) => self
|
||||
.iface_cx
|
||||
.iface
|
||||
.context()
|
||||
.ipv4_addr()
|
||||
.is_some_and(|addr| addr == dst_addr),
|
||||
}
|
||||
@ -327,7 +329,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
D: Device + ?Sized,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
while let Some(tx_token) = device.transmit(self.iface_cx.now()) {
|
||||
while let Some(tx_token) = device.transmit(self.iface.context().now()) {
|
||||
if !self.dispatch_ipv4(tx_token, dispatch_phy) {
|
||||
break;
|
||||
}
|
||||
@ -359,12 +361,10 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
let mut did_something = false;
|
||||
let mut dead_conns = Vec::new();
|
||||
|
||||
// We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable
|
||||
// reference at this point. Instead, we will retry after the entire poll is complete.
|
||||
for socket in self.sockets.tcp_conn_iter() {
|
||||
if !socket.need_dispatch(self.iface_cx.now()) {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
let Some(socket) = self.iface.pop_pending_tcp() else {
|
||||
break;
|
||||
};
|
||||
|
||||
// We set `did_something` even if no packets are actually generated. This is because a
|
||||
// timer can expire, but no packets are actually generated.
|
||||
@ -373,14 +373,14 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
let mut deferred = None;
|
||||
|
||||
let (reply, became_dead) =
|
||||
TcpConnectionBg::dispatch(socket, self.iface_cx, |cx, ip_repr, tcp_repr| {
|
||||
TcpConnectionBg::dispatch(&socket, &mut self.iface, |iface, ip_repr, tcp_repr| {
|
||||
let mut this =
|
||||
PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns);
|
||||
PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns);
|
||||
|
||||
if !this.is_unicast_local(ip_repr.dst_addr()) {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)),
|
||||
this.iface_cx,
|
||||
this.iface.context_mut(),
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
return None;
|
||||
@ -418,13 +418,13 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
&ip_payload,
|
||||
&ChecksumCapabilities::ignored(),
|
||||
) {
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
|
||||
dispatch_phy(&reply, self.iface.context_mut(), tx_token.take().unwrap());
|
||||
}
|
||||
}
|
||||
(None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)),
|
||||
self.iface_cx,
|
||||
self.iface.context_mut(),
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
}
|
||||
@ -434,7 +434,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
{
|
||||
dispatch_phy(
|
||||
&Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)),
|
||||
self.iface_cx,
|
||||
self.iface.context_mut(),
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
}
|
||||
@ -462,7 +462,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
let mut dead_conns = Vec::new();
|
||||
|
||||
for socket in self.sockets.udp_socket_iter() {
|
||||
if !socket.need_dispatch(self.iface_cx.now()) {
|
||||
if !socket.need_dispatch() {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -472,14 +472,16 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
|
||||
let mut deferred = None;
|
||||
|
||||
socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| {
|
||||
let (cx, pending) = self.iface.inner_mut();
|
||||
socket.dispatch(cx, |cx, ip_repr, udp_repr, udp_payload| {
|
||||
let iface = PollableIfaceMut::new(cx, pending);
|
||||
let mut this =
|
||||
PollContext::new(cx, self.sockets, self.new_tcp_conns, &mut dead_conns);
|
||||
PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns);
|
||||
|
||||
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)),
|
||||
this.iface_cx,
|
||||
this.iface.context_mut(),
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
if !ip_repr.dst_addr().is_broadcast() {
|
||||
@ -516,7 +518,7 @@ impl<E: Ext> PollContext<'_, E> {
|
||||
&ip_payload,
|
||||
&ChecksumCapabilities::ignored(),
|
||||
) {
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
|
||||
dispatch_phy(&reply, self.iface.context_mut(), tx_token.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
|
282
kernel/libs/aster-bigtcp/src/iface/poll_iface.rs
Normal file
282
kernel/libs/aster-bigtcp/src/iface/poll_iface.rs
Normal file
@ -0,0 +1,282 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{collections::btree_set::BTreeSet, sync::Arc};
|
||||
use core::{
|
||||
borrow::Borrow,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ext::Ext,
|
||||
socket::{NeedIfacePoll, TcpConnectionBg},
|
||||
};
|
||||
|
||||
/// An interface with auxiliary data that makes it pollable.
|
||||
///
|
||||
/// This is used, for example, when updating a socket's next poll time and finding a socket to
|
||||
/// poll.
|
||||
pub(crate) struct PollableIface<E: Ext> {
|
||||
interface: smoltcp::iface::Interface,
|
||||
pending_conns: PendingConnSet<E>,
|
||||
}
|
||||
|
||||
impl<E: Ext> PollableIface<E> {
|
||||
pub(super) fn new(interface: smoltcp::iface::Interface) -> Self {
|
||||
Self {
|
||||
interface,
|
||||
pending_conns: PendingConnSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn as_mut(&mut self) -> PollableIfaceMut<E> {
|
||||
PollableIfaceMut {
|
||||
context: self.interface.context(),
|
||||
pending_conns: &mut self.pending_conns,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn ipv4_addr(&self) -> Option<smoltcp::wire::Ipv4Address> {
|
||||
self.interface.ipv4_addr()
|
||||
}
|
||||
|
||||
/// Returns the next poll time.
|
||||
pub(super) fn next_poll_at_ms(&self) -> Option<u64> {
|
||||
self.pending_conns.next_poll_at_ms()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Ext> PollableIface<E> {
|
||||
/// Returns the `smoltcp` context for passing to the `smoltcp` APIs.
|
||||
pub(crate) fn context_mut(&mut self) -> &mut smoltcp::iface::Context {
|
||||
self.interface.context()
|
||||
}
|
||||
|
||||
/// Updates the next poll time of `socket` to `poll_at`.
|
||||
///
|
||||
/// This method (or [`PollableIfaceMut::update_next_poll_at_ms`]) should be called after network or
|
||||
/// user events that change the poll time occur.
|
||||
pub(crate) fn update_next_poll_at_ms(
|
||||
&mut self,
|
||||
socket: &Arc<TcpConnectionBg<E>>,
|
||||
poll_at: smoltcp::socket::PollAt,
|
||||
) -> NeedIfacePoll {
|
||||
self.pending_conns.update_next_poll_at_ms(socket, poll_at)
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable reference to a [`PollableIface`].
|
||||
///
|
||||
/// This type is reconstructed from mutable references to fields in [`PollableIface`], since the fields
|
||||
/// must be broken into individual fields during interface polling due to limitations of the
|
||||
/// [`smoltcp`] APIs.
|
||||
pub(crate) struct PollableIfaceMut<'a, E: Ext> {
|
||||
context: &'a mut smoltcp::iface::Context,
|
||||
pending_conns: &'a mut PendingConnSet<E>,
|
||||
}
|
||||
|
||||
// FIXME: We provide `new()` and `inner_mut()` as `pub(crate)` methods because it's necessary to
|
||||
// allow the Rust compiler to check the lifetime for separate fields. We should find better ways to
|
||||
// avoid these `pub(crate)` methods in the future.
|
||||
impl<'a, E: Ext> PollableIfaceMut<'a, E> {
|
||||
pub(crate) fn new(
|
||||
context: &'a mut smoltcp::iface::Context,
|
||||
pending_conns: &'a mut PendingConnSet<E>,
|
||||
) -> Self {
|
||||
Self {
|
||||
context,
|
||||
pending_conns,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inner_mut(&mut self) -> (&mut smoltcp::iface::Context, &mut PendingConnSet<E>) {
|
||||
(self.context, self.pending_conns)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Ext> PollableIfaceMut<'_, E> {
|
||||
pub(super) fn pop_pending_tcp(&mut self) -> Option<Arc<TcpConnectionBg<E>>> {
|
||||
let now = self.context.now.total_millis() as u64;
|
||||
self.pending_conns.pop_tcp_before_now(now)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Ext> PollableIfaceMut<'_, E> {
|
||||
/// Returns an immutable reference to the `smoltcp` context.
|
||||
pub(crate) fn context(&self) -> &smoltcp::iface::Context {
|
||||
self.context
|
||||
}
|
||||
|
||||
/// Returns the `smoltcp` context for passing to the `smoltcp` APIs.
|
||||
pub(crate) fn context_mut(&mut self) -> &mut smoltcp::iface::Context {
|
||||
self.context
|
||||
}
|
||||
|
||||
/// Updates the next poll time of `socket` to `poll_at`.
|
||||
///
|
||||
/// This method (or [`PollableIface::update_next_poll_at_ms`]) should be called after network
|
||||
/// or user events that change the poll time occur.
|
||||
pub(crate) fn update_next_poll_at_ms(
|
||||
&mut self,
|
||||
socket: &Arc<TcpConnectionBg<E>>,
|
||||
poll_at: smoltcp::socket::PollAt,
|
||||
) -> NeedIfacePoll {
|
||||
self.pending_conns.update_next_poll_at_ms(socket, poll_at)
|
||||
}
|
||||
}
|
||||
|
||||
/// A key to sort sockets by their next poll time.
|
||||
pub(crate) struct PollKey {
|
||||
next_poll_at_ms: AtomicU64,
|
||||
id: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for PollKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
== other.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
&& self.id == other.id
|
||||
}
|
||||
}
|
||||
impl Eq for PollKey {}
|
||||
impl PartialOrd for PollKey {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
impl Ord for PollKey {
|
||||
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
|
||||
self.next_poll_at_ms
|
||||
.load(Ordering::Relaxed)
|
||||
.cmp(&other.next_poll_at_ms.load(Ordering::Relaxed))
|
||||
.then_with(|| self.id.cmp(&other.id))
|
||||
}
|
||||
}
|
||||
|
||||
impl PollKey {
|
||||
/// A value indicating that an immediate poll is required.
|
||||
const IMMEDIATE_VAL: u64 = 0;
|
||||
/// A value indicating that no poll is required.
|
||||
const INACTIVE_VAL: u64 = u64::MAX;
|
||||
|
||||
/// Creates a new [`PollKey`].
|
||||
///
|
||||
/// `id` must be a unique identifier for the associated socket, as it will be used to locate
|
||||
/// the socket to update its next poll time. This is usually done using the address of the
|
||||
/// [`Arc`] socket (see [`Arc::as_ptr`]).
|
||||
///
|
||||
/// [`Arc`]: alloc::sync::Arc
|
||||
/// [`Arc::as_ptr`]: alloc::sync::Arc::as_ptr
|
||||
pub(crate) fn new(id: usize) -> Self {
|
||||
Self {
|
||||
next_poll_at_ms: AtomicU64::new(Self::INACTIVE_VAL),
|
||||
id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sockets to poll in the future, sorted by poll time.
|
||||
pub(crate) struct PendingConnSet<E: Ext>(BTreeSet<PendingTcpConn<E>>);
|
||||
|
||||
/// A TCP socket to poll in the future.
|
||||
///
|
||||
/// Note that currently only TCP sockets can set a timer to fire in the future, so a
|
||||
/// [`PendingConnSet`] contains only [`PendingTcpConn`]s.
|
||||
struct PendingTcpConn<E: Ext>(Arc<TcpConnectionBg<E>>);
|
||||
|
||||
impl<E: Ext> PartialEq for PendingTcpConn<E> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0.poll_key() == other.0.poll_key()
|
||||
}
|
||||
}
|
||||
impl<E: Ext> Eq for PendingTcpConn<E> {}
|
||||
impl<E: Ext> PartialOrd for PendingTcpConn<E> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
impl<E: Ext> Ord for PendingTcpConn<E> {
|
||||
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
|
||||
self.0.poll_key().cmp(other.0.poll_key())
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Ext> Borrow<PollKey> for PendingTcpConn<E> {
|
||||
fn borrow(&self) -> &PollKey {
|
||||
self.0.poll_key()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Ext> PendingConnSet<E> {
|
||||
fn new() -> Self {
|
||||
Self(BTreeSet::new())
|
||||
}
|
||||
|
||||
fn update_next_poll_at_ms(
|
||||
&mut self,
|
||||
socket: &Arc<TcpConnectionBg<E>>,
|
||||
poll_at: smoltcp::socket::PollAt,
|
||||
) -> NeedIfacePoll {
|
||||
let key = socket.poll_key();
|
||||
let old_poll_at_ms = key.next_poll_at_ms.load(Ordering::Relaxed);
|
||||
|
||||
let new_poll_at_ms = match poll_at {
|
||||
smoltcp::socket::PollAt::Now => PollKey::IMMEDIATE_VAL,
|
||||
smoltcp::socket::PollAt::Time(instant) => instant.total_millis() as u64,
|
||||
smoltcp::socket::PollAt::Ingress => PollKey::INACTIVE_VAL,
|
||||
};
|
||||
|
||||
// Fast path: There is nothing to update.
|
||||
if old_poll_at_ms == new_poll_at_ms {
|
||||
return NeedIfacePoll::FALSE;
|
||||
}
|
||||
|
||||
// Remove the socket from the pending queue if it is in the queue.
|
||||
let owned_socket = if old_poll_at_ms != PollKey::INACTIVE_VAL {
|
||||
self.0.take(key).unwrap()
|
||||
} else {
|
||||
PendingTcpConn(socket.clone())
|
||||
};
|
||||
|
||||
// Update the poll time _after_ it is removed from the queue.
|
||||
key.next_poll_at_ms.store(new_poll_at_ms, Ordering::Relaxed);
|
||||
|
||||
// If no new poll is required, do not add the socket to the pending queue.
|
||||
if new_poll_at_ms == PollKey::INACTIVE_VAL {
|
||||
return NeedIfacePoll::FALSE;
|
||||
}
|
||||
|
||||
// Add the socket back to the queue.
|
||||
let inserted = self.0.insert(owned_socket);
|
||||
debug_assert!(inserted);
|
||||
|
||||
if new_poll_at_ms < old_poll_at_ms {
|
||||
NeedIfacePoll::TRUE
|
||||
} else {
|
||||
NeedIfacePoll::FALSE
|
||||
}
|
||||
}
|
||||
|
||||
fn pop_tcp_before_now(&mut self, now_at_ms: u64) -> Option<Arc<TcpConnectionBg<E>>> {
|
||||
if self.0.first().is_some_and(|first| {
|
||||
first.0.poll_key().next_poll_at_ms.load(Ordering::Relaxed) <= now_at_ms
|
||||
}) {
|
||||
self.0.pop_first().map(|first| {
|
||||
// Reset `next_poll_at_ms` since the socket is no longer in the queue.
|
||||
first
|
||||
.0
|
||||
.poll_key()
|
||||
.next_poll_at_ms
|
||||
.store(PollKey::INACTIVE_VAL, Ordering::Relaxed);
|
||||
first.0
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn next_poll_at_ms(&self) -> Option<u64> {
|
||||
self.0
|
||||
.first()
|
||||
.map(|first| first.0.poll_key().next_poll_at_ms.load(Ordering::Relaxed))
|
||||
}
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use core::sync::atomic::{AtomicU64, AtomicU8, Ordering};
|
||||
use alloc::sync::{Arc, Weak};
|
||||
use core::sync::atomic::{AtomicU8, Ordering};
|
||||
|
||||
use smoltcp::{socket::PollAt, time::Instant, wire::IpEndpoint};
|
||||
use smoltcp::wire::IpEndpoint;
|
||||
use spin::once::Once;
|
||||
use takeable::Takeable;
|
||||
|
||||
@ -46,7 +46,6 @@ pub struct SocketBg<T: Inner<E>, E: Ext> {
|
||||
pub(super) inner: T,
|
||||
observer: Once<T::Observer>,
|
||||
events: AtomicU8,
|
||||
next_poll_at_ms: AtomicU64,
|
||||
}
|
||||
|
||||
impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
|
||||
@ -58,13 +57,24 @@ impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
|
||||
}
|
||||
|
||||
impl<T: Inner<E>, E: Ext> Socket<T, E> {
|
||||
pub(crate) fn new(bound: BoundPort<E>, inner: T) -> Self {
|
||||
pub(super) fn new(bound: BoundPort<E>, inner: T) -> Self {
|
||||
Self(Takeable::new(Arc::new(SocketBg {
|
||||
bound,
|
||||
inner,
|
||||
observer: Once::new(),
|
||||
events: AtomicU8::new(0),
|
||||
next_poll_at_ms: AtomicU64::new(u64::MAX),
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) fn new_cyclic<F>(bound: BoundPort<E>, inner_fn: F) -> Self
|
||||
where
|
||||
F: FnOnce(&Weak<SocketBg<T, E>>) -> T,
|
||||
{
|
||||
Self(Takeable::new(Arc::new_cyclic(|weak| SocketBg {
|
||||
bound,
|
||||
inner: inner_fn(weak),
|
||||
observer: Once::new(),
|
||||
events: AtomicU8::new(0),
|
||||
})))
|
||||
}
|
||||
|
||||
@ -119,10 +129,8 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
where
|
||||
T::Observer: Clone,
|
||||
{
|
||||
// This method can only be called to process network events, so we assume we are holding the
|
||||
// poll lock and no race conditions can occur.
|
||||
// There is no need to clear the events because the socket is dead.
|
||||
let events = self.events.load(Ordering::Relaxed);
|
||||
self.events.store(0, Ordering::Relaxed);
|
||||
|
||||
let observer = self.observer.get().cloned();
|
||||
drop(self);
|
||||
@ -141,41 +149,6 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
self.events
|
||||
.store(events | new_events.bits(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Returns the next polling time.
|
||||
///
|
||||
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
|
||||
/// before new network or user events.
|
||||
pub(crate) fn next_poll_at_ms(&self) -> u64 {
|
||||
self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Updates the next polling time according to `poll_at`.
|
||||
///
|
||||
/// The update is typically needed after new network or user events have been handled, so this
|
||||
/// method also marks that there may be new events, so that the event observer provided by
|
||||
/// [`Socket::init_observer`] can be notified later.
|
||||
pub(super) fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll {
|
||||
match poll_at {
|
||||
PollAt::Now => {
|
||||
self.next_poll_at_ms.store(0, Ordering::Relaxed);
|
||||
NeedIfacePoll::TRUE
|
||||
}
|
||||
PollAt::Time(instant) => {
|
||||
let old_total_millis = self.next_poll_at_ms.load(Ordering::Relaxed);
|
||||
let new_total_millis = instant.total_millis() as u64;
|
||||
|
||||
self.next_poll_at_ms
|
||||
.store(new_total_millis, Ordering::Relaxed);
|
||||
|
||||
NeedIfacePoll(new_total_millis < old_total_millis)
|
||||
}
|
||||
PollAt::Ingress => {
|
||||
self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed);
|
||||
NeedIfacePoll::FALSE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
@ -185,11 +158,4 @@ impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
|
||||
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
|
||||
self.bound.port() == dst_port
|
||||
}
|
||||
|
||||
/// Returns whether the socket _may_ generate an outgoing packet.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
|
||||
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,13 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
iface::Context,
|
||||
socket::{tcp::State, PollAt},
|
||||
time::Duration,
|
||||
wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr},
|
||||
@ -19,7 +21,7 @@ use crate::{
|
||||
define_boolean_value,
|
||||
errors::tcp::{ConnectError, RecvError, SendError},
|
||||
ext::Ext,
|
||||
iface::BoundPort,
|
||||
iface::{BoundPort, PollKey, PollableIfaceMut},
|
||||
socket::{
|
||||
event::SocketEvents,
|
||||
option::{RawTcpOption, RawTcpSetOption},
|
||||
@ -33,6 +35,7 @@ pub type TcpConnection<E> = Socket<TcpConnectionInner<E>, E>;
|
||||
/// States needed by [`TcpConnectionBg`].
|
||||
pub struct TcpConnectionInner<E: Ext> {
|
||||
socket: SpinLock<RawTcpSocketExt<E>, LocalIrqDisabled>,
|
||||
poll_key: PollKey,
|
||||
connection_key: ConnectionKey,
|
||||
}
|
||||
|
||||
@ -216,7 +219,11 @@ impl<E: Ext> RawTcpSocketExt<E> {
|
||||
}
|
||||
|
||||
impl<E: Ext> TcpConnectionInner<E> {
|
||||
pub(super) fn new(socket: Box<RawTcpSocket>, listener: Option<Arc<TcpListenerBg<E>>>) -> Self {
|
||||
pub(super) fn new(
|
||||
socket: Box<RawTcpSocket>,
|
||||
listener: Option<Arc<TcpListenerBg<E>>>,
|
||||
weak_self: &Weak<TcpConnectionBg<E>>,
|
||||
) -> Self {
|
||||
let connection_key = {
|
||||
// Since the socket is connected, the following unwrap can never fail
|
||||
let local_endpoint = socket.local_endpoint().unwrap();
|
||||
@ -224,6 +231,8 @@ impl<E: Ext> TcpConnectionInner<E> {
|
||||
ConnectionKey::from((local_endpoint, remote_endpoint))
|
||||
};
|
||||
|
||||
let poll_key = PollKey::new(Weak::as_ptr(weak_self).addr());
|
||||
|
||||
let socket_ext = RawTcpSocketExt {
|
||||
socket,
|
||||
listener,
|
||||
@ -234,6 +243,7 @@ impl<E: Ext> TcpConnectionInner<E> {
|
||||
|
||||
TcpConnectionInner {
|
||||
socket: SpinLock::new(socket_ext),
|
||||
poll_key,
|
||||
connection_key,
|
||||
}
|
||||
}
|
||||
@ -293,7 +303,7 @@ impl<E: Ext> TcpConnection<E> {
|
||||
};
|
||||
|
||||
let iface = bound.iface().clone();
|
||||
// We have to lock interface before locking interface
|
||||
// We have to lock `interface` before locking `sockets`
|
||||
// to avoid dead lock due to inconsistent lock orders.
|
||||
let mut interface = iface.common().interface();
|
||||
let mut sockets = iface.common().sockets();
|
||||
@ -309,18 +319,19 @@ impl<E: Ext> TcpConnection<E> {
|
||||
|
||||
option.apply(&mut socket);
|
||||
|
||||
if let Err(err) = socket.connect(interface.context(), remote_endpoint, bound.port()) {
|
||||
if let Err(err) = socket.connect(interface.context_mut(), remote_endpoint, bound.port())
|
||||
{
|
||||
return Err((bound, err.into()));
|
||||
}
|
||||
|
||||
socket
|
||||
};
|
||||
|
||||
let inner = TcpConnectionInner::new(socket, None);
|
||||
|
||||
let connection = Self::new(bound, inner);
|
||||
connection.0.update_next_poll_at_ms(PollAt::Now);
|
||||
let connection =
|
||||
Self::new_cyclic(bound, |weak| TcpConnectionInner::new(socket, None, weak));
|
||||
interface.update_next_poll_at_ms(&connection.0, PollAt::Now);
|
||||
connection.init_observer(observer);
|
||||
|
||||
let res = sockets.insert_connection(connection.inner().clone());
|
||||
debug_assert!(res.is_ok());
|
||||
|
||||
@ -378,9 +389,8 @@ impl<E: Ext> TcpConnection<E> {
|
||||
}
|
||||
let result = socket.send(f)?;
|
||||
|
||||
let need_poll = self
|
||||
.0
|
||||
.update_next_poll_at_ms(socket.poll_at(iface.context()));
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
|
||||
Ok((result, need_poll))
|
||||
}
|
||||
@ -408,9 +418,8 @@ impl<E: Ext> TcpConnection<E> {
|
||||
res => res,
|
||||
}?;
|
||||
|
||||
let need_poll = self
|
||||
.0
|
||||
.update_next_poll_at_ms(socket.poll_at(iface.context()));
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
let need_poll = iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
|
||||
Ok((result, need_poll))
|
||||
}
|
||||
@ -433,6 +442,7 @@ impl<E: Ext> TcpConnection<E> {
|
||||
///
|
||||
/// Polling the iface is _always_ required after this method succeeds.
|
||||
pub fn shut_send(&self) -> bool {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
if matches!(socket.state(), State::Closed | State::TimeWait) {
|
||||
@ -440,7 +450,9 @@ impl<E: Ext> TcpConnection<E> {
|
||||
}
|
||||
|
||||
socket.close();
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
|
||||
true
|
||||
}
|
||||
@ -469,6 +481,7 @@ impl<E: Ext> TcpConnection<E> {
|
||||
/// Note that either this method or [`Self::reset`] must be called before dropping the TCP
|
||||
/// connection to avoid resource leakage.
|
||||
pub fn close(&self) {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
socket.is_recv_shut = true;
|
||||
@ -479,7 +492,9 @@ impl<E: Ext> TcpConnection<E> {
|
||||
} else {
|
||||
socket.close();
|
||||
}
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
}
|
||||
|
||||
/// Resets the connection.
|
||||
@ -489,10 +504,13 @@ impl<E: Ext> TcpConnection<E> {
|
||||
/// Note that either this method or [`Self::close`] must be called before dropping the TCP
|
||||
/// connection to avoid resource leakage.
|
||||
pub fn reset(&self) {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
socket.abort();
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at);
|
||||
}
|
||||
|
||||
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
|
||||
@ -510,15 +528,13 @@ impl<E: Ext> TcpConnection<E> {
|
||||
|
||||
impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
|
||||
fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
|
||||
let mut iface = self.iface().common().interface();
|
||||
let mut socket = self.0.inner.lock();
|
||||
|
||||
socket.set_keep_alive(interval);
|
||||
|
||||
if interval.is_some() {
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
NeedIfacePoll::TRUE
|
||||
} else {
|
||||
NeedIfacePoll::FALSE
|
||||
}
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(&self.0, poll_at)
|
||||
}
|
||||
|
||||
fn set_nagle_enabled(&self, enabled: bool) {
|
||||
@ -528,6 +544,10 @@ impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
|
||||
}
|
||||
|
||||
impl<E: Ext> TcpConnectionBg<E> {
|
||||
pub(crate) const fn poll_key(&self) -> &PollKey {
|
||||
&self.inner.poll_key
|
||||
}
|
||||
|
||||
pub(crate) const fn connection_key(&self) -> &ConnectionKey {
|
||||
&self.inner.connection_key
|
||||
}
|
||||
@ -544,13 +564,13 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
self: &Arc<Self>,
|
||||
cx: &mut Context,
|
||||
iface: &mut PollableIfaceMut<E>,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> (TcpProcessResult, TcpConnBecameDead) {
|
||||
let mut socket = self.inner.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
|
||||
return (TcpProcessResult::NotProcessed, TcpConnBecameDead::FALSE);
|
||||
}
|
||||
|
||||
@ -581,7 +601,7 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
// to be queued.
|
||||
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
|
||||
|
||||
let result = match socket.process(cx, ip_repr, tcp_repr) {
|
||||
let result = match socket.process(iface.context_mut(), ip_repr, tcp_repr) {
|
||||
None => TcpProcessResult::Processed,
|
||||
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
|
||||
};
|
||||
@ -591,7 +611,9 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
events |= state_events;
|
||||
|
||||
self.add_events(events);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(self, poll_at);
|
||||
|
||||
(result, became_dead)
|
||||
}
|
||||
@ -599,11 +621,11 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
/// Tries to generate an outgoing packet and dispatches the generated packet.
|
||||
pub(crate) fn dispatch<D>(
|
||||
self: &Arc<Self>,
|
||||
cx: &mut Context,
|
||||
iface: &mut PollableIfaceMut<E>,
|
||||
dispatch: D,
|
||||
) -> (Option<(IpRepr, TcpRepr<'static>)>, TcpConnBecameDead)
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
|
||||
D: FnOnce(PollableIfaceMut<E>, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
|
||||
{
|
||||
let mut socket = self.inner.lock();
|
||||
|
||||
@ -613,9 +635,10 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
let mut events = SocketEvents::empty();
|
||||
|
||||
let mut reply = None;
|
||||
let (cx, pending) = iface.inner_mut();
|
||||
socket
|
||||
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
|
||||
reply = dispatch(cx, &ip_repr, &tcp_repr);
|
||||
reply = dispatch(PollableIfaceMut::new(cx, pending), &ip_repr, &tcp_repr);
|
||||
Ok::<(), ()>(())
|
||||
})
|
||||
.unwrap();
|
||||
@ -623,12 +646,12 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
// `dispatch` can return a packet in response to the generated packet. If the socket
|
||||
// accepts the packet, we can process it directly.
|
||||
while let Some((ref ip_repr, ref tcp_repr)) = reply {
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
if !socket.accepts(iface.context_mut(), ip_repr, tcp_repr) {
|
||||
break;
|
||||
}
|
||||
is_rst |= tcp_repr.control == TcpControl::Rst;
|
||||
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
|
||||
reply = socket.process(cx, ip_repr, tcp_repr);
|
||||
reply = socket.process(iface.context_mut(), ip_repr, tcp_repr);
|
||||
}
|
||||
|
||||
let (state_events, became_dead) =
|
||||
@ -636,7 +659,9 @@ impl<E: Ext> TcpConnectionBg<E> {
|
||||
events |= state_events;
|
||||
|
||||
self.add_events(events);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
let poll_at = socket.poll_at(iface.context_mut());
|
||||
iface.update_next_poll_at_ms(self, poll_at);
|
||||
|
||||
(reply, became_dead)
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ use alloc::{boxed::Box, collections::btree_map::BTreeMap, sync::Arc, vec::Vec};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock};
|
||||
use smoltcp::{
|
||||
iface::Context,
|
||||
socket::PollAt,
|
||||
time::Duration,
|
||||
wire::{IpEndpoint, IpRepr, TcpRepr},
|
||||
@ -17,7 +16,7 @@ use super::{
|
||||
use crate::{
|
||||
errors::tcp::ListenError,
|
||||
ext::Ext,
|
||||
iface::{BindPortConfig, BoundPort},
|
||||
iface::{BindPortConfig, BoundPort, PollableIfaceMut},
|
||||
socket::{
|
||||
option::{RawTcpOption, RawTcpSetOption},
|
||||
unbound::{new_tcp_socket, RawTcpSocket},
|
||||
@ -194,13 +193,16 @@ impl<E: Ext> TcpListenerBg<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
self: &Arc<Self>,
|
||||
cx: &mut Context,
|
||||
iface: &mut PollableIfaceMut<E>,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> (TcpProcessResult, Option<Arc<TcpConnectionBg<E>>>) {
|
||||
let mut backlog = self.inner.backlog.lock();
|
||||
|
||||
if !backlog.socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
if !backlog
|
||||
.socket
|
||||
.accepts(iface.context_mut(), ip_repr, tcp_repr)
|
||||
{
|
||||
return (TcpProcessResult::NotProcessed, None);
|
||||
}
|
||||
|
||||
@ -211,7 +213,10 @@ impl<E: Ext> TcpListenerBg<E> {
|
||||
return (TcpProcessResult::Processed, None);
|
||||
}
|
||||
|
||||
let result = match backlog.socket.process(cx, ip_repr, tcp_repr) {
|
||||
let result = match backlog
|
||||
.socket
|
||||
.process(iface.context_mut(), ip_repr, tcp_repr)
|
||||
{
|
||||
None => TcpProcessResult::Processed,
|
||||
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
|
||||
};
|
||||
@ -227,23 +232,25 @@ impl<E: Ext> TcpListenerBg<E> {
|
||||
socket
|
||||
};
|
||||
|
||||
let inner = TcpConnectionInner::new(
|
||||
core::mem::replace(&mut backlog.socket, new_socket),
|
||||
Some(self.clone()),
|
||||
);
|
||||
let conn = TcpConnection::new(
|
||||
let conn = TcpConnection::new_cyclic(
|
||||
self.bound
|
||||
.iface()
|
||||
.bind(BindPortConfig::CanReuse(self.bound.port()))
|
||||
.unwrap(),
|
||||
inner,
|
||||
|weak| {
|
||||
TcpConnectionInner::new(
|
||||
core::mem::replace(&mut backlog.socket, new_socket),
|
||||
Some(self.clone()),
|
||||
weak,
|
||||
)
|
||||
},
|
||||
);
|
||||
let conn_bg = conn.inner().clone();
|
||||
|
||||
let old_conn = backlog.connecting.insert(*conn_bg.connection_key(), conn);
|
||||
debug_assert!(old_conn.is_none());
|
||||
|
||||
conn_bg.update_next_poll_at_ms(PollAt::Now);
|
||||
iface.update_next_poll_at_ms(&conn_bg, PollAt::Now);
|
||||
|
||||
(result, Some(conn_bg))
|
||||
}
|
||||
|
@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock};
|
||||
use smoltcp::{
|
||||
iface::Context,
|
||||
socket::{udp::UdpMetadata, PollAt},
|
||||
socket::udp::UdpMetadata,
|
||||
wire::{IpRepr, UdpRepr},
|
||||
};
|
||||
|
||||
@ -20,13 +21,16 @@ use crate::{
|
||||
pub type UdpSocket<E> = Socket<UdpSocketInner, E>;
|
||||
|
||||
/// States needed by [`UdpSocketBg`].
|
||||
type UdpSocketInner = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
|
||||
pub struct UdpSocketInner {
|
||||
socket: SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>,
|
||||
need_dispatch: AtomicBool,
|
||||
}
|
||||
|
||||
impl<E: Ext> Inner<E> for UdpSocketInner {
|
||||
type Observer = E::UdpEventObserver;
|
||||
|
||||
fn on_drop(this: &Arc<SocketBg<Self, E>>) {
|
||||
this.inner.lock().close();
|
||||
this.inner.socket.lock().close();
|
||||
|
||||
// A UDP socket can be removed immediately.
|
||||
this.bound.iface().common().remove_udp_socket(this);
|
||||
@ -44,7 +48,7 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
udp_repr: &UdpRepr,
|
||||
udp_payload: &[u8],
|
||||
) -> bool {
|
||||
let mut socket = self.inner.lock();
|
||||
let mut socket = self.inner.socket.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, udp_repr) {
|
||||
return false;
|
||||
@ -59,7 +63,6 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
);
|
||||
|
||||
self.add_events(SocketEvents::CAN_RECV);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
true
|
||||
}
|
||||
@ -69,7 +72,7 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
|
||||
{
|
||||
let mut socket = self.inner.lock();
|
||||
let mut socket = self.inner.socket.lock();
|
||||
|
||||
socket
|
||||
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
|
||||
@ -80,7 +83,17 @@ impl<E: Ext> UdpSocketBg<E> {
|
||||
|
||||
// For UDP, dequeuing a packet means that we can queue more packets.
|
||||
self.add_events(SocketEvents::CAN_SEND);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
self.inner
|
||||
.need_dispatch
|
||||
.store(socket.send_queue() > 0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Returns whether the socket _may_ generate an outgoing packet.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn need_dispatch(&self) -> bool {
|
||||
self.inner.need_dispatch.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,7 +119,10 @@ impl<E: Ext> UdpSocket<E> {
|
||||
socket
|
||||
};
|
||||
|
||||
let inner = UdpSocketInner::new(socket);
|
||||
let inner = UdpSocketInner {
|
||||
socket: SpinLock::new(socket),
|
||||
need_dispatch: AtomicBool::new(false),
|
||||
};
|
||||
|
||||
let socket = Self::new(bound, inner);
|
||||
socket.init_observer(observer);
|
||||
@ -130,7 +146,7 @@ impl<E: Ext> UdpSocket<E> {
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> R,
|
||||
{
|
||||
let mut socket = self.0.inner.lock();
|
||||
let mut socket = self.0.inner.socket.lock();
|
||||
|
||||
if size > socket.packet_send_capacity() {
|
||||
return Err(SendError::TooLarge);
|
||||
@ -141,7 +157,11 @@ impl<E: Ext> UdpSocket<E> {
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
let result = f(buffer);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
self.0
|
||||
.inner
|
||||
.need_dispatch
|
||||
.store(socket.send_queue() > 0, Ordering::Relaxed);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
@ -153,7 +173,7 @@ impl<E: Ext> UdpSocket<E> {
|
||||
where
|
||||
F: FnOnce(&[u8], UdpMetadata) -> R,
|
||||
{
|
||||
let mut socket = self.0.inner.lock();
|
||||
let mut socket = self.0.inner.socket.lock();
|
||||
|
||||
let (data, meta) = socket.recv()?;
|
||||
let result = f(data, meta);
|
||||
@ -169,7 +189,7 @@ impl<E: Ext> UdpSocket<E> {
|
||||
where
|
||||
F: FnOnce(&RawUdpSocket) -> R,
|
||||
{
|
||||
let socket = self.0.inner.lock();
|
||||
let socket = self.0.inner.socket.lock();
|
||||
f(&socket)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user