mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-23 01:13:23 +00:00
Move packet dispatch out of smoltcp
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
f793259512
commit
ee1656ba35
@ -7,42 +7,45 @@ use alloc::{
|
||||
btree_set::BTreeSet,
|
||||
},
|
||||
sync::Arc,
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use keyable_arc::KeyableArc;
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLock, SpinLock, SpinLockGuard};
|
||||
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard};
|
||||
use smoltcp::{
|
||||
iface::{SocketHandle, SocketSet},
|
||||
iface::{packet::Packet, Context},
|
||||
phy::Device,
|
||||
wire::Ipv4Address,
|
||||
wire::{Ipv4Address, Ipv4Packet},
|
||||
};
|
||||
|
||||
use super::{port::BindPortConfig, time::get_network_timestamp, Iface};
|
||||
use super::{
|
||||
poll::{FnHelper, PollContext},
|
||||
port::BindPortConfig,
|
||||
time::get_network_timestamp,
|
||||
Iface,
|
||||
};
|
||||
use crate::{
|
||||
errors::BindError,
|
||||
socket::{AnyBoundSocket, AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
|
||||
socket::{
|
||||
BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket,
|
||||
UnboundUdpSocket,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct IfaceCommon<E> {
|
||||
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
|
||||
sockets: SpinLock<SocketSet<'static>, LocalIrqDisabled>,
|
||||
used_ports: SpinLock<BTreeMap<u16, usize>, PreemptDisabled>,
|
||||
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
|
||||
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>, LocalIrqDisabled>,
|
||||
tcp_sockets: SpinLock<BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, LocalIrqDisabled>,
|
||||
udp_sockets: SpinLock<BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, LocalIrqDisabled>,
|
||||
ext: E,
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
pub(super) fn new(interface: smoltcp::iface::Interface, ext: E) -> Self {
|
||||
let socket_set = SocketSet::new(Vec::new());
|
||||
let used_ports = BTreeMap::new();
|
||||
Self {
|
||||
interface: SpinLock::new(interface),
|
||||
sockets: SpinLock::new(socket_set),
|
||||
used_ports: SpinLock::new(used_ports),
|
||||
bound_sockets: RwLock::new(BTreeSet::new()),
|
||||
closing_sockets: SpinLock::new(BTreeSet::new()),
|
||||
used_ports: SpinLock::new(BTreeMap::new()),
|
||||
tcp_sockets: SpinLock::new(BTreeSet::new()),
|
||||
udp_sockets: SpinLock::new(BTreeSet::new()),
|
||||
ext,
|
||||
}
|
||||
}
|
||||
@ -58,58 +61,57 @@ impl<E> IfaceCommon<E> {
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
/// Acquires the lock to the interface.
|
||||
///
|
||||
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
|
||||
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
|
||||
self.interface.lock()
|
||||
}
|
||||
|
||||
/// Acuqires the lock to the sockets.
|
||||
///
|
||||
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
|
||||
pub(crate) fn sockets(
|
||||
&self,
|
||||
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
|
||||
self.sockets.lock()
|
||||
}
|
||||
}
|
||||
|
||||
const IP_LOCAL_PORT_START: u16 = 32768;
|
||||
const IP_LOCAL_PORT_END: u16 = 60999;
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
pub(super) fn bind_socket(
|
||||
pub(super) fn bind_tcp(
|
||||
&self,
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
socket: Box<AnyUnboundSocket>,
|
||||
socket: Box<UnboundTcpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Some(port) => port,
|
||||
None => return Err((BindError::Exhausted, socket)),
|
||||
}
|
||||
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
|
||||
let port = match self.bind_port(config) {
|
||||
Ok(port) => port,
|
||||
Err(err) => return Err((err, socket)),
|
||||
};
|
||||
if !self.bind_port(port, config.can_reuse()) {
|
||||
return Err((BindError::InUse, socket));
|
||||
}
|
||||
|
||||
let (handle, socket_family, observer) = match socket.into_raw() {
|
||||
(AnyRawSocket::Tcp(tcp_socket), observer) => (
|
||||
self.sockets.lock().add(tcp_socket),
|
||||
SocketFamily::Tcp,
|
||||
observer,
|
||||
),
|
||||
(AnyRawSocket::Udp(udp_socket), observer) => (
|
||||
self.sockets.lock().add(udp_socket),
|
||||
SocketFamily::Udp,
|
||||
observer,
|
||||
),
|
||||
let (raw_socket, observer) = socket.into_raw();
|
||||
let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer);
|
||||
|
||||
let inserted = self
|
||||
.tcp_sockets
|
||||
.lock()
|
||||
.insert(KeyableArc::from(bound_socket.inner().clone()));
|
||||
assert!(inserted);
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
|
||||
pub(super) fn bind_udp(
|
||||
&self,
|
||||
iface: Arc<dyn Iface<E>>,
|
||||
socket: Box<UnboundUdpSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
|
||||
let port = match self.bind_port(config) {
|
||||
Ok(port) => port,
|
||||
Err(err) => return Err((err, socket)),
|
||||
};
|
||||
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
|
||||
self.insert_bound_socket(bound_socket.inner());
|
||||
|
||||
let (raw_socket, observer) = socket.into_raw();
|
||||
let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer);
|
||||
|
||||
let inserted = self
|
||||
.udp_sockets
|
||||
.lock()
|
||||
.insert(KeyableArc::from(bound_socket.inner().clone()));
|
||||
assert!(inserted);
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
@ -130,36 +132,57 @@ impl<E> IfaceCommon<E> {
|
||||
None
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn bind_port(&self, port: u16, can_reuse: bool) -> bool {
|
||||
fn bind_port(&self, config: BindPortConfig) -> Result<u16, BindError> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Some(port) => port,
|
||||
None => return Err(BindError::Exhausted),
|
||||
}
|
||||
};
|
||||
|
||||
let mut used_ports = self.used_ports.lock();
|
||||
|
||||
if let Some(used_times) = used_ports.get_mut(&port) {
|
||||
if *used_times == 0 || can_reuse {
|
||||
if *used_times == 0 || config.can_reuse() {
|
||||
// FIXME: Check if the previous socket was bound with SO_REUSEADDR.
|
||||
*used_times += 1;
|
||||
} else {
|
||||
return false;
|
||||
return Err(BindError::InUse);
|
||||
}
|
||||
} else {
|
||||
used_ports.insert(port, 1);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let inserted = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
Ok(port)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>) {
|
||||
sockets.retain(|socket| {
|
||||
if socket.is_dead() {
|
||||
self.release_port(socket.port());
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn remove_udp_socket(&self, socket: &Arc<BoundUdpSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self.udp_sockets.lock().remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
self.release_port(keyable_socket.port());
|
||||
}
|
||||
|
||||
/// Releases the port so that it can be used again (if it is not being reused).
|
||||
pub(crate) fn release_port(&self, port: u16) {
|
||||
fn release_port(&self, port: u16) {
|
||||
let mut used_ports = self.used_ports.lock();
|
||||
if let Some(used_times) = used_ports.remove(&port) {
|
||||
if used_times != 1 {
|
||||
@ -167,96 +190,63 @@ impl<E> IfaceCommon<E> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a socket from the interface.
|
||||
pub(crate) fn remove_socket(&self, handle: SocketHandle) {
|
||||
self.sockets.lock().remove(handle);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self
|
||||
.bound_sockets
|
||||
.write_irq_disabled()
|
||||
.remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
let mut closing_sockets = self.closing_sockets.lock();
|
||||
|
||||
// Check `is_closed` after holding the lock to avoid race conditions.
|
||||
if keyable_socket.is_closed() {
|
||||
return;
|
||||
}
|
||||
|
||||
let inserted = closing_sockets.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IfaceCommon<E> {
|
||||
#[must_use]
|
||||
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> {
|
||||
let mut sockets = self.sockets.lock();
|
||||
let mut interface = self.interface.lock();
|
||||
pub(super) fn poll<D, P, Q>(
|
||||
&self,
|
||||
device: &mut D,
|
||||
process_phy: P,
|
||||
mut dispatch_phy: Q,
|
||||
) -> Option<u64>
|
||||
where
|
||||
D: Device + ?Sized,
|
||||
P: for<'pkt, 'cx, 'tx> FnHelper<
|
||||
&'pkt [u8],
|
||||
&'cx mut Context,
|
||||
D::TxToken<'tx>,
|
||||
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
|
||||
>,
|
||||
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
|
||||
{
|
||||
let mut interface = self.interface();
|
||||
interface.context().now = get_network_timestamp();
|
||||
|
||||
let timestamp = get_network_timestamp();
|
||||
let (has_events, poll_at) = {
|
||||
let mut has_events = false;
|
||||
let mut poll_at;
|
||||
let mut tcp_sockets = self.tcp_sockets.lock();
|
||||
let udp_sockets = self.udp_sockets.lock();
|
||||
|
||||
loop {
|
||||
// `poll` transmits and receives a bounded number of packets. This loop ensures
|
||||
// that all packets are transmitted and received. For details, see
|
||||
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L400-L405>.
|
||||
while interface.poll(timestamp, device, &mut sockets) {
|
||||
has_events = true;
|
||||
}
|
||||
let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets);
|
||||
context.poll_ingress(device, process_phy, &mut dispatch_phy);
|
||||
context.poll_egress(device, dispatch_phy);
|
||||
|
||||
// `poll_at` can return `Some(Instant::from_millis(0))`, which means `PollAt::Now`.
|
||||
// For details, see
|
||||
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L478>.
|
||||
poll_at = interface.poll_at(timestamp, &sockets);
|
||||
let Some(instant) = poll_at else {
|
||||
break;
|
||||
};
|
||||
if instant > timestamp {
|
||||
break;
|
||||
}
|
||||
tcp_sockets.iter().for_each(|socket| {
|
||||
if socket.has_new_events() {
|
||||
socket.on_iface_events();
|
||||
}
|
||||
});
|
||||
udp_sockets.iter().for_each(|socket| {
|
||||
if socket.has_new_events() {
|
||||
socket.on_iface_events();
|
||||
}
|
||||
});
|
||||
|
||||
(has_events, poll_at)
|
||||
};
|
||||
self.remove_dead_tcp_sockets(&mut tcp_sockets);
|
||||
|
||||
// drop sockets here to avoid deadlock
|
||||
drop(interface);
|
||||
drop(sockets);
|
||||
|
||||
if has_events {
|
||||
// We never try to hold the write lock in the IRQ context, and we disable IRQ when
|
||||
// holding the write lock. So we don't need to disable IRQ when holding the read lock.
|
||||
self.bound_sockets.read().iter().for_each(|bound_socket| {
|
||||
bound_socket.on_iface_events();
|
||||
});
|
||||
|
||||
let closed_sockets = self
|
||||
.closing_sockets
|
||||
.lock()
|
||||
.extract_if(|closing_socket| closing_socket.is_closed())
|
||||
.collect::<Vec<_>>();
|
||||
drop(closed_sockets);
|
||||
match (
|
||||
tcp_sockets
|
||||
.iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min(),
|
||||
udp_sockets
|
||||
.iter()
|
||||
.map(|socket| socket.next_poll_at_ms())
|
||||
.min(),
|
||||
) {
|
||||
(Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => {
|
||||
Some(tcp_poll_at)
|
||||
}
|
||||
(tcp_poll_at, None) => tcp_poll_at,
|
||||
(_, udp_poll_at) => udp_poll_at,
|
||||
}
|
||||
|
||||
poll_at.map(|at| smoltcp::time::Instant::total_millis(&at) as u64)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user