mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-19 12:36:46 +00:00
Move packet dispatch out of smoltcp
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
f793259512
commit
ee1656ba35
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1383,7 +1383,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||
[[package]]
|
||||
name = "smoltcp"
|
||||
version = "0.11.0"
|
||||
source = "git+https://github.com/smoltcp-rs/smoltcp?rev=469dccd#469dccdbaece66696494f8540615152fc5123b17"
|
||||
source = "git+https://github.com/asterinas/smoltcp?rev=37716bf#37716bff5ed5b16aba1b8a37e788ee5a6bf32cab"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"byteorder",
|
||||
|
@ -8,7 +8,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
keyable-arc = { path = "../keyable-arc" }
|
||||
ostd = { path = "../../../ostd" }
|
||||
smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", rev = "469dccd", default-features = false, features = [
|
||||
smoltcp = { git = "https://github.com/asterinas/smoltcp", rev = "37716bf", default-features = false, features = [
|
||||
"alloc",
|
||||
"log",
|
||||
"medium-ethernet",
|
||||
|
@ -14,5 +14,15 @@ pub mod tcp {
|
||||
}
|
||||
|
||||
pub mod udp {
|
||||
pub use smoltcp::socket::udp::{RecvError, SendError};
|
||||
pub use smoltcp::socket::udp::RecvError;
|
||||
|
||||
/// An error returned by [`BoundTcpSocket::recv`].
|
||||
///
|
||||
/// [`BoundTcpSocket::recv`]: crate::socket::BoundTcpSocket::recv
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum SendError {
|
||||
TooLarge,
|
||||
Unaddressable,
|
||||
BufferFull,
|
||||
}
|
||||
}
|
||||
|
@ -7,42 +7,45 @@ use alloc::{
|
||||
btree_set::BTreeSet,
|
||||
},
|
||||
sync::Arc,
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use keyable_arc::KeyableArc;
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLock, SpinLock, SpinLockGuard};
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
iface::{SocketHandle, SocketSet},
|
||||
iface::{packet::Packet, Context},
|
||||
phy::Device,
|
||||
wire::Ipv4Address,
|
||||
wire::{Ipv4Address, Ipv4Packet},
|
||||
};
|
||||
|
||||
use super::{port::BindPortConfig, time::get_network_timestamp, Iface};
|
||||
use super::{
|
||||
poll::{FnHelper, PollContext},
|
||||
port::BindPortConfig,
|
||||
time::get_network_timestamp,
|
||||
Iface,
|
||||
};
|
||||
use crate::{
|
||||
errors::BindError,
|
||||
socket::{AnyBoundSocket, AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
|
||||
socket::{
|
||||
BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket,
|
||||
UnboundUdpSocket,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct IfaceCommon<E> {
|
||||
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
|
||||
sockets: SpinLock<SocketSet<'static>, LocalIrqDisabled>,
|
||||
used_ports: SpinLock<BTreeMap<u16, usize>, PreemptDisabled>,
|
||||
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
|
||||
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>, LocalIrqDisabled>,
|
||||
tcp_sockets: SpinLock<BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, LocalIrqDisabled>,
|
||||
udp_sockets: SpinLock<BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, LocalIrqDisabled>,
|
||||
ext: E,
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
pub(super) fn new(interface: smoltcp::iface::Interface, ext: E) -> Self {
|
||||
let socket_set = SocketSet::new(Vec::new());
|
||||
let used_ports = BTreeMap::new();
|
||||
Self {
|
||||
interface: SpinLock::new(interface),
|
||||
sockets: SpinLock::new(socket_set),
|
||||
used_ports: SpinLock::new(used_ports),
|
||||
bound_sockets: RwLock::new(BTreeSet::new()),
|
||||
closing_sockets: SpinLock::new(BTreeSet::new()),
|
||||
used_ports: SpinLock::new(BTreeMap::new()),
|
||||
tcp_sockets: SpinLock::new(BTreeSet::new()),
|
||||
udp_sockets: SpinLock::new(BTreeSet::new()),
|
||||
ext,
|
||||
}
|
||||
}
|
||||
@ -58,58 +61,57 @@ impl<E> IfaceCommon<E> {
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
/// Acquires the lock to the interface.
|
||||
///
|
||||
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
|
||||
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
|
||||
self.interface.lock()
|
||||
}
|
||||
|
||||
/// Acuqires the lock to the sockets.
|
||||
///
|
||||
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
|
||||
pub(crate) fn sockets(
|
||||
&self,
|
||||
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
|
||||
self.sockets.lock()
|
||||
}
|
||||
}
|
||||
|
||||
const IP_LOCAL_PORT_START: u16 = 32768;
|
||||
const IP_LOCAL_PORT_END: u16 = 60999;
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
pub(super) fn bind_socket(
|
||||
pub(super) fn bind_tcp(
|
||||
&self,
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
socket: Box<AnyUnboundSocket>,
|
||||
socket: Box<UnboundTcpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Some(port) => port,
|
||||
None => return Err((BindError::Exhausted, socket)),
|
||||
}
|
||||
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
|
||||
let port = match self.bind_port(config) {
|
||||
Ok(port) => port,
|
||||
Err(err) => return Err((err, socket)),
|
||||
};
|
||||
if !self.bind_port(port, config.can_reuse()) {
|
||||
return Err((BindError::InUse, socket));
|
||||
}
|
||||
|
||||
let (handle, socket_family, observer) = match socket.into_raw() {
|
||||
(AnyRawSocket::Tcp(tcp_socket), observer) => (
|
||||
self.sockets.lock().add(tcp_socket),
|
||||
SocketFamily::Tcp,
|
||||
observer,
|
||||
),
|
||||
(AnyRawSocket::Udp(udp_socket), observer) => (
|
||||
self.sockets.lock().add(udp_socket),
|
||||
SocketFamily::Udp,
|
||||
observer,
|
||||
),
|
||||
let (raw_socket, observer) = socket.into_raw();
|
||||
let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer);
|
||||
|
||||
let inserted = self
|
||||
.tcp_sockets
|
||||
.lock()
|
||||
.insert(KeyableArc::from(bound_socket.inner().clone()));
|
||||
assert!(inserted);
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
|
||||
pub(super) fn bind_udp(
|
||||
&self,
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
socket: Box<UnboundUdpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
|
||||
let port = match self.bind_port(config) {
|
||||
Ok(port) => port,
|
||||
Err(err) => return Err((err, socket)),
|
||||
};
|
||||
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
|
||||
self.insert_bound_socket(bound_socket.inner());
|
||||
|
||||
let (raw_socket, observer) = socket.into_raw();
|
||||
let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer);
|
||||
|
||||
let inserted = self
|
||||
.udp_sockets
|
||||
.lock()
|
||||
.insert(KeyableArc::from(bound_socket.inner().clone()));
|
||||
assert!(inserted);
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
@ -130,36 +132,57 @@ impl<E> IfaceCommon<E> {
|
||||
None
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn bind_port(&self, port: u16, can_reuse: bool) -> bool {
|
||||
fn bind_port(&self, config: BindPortConfig) -> Result<u16, BindError> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Some(port) => port,
|
||||
None => return Err(BindError::Exhausted),
|
||||
}
|
||||
};
|
||||
|
||||
let mut used_ports = self.used_ports.lock();
|
||||
|
||||
if let Some(used_times) = used_ports.get_mut(&port) {
|
||||
if *used_times == 0 || can_reuse {
|
||||
if *used_times == 0 || config.can_reuse() {
|
||||
// FIXME: Check if the previous socket was bound with SO_REUSEADDR.
|
||||
*used_times += 1;
|
||||
} else {
|
||||
return false;
|
||||
return Err(BindError::InUse);
|
||||
}
|
||||
} else {
|
||||
used_ports.insert(port, 1);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let inserted = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
Ok(port)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>) {
|
||||
sockets.retain(|socket| {
|
||||
if socket.is_dead() {
|
||||
self.release_port(socket.port());
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn remove_udp_socket(&self, socket: &Arc<BoundUdpSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self.udp_sockets.lock().remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
self.release_port(keyable_socket.port());
|
||||
}
|
||||
|
||||
/// Releases the port so that it can be used again (if it is not being reused).
|
||||
pub(crate) fn release_port(&self, port: u16) {
|
||||
fn release_port(&self, port: u16) {
|
||||
let mut used_ports = self.used_ports.lock();
|
||||
if let Some(used_times) = used_ports.remove(&port) {
|
||||
if used_times != 1 {
|
||||
@ -167,96 +190,63 @@ impl<E> IfaceCommon<E> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a socket from the interface.
|
||||
pub(crate) fn remove_socket(&self, handle: SocketHandle) {
|
||||
self.sockets.lock().remove(handle);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
let mut closing_sockets = self.closing_sockets.lock();
|
||||
|
||||
// Check `is_closed` after holding the lock to avoid race conditions.
|
||||
if keyable_socket.is_closed() {
|
||||
return;
|
||||
}
|
||||
|
||||
let inserted = closing_sockets.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
#[must_use]
|
||||
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> {
|
||||
let mut sockets = self.sockets.lock();
|
||||
let mut interface = self.interface.lock();
|
||||
pub(super) fn poll<D, P, Q>(
|
||||
&self,
|
||||
device: &mut D,
|
||||
process_phy: P,
|
||||
mut dispatch_phy: Q,
|
||||
) -> Option<u64>
|
||||
where
|
||||
D: Device + ?Sized,
|
||||
P: for<'pkt, 'cx, 'tx> FnHelper<
|
||||
&'pkt [u8],
|
||||
&'cx mut Context,
|
||||
D::TxToken<'tx>,
|
||||
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
|
||||
>,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
let mut interface = self.interface();
|
||||
interface.context().now = get_network_timestamp();
|
||||
|
||||
let timestamp = get_network_timestamp();
|
||||
let (has_events, poll_at) = {
|
||||
let mut has_events = false;
|
||||
let mut poll_at;
|
||||
let mut tcp_sockets = self.tcp_sockets.lock();
|
||||
let udp_sockets = self.udp_sockets.lock();
|
||||
|
||||
loop {
|
||||
// `poll` transmits and receives a bounded number of packets. This loop ensures
|
||||
// that all packets are transmitted and received. For details, see
|
||||
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L400-L405>.
|
||||
while interface.poll(timestamp, device, &mut sockets) {
|
||||
has_events = true;
|
||||
}
|
||||
let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets);
|
||||
context.poll_ingress(device, process_phy, &mut dispatch_phy);
|
||||
context.poll_egress(device, dispatch_phy);
|
||||
|
||||
// `poll_at` can return `Some(Instant::from_millis(0))`, which means `PollAt::Now`.
|
||||
// For details, see
|
||||
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L478>.
|
||||
poll_at = interface.poll_at(timestamp, &sockets);
|
||||
let Some(instant) = poll_at else {
|
||||
break;
|
||||
};
|
||||
if instant > timestamp {
|
||||
break;
|
||||
}
|
||||
tcp_sockets.iter().for_each(|socket| {
|
||||
if socket.has_new_events() {
|
||||
socket.on_iface_events();
|
||||
}
|
||||
});
|
||||
udp_sockets.iter().for_each(|socket| {
|
||||
if socket.has_new_events() {
|
||||
socket.on_iface_events();
|
||||
}
|
||||
});
|
||||
|
||||
(has_events, poll_at)
|
||||
};
|
||||
self.remove_dead_tcp_sockets(&mut tcp_sockets);
|
||||
|
||||
// drop sockets here to avoid deadlock
|
||||
drop(interface);
|
||||
drop(sockets);
|
||||
|
||||
if has_events {
|
||||
// We never try to hold the write lock in the IRQ context, and we disable IRQ when
|
||||
// holding the write lock. So we don't need to disable IRQ when holding the read lock.
|
||||
self.bound_sockets.read().iter().for_each(|bound_socket| {
|
||||
bound_socket.on_iface_events();
|
||||
});
|
||||
|
||||
let closed_sockets = self
|
||||
.closing_sockets
|
||||
.lock()
|
||||
.extract_if(|closing_socket| closing_socket.is_closed())
|
||||
.collect::<Vec<_>>();
|
||||
drop(closed_sockets);
|
||||
match (
|
||||
tcp_sockets
|
||||
.iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min(),
|
||||
udp_sockets
|
||||
.iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min(),
|
||||
) {
|
||||
(Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => {
|
||||
Some(tcp_poll_at)
|
||||
}
|
||||
(tcp_poll_at, None) => tcp_poll_at,
|
||||
(_, udp_poll_at) => udp_poll_at,
|
||||
}
|
||||
|
||||
poll_at.map(|at| smoltcp::time::Instant::total_millis(&at) as u64)
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ use smoltcp::wire::Ipv4Address;
|
||||
use super::port::BindPortConfig;
|
||||
use crate::{
|
||||
errors::BindError,
|
||||
socket::{AnyBoundSocket, AnyUnboundSocket},
|
||||
socket::{BoundTcpSocket, BoundUdpSocket, UnboundTcpSocket, UnboundUdpSocket},
|
||||
};
|
||||
|
||||
/// A network interface.
|
||||
@ -42,13 +42,22 @@ impl<E> dyn Iface<E> {
|
||||
/// FIXME: The reason for binding the socket and the iface together is because there are
|
||||
/// limitations inside smoltcp. See discussion at
|
||||
/// <https://github.com/smoltcp-rs/smoltcp/issues/779>.
|
||||
pub fn bind_socket(
|
||||
pub fn bind_tcp(
|
||||
self: &Arc<Self>,
|
||||
socket: Box<AnyUnboundSocket>,
|
||||
socket: Box<UnboundTcpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
|
||||
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
|
||||
let common = self.common();
|
||||
common.bind_socket(self.clone(), socket, config)
|
||||
common.bind_tcp(self.clone(), socket, config)
|
||||
}
|
||||
|
||||
pub fn bind_udp(
|
||||
self: &Arc<Self>,
|
||||
socket: Box<UnboundUdpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
|
||||
let common = self.common();
|
||||
common.bind_udp(self.clone(), socket, config)
|
||||
}
|
||||
|
||||
/// Gets the IPv4 address of the iface, if any.
|
||||
|
@ -4,6 +4,7 @@ mod common;
|
||||
#[allow(clippy::module_inception)]
|
||||
mod iface;
|
||||
mod phy;
|
||||
mod poll;
|
||||
mod port;
|
||||
mod time;
|
||||
|
||||
|
@ -1,10 +1,15 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use alloc::{collections::btree_map::BTreeMap, sync::Arc};
|
||||
|
||||
use ostd::sync::{LocalIrqDisabled, SpinLock};
|
||||
use smoltcp::{
|
||||
iface::Config,
|
||||
wire::{self, EthernetAddress, Ipv4Address, Ipv4Cidr},
|
||||
iface::{packet::Packet, Config, Context},
|
||||
phy::{DeviceCapabilities, TxToken},
|
||||
wire::{
|
||||
self, ArpOperation, ArpPacket, ArpRepr, EthernetAddress, EthernetFrame, EthernetProtocol,
|
||||
EthernetRepr, IpAddress, Ipv4Address, Ipv4Cidr, Ipv4Packet,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@ -14,9 +19,11 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
pub struct EtherIface<D: WithDevice, E> {
|
||||
pub struct EtherIface<D, E> {
|
||||
driver: D,
|
||||
common: IfaceCommon<E>,
|
||||
ether_addr: EthernetAddress,
|
||||
arp_table: SpinLock<BTreeMap<Ipv4Address, EthernetAddress>, LocalIrqDisabled>,
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E> EtherIface<D, E> {
|
||||
@ -45,21 +52,224 @@ impl<D: WithDevice, E> EtherIface<D, E> {
|
||||
|
||||
let common = IfaceCommon::new(interface, ext);
|
||||
|
||||
Arc::new(Self { driver, common })
|
||||
Arc::new(Self {
|
||||
driver,
|
||||
common,
|
||||
ether_addr,
|
||||
arp_table: SpinLock::new(BTreeMap::new()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E> IfaceInternal<E> for EtherIface<D, E> {
|
||||
impl<D, E> IfaceInternal<E> for EtherIface<D, E> {
|
||||
fn common(&self) -> &IfaceCommon<E> {
|
||||
&self.common
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E: Send + Sync> Iface<E> for EtherIface<D, E> {
|
||||
impl<D: WithDevice + 'static, E: Send + Sync> Iface<E> for EtherIface<D, E> {
|
||||
fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option<u64>)) {
|
||||
self.driver.with(|device| {
|
||||
let next_poll = self.common.poll(&mut *device);
|
||||
let next_poll = self.common.poll(
|
||||
&mut *device,
|
||||
|data, iface_cx, tx_token| self.process(data, iface_cx, tx_token),
|
||||
|pkt, iface_cx, tx_token| self.dispatch(pkt, iface_cx, tx_token),
|
||||
);
|
||||
schedule_next_poll(next_poll);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, E> EtherIface<D, E> {
|
||||
fn process<'pkt, T: TxToken>(
|
||||
&self,
|
||||
data: &'pkt [u8],
|
||||
iface_cx: &mut Context,
|
||||
tx_token: T,
|
||||
) -> Option<(Ipv4Packet<&'pkt [u8]>, T)> {
|
||||
match self.parse_ip_or_process_arp(data, iface_cx) {
|
||||
Ok(pkt) => Some((pkt, tx_token)),
|
||||
Err(Some(arp)) => {
|
||||
Self::emit_arp(&arp, tx_token);
|
||||
None
|
||||
}
|
||||
Err(None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ip_or_process_arp<'pkt>(
|
||||
&self,
|
||||
data: &'pkt [u8],
|
||||
iface_cx: &mut Context,
|
||||
) -> Result<Ipv4Packet<&'pkt [u8]>, Option<ArpRepr>> {
|
||||
// Parse the Ethernet header. Ignore the packet if the header is ill-formed.
|
||||
let frame = EthernetFrame::new_checked(data).map_err(|_| None)?;
|
||||
let repr = EthernetRepr::parse(&frame).map_err(|_| None)?;
|
||||
|
||||
// Ignore the Ethernet frame if it is not sent to us.
|
||||
if !repr.dst_addr.is_broadcast() && repr.dst_addr != self.ether_addr {
|
||||
return Err(None);
|
||||
}
|
||||
|
||||
// Ignore the Ethernet frame if the protocol is not supported.
|
||||
match repr.ethertype {
|
||||
EthernetProtocol::Ipv4 => {
|
||||
Ok(Ipv4Packet::new_checked(frame.payload()).map_err(|_| None)?)
|
||||
}
|
||||
EthernetProtocol::Arp => {
|
||||
let pkt = ArpPacket::new_checked(frame.payload()).map_err(|_| None)?;
|
||||
let arp = ArpRepr::parse(&pkt).map_err(|_| None)?;
|
||||
Err(self.process_arp(&arp, iface_cx))
|
||||
}
|
||||
_ => Err(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_arp(&self, arp_repr: &ArpRepr, iface_cx: &mut Context) -> Option<ArpRepr> {
|
||||
match arp_repr {
|
||||
ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Reply,
|
||||
source_hardware_addr,
|
||||
source_protocol_addr,
|
||||
..
|
||||
} => {
|
||||
// Ignore the ARP packet if the source addresses are not unicast or not local.
|
||||
if !source_hardware_addr.is_unicast()
|
||||
|| !iface_cx.in_same_network(&IpAddress::Ipv4(*source_protocol_addr))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Insert the mapping between the Ethernet address and the IP address.
|
||||
//
|
||||
// TODO: Remove the mapping if it expires.
|
||||
self.arp_table
|
||||
.lock()
|
||||
.insert(*source_protocol_addr, *source_hardware_addr);
|
||||
|
||||
None
|
||||
}
|
||||
ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Request,
|
||||
source_hardware_addr,
|
||||
source_protocol_addr,
|
||||
target_protocol_addr,
|
||||
..
|
||||
} => {
|
||||
// Ignore the ARP packet if the source addresses are not unicast.
|
||||
if !source_hardware_addr.is_unicast() || !source_protocol_addr.is_unicast() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Ignore the ARP packet if we do not own the target address.
|
||||
if !iface_cx
|
||||
.ipv4_addr()
|
||||
.is_some_and(|addr| addr == *target_protocol_addr)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Reply,
|
||||
source_hardware_addr: self.ether_addr,
|
||||
source_protocol_addr: *target_protocol_addr,
|
||||
target_hardware_addr: *source_hardware_addr,
|
||||
target_protocol_addr: *source_protocol_addr,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch<T: TxToken>(&self, pkt: &Packet, iface_cx: &mut Context, tx_token: T) {
|
||||
match self.resolve_ether_or_generate_arp(pkt, iface_cx) {
|
||||
Ok(ether) => Self::emit_ip(ðer, pkt, &iface_cx.caps, tx_token),
|
||||
Err(Some(arp)) => Self::emit_arp(&arp, tx_token),
|
||||
Err(None) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_ether_or_generate_arp(
|
||||
&self,
|
||||
pkt: &Packet,
|
||||
iface_cx: &mut Context,
|
||||
) -> Result<EthernetRepr, Option<ArpRepr>> {
|
||||
// Resolve the next-hop IP address.
|
||||
let next_hop_ip = match iface_cx.route(&pkt.ip_repr().dst_addr(), iface_cx.now()) {
|
||||
Some(IpAddress::Ipv4(next_hop_ip)) => next_hop_ip,
|
||||
None => return Err(None),
|
||||
};
|
||||
|
||||
// Resolve the next-hop Ethernet address.
|
||||
let next_hop_ether = if next_hop_ip.is_broadcast() {
|
||||
EthernetAddress::BROADCAST
|
||||
} else if let Some(next_hop_ether) = self.arp_table.lock().get(&next_hop_ip) {
|
||||
*next_hop_ether
|
||||
} else {
|
||||
// If the next-hop Ethernet address cannot be resolved, we drop the original packet and
|
||||
// send an ARP packet instead. The upper layer should be responsible for detecting the
|
||||
// packet loss and retrying later to see if the Ethernet address is ready.
|
||||
return Err(Some(ArpRepr::EthernetIpv4 {
|
||||
operation: ArpOperation::Request,
|
||||
source_hardware_addr: self.ether_addr,
|
||||
source_protocol_addr: iface_cx.ipv4_addr().unwrap_or(Ipv4Address::UNSPECIFIED),
|
||||
target_hardware_addr: EthernetAddress::BROADCAST,
|
||||
target_protocol_addr: next_hop_ip,
|
||||
}));
|
||||
};
|
||||
|
||||
Ok(EthernetRepr {
|
||||
src_addr: self.ether_addr,
|
||||
dst_addr: next_hop_ether,
|
||||
ethertype: EthernetProtocol::Ipv4,
|
||||
})
|
||||
}
|
||||
|
||||
/// Consumes the token and emits an IP packet.
|
||||
fn emit_ip<T: TxToken>(
|
||||
ether_repr: &EthernetRepr,
|
||||
ip_pkt: &Packet,
|
||||
caps: &DeviceCapabilities,
|
||||
tx_token: T,
|
||||
) {
|
||||
tx_token.consume(
|
||||
ether_repr.buffer_len() + ip_pkt.ip_repr().buffer_len(),
|
||||
|buffer| {
|
||||
let mut frame = EthernetFrame::new_unchecked(buffer);
|
||||
ether_repr.emit(&mut frame);
|
||||
|
||||
let ip_repr = ip_pkt.ip_repr();
|
||||
ip_repr.emit(frame.payload_mut(), &caps.checksum);
|
||||
ip_pkt.emit_payload(
|
||||
&ip_repr,
|
||||
&mut frame.payload_mut()[ip_repr.header_len()..],
|
||||
caps,
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Consumes the token and emits an ARP packet.
|
||||
fn emit_arp<T: TxToken>(arp_repr: &ArpRepr, tx_token: T) {
|
||||
let ether_repr = match arp_repr {
|
||||
ArpRepr::EthernetIpv4 {
|
||||
source_hardware_addr,
|
||||
target_hardware_addr,
|
||||
..
|
||||
} => EthernetRepr {
|
||||
src_addr: *source_hardware_addr,
|
||||
dst_addr: *target_hardware_addr,
|
||||
ethertype: EthernetProtocol::Arp,
|
||||
},
|
||||
_ => return,
|
||||
};
|
||||
|
||||
tx_token.consume(ether_repr.buffer_len() + arp_repr.buffer_len(), |buffer| {
|
||||
let mut frame = EthernetFrame::new_unchecked(buffer);
|
||||
ether_repr.emit(&mut frame);
|
||||
|
||||
let mut pkt = ArpPacket::new_unchecked(frame.payload_mut());
|
||||
arp_repr.emit(&mut pkt);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,8 @@ use alloc::sync::Arc;
|
||||
|
||||
use smoltcp::{
|
||||
iface::Config,
|
||||
wire::{self, Ipv4Cidr},
|
||||
phy::TxToken,
|
||||
wire::{self, Ipv4Cidr, Ipv4Packet},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@ -14,7 +15,7 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
pub struct IpIface<D: WithDevice, E> {
|
||||
pub struct IpIface<D, E> {
|
||||
driver: D,
|
||||
common: IfaceCommon<E>,
|
||||
}
|
||||
@ -39,16 +40,30 @@ impl<D: WithDevice, E> IpIface<D, E> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E> IfaceInternal<E> for IpIface<D, E> {
|
||||
impl<D, E> IfaceInternal<E> for IpIface<D, E> {
|
||||
fn common(&self) -> &IfaceCommon<E> {
|
||||
&self.common
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: WithDevice, E: Send + Sync> Iface<E> for IpIface<D, E> {
|
||||
impl<D: WithDevice + 'static, E: Send + Sync> Iface<E> for IpIface<D, E> {
|
||||
fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option<u64>)) {
|
||||
self.driver.with(|device| {
|
||||
let next_poll = self.common.poll(device);
|
||||
let next_poll = self.common.poll(
|
||||
device,
|
||||
|data, _iface_cx, tx_token| Some((Ipv4Packet::new_checked(data).ok()?, tx_token)),
|
||||
|pkt, iface_cx, tx_token| {
|
||||
let ip_repr = pkt.ip_repr();
|
||||
tx_token.consume(ip_repr.buffer_len(), |buffer| {
|
||||
ip_repr.emit(&mut buffer[..], &iface_cx.checksum_caps());
|
||||
pkt.emit_payload(
|
||||
&ip_repr,
|
||||
&mut buffer[ip_repr.header_len()..],
|
||||
&iface_cx.caps,
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
schedule_next_poll(next_poll);
|
||||
});
|
||||
}
|
||||
|
476
kernel/libs/aster-bigtcp/src/iface/poll.rs
Normal file
476
kernel/libs/aster-bigtcp/src/iface/poll.rs
Normal file
@ -0,0 +1,476 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{collections::btree_set::BTreeSet, vec};
|
||||
|
||||
use keyable_arc::KeyableArc;
|
||||
use smoltcp::{
|
||||
iface::{
|
||||
packet::{icmp_reply_payload_len, IpPayload, Packet},
|
||||
Context,
|
||||
},
|
||||
phy::{ChecksumCapabilities, Device, RxToken, TxToken},
|
||||
wire::{
|
||||
Icmpv4DstUnreachable, Icmpv4Repr, IpAddress, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet,
|
||||
Ipv4Repr, TcpControl, TcpPacket, TcpRepr, UdpPacket, UdpRepr, IPV4_HEADER_LEN,
|
||||
IPV4_MIN_MTU,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
|
||||
|
||||
pub(super) struct PollContext<'a, E> {
|
||||
iface_cx: &'a mut Context,
|
||||
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
|
||||
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
|
||||
}
|
||||
|
||||
impl<'a, E> PollContext<'a, E> {
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
pub(super) fn new(
|
||||
iface_cx: &'a mut Context,
|
||||
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
|
||||
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
iface_cx,
|
||||
tcp_sockets,
|
||||
udp_sockets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This works around <https://github.com/rust-lang/rust/issues/49601>.
|
||||
// See the issue above for details.
|
||||
pub(super) trait FnHelper<A, B, C, O>: FnMut(A, B, C) -> O {}
|
||||
impl<A, B, C, O, F> FnHelper<A, B, C, O> for F where F: FnMut(A, B, C) -> O {}
|
||||
|
||||
impl<'a, E> PollContext<'a, E> {
|
||||
pub(super) fn poll_ingress<D, P, Q>(
|
||||
&mut self,
|
||||
device: &mut D,
|
||||
mut process_phy: P,
|
||||
dispatch_phy: &mut Q,
|
||||
) where
|
||||
D: Device + ?Sized,
|
||||
P: for<'pkt, 'cx, 'tx> FnHelper<
|
||||
&'pkt [u8],
|
||||
&'cx mut Context,
|
||||
D::TxToken<'tx>,
|
||||
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
|
||||
>,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) {
|
||||
rx_token.consume(|data| {
|
||||
let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(reply) = self.parse_and_process_ipv4(pkt) else {
|
||||
return;
|
||||
};
|
||||
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_and_process_ipv4<'pkt>(
|
||||
&mut self,
|
||||
pkt: Ipv4Packet<&'pkt [u8]>,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// Parse the IP header. Ignore the packet if the header is ill-formed.
|
||||
let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?;
|
||||
|
||||
if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) {
|
||||
return self.generate_icmp_unreachable(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
Icmpv4DstUnreachable::HostUnreachable,
|
||||
);
|
||||
}
|
||||
|
||||
match repr.next_header {
|
||||
IpProtocol::Tcp => self.parse_and_process_tcp(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
&self.iface_cx.checksum_caps(),
|
||||
),
|
||||
IpProtocol::Udp => self.parse_and_process_udp(
|
||||
&IpRepr::Ipv4(repr),
|
||||
pkt.payload(),
|
||||
&self.iface_cx.checksum_caps(),
|
||||
),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_and_process_tcp<'pkt>(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
ip_payload: &'pkt [u8],
|
||||
checksum_caps: &ChecksumCapabilities,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// TCP connections can only be established between unicast addresses. Ignore the packet if
|
||||
// this is not the case. See
|
||||
// <https://datatracker.ietf.org/doc/html/rfc9293#section-3.9.2.3>.
|
||||
if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse the TCP header. Ignore the packet if the header is ill-formed.
|
||||
let tcp_pkt = TcpPacket::new_checked(ip_payload).ok()?;
|
||||
let tcp_repr = TcpRepr::parse(
|
||||
&tcp_pkt,
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
checksum_caps,
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
self.process_tcp_until_outgoing(ip_repr, &tcp_repr)
|
||||
.map(|(ip_repr, tcp_repr)| Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)))
|
||||
}
|
||||
|
||||
fn process_tcp_until_outgoing(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> Option<(IpRepr, TcpRepr<'static>)> {
|
||||
let (mut ip_repr, mut tcp_repr) = self.process_tcp(ip_repr, tcp_repr)?;
|
||||
|
||||
loop {
|
||||
if !self.is_unicast_local(ip_repr.dst_addr()) {
|
||||
return Some((ip_repr, tcp_repr));
|
||||
}
|
||||
|
||||
let (new_ip_repr, new_tcp_repr) = self.process_tcp(&ip_repr, &tcp_repr)?;
|
||||
ip_repr = new_ip_repr;
|
||||
tcp_repr = new_tcp_repr;
|
||||
}
|
||||
}
|
||||
|
||||
fn process_tcp(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> Option<(IpRepr, TcpRepr<'static>)> {
|
||||
for socket in self.tcp_sockets.iter() {
|
||||
if !socket.can_process(tcp_repr.dst_port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match socket.process(self.iface_cx, ip_repr, tcp_repr) {
|
||||
TcpProcessResult::NotProcessed => continue,
|
||||
TcpProcessResult::Processed => return None,
|
||||
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
|
||||
return Some((ip_repr, tcp_repr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// "In no case does receipt of a segment containing RST give rise to a RST in response."
|
||||
// See <https://datatracker.ietf.org/doc/html/rfc9293#section-4-1.64>.
|
||||
if tcp_repr.control == TcpControl::Rst {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(smoltcp::socket::tcp::Socket::rst_reply(ip_repr, tcp_repr))
|
||||
}
|
||||
|
||||
fn parse_and_process_udp<'pkt>(
|
||||
&mut self,
|
||||
ip_repr: &IpRepr,
|
||||
ip_payload: &'pkt [u8],
|
||||
checksum_caps: &ChecksumCapabilities,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
// Parse the UDP header. Ignore the packet if the header is ill-formed.
|
||||
let udp_pkt = UdpPacket::new_checked(ip_payload).ok()?;
|
||||
let udp_repr = UdpRepr::parse(
|
||||
&udp_pkt,
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
checksum_caps,
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
if !self.process_udp(ip_repr, &udp_repr, udp_pkt.payload()) {
|
||||
return self.generate_icmp_unreachable(
|
||||
ip_repr,
|
||||
ip_payload,
|
||||
Icmpv4DstUnreachable::PortUnreachable,
|
||||
);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool {
|
||||
let mut processed = false;
|
||||
|
||||
for socket in self.udp_sockets.iter() {
|
||||
if !socket.can_process(udp_repr.dst_port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload);
|
||||
if processed && ip_repr.dst_addr().is_unicast() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
processed
|
||||
}
|
||||
|
||||
fn generate_icmp_unreachable<'pkt>(
|
||||
&self,
|
||||
ip_repr: &IpRepr,
|
||||
ip_payload: &'pkt [u8],
|
||||
reason: Icmpv4DstUnreachable,
|
||||
) -> Option<Packet<'pkt>> {
|
||||
if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self.is_unicast_local(ip_repr.src_addr()) {
|
||||
// In this case, the generating ICMP message will have a local IP address as the
|
||||
// destination. However, since we don't have the ability to handle ICMP messages, we'll
|
||||
// just skip the generation.
|
||||
//
|
||||
// TODO: Generate the ICMP message here once we're able to handle incoming ICMP
|
||||
// messages.
|
||||
return None;
|
||||
}
|
||||
|
||||
let IpRepr::Ipv4(ipv4_repr) = ip_repr;
|
||||
|
||||
let reply_len = icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, IPV4_HEADER_LEN);
|
||||
let icmp_repr = Icmpv4Repr::DstUnreachable {
|
||||
reason,
|
||||
header: *ipv4_repr,
|
||||
data: &ip_payload[..reply_len],
|
||||
};
|
||||
|
||||
Some(Packet::new_ipv4(
|
||||
Ipv4Repr {
|
||||
src_addr: self
|
||||
.iface_cx
|
||||
.ipv4_addr()
|
||||
.unwrap_or(Ipv4Address::UNSPECIFIED),
|
||||
dst_addr: ipv4_repr.src_addr,
|
||||
next_header: IpProtocol::Icmp,
|
||||
payload_len: icmp_repr.buffer_len(),
|
||||
hop_limit: 64,
|
||||
},
|
||||
IpPayload::Icmpv4(icmp_repr),
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns whether the destination address is the unicast address of a local interface.
|
||||
///
|
||||
/// Note: "local" means that the IP address belongs to the local interface, not to be confused
|
||||
/// with the localhost IP (127.0.0.1).
|
||||
fn is_unicast_local(&self, dst_addr: IpAddress) -> bool {
|
||||
match dst_addr {
|
||||
IpAddress::Ipv4(dst_addr) => self
|
||||
.iface_cx
|
||||
.ipv4_addr()
|
||||
.is_some_and(|addr| addr == dst_addr),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E> PollContext<'a, E> {
|
||||
pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, mut dispatch_phy: Q)
|
||||
where
|
||||
D: Device + ?Sized,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
while let Some(tx_token) = device.transmit(self.iface_cx.now()) {
|
||||
if !self.dispatch_ipv4(tx_token, &mut dispatch_phy) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_ipv4<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> bool
|
||||
where
|
||||
T: TxToken,
|
||||
Q: FnMut(&Packet, &mut Context, T),
|
||||
{
|
||||
let (did_something_tcp, tx_token) = self.dispatch_tcp(tx_token, dispatch_phy);
|
||||
|
||||
let Some(tx_token) = tx_token else {
|
||||
return did_something_tcp;
|
||||
};
|
||||
|
||||
let (did_something_udp, _tx_token) = self.dispatch_udp(tx_token, dispatch_phy);
|
||||
|
||||
did_something_tcp || did_something_udp
|
||||
}
|
||||
|
||||
fn dispatch_tcp<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option<T>)
|
||||
where
|
||||
T: TxToken,
|
||||
Q: FnMut(&Packet, &mut Context, T),
|
||||
{
|
||||
let mut tx_token = Some(tx_token);
|
||||
let mut did_something = false;
|
||||
|
||||
for socket in self.tcp_sockets.iter() {
|
||||
if !socket.need_dispatch(self.iface_cx.now()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We set `did_something` even if no packets are actually generated. This is because a
|
||||
// timer can expire, but no packets are actually generated.
|
||||
did_something = true;
|
||||
|
||||
let mut deferred = None;
|
||||
|
||||
let reply = socket.dispatch(self.iface_cx, |cx, ip_repr, tcp_repr| {
|
||||
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets);
|
||||
|
||||
if !this.is_unicast_local(ip_repr.dst_addr()) {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)),
|
||||
this.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if !socket.can_process(tcp_repr.dst_port) {
|
||||
return this.process_tcp(ip_repr, tcp_repr);
|
||||
}
|
||||
|
||||
// We cannot call `process_tcp` now because it may cause deadlocks. We will copy
|
||||
// the packet and call `process_tcp` after releasing the socket lock.
|
||||
deferred = Some((ip_repr.clone(), {
|
||||
let mut data = vec![0; tcp_repr.buffer_len()];
|
||||
tcp_repr.emit(
|
||||
&mut TcpPacket::new_unchecked(data.as_mut_slice()),
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
&ChecksumCapabilities::ignored(),
|
||||
);
|
||||
data
|
||||
}));
|
||||
|
||||
None
|
||||
});
|
||||
|
||||
match (deferred, reply) {
|
||||
(None, None) => (),
|
||||
(Some((ip_repr, ip_payload)), None) => {
|
||||
if let Some(reply) = self.parse_and_process_tcp(
|
||||
&ip_repr,
|
||||
&ip_payload,
|
||||
&ChecksumCapabilities::ignored(),
|
||||
) {
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
|
||||
}
|
||||
}
|
||||
(None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)),
|
||||
self.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
}
|
||||
(None, Some((ip_repr, tcp_repr))) => {
|
||||
if let Some((new_ip_repr, new_tcp_repr)) =
|
||||
self.process_tcp_until_outgoing(&ip_repr, &tcp_repr)
|
||||
{
|
||||
dispatch_phy(
|
||||
&Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)),
|
||||
self.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
(Some(_), Some(_)) => unreachable!(),
|
||||
}
|
||||
|
||||
if tx_token.is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(did_something, tx_token)
|
||||
}
|
||||
|
||||
fn dispatch_udp<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option<T>)
|
||||
where
|
||||
T: TxToken,
|
||||
Q: FnMut(&Packet, &mut Context, T),
|
||||
{
|
||||
let mut tx_token = Some(tx_token);
|
||||
let mut did_something = false;
|
||||
|
||||
for socket in self.udp_sockets.iter() {
|
||||
if !socket.need_dispatch(self.iface_cx.now()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We set `did_something` even if no packets are actually generated. This is because a
|
||||
// timer can expire, but no packets are actually generated.
|
||||
did_something = true;
|
||||
|
||||
let mut deferred = None;
|
||||
|
||||
socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| {
|
||||
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets);
|
||||
|
||||
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
|
||||
dispatch_phy(
|
||||
&Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)),
|
||||
this.iface_cx,
|
||||
tx_token.take().unwrap(),
|
||||
);
|
||||
if !ip_repr.dst_addr().is_broadcast() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if !socket.can_process(udp_repr.dst_port) {
|
||||
// TODO: Generate the ICMP message here once we're able to handle incoming ICMP
|
||||
// messages.
|
||||
let _ = this.process_udp(ip_repr, udp_repr, udp_payload);
|
||||
return;
|
||||
}
|
||||
|
||||
// We cannot call `process_udp` now because it may cause deadlocks. We will copy
|
||||
// the packet and call `process_udp` after releasing the socket lock.
|
||||
deferred = Some((ip_repr.clone(), {
|
||||
let mut data = vec![0; udp_repr.header_len() + udp_payload.len()];
|
||||
udp_repr.emit(
|
||||
&mut UdpPacket::new_unchecked(&mut data),
|
||||
&ip_repr.src_addr(),
|
||||
&ip_repr.dst_addr(),
|
||||
udp_payload.len(),
|
||||
|payload| payload.copy_from_slice(udp_payload),
|
||||
&ChecksumCapabilities::ignored(),
|
||||
);
|
||||
data
|
||||
}));
|
||||
});
|
||||
|
||||
if let Some((ip_repr, ip_payload)) = deferred {
|
||||
if let Some(reply) = self.parse_and_process_udp(
|
||||
&ip_repr,
|
||||
&ip_payload,
|
||||
&ChecksumCapabilities::ignored(),
|
||||
) {
|
||||
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
if tx_token.is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(did_something, tx_token)
|
||||
}
|
||||
}
|
@ -1,44 +1,188 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::{Arc, Weak};
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use core::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
};
|
||||
|
||||
use ostd::sync::RwLock;
|
||||
use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
socket::tcp::ConnectError,
|
||||
wire::{IpAddress, IpEndpoint},
|
||||
iface::Context,
|
||||
socket::{udp::UdpMetadata, PollAt},
|
||||
time::Instant,
|
||||
wire::{IpAddress, IpEndpoint, IpRepr, TcpRepr, UdpRepr},
|
||||
};
|
||||
|
||||
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
|
||||
use crate::iface::Iface;
|
||||
|
||||
pub(crate) enum SocketFamily {
|
||||
Tcp,
|
||||
Udp,
|
||||
pub struct BoundSocket<T: AnySocket, E>(Arc<BoundSocketInner<T, E>>);
|
||||
|
||||
/// [`TcpSocket`] or [`UdpSocket`].
|
||||
pub trait AnySocket {
|
||||
type RawSocket;
|
||||
|
||||
/// Called by [`BoundSocket::new`].
|
||||
fn new(socket: Box<Self::RawSocket>) -> Self;
|
||||
|
||||
/// Called by [`BoundSocket::drop`].
|
||||
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>)
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub struct AnyBoundSocket<E>(Arc<AnyBoundSocketInner<E>>);
|
||||
pub type BoundTcpSocket<E> = BoundSocket<TcpSocket, E>;
|
||||
pub type BoundUdpSocket<E> = BoundSocket<UdpSocket, E>;
|
||||
|
||||
impl<E> AnyBoundSocket<E> {
|
||||
/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`].
|
||||
pub struct BoundSocketInner<T, E> {
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
port: u16,
|
||||
socket: T,
|
||||
observer: RwLock<Weak<dyn SocketEventObserver>>,
|
||||
next_poll_at_ms: AtomicU64,
|
||||
has_new_events: AtomicBool,
|
||||
}
|
||||
|
||||
/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`].
|
||||
pub struct TcpSocket {
|
||||
socket: SpinLock<RawTcpSocketExt, LocalIrqDisabled>,
|
||||
is_dead: AtomicBool,
|
||||
}
|
||||
|
||||
struct RawTcpSocketExt {
|
||||
socket: Box<RawTcpSocket>,
|
||||
/// Whether the socket is in the background.
|
||||
///
|
||||
/// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means
|
||||
/// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a
|
||||
/// state of waiting for certain network events (e.g., remote FIN/ACK packets), so
|
||||
/// [`BoundSocketInner`] may still be alive for a while.
|
||||
in_background: bool,
|
||||
}
|
||||
|
||||
impl Deref for RawTcpSocketExt {
|
||||
type Target = RawTcpSocket;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.socket
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RawTcpSocketExt {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.socket
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpSocket {
|
||||
fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> {
|
||||
self.socket.lock()
|
||||
}
|
||||
|
||||
/// Returns whether the TCP socket is dead.
|
||||
///
|
||||
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
|
||||
fn is_dead(&self) -> bool {
|
||||
self.is_dead.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Updates whether the TCP socket is dead.
|
||||
///
|
||||
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
|
||||
///
|
||||
/// This method must be called after handling network events. However, it is not necessary to
|
||||
/// call this method after handling non-closing user events, because the socket can never be
|
||||
/// dead if user events can reach the socket.
|
||||
fn update_dead(&self, socket: &RawTcpSocketExt) {
|
||||
if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed {
|
||||
self.is_dead.store(true, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnySocket for TcpSocket {
|
||||
type RawSocket = RawTcpSocket;
|
||||
|
||||
fn new(socket: Box<Self::RawSocket>) -> Self {
|
||||
let socket_ext = RawTcpSocketExt {
|
||||
socket,
|
||||
in_background: false,
|
||||
};
|
||||
|
||||
Self {
|
||||
socket: SpinLock::new(socket_ext),
|
||||
is_dead: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) {
|
||||
let mut socket = this.socket.lock();
|
||||
|
||||
socket.in_background = true;
|
||||
socket.close();
|
||||
|
||||
// A TCP socket may not be appropriate for immediate removal. We leave the removal decision
|
||||
// to the polling logic.
|
||||
this.update_next_poll_at_ms(PollAt::Now);
|
||||
this.socket.update_dead(&socket);
|
||||
}
|
||||
}
|
||||
|
||||
/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`].
|
||||
type UdpSocket = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
|
||||
|
||||
impl AnySocket for UdpSocket {
|
||||
type RawSocket = RawUdpSocket;
|
||||
|
||||
fn new(socket: Box<Self::RawSocket>) -> Self {
|
||||
Self::new(socket)
|
||||
}
|
||||
|
||||
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) {
|
||||
this.socket.lock().close();
|
||||
|
||||
// A UDP socket can be removed immediately.
|
||||
this.iface.common().remove_udp_socket(this);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AnySocket, E> Drop for BoundSocket<T, E> {
|
||||
fn drop(&mut self) {
|
||||
T::on_drop(&self.0);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type BoundTcpSocketInner<E> = BoundSocketInner<TcpSocket, E>;
|
||||
pub(crate) type BoundUdpSocketInner<E> = BoundSocketInner<UdpSocket, E>;
|
||||
|
||||
impl<T: AnySocket, E> BoundSocket<T, E> {
|
||||
pub(crate) fn new(
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
handle: smoltcp::iface::SocketHandle,
|
||||
port: u16,
|
||||
socket_family: SocketFamily,
|
||||
socket: Box<T::RawSocket>,
|
||||
observer: Weak<dyn SocketEventObserver>,
|
||||
) -> Self {
|
||||
Self(Arc::new(AnyBoundSocketInner {
|
||||
Self(Arc::new(BoundSocketInner {
|
||||
iface,
|
||||
handle,
|
||||
port,
|
||||
socket_family,
|
||||
socket: T::new(socket),
|
||||
observer: RwLock::new(observer),
|
||||
next_poll_at_ms: AtomicU64::new(u64::MAX),
|
||||
has_new_events: AtomicBool::new(false),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> &Arc<AnyBoundSocketInner<E>> {
|
||||
pub(crate) fn inner(&self) -> &Arc<BoundSocketInner<T, E>> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AnySocket, E> BoundSocket<T, E> {
|
||||
/// Sets the observer whose `on_events` will be called when certain iface events happen. After
|
||||
/// setting, the new observer will fire once immediately to avoid missing any events.
|
||||
///
|
||||
@ -51,6 +195,15 @@ impl<E> AnyBoundSocket<E> {
|
||||
self.0.on_iface_events();
|
||||
}
|
||||
|
||||
/// Returns the observer.
|
||||
///
|
||||
/// See also [`Self::set_observer`].
|
||||
pub fn observer(&self) -> Weak<dyn SocketEventObserver> {
|
||||
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
|
||||
// get the read lock.
|
||||
self.0.observer.read().clone()
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Option<IpEndpoint> {
|
||||
let ip_addr = {
|
||||
let ipv4_addr = self.0.iface.ipv4_addr()?;
|
||||
@ -59,58 +212,155 @@ impl<E> AnyBoundSocket<E> {
|
||||
Some(IpEndpoint::new(ip_addr, self.0.port))
|
||||
}
|
||||
|
||||
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
|
||||
&self,
|
||||
f: F,
|
||||
) -> R {
|
||||
self.0.raw_with(f)
|
||||
}
|
||||
|
||||
/// Connects to a remote endpoint.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method will panic if the socket is not a TCP socket.
|
||||
pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<(), ConnectError> {
|
||||
let common = self.iface().common();
|
||||
|
||||
let mut sockets = common.sockets();
|
||||
let socket = sockets.get_mut::<RawTcpSocket>(self.0.handle);
|
||||
|
||||
let mut iface = common.interface();
|
||||
let cx = iface.context();
|
||||
|
||||
socket.connect(cx, remote_endpoint, self.0.port)
|
||||
}
|
||||
|
||||
pub fn iface(&self) -> &Arc<dyn Iface<E>> {
|
||||
&self.0.iface
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Drop for AnyBoundSocket<E> {
|
||||
fn drop(&mut self) {
|
||||
if self.0.start_closing() {
|
||||
self.0.iface.common().remove_bound_socket_now(&self.0);
|
||||
} else {
|
||||
self.0
|
||||
.iface
|
||||
.common()
|
||||
.remove_bound_socket_when_closed(&self.0);
|
||||
}
|
||||
impl<E> BoundTcpSocket<E> {
|
||||
/// Connects to a remote endpoint.
|
||||
pub fn connect(
|
||||
&self,
|
||||
remote_endpoint: IpEndpoint,
|
||||
) -> Result<(), smoltcp::socket::tcp::ConnectError> {
|
||||
let common = self.iface().common();
|
||||
let mut iface = common.interface();
|
||||
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let result = socket.connect(iface.context(), remote_endpoint, self.0.port);
|
||||
self.0
|
||||
.update_next_poll_at_ms(socket.poll_at(iface.context()));
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Listens at a specified endpoint.
|
||||
pub fn listen(
|
||||
&self,
|
||||
local_endpoint: IpEndpoint,
|
||||
) -> Result<(), smoltcp::socket::tcp::ListenError> {
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
socket.listen(local_endpoint)
|
||||
}
|
||||
|
||||
pub fn send<F, R>(&self, f: F) -> Result<R, smoltcp::socket::tcp::SendError>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> (usize, R),
|
||||
{
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let result = socket.send(f);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn recv<F, R>(&self, f: F) -> Result<R, smoltcp::socket::tcp::RecvError>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> (usize, R),
|
||||
{
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let result = socket.recv(f);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
socket.close();
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
}
|
||||
|
||||
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
|
||||
//
|
||||
// NOTE: If a mutable reference is required, add a method above that correctly updates the next
|
||||
// polling time.
|
||||
pub fn raw_with<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&RawTcpSocket) -> R,
|
||||
{
|
||||
let socket = self.0.socket.lock();
|
||||
f(&socket)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AnyBoundSocketInner<E> {
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
handle: smoltcp::iface::SocketHandle,
|
||||
port: u16,
|
||||
socket_family: SocketFamily,
|
||||
observer: RwLock<Weak<dyn SocketEventObserver>>,
|
||||
impl<E> BoundUdpSocket<E> {
|
||||
/// Binds to a specified endpoint.
|
||||
pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> {
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
socket.bind(local_endpoint)
|
||||
}
|
||||
|
||||
pub fn send<F, R>(
|
||||
&self,
|
||||
size: usize,
|
||||
meta: impl Into<UdpMetadata>,
|
||||
f: F,
|
||||
) -> Result<R, crate::errors::udp::SendError>
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> R,
|
||||
{
|
||||
use smoltcp::socket::udp::SendError as SendErrorInner;
|
||||
|
||||
use crate::errors::udp::SendError;
|
||||
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
if size > socket.packet_send_capacity() {
|
||||
return Err(SendError::TooLarge);
|
||||
}
|
||||
|
||||
let buffer = match socket.send(size, meta) {
|
||||
Ok(data) => data,
|
||||
Err(SendErrorInner::Unaddressable) => return Err(SendError::Unaddressable),
|
||||
Err(SendErrorInner::BufferFull) => return Err(SendError::BufferFull),
|
||||
};
|
||||
let result = f(buffer);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn recv<F, R>(&self, f: F) -> Result<R, smoltcp::socket::udp::RecvError>
|
||||
where
|
||||
F: FnOnce(&[u8], UdpMetadata) -> R,
|
||||
{
|
||||
let mut socket = self.0.socket.lock();
|
||||
|
||||
let (data, meta) = socket.recv()?;
|
||||
let result = f(data, meta);
|
||||
self.0.update_next_poll_at_ms(PollAt::Now);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Calls `f` with an immutable reference to the associated [`RawUdpSocket`].
|
||||
//
|
||||
// NOTE: If a mutable reference is required, add a method above that correctly updates the next
|
||||
// polling time.
|
||||
pub fn raw_with<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&RawUdpSocket) -> R,
|
||||
{
|
||||
let socket = self.0.socket.lock();
|
||||
f(&socket)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> AnyBoundSocketInner<E> {
|
||||
impl<T, E> BoundSocketInner<T, E> {
|
||||
pub(crate) fn has_new_events(&self) -> bool {
|
||||
self.has_new_events.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub(crate) fn on_iface_events(&self) {
|
||||
self.has_new_events.store(false, Ordering::Relaxed);
|
||||
|
||||
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
|
||||
// get the read lock.
|
||||
let observer = Weak::upgrade(&*self.observer.read());
|
||||
@ -120,51 +370,173 @@ impl<E> AnyBoundSocketInner<E> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
match self.socket_family {
|
||||
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.state() == smoltcp::socket::tcp::State::Closed
|
||||
}),
|
||||
SocketFamily::Udp => true,
|
||||
}
|
||||
/// Returns the next polling time.
|
||||
///
|
||||
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
|
||||
/// before new network or user events.
|
||||
pub(crate) fn next_poll_at_ms(&self) -> u64 {
|
||||
self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Starts closing the socket and returns whether the socket is closed.
|
||||
/// Updates the next polling time according to `poll_at`.
|
||||
///
|
||||
/// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets,
|
||||
/// this method will always return `true`.
|
||||
///
|
||||
/// For other sockets, such as TCP connected sockets, they cannot be closed immediately because
|
||||
/// we at least need to send the FIN packet and wait for the remote end to send an ACK packet.
|
||||
/// In this case, this method will return `false` and [`Self::is_closed`] can be used to
|
||||
/// determine if the closing process is complete.
|
||||
fn start_closing(&self) -> bool {
|
||||
match self.socket_family {
|
||||
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.close();
|
||||
socket.state() == smoltcp::socket::tcp::State::Closed
|
||||
}),
|
||||
SocketFamily::Udp => {
|
||||
self.raw_with(|socket: &mut RawUdpSocket| socket.close());
|
||||
true
|
||||
}
|
||||
/// The update is typically needed after new network or user events have been handled, so this
|
||||
/// method also marks that there may be new events, so that the event observer provided by
|
||||
/// [`BoundSocket::set_observer`] can be notified later.
|
||||
fn update_next_poll_at_ms(&self, poll_at: PollAt) {
|
||||
self.has_new_events.store(true, Ordering::Relaxed);
|
||||
|
||||
match poll_at {
|
||||
PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed),
|
||||
PollAt::Time(instant) => self
|
||||
.next_poll_at_ms
|
||||
.store(instant.total_millis() as u64, Ordering::Relaxed),
|
||||
PollAt::Ingress => self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
|
||||
impl<T, E> BoundSocketInner<T, E> {
|
||||
pub(crate) fn port(&self) -> u16 {
|
||||
self.port
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> BoundTcpSocketInner<E> {
|
||||
/// Returns whether the TCP socket is dead.
|
||||
///
|
||||
/// A TCP socket is considered dead if and only if the following two conditions are met:
|
||||
/// 1. The TCP connection is closed, so this socket cannot process any network events.
|
||||
/// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this
|
||||
/// [`BoundSocketInner`] is in background and no more user events can reach it.
|
||||
pub(crate) fn is_dead(&self) -> bool {
|
||||
self.socket.is_dead()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E> BoundSocketInner<T, E> {
|
||||
/// Returns whether an incoming packet _may_ be processed by the socket.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
|
||||
self.port == dst_port
|
||||
}
|
||||
|
||||
/// Returns whether the socket _may_ generate an outgoing packet.
|
||||
///
|
||||
/// The check is intended to be lock-free and fast, but may have false positives.
|
||||
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
|
||||
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub(crate) enum TcpProcessResult {
|
||||
NotProcessed,
|
||||
Processed,
|
||||
ProcessedWithReply(IpRepr, TcpRepr<'static>),
|
||||
}
|
||||
|
||||
impl<E> BoundTcpSocketInner<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
&self,
|
||||
mut f: F,
|
||||
) -> R {
|
||||
let mut sockets = self.iface.common().sockets();
|
||||
let socket = sockets.get_mut::<T>(self.handle);
|
||||
f(socket)
|
||||
cx: &mut Context,
|
||||
ip_repr: &IpRepr,
|
||||
tcp_repr: &TcpRepr,
|
||||
) -> TcpProcessResult {
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
return TcpProcessResult::NotProcessed;
|
||||
}
|
||||
|
||||
let result = match socket.process(cx, ip_repr, tcp_repr) {
|
||||
None => TcpProcessResult::Processed,
|
||||
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
|
||||
};
|
||||
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
self.socket.update_dead(&socket);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Tries to generate an outgoing packet and dispatches the generated packet.
|
||||
pub(crate) fn dispatch<D>(
|
||||
&self,
|
||||
cx: &mut Context,
|
||||
dispatch: D,
|
||||
) -> Option<(IpRepr, TcpRepr<'static>)>
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
|
||||
{
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
let mut reply = None;
|
||||
socket
|
||||
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
|
||||
reply = dispatch(cx, &ip_repr, &tcp_repr);
|
||||
Ok::<(), ()>(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// `dispatch` can return a packet in response to the generated packet. If the socket
|
||||
// accepts the packet, we can process it directly.
|
||||
while let Some((ref ip_repr, ref tcp_repr)) = reply {
|
||||
if !socket.accepts(cx, ip_repr, tcp_repr) {
|
||||
break;
|
||||
}
|
||||
reply = socket.process(cx, ip_repr, tcp_repr);
|
||||
}
|
||||
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
self.socket.update_dead(&socket);
|
||||
|
||||
reply
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Drop for AnyBoundSocketInner<E> {
|
||||
fn drop(&mut self) {
|
||||
let iface_common = self.iface.common();
|
||||
iface_common.remove_socket(self.handle);
|
||||
iface_common.release_port(self.port);
|
||||
impl<E> BoundUdpSocketInner<E> {
|
||||
/// Tries to process an incoming packet and returns whether the packet is processed.
|
||||
pub(crate) fn process(
|
||||
&self,
|
||||
cx: &mut Context,
|
||||
ip_repr: &IpRepr,
|
||||
udp_repr: &UdpRepr,
|
||||
udp_payload: &[u8],
|
||||
) -> bool {
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
if !socket.accepts(cx, ip_repr, udp_repr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
socket.process(
|
||||
cx,
|
||||
smoltcp::phy::PacketMeta::default(),
|
||||
ip_repr,
|
||||
udp_repr,
|
||||
udp_payload,
|
||||
);
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Tries to generate an outgoing packet and dispatches the generated packet.
|
||||
pub(crate) fn dispatch<D>(&self, cx: &mut Context, dispatch: D)
|
||||
where
|
||||
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
|
||||
{
|
||||
let mut socket = self.socket.lock();
|
||||
|
||||
socket
|
||||
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
|
||||
dispatch(cx, &ip_repr, &udp_repr, udp_payload);
|
||||
Ok::<(), ()>(())
|
||||
})
|
||||
.unwrap();
|
||||
self.update_next_poll_at_ms(socket.poll_at(cx));
|
||||
}
|
||||
}
|
||||
|
@ -4,12 +4,11 @@ mod bound;
|
||||
mod event;
|
||||
mod unbound;
|
||||
|
||||
pub use bound::AnyBoundSocket;
|
||||
pub(crate) use bound::{AnyBoundSocketInner, SocketFamily};
|
||||
pub use bound::{BoundTcpSocket, BoundUdpSocket};
|
||||
pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
|
||||
pub use event::SocketEventObserver;
|
||||
pub(crate) use unbound::AnyRawSocket;
|
||||
pub use unbound::{
|
||||
AnyUnboundSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
|
||||
UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
|
||||
UDP_SEND_PAYLOAD_LEN,
|
||||
};
|
||||
|
||||
|
@ -1,34 +1,33 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::{sync::Weak, vec};
|
||||
use alloc::{boxed::Box, sync::Weak, vec};
|
||||
|
||||
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
|
||||
|
||||
pub struct AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket,
|
||||
pub struct UnboundSocket<T> {
|
||||
socket: Box<T>,
|
||||
observer: Weak<dyn SocketEventObserver>,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub(crate) enum AnyRawSocket {
|
||||
Tcp(RawTcpSocket),
|
||||
Udp(RawUdpSocket),
|
||||
}
|
||||
pub type UnboundTcpSocket = UnboundSocket<RawTcpSocket>;
|
||||
pub type UnboundUdpSocket = UnboundSocket<RawUdpSocket>;
|
||||
|
||||
impl AnyUnboundSocket {
|
||||
pub fn new_tcp(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
impl UnboundTcpSocket {
|
||||
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
let raw_tcp_socket = {
|
||||
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]);
|
||||
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]);
|
||||
RawTcpSocket::new(rx_buffer, tx_buffer)
|
||||
};
|
||||
AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
|
||||
Self {
|
||||
socket: Box::new(raw_tcp_socket),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_udp(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
impl UnboundUdpSocket {
|
||||
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
let raw_udp_socket = {
|
||||
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
|
||||
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
|
||||
@ -41,14 +40,16 @@ impl AnyUnboundSocket {
|
||||
);
|
||||
RawUdpSocket::new(rx_buffer, tx_buffer)
|
||||
};
|
||||
AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket::Udp(raw_udp_socket),
|
||||
Self {
|
||||
socket: Box::new(raw_udp_socket),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_raw(self) -> (AnyRawSocket, Weak<dyn SocketEventObserver>) {
|
||||
(self.socket_family, self.observer)
|
||||
impl<T> UnboundSocket<T> {
|
||||
pub(crate) fn into_raw(self) -> (Box<T>, Weak<dyn SocketEventObserver>) {
|
||||
(self.socket, self.observer)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,4 +8,5 @@ pub use init::{init, IFACES};
|
||||
pub use poll::{lazy_init, poll_ifaces};
|
||||
|
||||
pub type Iface = dyn aster_bigtcp::iface::Iface<ext::IfaceExt>;
|
||||
pub type AnyBoundSocket = aster_bigtcp::socket::AnyBoundSocket<ext::IfaceExt>;
|
||||
pub type BoundTcpSocket = aster_bigtcp::socket::BoundTcpSocket<ext::IfaceExt>;
|
||||
pub type BoundUdpSocket = aster_bigtcp::socket::BoundUdpSocket<ext::IfaceExt>;
|
||||
|
@ -3,12 +3,11 @@
|
||||
use aster_bigtcp::{
|
||||
errors::BindError,
|
||||
iface::BindPortConfig,
|
||||
socket::AnyUnboundSocket,
|
||||
wire::{IpAddress, IpEndpoint},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
net::iface::{AnyBoundSocket, Iface, IFACES},
|
||||
net::iface::{Iface, IFACES},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
@ -46,11 +45,16 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<Iface> {
|
||||
ifaces[0].clone()
|
||||
}
|
||||
|
||||
pub(super) fn bind_socket(
|
||||
unbound_socket: Box<AnyUnboundSocket>,
|
||||
pub(super) fn bind_socket<S, T>(
|
||||
unbound_socket: Box<S>,
|
||||
endpoint: &IpEndpoint,
|
||||
can_reuse: bool,
|
||||
) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
|
||||
bind: impl FnOnce(
|
||||
Arc<Iface>,
|
||||
Box<S>,
|
||||
BindPortConfig,
|
||||
) -> core::result::Result<T, (BindError, Box<S>)>,
|
||||
) -> core::result::Result<T, (Error, Box<S>)> {
|
||||
let iface = match get_iface_to_bind(&endpoint.addr) {
|
||||
Some(iface) => iface,
|
||||
None => {
|
||||
@ -64,9 +68,7 @@ pub(super) fn bind_socket(
|
||||
|
||||
let bind_port_config = BindPortConfig::new(endpoint.port, can_reuse);
|
||||
|
||||
iface
|
||||
.bind_socket(unbound_socket, bind_port_config)
|
||||
.map_err(|(err, unbound)| (err.into(), unbound))
|
||||
bind(iface, unbound_socket, bind_port_config).map_err(|(err, unbound)| (err.into(), unbound))
|
||||
}
|
||||
|
||||
impl From<BindError> for Error {
|
||||
|
@ -2,25 +2,24 @@
|
||||
|
||||
use aster_bigtcp::{
|
||||
errors::udp::{RecvError, SendError},
|
||||
socket::RawUdpSocket,
|
||||
wire::IpEndpoint,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::{iface::AnyBoundSocket, socket::util::send_recv_flags::SendRecvFlags},
|
||||
net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags},
|
||||
prelude::*,
|
||||
process::signal::Pollee,
|
||||
util::{MultiRead, MultiWrite},
|
||||
};
|
||||
|
||||
pub struct BoundDatagram {
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundUdpSocket,
|
||||
remote_endpoint: Option<IpEndpoint>,
|
||||
}
|
||||
|
||||
impl BoundDatagram {
|
||||
pub fn new(bound_socket: AnyBoundSocket) -> Self {
|
||||
pub fn new(bound_socket: BoundUdpSocket) -> Self {
|
||||
Self {
|
||||
bound_socket,
|
||||
remote_endpoint: None,
|
||||
@ -44,12 +43,10 @@ impl BoundDatagram {
|
||||
writer: &mut dyn MultiWrite,
|
||||
_flags: SendRecvFlags,
|
||||
) -> Result<(usize, IpEndpoint)> {
|
||||
let result = self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
|
||||
socket.recv().map(|(packet, udp_metadata)| {
|
||||
let copied_res = writer.write(&mut VmReader::from(packet));
|
||||
let endpoint = udp_metadata.endpoint;
|
||||
(copied_res, endpoint)
|
||||
})
|
||||
let result = self.bound_socket.recv(|packet, udp_metadata| {
|
||||
let copied_res = writer.write(&mut VmReader::from(packet));
|
||||
let endpoint = udp_metadata.endpoint;
|
||||
(copied_res, endpoint)
|
||||
});
|
||||
|
||||
match result {
|
||||
@ -59,7 +56,7 @@ impl BoundDatagram {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty")
|
||||
}
|
||||
Err(RecvError::Truncated) => {
|
||||
unreachable!("`Socket::recv` should never fail with `RecvError::Truncated`")
|
||||
unreachable!("`recv` should never fail with `RecvError::Truncated`")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -70,32 +67,31 @@ impl BoundDatagram {
|
||||
remote: &IpEndpoint,
|
||||
_flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
let reader_len = reader.sum_lens();
|
||||
let result = self
|
||||
.bound_socket
|
||||
.send(reader.sum_lens(), *remote, |socket_buffer| {
|
||||
// FIXME: If copy failed, we should not send any packet.
|
||||
// But current smoltcp API seems not to support this behavior.
|
||||
reader
|
||||
.read(&mut VmWriter::from(socket_buffer))
|
||||
.map_err(|e| {
|
||||
warn!("unexpected UDP packet will be sent");
|
||||
e
|
||||
})
|
||||
});
|
||||
|
||||
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
|
||||
if socket.payload_send_capacity() < reader_len {
|
||||
match result {
|
||||
Ok(inner) => inner,
|
||||
Err(SendError::TooLarge) => {
|
||||
return_errno_with_message!(Errno::EMSGSIZE, "the message is too large");
|
||||
}
|
||||
|
||||
let socket_buffer = match socket.send(reader_len, *remote) {
|
||||
Ok(socket_buffer) => socket_buffer,
|
||||
Err(SendError::BufferFull) => {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the send buffer is full")
|
||||
}
|
||||
Err(SendError::Unaddressable) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the destination address is invalid")
|
||||
}
|
||||
};
|
||||
|
||||
// FIXME: If copy failed, we should not send any packet.
|
||||
// But current smoltcp API seems not to support this behavior.
|
||||
reader
|
||||
.read(&mut VmWriter::from(socket_buffer))
|
||||
.map_err(|e| {
|
||||
warn!("unexpected UDP packet will be sent");
|
||||
e
|
||||
})
|
||||
})
|
||||
Err(SendError::Unaddressable) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the destination address is invalid");
|
||||
}
|
||||
Err(SendError::BufferFull) => {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the send buffer is full");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
@ -104,7 +100,7 @@ impl BoundDatagram {
|
||||
}
|
||||
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
|
||||
self.bound_socket.raw_with(|socket| {
|
||||
if socket.can_recv() {
|
||||
pollee.add_events(IoEvents::IN);
|
||||
} else {
|
||||
|
@ -154,13 +154,9 @@ impl DatagramSocket {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound");
|
||||
};
|
||||
|
||||
let received =
|
||||
bound_datagram
|
||||
.try_recv(writer, flags)
|
||||
.map(|(recv_bytes, remote_endpoint)| {
|
||||
bound_datagram.update_io_events(&self.pollee);
|
||||
(recv_bytes, remote_endpoint.into())
|
||||
});
|
||||
let received = bound_datagram
|
||||
.try_recv(writer, flags)
|
||||
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()));
|
||||
|
||||
drop(inner);
|
||||
poll_ifaces();
|
||||
@ -192,12 +188,7 @@ impl DatagramSocket {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound")
|
||||
};
|
||||
|
||||
let sent_bytes = bound_datagram
|
||||
.try_send(reader, remote, flags)
|
||||
.map(|sent_bytes| {
|
||||
bound_datagram.update_io_events(&self.pollee);
|
||||
sent_bytes
|
||||
});
|
||||
let sent_bytes = bound_datagram.try_send(reader, remote, flags);
|
||||
|
||||
drop(inner);
|
||||
poll_ifaces();
|
||||
|
@ -3,7 +3,7 @@
|
||||
use alloc::sync::Weak;
|
||||
|
||||
use aster_bigtcp::{
|
||||
socket::{AnyUnboundSocket, RawUdpSocket, SocketEventObserver},
|
||||
socket::{SocketEventObserver, UnboundUdpSocket},
|
||||
wire::IpEndpoint,
|
||||
};
|
||||
|
||||
@ -13,13 +13,13 @@ use crate::{
|
||||
};
|
||||
|
||||
pub struct UnboundDatagram {
|
||||
unbound_socket: Box<AnyUnboundSocket>,
|
||||
unbound_socket: Box<UnboundUdpSocket>,
|
||||
}
|
||||
|
||||
impl UnboundDatagram {
|
||||
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
Self {
|
||||
unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)),
|
||||
unbound_socket: Box::new(UnboundUdpSocket::new(observer)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,15 +28,18 @@ impl UnboundDatagram {
|
||||
endpoint: &IpEndpoint,
|
||||
can_reuse: bool,
|
||||
) -> core::result::Result<BoundDatagram, (Error, Self)> {
|
||||
let bound_socket = match bind_socket(self.unbound_socket, endpoint, can_reuse) {
|
||||
let bound_socket = match bind_socket(
|
||||
self.unbound_socket,
|
||||
endpoint,
|
||||
can_reuse,
|
||||
|iface, socket, config| iface.bind_udp(socket, config),
|
||||
) {
|
||||
Ok(bound_socket) => bound_socket,
|
||||
Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),
|
||||
};
|
||||
|
||||
let bound_endpoint = bound_socket.local_endpoint().unwrap();
|
||||
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
|
||||
socket.bind(bound_endpoint).unwrap();
|
||||
});
|
||||
bound_socket.bind(bound_endpoint).unwrap();
|
||||
|
||||
Ok(BoundDatagram::new(bound_socket))
|
||||
}
|
||||
|
@ -4,14 +4,14 @@ use alloc::sync::Weak;
|
||||
|
||||
use aster_bigtcp::{
|
||||
errors::tcp::{RecvError, SendError},
|
||||
socket::{RawTcpSocket, SocketEventObserver},
|
||||
socket::SocketEventObserver,
|
||||
wire::IpEndpoint,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::{
|
||||
iface::AnyBoundSocket,
|
||||
iface::BoundTcpSocket,
|
||||
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
|
||||
},
|
||||
prelude::*,
|
||||
@ -20,7 +20,7 @@ use crate::{
|
||||
};
|
||||
|
||||
pub struct ConnectedStream {
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
remote_endpoint: IpEndpoint,
|
||||
/// Indicates whether this connection is "new" in a `connect()` system call.
|
||||
///
|
||||
@ -37,7 +37,7 @@ pub struct ConnectedStream {
|
||||
|
||||
impl ConnectedStream {
|
||||
pub fn new(
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
remote_endpoint: IpEndpoint,
|
||||
is_new_connection: bool,
|
||||
) -> Self {
|
||||
@ -50,20 +50,16 @@ impl ConnectedStream {
|
||||
|
||||
pub fn shutdown(&self, _cmd: SockShutdownCmd) -> Result<()> {
|
||||
// TODO: deal with cmd
|
||||
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.close();
|
||||
});
|
||||
self.bound_socket.close();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn try_recv(&self, writer: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result<usize> {
|
||||
let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.recv(
|
||||
|socket_buffer| match writer.write(&mut VmReader::from(&*socket_buffer)) {
|
||||
Ok(len) => (len, Ok(len)),
|
||||
Err(e) => (0, Err(e)),
|
||||
},
|
||||
)
|
||||
let result = self.bound_socket.recv(|socket_buffer| {
|
||||
match writer.write(&mut VmReader::from(&*socket_buffer)) {
|
||||
Ok(len) => (len, Ok(len)),
|
||||
Err(e) => (0, Err(e)),
|
||||
}
|
||||
});
|
||||
|
||||
match result {
|
||||
@ -78,13 +74,11 @@ impl ConnectedStream {
|
||||
}
|
||||
|
||||
pub fn try_send(&self, reader: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result<usize> {
|
||||
let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.send(
|
||||
|socket_buffer| match reader.read(&mut VmWriter::from(socket_buffer)) {
|
||||
Ok(len) => (len, Ok(len)),
|
||||
Err(e) => (0, Err(e)),
|
||||
},
|
||||
)
|
||||
let result = self.bound_socket.send(|socket_buffer| {
|
||||
match reader.read(&mut VmWriter::from(socket_buffer)) {
|
||||
Ok(len) => (len, Ok(len)),
|
||||
Err(e) => (0, Err(e)),
|
||||
}
|
||||
});
|
||||
|
||||
match result {
|
||||
@ -123,7 +117,7 @@ impl ConnectedStream {
|
||||
}
|
||||
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
self.bound_socket.raw_with(|socket| {
|
||||
if socket.can_recv() {
|
||||
pollee.add_events(IoEvents::IN);
|
||||
} else {
|
||||
|
@ -1,12 +1,12 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use aster_bigtcp::{socket::RawTcpSocket, wire::IpEndpoint};
|
||||
use aster_bigtcp::wire::IpEndpoint;
|
||||
|
||||
use super::{connected::ConnectedStream, init::InitStream};
|
||||
use crate::{net::iface::AnyBoundSocket, prelude::*, process::signal::Pollee};
|
||||
use crate::{net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
|
||||
|
||||
pub struct ConnectingStream {
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
remote_endpoint: IpEndpoint,
|
||||
conn_result: RwLock<Option<ConnResult>>,
|
||||
}
|
||||
@ -24,15 +24,15 @@ pub enum NonConnectedStream {
|
||||
|
||||
impl ConnectingStream {
|
||||
pub fn new(
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
remote_endpoint: IpEndpoint,
|
||||
) -> core::result::Result<Self, (Error, AnyBoundSocket)> {
|
||||
) -> core::result::Result<Self, (Error, BoundTcpSocket)> {
|
||||
// The only reason this method might fail is because we're trying to connect to an
|
||||
// unspecified address (i.e. 0.0.0.0). We currently have no support for binding to,
|
||||
// listening on, or connecting to the unspecified address.
|
||||
//
|
||||
// We assume the remote will just refuse to connect, so we return `ECONNREFUSED`.
|
||||
if bound_socket.do_connect(remote_endpoint).is_err() {
|
||||
if bound_socket.connect(remote_endpoint).is_err() {
|
||||
return Err((
|
||||
Error::with_message(
|
||||
Errno::ECONNREFUSED,
|
||||
@ -91,7 +91,7 @@ impl ConnectingStream {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
self.bound_socket.raw_with(|socket| {
|
||||
let mut result = self.conn_result.write();
|
||||
if result.is_some() {
|
||||
return false;
|
||||
|
@ -3,7 +3,7 @@
|
||||
use alloc::sync::Weak;
|
||||
|
||||
use aster_bigtcp::{
|
||||
socket::{AnyUnboundSocket, SocketEventObserver},
|
||||
socket::{SocketEventObserver, UnboundTcpSocket},
|
||||
wire::IpEndpoint,
|
||||
};
|
||||
|
||||
@ -11,7 +11,7 @@ use super::{connecting::ConnectingStream, listen::ListenStream};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::{
|
||||
iface::AnyBoundSocket,
|
||||
iface::BoundTcpSocket,
|
||||
socket::ip::common::{bind_socket, get_ephemeral_endpoint},
|
||||
},
|
||||
prelude::*,
|
||||
@ -19,16 +19,16 @@ use crate::{
|
||||
};
|
||||
|
||||
pub enum InitStream {
|
||||
Unbound(Box<AnyUnboundSocket>),
|
||||
Bound(AnyBoundSocket),
|
||||
Unbound(Box<UnboundTcpSocket>),
|
||||
Bound(BoundTcpSocket),
|
||||
}
|
||||
|
||||
impl InitStream {
|
||||
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
|
||||
InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer)))
|
||||
InitStream::Unbound(Box::new(UnboundTcpSocket::new(observer)))
|
||||
}
|
||||
|
||||
pub fn new_bound(bound_socket: AnyBoundSocket) -> Self {
|
||||
pub fn new_bound(bound_socket: BoundTcpSocket) -> Self {
|
||||
InitStream::Bound(bound_socket)
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ impl InitStream {
|
||||
self,
|
||||
endpoint: &IpEndpoint,
|
||||
can_reuse: bool,
|
||||
) -> core::result::Result<AnyBoundSocket, (Error, Self)> {
|
||||
) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
|
||||
let unbound_socket = match self {
|
||||
InitStream::Unbound(unbound_socket) => unbound_socket,
|
||||
InitStream::Bound(bound_socket) => {
|
||||
@ -46,7 +46,12 @@ impl InitStream {
|
||||
));
|
||||
}
|
||||
};
|
||||
let bound_socket = match bind_socket(unbound_socket, endpoint, can_reuse) {
|
||||
let bound_socket = match bind_socket(
|
||||
unbound_socket,
|
||||
endpoint,
|
||||
can_reuse,
|
||||
|iface, socket, config| iface.bind_tcp(socket, config),
|
||||
) {
|
||||
Ok(bound_socket) => bound_socket,
|
||||
Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
|
||||
};
|
||||
@ -56,7 +61,7 @@ impl InitStream {
|
||||
fn bind_to_ephemeral_endpoint(
|
||||
self,
|
||||
remote_endpoint: &IpEndpoint,
|
||||
) -> core::result::Result<AnyBoundSocket, (Error, Self)> {
|
||||
) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
|
||||
let endpoint = get_ephemeral_endpoint(remote_endpoint);
|
||||
self.bind(&endpoint, false)
|
||||
}
|
||||
|
@ -1,28 +1,25 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use aster_bigtcp::{
|
||||
errors::tcp::ListenError,
|
||||
iface::BindPortConfig,
|
||||
socket::{AnyUnboundSocket, RawTcpSocket},
|
||||
wire::IpEndpoint,
|
||||
errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint,
|
||||
};
|
||||
|
||||
use super::connected::ConnectedStream;
|
||||
use crate::{events::IoEvents, net::iface::AnyBoundSocket, prelude::*, process::signal::Pollee};
|
||||
use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
|
||||
|
||||
pub struct ListenStream {
|
||||
backlog: usize,
|
||||
/// A bound socket held to ensure the TCP port cannot be released
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
/// Backlog sockets listening at the local endpoint
|
||||
backlog_sockets: RwLock<Vec<BacklogSocket>>,
|
||||
}
|
||||
|
||||
impl ListenStream {
|
||||
pub fn new(
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
backlog: usize,
|
||||
) -> core::result::Result<Self, (Error, AnyBoundSocket)> {
|
||||
) -> core::result::Result<Self, (Error, BoundTcpSocket)> {
|
||||
const SOMAXCONN: usize = 4096;
|
||||
let somaxconn = SOMAXCONN.min(backlog);
|
||||
|
||||
@ -102,30 +99,28 @@ impl ListenStream {
|
||||
}
|
||||
|
||||
struct BacklogSocket {
|
||||
bound_socket: AnyBoundSocket,
|
||||
bound_socket: BoundTcpSocket,
|
||||
}
|
||||
|
||||
impl BacklogSocket {
|
||||
// FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason
|
||||
// why the error may occur. Perhaps it is better to call `unwrap()` directly?
|
||||
fn new(bound_socket: &AnyBoundSocket) -> Result<Self> {
|
||||
fn new(bound_socket: &BoundTcpSocket) -> Result<Self> {
|
||||
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket is not bound",
|
||||
))?;
|
||||
|
||||
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new()));
|
||||
let unbound_socket = Box::new(UnboundTcpSocket::new(bound_socket.observer()));
|
||||
let bound_socket = {
|
||||
let iface = bound_socket.iface();
|
||||
let bind_port_config = BindPortConfig::new(local_endpoint.port, true);
|
||||
iface
|
||||
.bind_socket(unbound_socket, bind_port_config)
|
||||
.bind_tcp(unbound_socket, bind_port_config)
|
||||
.map_err(|(err, _)| err)?
|
||||
};
|
||||
|
||||
let result = bound_socket
|
||||
.raw_with(|raw_tcp_socket: &mut RawTcpSocket| raw_tcp_socket.listen(local_endpoint));
|
||||
match result {
|
||||
match bound_socket.listen(local_endpoint) {
|
||||
Ok(()) => Ok(Self { bound_socket }),
|
||||
Err(ListenError::Unaddressable) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the listening address is invalid")
|
||||
@ -137,16 +132,15 @@ impl BacklogSocket {
|
||||
}
|
||||
|
||||
fn is_active(&self) -> bool {
|
||||
self.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.is_active())
|
||||
self.bound_socket.raw_with(|socket| socket.is_active())
|
||||
}
|
||||
|
||||
fn remote_endpoint(&self) -> Option<IpEndpoint> {
|
||||
self.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
|
||||
.raw_with(|socket| socket.remote_endpoint())
|
||||
}
|
||||
|
||||
fn into_bound_socket(self) -> AnyBoundSocket {
|
||||
fn into_bound_socket(self) -> BoundTcpSocket {
|
||||
self.bound_socket
|
||||
}
|
||||
}
|
||||
|
@ -230,18 +230,13 @@ impl StreamSocket {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
|
||||
};
|
||||
|
||||
let accepted = listen_stream.try_accept().map(|connected_stream| {
|
||||
listen_stream.try_accept().map(|connected_stream| {
|
||||
listen_stream.update_io_events(&self.pollee);
|
||||
|
||||
let remote_endpoint = connected_stream.remote_endpoint();
|
||||
let accepted_socket = Self::new_connected(connected_stream);
|
||||
(accepted_socket as _, remote_endpoint.into())
|
||||
});
|
||||
|
||||
drop(state);
|
||||
poll_ifaces();
|
||||
|
||||
accepted
|
||||
})
|
||||
}
|
||||
|
||||
fn try_recv(
|
||||
@ -262,8 +257,6 @@ impl StreamSocket {
|
||||
};
|
||||
|
||||
let received = connected_stream.try_recv(writer, flags).map(|recv_bytes| {
|
||||
connected_stream.update_io_events(&self.pollee);
|
||||
|
||||
let remote_endpoint = connected_stream.remote_endpoint();
|
||||
(recv_bytes, remote_endpoint.into())
|
||||
});
|
||||
@ -302,10 +295,7 @@ impl StreamSocket {
|
||||
}
|
||||
};
|
||||
|
||||
let sent_bytes = connected_stream.try_send(reader, flags).map(|sent_bytes| {
|
||||
connected_stream.update_io_events(&self.pollee);
|
||||
sent_bytes
|
||||
});
|
||||
let sent_bytes = connected_stream.try_send(reader, flags);
|
||||
|
||||
drop(state);
|
||||
poll_ifaces();
|
||||
@ -658,3 +648,11 @@ impl SocketEventObserver for StreamSocket {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StreamSocket {
|
||||
fn drop(&mut self) {
|
||||
self.state.write().take();
|
||||
|
||||
poll_ifaces();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user