mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-21 00:06:34 +00:00
Move packet dispatch out of smoltcp
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
f793259512
commit
ee1656ba35
@ -8,7 +8,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
keyable-arc = { path = "../keyable-arc" }
|
||||
ostd = { path = "../../../ostd" }
|
||||
smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", rev = "469dccd", default-features = false, features = [
|
||||
smoltcp = { git = "https://github.com/asterinas/smoltcp", rev = "37716bf", default-features = false, features = [
|
||||
"alloc",
|
||||
"log",
|
||||
"medium-ethernet",
|
||||
|
@ -14,5 +14,15 @@ pub mod tcp {
|
||||
}
|
||||
|
||||
pub mod udp {
|
||||
pub use smoltcp::socket::udp::{RecvError, SendError};
|
||||
pub use smoltcp::socket::udp::RecvError;
|
||||
|
||||
/// An error returned by [`BoundTcpSocket::recv`].
|
||||
///
|
||||
/// [`BoundTcpSocket::recv`]: crate::socket::BoundTcpSocket::recv
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum SendError {
|
||||
TooLarge,
|
||||
Unaddressable,
|
||||
BufferFull,
|
||||
}
|
||||
}
|
||||
|
@ -7,42 +7,45 @@ use alloc::{
|
||||
btree_set::BTreeSet,
|
||||
},
|
||||
sync::Arc,
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use keyable_arc::KeyableArc;
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLock, SpinLock, SpinLockGuard};
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
iface::{SocketHandle, SocketSet},
|
||||
iface::{packet::Packet, Context},
|
||||
phy::Device,
|
||||
wire::Ipv4Address,
|
||||
wire::{Ipv4Address, Ipv4Packet},
|
||||
};
|
||||
|
||||
use super::{port::BindPortConfig, time::get_network_timestamp, Iface};
|
||||
use super::{
|
||||
poll::{FnHelper, PollContext},
|
||||
port::BindPortConfig,
|
||||
time::get_network_timestamp,
|
||||
Iface,
|
||||
};
|
||||
use crate::{
|
||||
errors::BindError,
|
||||
socket::{AnyBoundSocket, AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
|
||||
socket::{
|
||||
BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket,
|
||||
UnboundUdpSocket,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct IfaceCommon<E> {
|
||||
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
|
||||
sockets: SpinLock<SocketSet<'static>, LocalIrqDisabled>,
|
||||
used_ports: SpinLock<BTreeMap<u16, usize>, PreemptDisabled>,
|
||||
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
|
||||
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>, LocalIrqDisabled>,
|
||||
tcp_sockets: SpinLock<BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, LocalIrqDisabled>,
|
||||
udp_sockets: SpinLock<BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, LocalIrqDisabled>,
|
||||
ext: E,
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
pub(super) fn new(interface: smoltcp::iface::Interface, ext: E) -> Self {
|
||||
let socket_set = SocketSet::new(Vec::new());
|
||||
let used_ports = BTreeMap::new();
|
||||
Self {
|
||||
interface: SpinLock::new(interface),
|
||||
sockets: SpinLock::new(socket_set),
|
||||
used_ports: SpinLock::new(used_ports),
|
||||
bound_sockets: RwLock::new(BTreeSet::new()),
|
||||
closing_sockets: SpinLock::new(BTreeSet::new()),
|
||||
used_ports: SpinLock::new(BTreeMap::new()),
|
||||
tcp_sockets: SpinLock::new(BTreeSet::new()),
|
||||
udp_sockets: SpinLock::new(BTreeSet::new()),
|
||||
ext,
|
||||
}
|
||||
}
|
||||
@ -58,58 +61,57 @@ impl<E> IfaceCommon<E> {
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
/// Acquires the lock to the interface.
|
||||
///
|
||||
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
|
||||
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
|
||||
self.interface.lock()
|
||||
}
|
||||
|
||||
/// Acuqires the lock to the sockets.
|
||||
///
|
||||
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
|
||||
pub(crate) fn sockets(
|
||||
&self,
|
||||
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
|
||||
self.sockets.lock()
|
||||
}
|
||||
}
|
||||
|
||||
const IP_LOCAL_PORT_START: u16 = 32768;
|
||||
const IP_LOCAL_PORT_END: u16 = 60999;
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
pub(super) fn bind_socket(
|
||||
pub(super) fn bind_tcp(
|
||||
&self,
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
socket: Box<AnyUnboundSocket>,
|
||||
socket: Box<UnboundTcpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Some(port) => port,
|
||||
None => return Err((BindError::Exhausted, socket)),
|
||||
}
|
||||
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
|
||||
let port = match self.bind_port(config) {
|
||||
Ok(port) => port,
|
||||
Err(err) => return Err((err, socket)),
|
||||
};
|
||||
if !self.bind_port(port, config.can_reuse()) {
|
||||
return Err((BindError::InUse, socket));
|
||||
}
|
||||
|
||||
let (handle, socket_family, observer) = match socket.into_raw() {
|
||||
(AnyRawSocket::Tcp(tcp_socket), observer) => (
|
||||
self.sockets.lock().add(tcp_socket),
|
||||
SocketFamily::Tcp,
|
||||
observer,
|
||||
),
|
||||
(AnyRawSocket::Udp(udp_socket), observer) => (
|
||||
self.sockets.lock().add(udp_socket),
|
||||
SocketFamily::Udp,
|
||||
observer,
|
||||
),
|
||||
let (raw_socket, observer) = socket.into_raw();
|
||||
let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer);
|
||||
|
||||
let inserted = self
|
||||
.tcp_sockets
|
||||
.lock()
|
||||
.insert(KeyableArc::from(bound_socket.inner().clone()));
|
||||
assert!(inserted);
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
|
||||
pub(super) fn bind_udp(
|
||||
&self,
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
socket: Box<UnboundUdpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
|
||||
let port = match self.bind_port(config) {
|
||||
Ok(port) => port,
|
||||
Err(err) => return Err((err, socket)),
|
||||
};
|
||||
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
|
||||
self.insert_bound_socket(bound_socket.inner());
|
||||
|
||||
let (raw_socket, observer) = socket.into_raw();
|
||||
let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer);
|
||||
|
||||
let inserted = self
|
||||
.udp_sockets
|
||||
.lock()
|
||||
.insert(KeyableArc::from(bound_socket.inner().clone()));
|
||||
assert!(inserted);
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
@ -130,36 +132,57 @@ impl<E> IfaceCommon<E> {
|
||||
None
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn bind_port(&self, port: u16, can_reuse: bool) -> bool {
|
||||
fn bind_port(&self, config: BindPortConfig) -> Result<u16, BindError> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Some(port) => port,
|
||||
None => return Err(BindError::Exhausted),
|
||||
}
|
||||
};
|
||||
|
||||
let mut used_ports = self.used_ports.lock();
|
||||
|
||||
if let Some(used_times) = used_ports.get_mut(&port) {
|
||||
if *used_times == 0 || can_reuse {
|
||||
if *used_times == 0 || config.can_reuse() {
|
||||
// FIXME: Check if the previous socket was bound with SO_REUSEADDR.
|
||||
*used_times += 1;
|
||||
} else {
|
||||
return false;
|
||||
return Err(BindError::InUse);
|
||||
}
|
||||
} else {
|
||||
used_ports.insert(port, 1);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let inserted = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
Ok(port)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>) {
|
||||
sockets.retain(|socket| {
|
||||
if socket.is_dead() {
|
||||
self.release_port(socket.port());
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn remove_udp_socket(&self, socket: &Arc<BoundUdpSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self.udp_sockets.lock().remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
self.release_port(keyable_socket.port());
|
||||
}
|
||||
|
||||
/// Releases the port so that it can be used again (if it is not being reused).
|
||||
pub(crate) fn release_port(&self, port: u16) {
|
||||
fn release_port(&self, port: u16) {
|
||||
let mut used_ports = self.used_ports.lock();
|
||||
if let Some(used_times) = used_ports.remove(&port) {
|
||||
if used_times != 1 {
|
||||
@ -167,96 +190,63 @@ impl<E> IfaceCommon<E> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a socket from the interface.
|
||||
pub(crate) fn remove_socket(&self, handle: SocketHandle) {
|
||||
self.sockets.lock().remove(handle);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
let mut closing_sockets = self.closing_sockets.lock();
|
||||
|
||||
// Check `is_closed` after holding the lock to avoid race conditions.
|
||||
if keyable_socket.is_closed() {
|
||||
return;
|
||||
}
|
||||
|
||||
let inserted = closing_sockets.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
#[must_use]
|
||||
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> {
|
||||
let mut sockets = self.sockets.lock();
|
||||
let mut interface = self.interface.lock();
|
||||
pub(super) fn poll<D, P, Q>(
|
||||
&self,
|
||||
device: &mut D,
|
||||
process_phy: P,
|
||||
mut dispatch_phy: Q,
|
||||
) -> Option<u64>
|
||||
where
|
||||
D: Device + ?Sized,
|
||||
P: for<'pkt, 'cx, 'tx> FnHelper<
|
||||
&'pkt [u8],
|
||||
&'cx mut Context,
|
||||
D::TxToken<'tx>,
|
||||
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
|
||||
>,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
let mut interface = self.interface();
|
||||
interface.context().now = get_network_timestamp();
|
||||
|
||||
let timestamp = get_network_timestamp();
|
||||
let (has_events, poll_at) = {
|
||||
let mut has_events = false;
|
||||
let mut poll_at;
|
||||
let mut tcp_sockets = self.tcp_sockets.lock();
|
||||
let udp_sockets = self.udp_sockets.lock();
|
||||
|
||||
loop {
|
||||
// `poll` transmits and receives a bounded number of packets. This loop ensures
|
||||
// that all packets are transmitted and received. For details, see
|
||||
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L400-L405>.
|
||||
while interface.poll(timestamp, device, &mut sockets) {
|
||||
has_events = true;
|
||||
}
|
||||
let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets);
|
||||
context.poll_ingress(device, process_phy, &mut dispatch_phy);
|
||||
context.poll_egress(device, dispatch_phy);
|
||||
|
||||
// `poll_at` can return `Some(Instant::from_millis(0))`, which means `PollAt::Now`.
|
||||
// For details, see
|
||||
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L478>.
|
||||
poll_at = interface.poll_at(timestamp, &sockets);
|
||||
let Some(instant) = poll_at else {
|
||||
break;
|
||||
};
|
||||
if instant > timestamp {
|
||||
break;
|
||||
}
|
||||
tcp_sockets.iter().for_each(|socket| {
|
||||
if socket.has_new_events() {
|
||||
socket.on_iface_events();
|
||||
}
|
||||
});
|
||||
udp_sockets.iter().for_each(|socket| {
|
||||
if socket.has_new_events() {
|
||||
socket.on_iface_events();
|
||||
}
|
||||
});
|
||||
|
||||
(has_events, poll_at)
|
||||
};
|
||||
self.remove_dead_tcp_sockets(&mut tcp_sockets);
|
||||
|
||||
// drop sockets here to avoid deadlock
|
||||
drop(interface);
|
||||
drop(sockets);
|
||||
|
||||
if has_events {
|
||||
// We never try to hold the write lock in the IRQ context, and we disable IRQ when
|
||||
// holding the write lock. So we don't need to disable IRQ when holding the read lock.
|
||||
self.bound_sockets.read().iter().for_each(|bound_socket| {
|
||||
bound_socket.on_iface_events();
|
||||
});
|
||||
|
||||
let closed_sockets = self
|
||||
.closing_sockets
|
||||
.lock()
|
||||
.extract_if(|closing_socket| closing_socket.is_closed())
|
||||
.collect::<Vec<_>>();
|
||||
drop(closed_sockets);
|
||||
match (
|
||||
tcp_sockets
|
||||
.iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min(),
|
||||
udp_sockets
|
||||
.iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min(),
|
||||
) {
|
||||
(Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => {
|
||||
Some(tcp_poll_at)
|
||||
}
|
||||
(tcp_poll_at, None) => tcp_poll_at,
|
||||
(_, udp_poll_at) => udp_poll_at,
|
||||
}
|
||||
|
||||
poll_at.map(|at| smoltcp::time::Instant::total_millis(&at) as u64)
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ use smoltcp::wire::Ipv4Address;
|
||||
use super::port::BindPortConfig;
|
||||
use crate::{
|
||||
errors::BindError,
|
||||
socket::{AnyBoundSocket, AnyUnboundSocket},
|
||||
socket::{BoundTcpSocket, BoundUdpSocket, UnboundTcpSocket, UnboundUdpSocket},
|
||||
};
|
||||
|
||||
/// A network interface.
|
||||
@ -42,13 +42,22 @@ impl<E> dyn Iface<E> {
|
||||
/// FIXME: The reason for binding the socket and the iface together is because there are
|
||||
/// limitations inside smoltcp. See discussion at
|
||||
/// <https://github.com/smoltcp-rs/smoltcp/issues/779>.
|
||||
pub fn bind_socket(
|
||||
pub fn bind_tcp(
|
||||
self: &Arc<Self>,
|
||||
socket: Box<AnyUnboundSocket>,
|
||||
socket: Box<UnboundTcpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
|
||||
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
|
||||
let common = self.common();
|
||||
common.bind_socket(self.clone(), socket, config)
|
||||
common.bind_tcp(self.clone(), socket, config)
|
||||
}
|
||||
|
||||
pub fn bind_udp(
|
||||
self: &Arc<Self>,
|
||||
socket: Box<UnboundUdpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
|
||||
let common = self.common();
|
||||
common.bind_udp(self.clone(), socket, config)
|
||||
}
|
||||
|
||||
/// Gets the IPv4 address of the iface, if any.
|
||||
|
@ -4,6 +4,7 @@ mod common;
|
||||
#[allow(clippy::module_inception)]
|
||||
mod iface;
|
||||
mod phy;
|
||||
mod poll;
|
||||
mod port;
|
||||
mod time;
|
||||
|
||||
|
@ -1,10 +1,15 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use alloc::{collections::btree_map::BTreeMap, sync::Arc};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock};
|
||||
use smoltcp::{
|
||||
iface::Config,
|
||||
wire::{self, EthernetAddress, Ipv4Address, Ipv4Cidr},
|
||||
iface::{packet::Packet, Config, Context},
|
||||
phy::{DeviceCapabilities, TxToken},
|
||||
wire::{
|
||||
self, ArpOperation, ArpPacket, ArpRepr, EthernetAddress, EthernetFrame, EthernetProtocol,
|
||||
EthernetRepr, IpAddress, Ipv4Address, Ipv4Cidr, Ipv4Packet,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@ -14,9 +19,11 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
pub struct EtherIface<D: WithDevice, E> {
|
||||
pub struct EtherIface<D, E> {
|
||||
driver: D,
|
||||
common: IfaceCommon<E>,
|
||||
ether_addr: EthernetAddress,
|
||||
arp_table: SpinLock<BTreeMap<Ipv4Address, EthernetAddress>, LocalIrqDisabled>,
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E> EtherIface<D, E> {
|
||||
@ -45,21 +52,224 @@ impl<D: WithDevice, E> EtherIface<D, E> {
|
||||
|
||||
let common = IfaceCommon::new(interface, ext);
|
||||
|
||||
Arc::new(Self { driver, common })
|
||||
Arc::new(Self {
|
||||
driver,
|
||||
common,
|
||||
ether_addr,
|
||||
arp_table: SpinLock::new(BTreeMap::new()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E> IfaceInternal<E> for EtherIface<D, E> {
|
||||
impl<D, E> IfaceInternal<E> for EtherIface<D, E> {
|
||||
fn common(&self) -> &IfaceCommon<E> {
|
||||
&self.common
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E: Send + Sync> Iface<E> for EtherIface<D, E> {
|
||||
impl<D: WithDevice + 'static, E: Send + Sync> Iface<E> for EtherIface<D, E> {
|
||||
fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option<u64>)) {
|
||||
self.driver.with(|device| {
|
||||
let next_poll = self.common.poll(&mut *device);
|
||||
let next_poll = self.common.poll(
|
||||
&mut *device,
|
||||
|data, iface_cx, tx_token| self.process(data, iface_cx, tx_token),
|
||||
|pkt, iface_cx, tx_token| self.dispatch(pkt, iface_cx, tx_token),
|
||||
);
|
||||
schedule_next_poll(next_poll);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, E> EtherIface<D, E> {
|
||||
fn process<'pkt, T: TxToken>(
|
||||
&self,
|
||||
data: &'pkt [u8],
|
||||
iface_cx: &mut Context,
|
||||
tx_token: T,
|
||||
) -> Option<(Ipv4Packet<&'pkt [u8]>, T)> {
|
||||
match self.parse_ip_or_process_arp(data, iface_cx) {
|
||||
Ok(pkt) => Some((pkt, tx_token)),
|
||||
Err(Some(arp)) => {
|
||||
Self::emit_arp(&arp, tx_token);
|
||||
None
|
||||
}
|
||||
Err(None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ip_or_process_arp<'pkt>(
|
||||
&self,
|
||||
data: &'pkt [u8],
|
||||
iface_cx: &mut Context,
|
||||
) -> Result<Ipv4Packet<&'pkt [u8]>, Option<ArpRepr>> {
|
||||
// Parse the Ethernet header. Ignore the packet if the header is ill-formed.
|
||||
let frame = EthernetFrame::new_checked(data).map_err(|_| None)?;
|
||||
let repr = EthernetRepr::parse(&frame).map_err(|_| None)?;
|
||||
|
||||
// Ignore the Ethernet frame if it is not sent to us.
|
||||
if !repr.dst_addr.is_broadcast() && repr.dst_addr != self.ether_addr {
|
||||
return Err(None);
|
||||
}
|
||||
|
||||
// Ignore the Ethernet frame if the protocol is not supported.
|
||||
match repr.ethertype {
|
||||
EthernetProtocol::Ipv4 => {
|
||||
Ok(Ipv4Packet::new_checked(frame.payload()).map_err(|_| None)?)
|
||||
}
|
||||
EthernetProtocol::Arp => {
|
||||
let pkt = ArpPacket::new_checked(frame.payload()).map_err(|_| None)?;
|
||||
let arp = ArpRepr::parse(&pkt).map_err(|_| None)?;
|
||||
Err(self.process_arp(&arp, iface_cx))
|
||||
}
|
||||
_ => Err(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_arp(&self, arp_repr: &ArpRepr, iface_cx: &mut Context) -> Option<ArpRepr> {
|
||||
match arp_repr {
|
||||
ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Reply,
|
||||
source_hardware_addr,
|
||||
source_protocol_addr,
|
||||
..
|
||||
} => {
|
||||
// Ignore the ARP packet if the source addresses are not unicast or not local.
|
||||
if !source_hardware_addr.is_unicast()
|
||||
|| !iface_cx.in_same_network(&IpAddress::Ipv4(*source_protocol_addr))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Insert the mapping between the Ethernet address and the IP address.
|
||||
//
|
||||
// TODO: Remove the mapping if it expires.
|
||||
self.arp_table
|
||||
.lock()
|
||||
.insert(*source_protocol_addr, *source_hardware_addr);
|
||||
|
||||
None
|
||||
}
|
||||
ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Request,
|
||||
source_hardware_addr,
|
||||
source_protocol_addr,
|
||||
target_protocol_addr,
|
||||
..
|
||||
} => {
|
||||
// Ignore the ARP packet if the source addresses are not unicast.
|
||||
if !source_hardware_addr.is_unicast() || !source_protocol_addr.is_unicast() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Ignore the ARP packet if we do not own the target address.
|
||||
if !iface_cx
|
||||
.ipv4_addr()
|
||||
.is_some_and(|addr| addr == *target_protocol_addr)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Reply,
|
||||
source_hardware_addr: self.ether_addr,
|
||||
source_protocol_addr: *target_protocol_addr,
|
||||
target_hardware_addr: *source_hardware_addr,
|
||||
target_protocol_addr: *source_protocol_addr,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch<T: TxToken>(&self, pkt: &Packet, iface_cx: &mut Context, tx_token: T) {
|
||||
match self.resolve_ether_or_generate_arp(pkt, iface_cx) {
|
||||
Ok(ether) => Self::emit_ip(ðer, pkt, &iface_cx.caps, tx_token),
|
||||
Err(Some(arp)) => Self::emit_arp(&arp, tx_token),
|
||||
Err(None) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_ether_or_generate_arp(
|
||||
&self,
|
||||
pkt: &Packet,
|
||||
iface_cx: &mut Context,
|
||||
) -> Result<EthernetRepr, Option<ArpRepr>> {
|
||||
// Resolve the next-hop IP address.
|
||||
let next_hop_ip = match iface_cx.route(&pkt.ip_repr().dst_addr(), iface_cx.now()) {
|
||||
Some(IpAddress::Ipv4(next_hop_ip)) => next_hop_ip,
|
||||
None => return Err(None),
|
||||
};
|
||||
|
||||
// Resolve the next-hop Ethernet address.
|
||||
let next_hop_ether = if next_hop_ip.is_broadcast() {
|
||||
EthernetAddress::BROADCAST
|
||||
} else if let Some(next_hop_ether) = self.arp_table.lock().get(&next_hop_ip) {
|
||||
*next_hop_ether
|
||||
} else {
|
||||
// If the next-hop Ethernet address cannot be resolved, we drop the original packet and
|
||||
// send an ARP packet instead. The upper layer should be responsible for detecting the
|
||||
// packet loss and retrying later to see if the Ethernet address is ready.
|
||||
return Err(Some(ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Request,
|
||||
source_hardware_addr: self.ether_addr,
|
||||
source_protocol_addr: iface_cx.ipv4_addr().unwrap_or(Ipv4Address::UNSPECIFIED),
|
||||
target_hardware_addr: EthernetAddress::BROADCAST,
|
||||
target_protocol_addr: next_hop_ip,
|
||||
}));
|
||||
};
|
||||
|
||||
Ok(EthernetRepr {
|
||||
src_addr: self.ether_addr,
|
||||
dst_addr: next_hop_ether,
|
||||
ethertype: EthernetProtocol::Ipv4,
|
||||
})
|
||||
}
|
||||
|
||||
/// Consumes the token and emits an IP packet.
|
||||
fn emit_ip<T: TxToken>(
|
||||
ether_repr: &EthernetRepr,
|
||||
ip_pkt: &Packet,
|
||||
caps: &DeviceCapabilities,
|
||||
tx_token: T,
|
||||
) {
|
||||
tx_token.consume(
|
||||
ether_repr.buffer_len() + ip_pkt.ip_repr().buffer_len(),
|
||||
|buffer| {
|
||||
let mut frame = EthernetFrame::new_unchecked(buffer);
|
||||
ether_repr.emit(&mut frame);
|
||||
|
||||
let ip_repr = ip_pkt.ip_repr();
|
||||
ip_repr.emit(frame.payload_mut(), &caps.checksum);
|
||||
ip_pkt.emit_payload(
|
||||
&ip_repr,
|
||||
&mut frame.payload_mut()[ip_repr.header_len()..],
|
||||
caps,
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Consumes the token and emits an ARP packet.
|
||||
fn emit_arp<T: TxToken>(arp_repr: &ArpRepr, tx_token: T) {
|
||||
let ether_repr = match arp_repr {
|
||||
ArpRepr::EthernetIpv4 {
|
||||
source_hardware_addr,
|
||||
target_hardware_addr,
|
||||
..
|
||||
} => EthernetRepr {
|
||||
src_addr: *source_hardware_addr,
|
||||
dst_addr: *target_hardware_addr,
|
||||
ethertype: EthernetProtocol::Arp,
|
||||
},
|
||||
_ => return,
|
||||
};
|
||||
|
||||
tx_token.consume(ether_repr.buffer_len() + arp_repr.buffer_len(), |buffer| {
|
||||
let mut frame = EthernetFrame::new_unchecked(buffer);
|
||||
ether_repr.emit(&mut frame);
|
||||
|
||||
let mut pkt = ArpPacket::new_unchecked(frame.payload_mut());
|
||||
arp_repr.emit(&mut pkt);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,8 @@ use alloc::sync::Arc;
|
||||
|
||||
use smoltcp::{
|
||||
iface::Config,
|
||||
wire::{self, Ipv4Cidr},
|
||||
phy::TxToken,
|
||||
wire::{self, Ipv4Cidr, Ipv4Packet},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@ -14,7 +15,7 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
pub struct IpIface<D: WithDevice, E> {
|
||||
pub struct IpIface<D, E> {
|
||||
driver: D,
|
||||
common: IfaceCommon<E>,
|
||||
}
|
||||
@ -39,16 +40,30 @@ impl<D: WithDevice, E> IpIface<D, E> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E> IfaceInternal<E> for IpIface<D, E> {
|
||||
impl<D, E> IfaceInternal<E> for IpIface<D, E> {
|
||||
fn common(&self) -> &IfaceCommon<E> {
|
||||
&self.common
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E: Send + Sync> Iface<E> for IpIface<D, E> {
|
||||
impl<D: WithDevice + 'static, E: Send + Sync> Iface<E> for IpIface<D, E> {
|
||||
fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option<u64>)) {
|
||||
self.driver.with(|device| {
|
||||
let next_poll = self.common.poll(device);
|
||||
let next_poll = self.common.poll(
|
||||
device,
|
||||
|data, _iface_cx, tx_token| Some((Ipv4Packet::new_checked(data).ok()?, tx_token)),
|
||||
|pkt, iface_cx, tx_token| {
|
||||
let ip_repr = pkt.ip_repr();
|
||||
tx_token.consume(ip_repr.buffer_len(), |buffer| {
|
||||
ip_repr.emit(&mut buffer[..], &iface_cx.checksum_caps());
|
||||
pkt.emit_payload(
|
||||
&ip_repr,
|
||||
&mut buffer[ip_repr.header_len()..],
|
||||
&iface_cx.caps,
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
schedule_next_poll(next_poll);
|
||||
});
|
||||
}
|
||||
|
476
kernel/libs/aster-bigtcp/src/iface/poll.rs
Normal file
476
kernel/libs/aster-bigtcp/src/iface/poll.rs
Normal file
@ -0,0 +1,476 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{collections::btree_set::BTreeSet, vec};
|
||||
|
||||
use keyable_arc::KeyableArc;
|
||||
use smoltcp::{
|
||||
iface::{
|
||||
packet::{icmp_reply_payload_len, IpPayload, Packet},
|
||||
Context,
|
||||
},
|
||||
phy::{ChecksumCapabilities, Device, RxToken, TxToken},
|
||||
wire::{
|
||||
Icmpv4DstUnreachable, Icmpv4Repr, IpAddress, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet,
|
||||
Ipv4Repr, TcpControl, TcpPacket, TcpRepr, UdpPacket, UdpRepr, IPV4_HEADER_LEN,
|
||||
IPV4_MIN_MTU,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
|
||||
|
||||
pub(super) struct PollContext<'a, E> {
|
||||
iface_cx: &'a mut Context,
|
||||
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
|
||||
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
|
||||
}
|
||||
|
||||
impl<'a, E> PollContext<'a, E> {
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
pub(super) fn new(
|
||||
iface_cx: &'a mut Context,
|
||||
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
|
||||
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
iface_cx,
|
||||
tcp_sockets,
|
||||
udp_sockets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This works around <https://github.com/rust-lang/rust/issues/49601>.
|
||||
// See the issue above for details.
|
||||
pub(super) trait FnHelper<A, B, C, O>: FnMut(A, B, C) -> O {}
|
||||
impl<A, B, C, O, F> FnHelper<A, B, C, O> for F where F: FnMut(A, B, C) -> O {}
|
||||
|
||||
impl<'a, E> PollContext<'a, E> {
|
||||
pub(super) fn poll_ingress<D, P, Q>(
|
||||
&mut self,
|
||||
device: &mut D,
|
||||
mut process_phy: P,
|
||||
dispatch_phy: &mut Q,
|
||||
) where
|
||||
D: Device + ?Sized,
|
||||
P: for<'pkt, 'cx, 'tx> FnHelper<
|
||||
&'pkt [u8],
|
||||
&'cx mut Context,
|
||||
D::TxToken<'tx>,
|
||||
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
|
||||
>,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) {
|
||||
rx_token.consume(|data| {
|
||||
let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(reply) = self.parse_and_process_ipv4(pkt) else {
|
||||
return;
|
||||
};
|
||||
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_and_process_ipv4<'pkt>(
|
||||
&mut self,
|
||||
pkt: Ipv4Packet<&'pkt [u8]>,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// Parse the IP header. Ignore the packet if the header is ill-formed.
|
||||
let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?;
|
||||
|
||||
if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) {
|
||||
return self.generate_icmp_unreachable(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
Icmpv4DstUnreachable::HostUnreachable,
|
||||
);
|
||||
}
|
||||
|
||||
match repr.next_header {
|
||||
IpProtocol::Tcp => self.parse_and_process_tcp(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
&self.iface_cx.checksum_caps(),
|
||||
),
|
||||
IpProtocol::Udp => self.parse_and_process_udp(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
&self.iface_cx.checksum_caps(),
|
||||
),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_and_process_tcp<'pkt>(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
ip_payload: &'pkt [u8],
|
||||
checksum_caps: &ChecksumCapabilities,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// TCP connections can only be established between unicast addresses. Ignore the packet if
|
||||
// this is not the case. See
|
||||
// <https://datatracker.ietf.org/doc/html/rfc9293#section-3.9.2.3>.
|
||||
if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse the TCP header. Ignore the packet if the header is ill-formed.
|
||||
let tcp_pkt = TcpPacket::new_checked(ip_payload).ok()?;
|
||||
let tcp_repr = TcpRepr::parse(
|
||||
&tcp_pkt,
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
checksum_caps,
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
self.process_tcp_until_outgoing(ip_repr, &tcp_repr)
|
||||
.map(|(ip_repr, tcp_repr)| Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)))
|
||||
}
|
||||
|
||||
fn process_tcp_until_outgoing(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> Option<(IpRepr, TcpRepr<'static>)> {
|
||||
let (mut ip_repr, mut tcp_repr) = self.process_tcp(ip_repr, tcp_repr)?;
|
||||
|
||||
loop {
|
||||
if !self.is_unicast_local(ip_repr.dst_addr()) {
|
||||
return Some((ip_repr, tcp_repr));
|
||||
}
|
||||
|
||||
let (new_ip_repr, new_tcp_repr) = self.process_tcp(&ip_repr, &tcp_repr)?;
|
||||
ip_repr = new_ip_repr;
|
||||
tcp_repr = new_tcp_repr;
|
||||
}
|
||||
}
|
||||
|
||||
fn process_tcp(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> Option<(IpRepr, TcpRepr<'static>)> {
|
||||
for socket in self.tcp_sockets.iter() {
|
||||
if !socket.can_process(tcp_repr.dst_port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match socket.process(self.iface_cx, ip_repr, tcp_repr) {
|
||||
TcpProcessResult::NotProcessed => continue,
|
||||
TcpProcessResult::Processed => return None,
|
||||
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
|
||||
return Some((ip_repr, tcp_repr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// "In no case does receipt of a segment containing RST give rise to a RST in response."
|
||||
// See <https://datatracker.ietf.org/doc/html/rfc9293#section-4-1.64>.
|
||||
if tcp_repr.control == TcpControl::Rst {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(smoltcp::socket::tcp::Socket::rst_reply(ip_repr, tcp_repr))
|
||||
}
|
||||
|
||||
fn parse_and_process_udp<'pkt>(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
ip_payload: &'pkt [u8],
|
||||
checksum_caps: &ChecksumCapabilities,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// Parse the UDP header. Ignore the packet if the header is ill-formed.
|
||||
let udp_pkt = UdpPacket::new_checked(ip_payload).ok()?;
|
||||
let udp_repr = UdpRepr::parse(
|
||||
&udp_pkt,
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
checksum_caps,
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
if !self.process_udp(ip_repr, &udp_repr, udp_pkt.payload()) {
|
||||
return self.generate_icmp_unreachable(
|
||||
ip_repr,
|
||||
ip_payload,
|
||||
Icmpv4DstUnreachable::PortUnreachable,
|
||||
);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool {
|
||||
let mut processed = false;
|
||||
|
||||
for socket in self.udp_sockets.iter() {
|
||||
if !socket.can_process(udp_repr.dst_port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload);
|
||||
if processed && ip_repr.dst_addr().is_unicast() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
processed
|
||||
}
|
||||
|
||||
fn generate_icmp_unreachable<'pkt>(
|
||||
&self,
|
||||
ip_repr: &IpRepr,
|
||||
ip_payload: &'pkt [u8],
|
||||
reason: Icmpv4DstUnreachable,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self.is_unicast_local(ip_repr.src_addr()) {
|
||||
// In this case, the generating ICMP message will have a local IP address as the
|
||||
// destination. However, since we don't have the ability to handle ICMP messages, we'll
|
||||
// just skip the generation.
|
||||
//
|
||||
// TODO: Generate the ICMP message here once we're able to handle incoming ICMP
|
||||
// messages.
|
||||
return None;
|
||||
}
|
||||
|
||||
let IpRepr::Ipv4(ipv4_repr) = ip_repr;
|
||||
|
||||
let reply_len = icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, IPV4_HEADER_LEN);
|
||||
let icmp_repr = Icmpv4Repr::DstUnreachable {
|
||||
reason,
|
||||
header: *ipv4_repr,
|
||||
data: &ip_payload[..reply_len],
|
||||
};
|
||||
|
||||
Some(Packet::new_ipv4(
|
||||
Ipv4Repr {
|
||||
src_addr: self
|
||||
.iface_cx
|
||||
.ipv4_addr()
|
||||
.unwrap_or(Ipv4Address::UNSPECIFIED),
|
||||
dst_addr: ipv4_repr.src_addr,
|
||||
next_header: IpProtocol::Icmp,
|
||||
payload_len: icmp_repr.buffer_len(),
|
||||
hop_limit: 64,
|
||||
},
|
||||
IpPayload::Icmpv4(icmp_repr),
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns whether the destination address is the unicast address of a local interface.
|
||||
///
|
||||
/// Note: "local" means that the IP address belongs to the local interface, not to be confused
|
||||
/// with the localhost IP (127.0.0.1).
|
||||
fn is_unicast_local(&self, dst_addr: IpAddress) -> bool {
|
||||
match dst_addr {
|
||||
IpAddress::Ipv4(dst_addr) => self
|
||||
.iface_cx
|
||||
.ipv4_addr()
|
||||
.is_some_and(|addr| addr == dst_addr),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E> PollContext<'a, E> {
|
||||
pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, mut dispatch_phy: Q)
|
||||
where
|
||||
D: Device + ?Sized,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
while let Some(tx_token) = device.transmit(self.iface_cx.now()) {
|
||||
if !self.dispatch_ipv4(tx_token, &mut dispatch_phy) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_ipv4<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> bool
|
||||
where
|
||||
T: TxToken,
|
||||
Q: FnMut(&Packet, &mut Context, T),
|
||||
{
|
||||
let (did_something_tcp, tx_token) = self.dispatch_tcp(tx_token, dispatch_phy);
|
||||
|
||||
let Some(tx_token) = tx_token else {
|
||||
return did_something_tcp;
|
||||
};
|
||||
|
||||
let (did_something_udp, _tx_token) = self.dispatch_udp(tx_token, dispatch_phy);
|
||||
|
||||
did_something_tcp || did_something_udp
|
||||
}
|
||||
|
||||
fn dispatch_tcp<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option<T>)
|
||||
where
|
||||
T: TxToken,
|
||||
Q: FnMut(&Packet, &mut Context, T),
|
||||
{
|
||||
let mut tx_token = Some(tx_token);
|
||||
let mut did_something = false;
|
||||
|
||||
for socket in self.tcp_sockets.iter() {
|
||||
if !socket.need_dispatch(self.iface_cx.now()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We set `did_something` even if no packets are actually generated. This is because a
|
||||
// timer can expire, but no packets are actually generated.
|
||||
did_something = true;
|
||||
|
||||
let mut deferred = None;
|
||||
|
||||
let reply = socket.dispatch(self.iface_cx, |cx, ip_repr, tcp_repr| {
|
||||
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets);
|
||||
|
||||
if !this.is_unicast_local(ip_repr.dst_addr()) {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)),
|
||||
this.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if !socket.can_process(tcp_repr.dst_port) {
|
||||
return this.process_tcp(ip_repr, tcp_repr);
|
||||
}
|
||||
|
||||
// We cannot call `process_tcp` now because it may cause deadlocks. We will copy
|
||||
// the packet and call `process_tcp` after releasing the socket lock.
|
||||
deferred = Some((ip_repr.clone(), {
|
||||
let mut data = vec![0; tcp_repr.buffer_len()];
|
||||
tcp_repr.emit(
|
||||
&mut TcpPacket::new_unchecked(data.as_mut_slice()),
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
&ChecksumCapabilities::ignored(),
|
||||
);
|
||||
data
|
||||
}));
|
||||
|
||||
None
|
||||
});
|
||||
|
||||
match (deferred, reply) {
|
||||
(None, None) => (),
|
||||
(Some((ip_repr, ip_payload)), None) => {
|
||||
if let Some(reply) = self.parse_and_process_tcp(
|
||||
&ip_repr,
|
||||
&ip_payload,
|
||||
&ChecksumCapabilities::ignored(),
|
||||
) {
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
|
||||
}
|
||||
}
|
||||
(None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)),
|
||||
self.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
}
|
||||
(None, Some((ip_repr, tcp_repr))) => {
|
||||
if let Some((new_ip_repr, new_tcp_repr)) =
|
||||
self.process_tcp_until_outgoing(&ip_repr, &tcp_repr)
|
||||
{
|
||||
dispatch_phy(
|
||||
&Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)),
|
||||
self.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
(Some(_), Some(_)) => unreachable!(),
|
||||
}
|
||||
|
||||
if tx_token.is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(did_something, tx_token)
|
||||
}
|
||||
|
||||
fn dispatch_udp<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option<T>)
|
||||
where
|
||||
T: TxToken,
|
||||
Q: FnMut(&Packet, &mut Context, T),
|
||||
{
|
||||
let mut tx_token = Some(tx_token);
|
||||
let mut did_something = false;
|
||||
|
||||
for socket in self.udp_sockets.iter() {
|
||||
if !socket.need_dispatch(self.iface_cx.now()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We set `did_something` even if no packets are actually generated. This is because a
|
||||
// timer can expire, but no packets are actually generated.
|
||||
did_something = true;
|
||||
|
||||
let mut deferred = None;
|
||||
|
||||
socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| {
|
||||
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets);
|
||||
|
||||
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)),
|
||||
this.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
if !ip_repr.dst_addr().is_broadcast() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if !socket.can_process(udp_repr.dst_port) {
|
||||
// TODO: Generate the ICMP message here once we're able to handle incoming ICMP
|
||||
// messages.
|
||||
let _ = this.process_udp(ip_repr, udp_repr, udp_payload);
|
||||
return;
|
||||
}
|
||||
|
||||
// We cannot call `process_udp` now because it may cause deadlocks. We will copy
|
||||
// the packet and call `process_udp` after releasing the socket lock.
|
||||
deferred = Some((ip_repr.clone(), {
|
||||
let mut data = vec![0; udp_repr.header_len() + udp_payload.len()];
|
||||
udp_repr.emit(
|
||||
&mut UdpPacket::new_unchecked(&mut data),
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
udp_payload.len(),
|
||||
|payload| payload.copy_from_slice(udp_payload),
|
||||
&ChecksumCapabilities::ignored(),
|
||||
);
|
||||
data
|
||||
}));
|
||||
});
|
||||
|
||||
if let Some((ip_repr, ip_payload)) = deferred {
|
||||
if let Some(reply) = self.parse_and_process_udp(
|
||||
&ip_repr,
|
||||
&ip_payload,
|
||||
&ChecksumCapabilities::ignored(),
|
||||
) {
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
if tx_token.is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(did_something, tx_token)
|
||||
}
|
||||
}
|
@ -1,44 +1,188 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::{Arc, Weak};
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use core::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
};
|
||||
|
||||
use ostd::sync::RwLock;
|
||||
use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
socket::tcp::ConnectError,
|
||||
wire::{IpAddress, IpEndpoint},
|
||||
iface::Context,
|
||||
socket::{udp::UdpMetadata, PollAt},
|
||||
time::Instant,
|
||||
wire::{IpAddress, IpEndpoint, IpRepr, TcpRepr, UdpRepr},
|
||||
};
|
||||
|
||||
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
|
||||
use crate::iface::Iface;
|
||||
|
||||
pub(crate) enum SocketFamily {
|
||||
Tcp,
|
||||
Udp,
|
||||
pub struct BoundSocket<T: AnySocket, E>(Arc<BoundSocketInner<T, E>>);
|
||||
|
||||
/// [`TcpSocket`] or [`UdpSocket`].
|
||||
pub trait AnySocket {
|
||||
type RawSocket;
|
||||
|
||||
/// Called by [`BoundSocket::new`].
|
||||
fn new(socket: Box<Self::RawSocket>) -> Self;
|
||||
|
||||
/// Called by [`BoundSocket::drop`].
|
||||
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>)
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub struct AnyBoundSocket<E>(Arc<AnyBoundSocketInner<E>>);
|
||||
pub type BoundTcpSocket<E> = BoundSocket<TcpSocket, E>;
|
||||
pub type BoundUdpSocket<E> = BoundSocket<UdpSocket, E>;
|
||||
|
||||
impl<E> AnyBoundSocket<E> {
|
||||
/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`].
|
||||
pub struct BoundSocketInner<T, E> {
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
port: u16,
|
||||
socket: T,
|
||||
observer: RwLock<Weak<dyn SocketEventObserver>>,
|
||||
next_poll_at_ms: AtomicU64,
|
||||
has_new_events: AtomicBool,
|
||||
}
|
||||
|
||||
/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`].
|
||||
pub struct TcpSocket {
|
||||
socket: SpinLock<RawTcpSocketExt, LocalIrqDisabled>,
|
||||
is_dead: AtomicBool,
|
||||
}
|
||||
|
||||
struct RawTcpSocketExt {
|
||||
socket: Box<RawTcpSocket>,
|
||||
/// Whether the socket is in the background.
|
||||
///
|
||||
/// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means
|
||||
/// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a
|
||||
/// state of waiting for certain network events (e.g., remote FIN/ACK packets), so
|
||||
/// [`BoundSocketInner`] may still be alive for a while.
|
||||
in_background: bool,
|
||||
}
|
||||
|
||||
impl Deref for RawTcpSocketExt {
|
||||
type Target = RawTcpSocket;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.socket
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RawTcpSocketExt {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.socket
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpSocket {
|
||||
fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> {
|
||||
self.socket.lock()
|
||||
}
|
||||
|
||||
/// Returns whether the TCP socket is dead.
|
||||
///
|
||||
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
|
||||
fn is_dead(&self) -> bool {
|
||||
self.is_dead.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Updates whether the TCP socket is dead.
|
||||
///
|
||||
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
|
||||
///
|
||||
/// This method must be called after handling network events. However, it is not necessary to
|
||||
/// call this method after handling non-closing user events, because the socket can never be
|
||||
/// dead if user events can reach the socket.
|
||||
fn update_dead(&self, socket: &RawTcpSocketExt) {
|
||||
if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed {
|
||||
self.is_dead.store(true, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnySocket for TcpSocket {
|
||||
type RawSocket = RawTcpSocket;
|
||||
|
||||
fn new(socket: Box<Self::RawSocket>) -> Self {
|
||||
let socket_ext = RawTcpSocketExt {
|
||||
socket,
|
||||
in_background: false,
|
||||
};
|
||||
|
||||
Self {
|
||||
socket: SpinLock::new(socket_ext),
|
||||
is_dead: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) {
|
||||
let mut socket = this.socket.lock();
|
||||
|
||||
socket.in_background = true;
|
||||
socket.close();
|
||||
|
||||
// A TCP socket may not be appropriate for immediate removal. We leave the removal decision
|
||||
// to the polling logic.
|
||||
this.update_next_poll_at_ms(PollAt::Now);
|
||||
this.socket.update_dead(&socket);
|
||||
}
|
||||
}
|
||||
|
||||
/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`].
|
||||
type UdpSocket = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
|
||||
|
||||
impl AnySocket for UdpSocket {
|
||||
type RawSocket = RawUdpSocket;
|
||||
|
||||
fn new(socket: Box<Self::RawSocket>) -> Self {
|
||||
Self::new(socket)
|
||||
}
|
||||
|
||||
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) {
|
||||
this.socket.lock().close();
|
||||
|
||||
// A UDP socket can be removed immediately.
|
||||
this.iface.common().remove_udp_socket(this);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AnySocket, E> Drop for BoundSocket<T, E> {
|
||||
fn drop(&mut self) {
|
||||
T::on_drop(&self.0);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type BoundTcpSocketInner<E> = BoundSocketInner<TcpSocket, E>;
|
||||
pub(crate) type BoundUdpSocketInner<E> = BoundSocketInner<UdpSocket, E>;
|
||||
|
||||
impl<T: AnySocket, E> BoundSocket<T, E> {
|
||||
pub(crate) fn new(
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
handle: smoltcp::iface::SocketHandle,
|
||||
port: u16,
|
||||
socket_family: SocketFamily,
|
||||
socket: Box<T::RawSocket>,
|
||||
observer: Weak<dyn SocketEventObserver>,
|
||||
) -> Self {
|
||||
Self(Arc::new(AnyBoundSocketInner {
|
||||
Self(Arc::new(BoundSocketInner {
|
||||
iface,
|
||||
handle,
|
||||
port,
|
||||
socket_family,
|
||||
socket: T::new(socket),
|
||||
observer: RwLock::new(observer),
|
||||
next_poll_at_ms: AtomicU64::new(u64::MAX),
|
||||
has_new_events: AtomicBool::new(false),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> &Arc<AnyBoundSocketInner<E>> {
|
||||
pub(crate) fn inner(&self) -> &Arc<BoundSocketInner<T, E>> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AnySocket, E> BoundSocket<T, E> {
|
||||
/// Sets the observer whose `on_events` will be called when certain iface events happen. After
|
||||
/// setting, the new observer will fire once immediately to avoid missing any events.
|
||||
///
|
||||
@ -51,6 +195,15 @@ impl<E> AnyBoundSocket<E> {
|
||||
self.0.on_iface_events();
|
||||
}
|
||||
|
||||
/// Returns the observer.
|
||||
///
|
||||
/// See also [`Self::set_observer`].
|
||||
pub fn observer(&self) -> Weak<dyn SocketEventObserver> {
|
||||
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
|
||||
// get the read lock.
|
||||
self.0.observer.read().clone()
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Option<IpEndpoint> {
|
||||
let ip_addr = {
|
||||
let ipv4_addr = self.0.iface.ipv4_addr()?;
|
||||
@ -59,58 +212,155 @@ impl<E> AnyBoundSocket<E> {
|
||||
Some(IpEndpoint::new(ip_addr, self.0.port))
|
||||
}
|
||||
|
||||
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
|
||||
&self,
|
||||
f: F,
|
||||
) -> R {
|
||||
self.0.raw_with(f)
|
||||
}
|
||||
|
||||
/// Connects to a remote endpoint.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method will panic if the socket is not a TCP socket.
|
||||
pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<(), ConnectError> {
|
||||
let common = self.iface().common();
|
||||
|
||||
let mut sockets = common.sockets();
|
||||
let socket = sockets.get_mut::<RawTcpSocket>(self.0.handle);
|
||||
|
||||
let mut iface = common.interface();
|
||||
let cx = iface.context();
|
||||
|
||||
socket.connect(cx, remote_endpoint, self.0.port)
|
||||
}
|
||||
|
||||
pub fn iface(&self) -> &Arc<dyn Iface<E>> {
|
||||
&self.0.iface
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Drop for AnyBoundSocket<E> {
|
||||
fn drop(&mut self) {
|
||||
if self.0.start_closing() {
|
||||
self.0.iface.common().remove_bound_socket_now(&self.0);
|
||||
} else {
|
||||
self.0
|
||||
.iface
|
||||
.common()
|
||||
.remove_bound_socket_when_closed(&self.0);
|
||||
}
|
||||
impl<E> BoundTcpSocket<E> {
|
||||
/// Connects to a remote endpoint.
|
||||
pub fn connect(
|
||||
&self,
|
||||
remote_endpoint: IpEndpoint,
|
||||
) -> Result<(), smoltcp::socket::tcp::ConnectError> {
|
||||
let common = self.iface().common();
|
||||
let mut iface = common.interface();
|
||||
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let result = socket.connect(iface.context(), remote_endpoint, self.0.port);
|
||||
self.0
|
||||
.update_next_poll_at_ms(socket.poll_at(iface.context()));
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Listens at a specified endpoint.
|
||||
pub fn listen(
|
||||
&self,
|
||||
local_endpoint: IpEndpoint,
|
||||
) -> Result<(), smoltcp::socket::tcp::ListenError> {
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
socket.listen(local_endpoint)
|
||||
}
|
||||
|
||||
pub fn send<F, R>(&self, f: F) -> Result<R, smoltcp::socket::tcp::SendError>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> (usize, R),
|
||||
{
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let result = socket.send(f);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn recv<F, R>(&self, f: F) -> Result<R, smoltcp::socket::tcp::RecvError>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> (usize, R),
|
||||
{
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let result = socket.recv(f);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
socket.close();
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
}
|
||||
|
||||
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
|
||||
//
|
||||
// NOTE: If a mutable reference is required, add a method above that correctly updates the next
|
||||
// polling time.
|
||||
pub fn raw_with<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&RawTcpSocket) -> R,
|
||||
{
|
||||
let socket = self.0.socket.lock();
|
||||
f(&socket)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AnyBoundSocketInner<E> {
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
handle: smoltcp::iface::SocketHandle,
|
||||
port: u16,
|
||||
socket_family: SocketFamily,
|
||||
observer: RwLock<Weak<dyn SocketEventObserver>>,
|
||||
impl<E> BoundUdpSocket<E> {
|
||||
/// Binds to a specified endpoint.
|
||||
pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> {
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
socket.bind(local_endpoint)
|
||||
}
|
||||
|
||||
pub fn send<F, R>(
|
||||
&self,
|
||||
size: usize,
|
||||
meta: impl Into<UdpMetadata>,
|
||||
f: F,
|
||||
) -> Result<R, crate::errors::udp::SendError>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> R,
|
||||
{
|
||||
use smoltcp::socket::udp::SendError as SendErrorInner;
|
||||
|
||||
use crate::errors::udp::SendError;
|
||||
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
if size > socket.packet_send_capacity() {
|
||||
return Err(SendError::TooLarge);
|
||||
}
|
||||
|
||||
let buffer = match socket.send(size, meta) {
|
||||
Ok(data) => data,
|
||||
Err(SendErrorInner::Unaddressable) => return Err(SendError::Unaddressable),
|
||||
Err(SendErrorInner::BufferFull) => return Err(SendError::BufferFull),
|
||||
};
|
||||
let result = f(buffer);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn recv<F, R>(&self, f: F) -> Result<R, smoltcp::socket::udp::RecvError>
|
||||
where
|
||||
F: FnOnce(&[u8], UdpMetadata) -> R,
|
||||
{
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let (data, meta) = socket.recv()?;
|
||||
let result = f(data, meta);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Calls `f` with an immutable reference to the associated [`RawUdpSocket`].
|
||||
//
|
||||
// NOTE: If a mutable reference is required, add a method above that correctly updates the next
|
||||
// polling time.
|
||||
pub fn raw_with<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&RawUdpSocket) -> R,
|
||||
{
|
||||
let socket = self.0.socket.lock();
|
||||
f(&socket)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> AnyBoundSocketInner<E> {
|
||||
impl<T, E> BoundSocketInner<T, E> {
|
||||
pub(crate) fn has_new_events(&self) -> bool {
|
||||
self.has_new_events.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub(crate) fn on_iface_events(&self) {
|
||||
self.has_new_events.store(false, Ordering::Relaxed);
|
||||
|
||||
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
|
||||
// get the read lock.
|
||||
let observer = Weak::upgrade(&*self.observer.read());
|
||||
@ -120,51 +370,173 @@ impl<E> AnyBoundSocketInner<E> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
match self.socket_family {
|
||||
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.state() == smoltcp::socket::tcp::State::Closed
|
||||
}),
|
||||
SocketFamily::Udp => true,
|
||||
}
|
||||
/// Returns the next polling time.
|
||||
///
|
||||
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
|
||||
/// before new network or user events.
|
||||
pub(crate) fn next_poll_at_ms(&self) -> u64 {
|
||||
self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Starts closing the socket and returns whether the socket is closed.
|
||||
/// Updates the next polling time according to `poll_at`.
|
||||
///
|
||||
/// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets,
|
||||
/// this method will always return `true`.
|
||||
///
|
||||
/// For other sockets, such as TCP connected sockets, they cannot be closed immediately because
|
||||
/// we at least need to send the FIN packet and wait for the remote end to send an ACK packet.
|
||||
/// In this case, this method will return `false` and [`Self::is_closed`] can be used to
|
||||
/// determine if the closing process is complete.
|
||||
fn start_closing(&self) -> bool {
|
||||
match self.socket_family {
|
||||
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.close();
|
||||
socket.state() == smoltcp::socket::tcp::State::Closed
|
||||
}),
|
||||
SocketFamily::Udp => {
|
||||
self.raw_with(|socket: &mut RawUdpSocket| socket.close());
|
||||
true
|
||||
}
|
||||
/// The update is typically needed after new network or user events have been handled, so this
|
||||
/// method also marks that there may be new events, so that the event observer provided by
|
||||
/// [`BoundSocket::set_observer`] can be notified later.
|
||||
fn update_next_poll_at_ms(&self, poll_at: PollAt) {
|
||||
self.has_new_events.store(true, Ordering::Relaxed);
|
||||
|
||||
match poll_at {
|
||||
PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed),
|
||||
PollAt::Time(instant) => self
|
||||
.next_poll_at_ms
|
||||
.store(instant.total_millis() as u64, Ordering::Relaxed),
|
||||
PollAt::Ingress => self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
|
||||
impl<T, E> BoundSocketInner<T, E> {
|
||||
pub(crate) fn port(&self) -> u16 {
|
||||
self.port
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> BoundTcpSocketInner<E> {
|
||||
/// Returns whether the TCP socket is dead.
|
||||
///
|
||||
/// A TCP socket is considered dead if and only if the following two conditions are met:
|
||||
/// 1. The TCP connection is closed, so this socket cannot process any network events.
|
||||
/// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this
|
||||
/// [`BoundSocketInner`] is in background and no more user events can reach it.
|
||||
pub(crate) fn is_dead(&self) -> bool {
|
||||
self.socket.is_dead()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E> BoundSocketInner<T, E> {
|
||||
/// Returns whether an incoming packet _may_ be processed by the socket.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
|
||||
self.port == dst_port
|
||||
}
|
||||
|
||||
/// Returns whether the socket _may_ generate an outgoing packet.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
|
||||
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub(crate) enum TcpProcessResult {
|
||||
NotProcessed,
|
||||
Processed,
|
||||
ProcessedWithReply(IpRepr, TcpRepr<'static>),
|
||||
}
|
||||
|
||||
impl<E> BoundTcpSocketInner<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
&self,
|
||||
mut f: F,
|
||||
) -> R {
|
||||
let mut sockets = self.iface.common().sockets();
|
||||
let socket = sockets.get_mut::<T>(self.handle);
|
||||
f(socket)
|
||||
cx: &mut Context,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> TcpProcessResult {
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
return TcpProcessResult::NotProcessed;
|
||||
}
|
||||
|
||||
let result = match socket.process(cx, ip_repr, tcp_repr) {
|
||||
None => TcpProcessResult::Processed,
|
||||
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
|
||||
};
|
||||
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
self.socket.update_dead(&socket);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Tries to generate an outgoing packet and dispatches the generated packet.
|
||||
pub(crate) fn dispatch<D>(
|
||||
&self,
|
||||
cx: &mut Context,
|
||||
dispatch: D,
|
||||
) -> Option<(IpRepr, TcpRepr<'static>)>
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
|
||||
{
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
let mut reply = None;
|
||||
socket
|
||||
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
|
||||
reply = dispatch(cx, &ip_repr, &tcp_repr);
|
||||
Ok::<(), ()>(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// `dispatch` can return a packet in response to the generated packet. If the socket
|
||||
// accepts the packet, we can process it directly.
|
||||
while let Some((ref ip_repr, ref tcp_repr)) = reply {
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
break;
|
||||
}
|
||||
reply = socket.process(cx, ip_repr, tcp_repr);
|
||||
}
|
||||
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
self.socket.update_dead(&socket);
|
||||
|
||||
reply
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Drop for AnyBoundSocketInner<E> {
|
||||
fn drop(&mut self) {
|
||||
let iface_common = self.iface.common();
|
||||
iface_common.remove_socket(self.handle);
|
||||
iface_common.release_port(self.port);
|
||||
impl<E> BoundUdpSocketInner<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
&self,
|
||||
cx: &mut Context,
|
||||
ip_repr: &IpRepr,
|
||||
udp_repr: &UdpRepr,
|
||||
udp_payload: &[u8],
|
||||
) -> bool {
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, udp_repr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
socket.process(
|
||||
cx,
|
||||
smoltcp::phy::PacketMeta::default(),
|
||||
ip_repr,
|
||||
udp_repr,
|
||||
udp_payload,
|
||||
);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Tries to generate an outgoing packet and dispatches the generated packet.
|
||||
pub(crate) fn dispatch<D>(&self, cx: &mut Context, dispatch: D)
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
|
||||
{
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
socket
|
||||
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
|
||||
dispatch(cx, &ip_repr, &udp_repr, udp_payload);
|
||||
Ok::<(), ()>(())
|
||||
})
|
||||
.unwrap();
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
}
|
||||
}
|
||||
|
@ -4,12 +4,11 @@ mod bound;
|
||||
mod event;
|
||||
mod unbound;
|
||||
|
||||
pub use bound::AnyBoundSocket;
|
||||
pub(crate) use bound::{AnyBoundSocketInner, SocketFamily};
|
||||
pub use bound::{BoundTcpSocket, BoundUdpSocket};
|
||||
pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
|
||||
pub use event::SocketEventObserver;
|
||||
pub(crate) use unbound::AnyRawSocket;
|
||||
pub use unbound::{
|
||||
AnyUnboundSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
|
||||
UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
|
||||
UDP_SEND_PAYLOAD_LEN,
|
||||
};
|
||||
|
||||
|
@ -1,34 +1,33 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{sync::Weak, vec};
|
||||
use alloc::{boxed::Box, sync::Weak, vec};
|
||||
|
||||
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
|
||||
|
||||
pub struct AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket,
|
||||
pub struct UnboundSocket<T> {
|
||||
socket: Box<T>,
|
||||
observer: Weak<dyn SocketEventObserver>,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub(crate) enum AnyRawSocket {
|
||||
Tcp(RawTcpSocket),
|
||||
Udp(RawUdpSocket),
|
||||
}
|
||||
pub type UnboundTcpSocket = UnboundSocket<RawTcpSocket>;
|
||||
pub type UnboundUdpSocket = UnboundSocket<RawUdpSocket>;
|
||||
|
||||
impl AnyUnboundSocket {
|
||||
pub fn new_tcp(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
impl UnboundTcpSocket {
|
||||
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
let raw_tcp_socket = {
|
||||
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]);
|
||||
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]);
|
||||
RawTcpSocket::new(rx_buffer, tx_buffer)
|
||||
};
|
||||
AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
|
||||
Self {
|
||||
socket: Box::new(raw_tcp_socket),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_udp(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
impl UnboundUdpSocket {
|
||||
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
let raw_udp_socket = {
|
||||
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
|
||||
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
|
||||
@ -41,14 +40,16 @@ impl AnyUnboundSocket {
|
||||
);
|
||||
RawUdpSocket::new(rx_buffer, tx_buffer)
|
||||
};
|
||||
AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket::Udp(raw_udp_socket),
|
||||
Self {
|
||||
socket: Box::new(raw_udp_socket),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_raw(self) -> (AnyRawSocket, Weak<dyn SocketEventObserver>) {
|
||||
(self.socket_family, self.observer)
|
||||
impl<T> UnboundSocket<T> {
|
||||
pub(crate) fn into_raw(self) -> (Box<T>, Weak<dyn SocketEventObserver>) {
|
||||
(self.socket, self.observer)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user