Move packet dispatch out of smoltcp

This commit is contained in:
Ruihan Li
2024-09-20 11:33:20 +08:00
committed by Tate, Hongliang Tian
parent f793259512
commit ee1656ba35
22 changed files with 1483 additions and 416 deletions

2
Cargo.lock generated
View File

@ -1383,7 +1383,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "smoltcp"
version = "0.11.0"
source = "git+https://github.com/smoltcp-rs/smoltcp?rev=469dccd#469dccdbaece66696494f8540615152fc5123b17"
source = "git+https://github.com/asterinas/smoltcp?rev=37716bf#37716bff5ed5b16aba1b8a37e788ee5a6bf32cab"
dependencies = [
"bitflags 1.3.2",
"byteorder",

View File

@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
keyable-arc = { path = "../keyable-arc" }
ostd = { path = "../../../ostd" }
smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", rev = "469dccd", default-features = false, features = [
smoltcp = { git = "https://github.com/asterinas/smoltcp", rev = "37716bf", default-features = false, features = [
"alloc",
"log",
"medium-ethernet",

View File

@ -14,5 +14,15 @@ pub mod tcp {
}
pub mod udp {
pub use smoltcp::socket::udp::{RecvError, SendError};
pub use smoltcp::socket::udp::RecvError;
/// An error returned by [`BoundTcpSocket::recv`].
///
/// [`BoundTcpSocket::recv`]: crate::socket::BoundTcpSocket::recv
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum SendError {
TooLarge,
Unaddressable,
BufferFull,
}
}

View File

@ -7,42 +7,45 @@ use alloc::{
btree_set::BTreeSet,
},
sync::Arc,
vec::Vec,
};
use keyable_arc::KeyableArc;
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLock, SpinLock, SpinLockGuard};
use ostd::sync::{LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard};
use smoltcp::{
iface::{SocketHandle, SocketSet},
iface::{packet::Packet, Context},
phy::Device,
wire::Ipv4Address,
wire::{Ipv4Address, Ipv4Packet},
};
use super::{port::BindPortConfig, time::get_network_timestamp, Iface};
use super::{
poll::{FnHelper, PollContext},
port::BindPortConfig,
time::get_network_timestamp,
Iface,
};
use crate::{
errors::BindError,
socket::{AnyBoundSocket, AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
socket::{
BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket,
UnboundUdpSocket,
},
};
pub struct IfaceCommon<E> {
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
sockets: SpinLock<SocketSet<'static>, LocalIrqDisabled>,
used_ports: SpinLock<BTreeMap<u16, usize>, PreemptDisabled>,
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>>,
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner<E>>>, LocalIrqDisabled>,
tcp_sockets: SpinLock<BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, LocalIrqDisabled>,
udp_sockets: SpinLock<BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, LocalIrqDisabled>,
ext: E,
}
impl<E> IfaceCommon<E> {
pub(super) fn new(interface: smoltcp::iface::Interface, ext: E) -> Self {
let socket_set = SocketSet::new(Vec::new());
let used_ports = BTreeMap::new();
Self {
interface: SpinLock::new(interface),
sockets: SpinLock::new(socket_set),
used_ports: SpinLock::new(used_ports),
bound_sockets: RwLock::new(BTreeSet::new()),
closing_sockets: SpinLock::new(BTreeSet::new()),
used_ports: SpinLock::new(BTreeMap::new()),
tcp_sockets: SpinLock::new(BTreeSet::new()),
udp_sockets: SpinLock::new(BTreeSet::new()),
ext,
}
}
@ -58,58 +61,57 @@ impl<E> IfaceCommon<E> {
impl<E> IfaceCommon<E> {
/// Acquires the lock to the interface.
///
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
pub(crate) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
self.interface.lock()
}
/// Acuqires the lock to the sockets.
///
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
pub(crate) fn sockets(
&self,
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
self.sockets.lock()
}
}
const IP_LOCAL_PORT_START: u16 = 32768;
const IP_LOCAL_PORT_END: u16 = 60999;
impl<E> IfaceCommon<E> {
pub(super) fn bind_socket(
pub(super) fn bind_tcp(
&self,
iface: Arc<dyn Iface<E>>,
socket: Box<AnyUnboundSocket>,
socket: Box<UnboundTcpSocket>,
config: BindPortConfig,
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
let port = if let Some(port) = config.port() {
port
} else {
match self.alloc_ephemeral_port() {
Some(port) => port,
None => return Err((BindError::Exhausted, socket)),
}
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
let port = match self.bind_port(config) {
Ok(port) => port,
Err(err) => return Err((err, socket)),
};
if !self.bind_port(port, config.can_reuse()) {
return Err((BindError::InUse, socket));
let (raw_socket, observer) = socket.into_raw();
let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer);
let inserted = self
.tcp_sockets
.lock()
.insert(KeyableArc::from(bound_socket.inner().clone()));
assert!(inserted);
Ok(bound_socket)
}
let (handle, socket_family, observer) = match socket.into_raw() {
(AnyRawSocket::Tcp(tcp_socket), observer) => (
self.sockets.lock().add(tcp_socket),
SocketFamily::Tcp,
observer,
),
(AnyRawSocket::Udp(udp_socket), observer) => (
self.sockets.lock().add(udp_socket),
SocketFamily::Udp,
observer,
),
pub(super) fn bind_udp(
&self,
iface: Arc<dyn Iface<E>>,
socket: Box<UnboundUdpSocket>,
config: BindPortConfig,
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
let port = match self.bind_port(config) {
Ok(port) => port,
Err(err) => return Err((err, socket)),
};
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
self.insert_bound_socket(bound_socket.inner());
let (raw_socket, observer) = socket.into_raw();
let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer);
let inserted = self
.udp_sockets
.lock()
.insert(KeyableArc::from(bound_socket.inner().clone()));
assert!(inserted);
Ok(bound_socket)
}
@ -130,36 +132,57 @@ impl<E> IfaceCommon<E> {
None
}
#[must_use]
fn bind_port(&self, port: u16, can_reuse: bool) -> bool {
fn bind_port(&self, config: BindPortConfig) -> Result<u16, BindError> {
let port = if let Some(port) = config.port() {
port
} else {
match self.alloc_ephemeral_port() {
Some(port) => port,
None => return Err(BindError::Exhausted),
}
};
let mut used_ports = self.used_ports.lock();
if let Some(used_times) = used_ports.get_mut(&port) {
if *used_times == 0 || can_reuse {
if *used_times == 0 || config.can_reuse() {
// FIXME: Check if the previous socket was bound with SO_REUSEADDR.
*used_times += 1;
} else {
return false;
return Err(BindError::InUse);
}
} else {
used_ports.insert(port, 1);
}
true
}
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
let keyable_socket = KeyableArc::from(socket.clone());
let inserted = self
.bound_sockets
.write_irq_disabled()
.insert(keyable_socket);
assert!(inserted);
Ok(port)
}
}
impl<E> IfaceCommon<E> {
#[allow(clippy::mutable_key_type)]
fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>) {
sockets.retain(|socket| {
if socket.is_dead() {
self.release_port(socket.port());
false
} else {
true
}
});
}
pub(crate) fn remove_udp_socket(&self, socket: &Arc<BoundUdpSocketInner<E>>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self.udp_sockets.lock().remove(&keyable_socket);
assert!(removed);
self.release_port(keyable_socket.port());
}
/// Releases the port so that it can be used again (if it is not being reused).
pub(crate) fn release_port(&self, port: u16) {
fn release_port(&self, port: u16) {
let mut used_ports = self.used_ports.lock();
if let Some(used_times) = used_ports.remove(&port) {
if used_times != 1 {
@ -167,96 +190,63 @@ impl<E> IfaceCommon<E> {
}
}
}
/// Removes a socket from the interface.
pub(crate) fn remove_socket(&self, handle: SocketHandle) {
self.sockets.lock().remove(handle);
}
pub(crate) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self
.bound_sockets
.write_irq_disabled()
.remove(&keyable_socket);
assert!(removed);
}
pub(crate) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner<E>>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self
.bound_sockets
.write_irq_disabled()
.remove(&keyable_socket);
assert!(removed);
let mut closing_sockets = self.closing_sockets.lock();
// Check `is_closed` after holding the lock to avoid race conditions.
if keyable_socket.is_closed() {
return;
}
let inserted = closing_sockets.insert(keyable_socket);
assert!(inserted);
}
}
impl<E> IfaceCommon<E> {
#[must_use]
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) -> Option<u64> {
let mut sockets = self.sockets.lock();
let mut interface = self.interface.lock();
pub(super) fn poll<D, P, Q>(
&self,
device: &mut D,
process_phy: P,
mut dispatch_phy: Q,
) -> Option<u64>
where
D: Device + ?Sized,
P: for<'pkt, 'cx, 'tx> FnHelper<
&'pkt [u8],
&'cx mut Context,
D::TxToken<'tx>,
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
>,
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{
let mut interface = self.interface();
interface.context().now = get_network_timestamp();
let timestamp = get_network_timestamp();
let (has_events, poll_at) = {
let mut has_events = false;
let mut poll_at;
let mut tcp_sockets = self.tcp_sockets.lock();
let udp_sockets = self.udp_sockets.lock();
loop {
// `poll` transmits and receives a bounded number of packets. This loop ensures
// that all packets are transmitted and received. For details, see
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L400-L405>.
while interface.poll(timestamp, device, &mut sockets) {
has_events = true;
let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets);
context.poll_ingress(device, process_phy, &mut dispatch_phy);
context.poll_egress(device, dispatch_phy);
tcp_sockets.iter().for_each(|socket| {
if socket.has_new_events() {
socket.on_iface_events();
}
// `poll_at` can return `Some(Instant::from_millis(0))`, which means `PollAt::Now`.
// For details, see
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L478>.
poll_at = interface.poll_at(timestamp, &sockets);
let Some(instant) = poll_at else {
break;
};
if instant > timestamp {
break;
});
udp_sockets.iter().for_each(|socket| {
if socket.has_new_events() {
socket.on_iface_events();
}
}
(has_events, poll_at)
};
// drop sockets here to avoid deadlock
drop(interface);
drop(sockets);
if has_events {
// We never try to hold the write lock in the IRQ context, and we disable IRQ when
// holding the write lock. So we don't need to disable IRQ when holding the read lock.
self.bound_sockets.read().iter().for_each(|bound_socket| {
bound_socket.on_iface_events();
});
let closed_sockets = self
.closing_sockets
.lock()
.extract_if(|closing_socket| closing_socket.is_closed())
.collect::<Vec<_>>();
drop(closed_sockets);
}
self.remove_dead_tcp_sockets(&mut tcp_sockets);
poll_at.map(|at| smoltcp::time::Instant::total_millis(&at) as u64)
match (
tcp_sockets
.iter()
.map(|socket| socket.next_poll_at_ms())
.min(),
udp_sockets
.iter()
.map(|socket| socket.next_poll_at_ms())
.min(),
) {
(Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => {
Some(tcp_poll_at)
}
(tcp_poll_at, None) => tcp_poll_at,
(_, udp_poll_at) => udp_poll_at,
}
}
}

View File

@ -7,7 +7,7 @@ use smoltcp::wire::Ipv4Address;
use super::port::BindPortConfig;
use crate::{
errors::BindError,
socket::{AnyBoundSocket, AnyUnboundSocket},
socket::{BoundTcpSocket, BoundUdpSocket, UnboundTcpSocket, UnboundUdpSocket},
};
/// A network interface.
@ -42,13 +42,22 @@ impl<E> dyn Iface<E> {
/// FIXME: The reason for binding the socket and the iface together is because there are
/// limitations inside smoltcp. See discussion at
/// <https://github.com/smoltcp-rs/smoltcp/issues/779>.
pub fn bind_socket(
pub fn bind_tcp(
self: &Arc<Self>,
socket: Box<AnyUnboundSocket>,
socket: Box<UnboundTcpSocket>,
config: BindPortConfig,
) -> core::result::Result<AnyBoundSocket<E>, (BindError, Box<AnyUnboundSocket>)> {
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> {
let common = self.common();
common.bind_socket(self.clone(), socket, config)
common.bind_tcp(self.clone(), socket, config)
}
pub fn bind_udp(
self: &Arc<Self>,
socket: Box<UnboundUdpSocket>,
config: BindPortConfig,
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
let common = self.common();
common.bind_udp(self.clone(), socket, config)
}
/// Gets the IPv4 address of the iface, if any.

View File

@ -4,6 +4,7 @@ mod common;
#[allow(clippy::module_inception)]
mod iface;
mod phy;
mod poll;
mod port;
mod time;

View File

@ -1,10 +1,15 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::sync::Arc;
use alloc::{collections::btree_map::BTreeMap, sync::Arc};
use ostd::sync::{LocalIrqDisabled, SpinLock};
use smoltcp::{
iface::Config,
wire::{self, EthernetAddress, Ipv4Address, Ipv4Cidr},
iface::{packet::Packet, Config, Context},
phy::{DeviceCapabilities, TxToken},
wire::{
self, ArpOperation, ArpPacket, ArpRepr, EthernetAddress, EthernetFrame, EthernetProtocol,
EthernetRepr, IpAddress, Ipv4Address, Ipv4Cidr, Ipv4Packet,
},
};
use crate::{
@ -14,9 +19,11 @@ use crate::{
},
};
pub struct EtherIface<D: WithDevice, E> {
pub struct EtherIface<D, E> {
driver: D,
common: IfaceCommon<E>,
ether_addr: EthernetAddress,
arp_table: SpinLock<BTreeMap<Ipv4Address, EthernetAddress>, LocalIrqDisabled>,
}
impl<D: WithDevice, E> EtherIface<D, E> {
@ -45,21 +52,224 @@ impl<D: WithDevice, E> EtherIface<D, E> {
let common = IfaceCommon::new(interface, ext);
Arc::new(Self { driver, common })
Arc::new(Self {
driver,
common,
ether_addr,
arp_table: SpinLock::new(BTreeMap::new()),
})
}
}
impl<D: WithDevice, E> IfaceInternal<E> for EtherIface<D, E> {
impl<D, E> IfaceInternal<E> for EtherIface<D, E> {
fn common(&self) -> &IfaceCommon<E> {
&self.common
}
}
impl<D: WithDevice, E: Send + Sync> Iface<E> for EtherIface<D, E> {
impl<D: WithDevice + 'static, E: Send + Sync> Iface<E> for EtherIface<D, E> {
fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option<u64>)) {
self.driver.with(|device| {
let next_poll = self.common.poll(&mut *device);
let next_poll = self.common.poll(
&mut *device,
|data, iface_cx, tx_token| self.process(data, iface_cx, tx_token),
|pkt, iface_cx, tx_token| self.dispatch(pkt, iface_cx, tx_token),
);
schedule_next_poll(next_poll);
});
}
}
impl<D, E> EtherIface<D, E> {
fn process<'pkt, T: TxToken>(
&self,
data: &'pkt [u8],
iface_cx: &mut Context,
tx_token: T,
) -> Option<(Ipv4Packet<&'pkt [u8]>, T)> {
match self.parse_ip_or_process_arp(data, iface_cx) {
Ok(pkt) => Some((pkt, tx_token)),
Err(Some(arp)) => {
Self::emit_arp(&arp, tx_token);
None
}
Err(None) => None,
}
}
fn parse_ip_or_process_arp<'pkt>(
&self,
data: &'pkt [u8],
iface_cx: &mut Context,
) -> Result<Ipv4Packet<&'pkt [u8]>, Option<ArpRepr>> {
// Parse the Ethernet header. Ignore the packet if the header is ill-formed.
let frame = EthernetFrame::new_checked(data).map_err(|_| None)?;
let repr = EthernetRepr::parse(&frame).map_err(|_| None)?;
// Ignore the Ethernet frame if it is not sent to us.
if !repr.dst_addr.is_broadcast() && repr.dst_addr != self.ether_addr {
return Err(None);
}
// Ignore the Ethernet frame if the protocol is not supported.
match repr.ethertype {
EthernetProtocol::Ipv4 => {
Ok(Ipv4Packet::new_checked(frame.payload()).map_err(|_| None)?)
}
EthernetProtocol::Arp => {
let pkt = ArpPacket::new_checked(frame.payload()).map_err(|_| None)?;
let arp = ArpRepr::parse(&pkt).map_err(|_| None)?;
Err(self.process_arp(&arp, iface_cx))
}
_ => Err(None),
}
}
fn process_arp(&self, arp_repr: &ArpRepr, iface_cx: &mut Context) -> Option<ArpRepr> {
match arp_repr {
ArpRepr::EthernetIpv4 {
operation: ArpOperation::Reply,
source_hardware_addr,
source_protocol_addr,
..
} => {
// Ignore the ARP packet if the source addresses are not unicast or not local.
if !source_hardware_addr.is_unicast()
|| !iface_cx.in_same_network(&IpAddress::Ipv4(*source_protocol_addr))
{
return None;
}
// Insert the mapping between the Ethernet address and the IP address.
//
// TODO: Remove the mapping if it expires.
self.arp_table
.lock()
.insert(*source_protocol_addr, *source_hardware_addr);
None
}
ArpRepr::EthernetIpv4 {
operation: ArpOperation::Request,
source_hardware_addr,
source_protocol_addr,
target_protocol_addr,
..
} => {
// Ignore the ARP packet if the source addresses are not unicast.
if !source_hardware_addr.is_unicast() || !source_protocol_addr.is_unicast() {
return None;
}
// Ignore the ARP packet if we do not own the target address.
if !iface_cx
.ipv4_addr()
.is_some_and(|addr| addr == *target_protocol_addr)
{
return None;
}
Some(ArpRepr::EthernetIpv4 {
operation: ArpOperation::Reply,
source_hardware_addr: self.ether_addr,
source_protocol_addr: *target_protocol_addr,
target_hardware_addr: *source_hardware_addr,
target_protocol_addr: *source_protocol_addr,
})
}
_ => None,
}
}
fn dispatch<T: TxToken>(&self, pkt: &Packet, iface_cx: &mut Context, tx_token: T) {
match self.resolve_ether_or_generate_arp(pkt, iface_cx) {
Ok(ether) => Self::emit_ip(&ether, pkt, &iface_cx.caps, tx_token),
Err(Some(arp)) => Self::emit_arp(&arp, tx_token),
Err(None) => (),
}
}
fn resolve_ether_or_generate_arp(
&self,
pkt: &Packet,
iface_cx: &mut Context,
) -> Result<EthernetRepr, Option<ArpRepr>> {
// Resolve the next-hop IP address.
let next_hop_ip = match iface_cx.route(&pkt.ip_repr().dst_addr(), iface_cx.now()) {
Some(IpAddress::Ipv4(next_hop_ip)) => next_hop_ip,
None => return Err(None),
};
// Resolve the next-hop Ethernet address.
let next_hop_ether = if next_hop_ip.is_broadcast() {
EthernetAddress::BROADCAST
} else if let Some(next_hop_ether) = self.arp_table.lock().get(&next_hop_ip) {
*next_hop_ether
} else {
// If the next-hop Ethernet address cannot be resolved, we drop the original packet and
// send an ARP packet instead. The upper layer should be responsible for detecting the
// packet loss and retrying later to see if the Ethernet address is ready.
return Err(Some(ArpRepr::EthernetIpv4 {
operation: ArpOperation::Request,
source_hardware_addr: self.ether_addr,
source_protocol_addr: iface_cx.ipv4_addr().unwrap_or(Ipv4Address::UNSPECIFIED),
target_hardware_addr: EthernetAddress::BROADCAST,
target_protocol_addr: next_hop_ip,
}));
};
Ok(EthernetRepr {
src_addr: self.ether_addr,
dst_addr: next_hop_ether,
ethertype: EthernetProtocol::Ipv4,
})
}
/// Consumes the token and emits an IP packet.
fn emit_ip<T: TxToken>(
ether_repr: &EthernetRepr,
ip_pkt: &Packet,
caps: &DeviceCapabilities,
tx_token: T,
) {
tx_token.consume(
ether_repr.buffer_len() + ip_pkt.ip_repr().buffer_len(),
|buffer| {
let mut frame = EthernetFrame::new_unchecked(buffer);
ether_repr.emit(&mut frame);
let ip_repr = ip_pkt.ip_repr();
ip_repr.emit(frame.payload_mut(), &caps.checksum);
ip_pkt.emit_payload(
&ip_repr,
&mut frame.payload_mut()[ip_repr.header_len()..],
caps,
);
},
);
}
/// Consumes the token and emits an ARP packet.
fn emit_arp<T: TxToken>(arp_repr: &ArpRepr, tx_token: T) {
let ether_repr = match arp_repr {
ArpRepr::EthernetIpv4 {
source_hardware_addr,
target_hardware_addr,
..
} => EthernetRepr {
src_addr: *source_hardware_addr,
dst_addr: *target_hardware_addr,
ethertype: EthernetProtocol::Arp,
},
_ => return,
};
tx_token.consume(ether_repr.buffer_len() + arp_repr.buffer_len(), |buffer| {
let mut frame = EthernetFrame::new_unchecked(buffer);
ether_repr.emit(&mut frame);
let mut pkt = ArpPacket::new_unchecked(frame.payload_mut());
arp_repr.emit(&mut pkt);
});
}
}

View File

@ -4,7 +4,8 @@ use alloc::sync::Arc;
use smoltcp::{
iface::Config,
wire::{self, Ipv4Cidr},
phy::TxToken,
wire::{self, Ipv4Cidr, Ipv4Packet},
};
use crate::{
@ -14,7 +15,7 @@ use crate::{
},
};
pub struct IpIface<D: WithDevice, E> {
pub struct IpIface<D, E> {
driver: D,
common: IfaceCommon<E>,
}
@ -39,16 +40,30 @@ impl<D: WithDevice, E> IpIface<D, E> {
}
}
impl<D: WithDevice, E> IfaceInternal<E> for IpIface<D, E> {
impl<D, E> IfaceInternal<E> for IpIface<D, E> {
fn common(&self) -> &IfaceCommon<E> {
&self.common
}
}
impl<D: WithDevice, E: Send + Sync> Iface<E> for IpIface<D, E> {
impl<D: WithDevice + 'static, E: Send + Sync> Iface<E> for IpIface<D, E> {
fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option<u64>)) {
self.driver.with(|device| {
let next_poll = self.common.poll(device);
let next_poll = self.common.poll(
device,
|data, _iface_cx, tx_token| Some((Ipv4Packet::new_checked(data).ok()?, tx_token)),
|pkt, iface_cx, tx_token| {
let ip_repr = pkt.ip_repr();
tx_token.consume(ip_repr.buffer_len(), |buffer| {
ip_repr.emit(&mut buffer[..], &iface_cx.checksum_caps());
pkt.emit_payload(
&ip_repr,
&mut buffer[ip_repr.header_len()..],
&iface_cx.caps,
);
});
},
);
schedule_next_poll(next_poll);
});
}

View File

@ -0,0 +1,476 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{collections::btree_set::BTreeSet, vec};
use keyable_arc::KeyableArc;
use smoltcp::{
iface::{
packet::{icmp_reply_payload_len, IpPayload, Packet},
Context,
},
phy::{ChecksumCapabilities, Device, RxToken, TxToken},
wire::{
Icmpv4DstUnreachable, Icmpv4Repr, IpAddress, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet,
Ipv4Repr, TcpControl, TcpPacket, TcpRepr, UdpPacket, UdpRepr, IPV4_HEADER_LEN,
IPV4_MIN_MTU,
},
};
use crate::socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
pub(super) struct PollContext<'a, E> {
iface_cx: &'a mut Context,
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
}
impl<'a, E> PollContext<'a, E> {
#[allow(clippy::mutable_key_type)]
pub(super) fn new(
iface_cx: &'a mut Context,
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>,
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>,
) -> Self {
Self {
iface_cx,
tcp_sockets,
udp_sockets,
}
}
}
// This works around <https://github.com/rust-lang/rust/issues/49601>.
// See the issue above for details.
pub(super) trait FnHelper<A, B, C, O>: FnMut(A, B, C) -> O {}
impl<A, B, C, O, F> FnHelper<A, B, C, O> for F where F: FnMut(A, B, C) -> O {}
impl<'a, E> PollContext<'a, E> {
pub(super) fn poll_ingress<D, P, Q>(
&mut self,
device: &mut D,
mut process_phy: P,
dispatch_phy: &mut Q,
) where
D: Device + ?Sized,
P: for<'pkt, 'cx, 'tx> FnHelper<
&'pkt [u8],
&'cx mut Context,
D::TxToken<'tx>,
Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>,
>,
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{
while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) {
rx_token.consume(|data| {
let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else {
return;
};
let Some(reply) = self.parse_and_process_ipv4(pkt) else {
return;
};
dispatch_phy(&reply, self.iface_cx, tx_token);
});
}
}
fn parse_and_process_ipv4<'pkt>(
&mut self,
pkt: Ipv4Packet<&'pkt [u8]>,
) -> Option<Packet<'pkt>> {
// Parse the IP header. Ignore the packet if the header is ill-formed.
let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?;
if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) {
return self.generate_icmp_unreachable(
&IpRepr::Ipv4(repr),
pkt.payload(),
Icmpv4DstUnreachable::HostUnreachable,
);
}
match repr.next_header {
IpProtocol::Tcp => self.parse_and_process_tcp(
&IpRepr::Ipv4(repr),
pkt.payload(),
&self.iface_cx.checksum_caps(),
),
IpProtocol::Udp => self.parse_and_process_udp(
&IpRepr::Ipv4(repr),
pkt.payload(),
&self.iface_cx.checksum_caps(),
),
_ => None,
}
}
fn parse_and_process_tcp<'pkt>(
&mut self,
ip_repr: &IpRepr,
ip_payload: &'pkt [u8],
checksum_caps: &ChecksumCapabilities,
) -> Option<Packet<'pkt>> {
// TCP connections can only be established between unicast addresses. Ignore the packet if
// this is not the case. See
// <https://datatracker.ietf.org/doc/html/rfc9293#section-3.9.2.3>.
if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() {
return None;
}
// Parse the TCP header. Ignore the packet if the header is ill-formed.
let tcp_pkt = TcpPacket::new_checked(ip_payload).ok()?;
let tcp_repr = TcpRepr::parse(
&tcp_pkt,
&ip_repr.src_addr(),
&ip_repr.dst_addr(),
checksum_caps,
)
.ok()?;
self.process_tcp_until_outgoing(ip_repr, &tcp_repr)
.map(|(ip_repr, tcp_repr)| Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)))
}
fn process_tcp_until_outgoing(
&mut self,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> Option<(IpRepr, TcpRepr<'static>)> {
let (mut ip_repr, mut tcp_repr) = self.process_tcp(ip_repr, tcp_repr)?;
loop {
if !self.is_unicast_local(ip_repr.dst_addr()) {
return Some((ip_repr, tcp_repr));
}
let (new_ip_repr, new_tcp_repr) = self.process_tcp(&ip_repr, &tcp_repr)?;
ip_repr = new_ip_repr;
tcp_repr = new_tcp_repr;
}
}
fn process_tcp(
&mut self,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> Option<(IpRepr, TcpRepr<'static>)> {
for socket in self.tcp_sockets.iter() {
if !socket.can_process(tcp_repr.dst_port) {
continue;
}
match socket.process(self.iface_cx, ip_repr, tcp_repr) {
TcpProcessResult::NotProcessed => continue,
TcpProcessResult::Processed => return None,
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
return Some((ip_repr, tcp_repr))
}
}
}
// "In no case does receipt of a segment containing RST give rise to a RST in response."
// See <https://datatracker.ietf.org/doc/html/rfc9293#section-4-1.64>.
if tcp_repr.control == TcpControl::Rst {
return None;
}
Some(smoltcp::socket::tcp::Socket::rst_reply(ip_repr, tcp_repr))
}
fn parse_and_process_udp<'pkt>(
&mut self,
ip_repr: &IpRepr,
ip_payload: &'pkt [u8],
checksum_caps: &ChecksumCapabilities,
) -> Option<Packet<'pkt>> {
// Parse the UDP header. Ignore the packet if the header is ill-formed.
let udp_pkt = UdpPacket::new_checked(ip_payload).ok()?;
let udp_repr = UdpRepr::parse(
&udp_pkt,
&ip_repr.src_addr(),
&ip_repr.dst_addr(),
checksum_caps,
)
.ok()?;
if !self.process_udp(ip_repr, &udp_repr, udp_pkt.payload()) {
return self.generate_icmp_unreachable(
ip_repr,
ip_payload,
Icmpv4DstUnreachable::PortUnreachable,
);
}
None
}
fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool {
let mut processed = false;
for socket in self.udp_sockets.iter() {
if !socket.can_process(udp_repr.dst_port) {
continue;
}
processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload);
if processed && ip_repr.dst_addr().is_unicast() {
break;
}
}
processed
}
fn generate_icmp_unreachable<'pkt>(
&self,
ip_repr: &IpRepr,
ip_payload: &'pkt [u8],
reason: Icmpv4DstUnreachable,
) -> Option<Packet<'pkt>> {
if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() {
return None;
}
if self.is_unicast_local(ip_repr.src_addr()) {
// In this case, the generating ICMP message will have a local IP address as the
// destination. However, since we don't have the ability to handle ICMP messages, we'll
// just skip the generation.
//
// TODO: Generate the ICMP message here once we're able to handle incoming ICMP
// messages.
return None;
}
let IpRepr::Ipv4(ipv4_repr) = ip_repr;
let reply_len = icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, IPV4_HEADER_LEN);
let icmp_repr = Icmpv4Repr::DstUnreachable {
reason,
header: *ipv4_repr,
data: &ip_payload[..reply_len],
};
Some(Packet::new_ipv4(
Ipv4Repr {
src_addr: self
.iface_cx
.ipv4_addr()
.unwrap_or(Ipv4Address::UNSPECIFIED),
dst_addr: ipv4_repr.src_addr,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
},
IpPayload::Icmpv4(icmp_repr),
))
}
/// Returns whether the destination address is the unicast address of a local interface.
///
/// Note: "local" means that the IP address belongs to the local interface, not to be confused
/// with the localhost IP (127.0.0.1).
fn is_unicast_local(&self, dst_addr: IpAddress) -> bool {
match dst_addr {
IpAddress::Ipv4(dst_addr) => self
.iface_cx
.ipv4_addr()
.is_some_and(|addr| addr == dst_addr),
}
}
}
impl<'a, E> PollContext<'a, E> {
pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, mut dispatch_phy: Q)
where
D: Device + ?Sized,
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{
while let Some(tx_token) = device.transmit(self.iface_cx.now()) {
if !self.dispatch_ipv4(tx_token, &mut dispatch_phy) {
break;
}
}
}
fn dispatch_ipv4<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> bool
where
T: TxToken,
Q: FnMut(&Packet, &mut Context, T),
{
let (did_something_tcp, tx_token) = self.dispatch_tcp(tx_token, dispatch_phy);
let Some(tx_token) = tx_token else {
return did_something_tcp;
};
let (did_something_udp, _tx_token) = self.dispatch_udp(tx_token, dispatch_phy);
did_something_tcp || did_something_udp
}
fn dispatch_tcp<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option<T>)
where
T: TxToken,
Q: FnMut(&Packet, &mut Context, T),
{
let mut tx_token = Some(tx_token);
let mut did_something = false;
for socket in self.tcp_sockets.iter() {
if !socket.need_dispatch(self.iface_cx.now()) {
continue;
}
// We set `did_something` even if no packets are actually generated. This is because a
// timer can expire, but no packets are actually generated.
did_something = true;
let mut deferred = None;
let reply = socket.dispatch(self.iface_cx, |cx, ip_repr, tcp_repr| {
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets);
if !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy(
&Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)),
this.iface_cx,
tx_token.take().unwrap(),
);
return None;
}
if !socket.can_process(tcp_repr.dst_port) {
return this.process_tcp(ip_repr, tcp_repr);
}
// We cannot call `process_tcp` now because it may cause deadlocks. We will copy
// the packet and call `process_tcp` after releasing the socket lock.
deferred = Some((ip_repr.clone(), {
let mut data = vec![0; tcp_repr.buffer_len()];
tcp_repr.emit(
&mut TcpPacket::new_unchecked(data.as_mut_slice()),
&ip_repr.src_addr(),
&ip_repr.dst_addr(),
&ChecksumCapabilities::ignored(),
);
data
}));
None
});
match (deferred, reply) {
(None, None) => (),
(Some((ip_repr, ip_payload)), None) => {
if let Some(reply) = self.parse_and_process_tcp(
&ip_repr,
&ip_payload,
&ChecksumCapabilities::ignored(),
) {
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
}
}
(None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => {
dispatch_phy(
&Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)),
self.iface_cx,
tx_token.take().unwrap(),
);
}
(None, Some((ip_repr, tcp_repr))) => {
if let Some((new_ip_repr, new_tcp_repr)) =
self.process_tcp_until_outgoing(&ip_repr, &tcp_repr)
{
dispatch_phy(
&Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)),
self.iface_cx,
tx_token.take().unwrap(),
);
}
}
(Some(_), Some(_)) => unreachable!(),
}
if tx_token.is_none() {
break;
}
}
(did_something, tx_token)
}
fn dispatch_udp<T, Q>(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option<T>)
where
T: TxToken,
Q: FnMut(&Packet, &mut Context, T),
{
let mut tx_token = Some(tx_token);
let mut did_something = false;
for socket in self.udp_sockets.iter() {
if !socket.need_dispatch(self.iface_cx.now()) {
continue;
}
// We set `did_something` even if no packets are actually generated. This is because a
// timer can expire, but no packets are actually generated.
did_something = true;
let mut deferred = None;
socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| {
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets);
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy(
&Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)),
this.iface_cx,
tx_token.take().unwrap(),
);
if !ip_repr.dst_addr().is_broadcast() {
return;
}
}
if !socket.can_process(udp_repr.dst_port) {
// TODO: Generate the ICMP message here once we're able to handle incoming ICMP
// messages.
let _ = this.process_udp(ip_repr, udp_repr, udp_payload);
return;
}
// We cannot call `process_udp` now because it may cause deadlocks. We will copy
// the packet and call `process_udp` after releasing the socket lock.
deferred = Some((ip_repr.clone(), {
let mut data = vec![0; udp_repr.header_len() + udp_payload.len()];
udp_repr.emit(
&mut UdpPacket::new_unchecked(&mut data),
&ip_repr.src_addr(),
&ip_repr.dst_addr(),
udp_payload.len(),
|payload| payload.copy_from_slice(udp_payload),
&ChecksumCapabilities::ignored(),
);
data
}));
});
if let Some((ip_repr, ip_payload)) = deferred {
if let Some(reply) = self.parse_and_process_udp(
&ip_repr,
&ip_payload,
&ChecksumCapabilities::ignored(),
) {
dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap());
}
}
if tx_token.is_none() {
break;
}
}
(did_something, tx_token)
}
}

View File

@ -1,44 +1,188 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::sync::{Arc, Weak};
use alloc::{
boxed::Box,
sync::{Arc, Weak},
};
use core::{
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, AtomicU64, Ordering},
};
use ostd::sync::RwLock;
use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard};
use smoltcp::{
socket::tcp::ConnectError,
wire::{IpAddress, IpEndpoint},
iface::Context,
socket::{udp::UdpMetadata, PollAt},
time::Instant,
wire::{IpAddress, IpEndpoint, IpRepr, TcpRepr, UdpRepr},
};
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
use crate::iface::Iface;
pub(crate) enum SocketFamily {
Tcp,
Udp,
pub struct BoundSocket<T: AnySocket, E>(Arc<BoundSocketInner<T, E>>);
/// [`TcpSocket`] or [`UdpSocket`].
pub trait AnySocket {
type RawSocket;
/// Called by [`BoundSocket::new`].
fn new(socket: Box<Self::RawSocket>) -> Self;
/// Called by [`BoundSocket::drop`].
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>)
where
Self: Sized;
}
pub struct AnyBoundSocket<E>(Arc<AnyBoundSocketInner<E>>);
pub type BoundTcpSocket<E> = BoundSocket<TcpSocket, E>;
pub type BoundUdpSocket<E> = BoundSocket<UdpSocket, E>;
impl<E> AnyBoundSocket<E> {
/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`].
pub struct BoundSocketInner<T, E> {
iface: Arc<dyn Iface<E>>,
port: u16,
socket: T,
observer: RwLock<Weak<dyn SocketEventObserver>>,
next_poll_at_ms: AtomicU64,
has_new_events: AtomicBool,
}
/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`].
pub struct TcpSocket {
socket: SpinLock<RawTcpSocketExt, LocalIrqDisabled>,
is_dead: AtomicBool,
}
struct RawTcpSocketExt {
socket: Box<RawTcpSocket>,
/// Whether the socket is in the background.
///
/// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means
/// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a
/// state of waiting for certain network events (e.g., remote FIN/ACK packets), so
/// [`BoundSocketInner`] may still be alive for a while.
in_background: bool,
}
impl Deref for RawTcpSocketExt {
type Target = RawTcpSocket;
fn deref(&self) -> &Self::Target {
&self.socket
}
}
impl DerefMut for RawTcpSocketExt {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.socket
}
}
impl TcpSocket {
fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> {
self.socket.lock()
}
/// Returns whether the TCP socket is dead.
///
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
fn is_dead(&self) -> bool {
self.is_dead.load(Ordering::Relaxed)
}
/// Updates whether the TCP socket is dead.
///
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
///
/// This method must be called after handling network events. However, it is not necessary to
/// call this method after handling non-closing user events, because the socket can never be
/// dead if user events can reach the socket.
fn update_dead(&self, socket: &RawTcpSocketExt) {
if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed {
self.is_dead.store(true, Ordering::Relaxed);
}
}
}
impl AnySocket for TcpSocket {
type RawSocket = RawTcpSocket;
fn new(socket: Box<Self::RawSocket>) -> Self {
let socket_ext = RawTcpSocketExt {
socket,
in_background: false,
};
Self {
socket: SpinLock::new(socket_ext),
is_dead: AtomicBool::new(false),
}
}
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) {
let mut socket = this.socket.lock();
socket.in_background = true;
socket.close();
// A TCP socket may not be appropriate for immediate removal. We leave the removal decision
// to the polling logic.
this.update_next_poll_at_ms(PollAt::Now);
this.socket.update_dead(&socket);
}
}
/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`].
type UdpSocket = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
impl AnySocket for UdpSocket {
type RawSocket = RawUdpSocket;
fn new(socket: Box<Self::RawSocket>) -> Self {
Self::new(socket)
}
fn on_drop<E>(this: &Arc<BoundSocketInner<Self, E>>) {
this.socket.lock().close();
// A UDP socket can be removed immediately.
this.iface.common().remove_udp_socket(this);
}
}
impl<T: AnySocket, E> Drop for BoundSocket<T, E> {
fn drop(&mut self) {
T::on_drop(&self.0);
}
}
pub(crate) type BoundTcpSocketInner<E> = BoundSocketInner<TcpSocket, E>;
pub(crate) type BoundUdpSocketInner<E> = BoundSocketInner<UdpSocket, E>;
impl<T: AnySocket, E> BoundSocket<T, E> {
pub(crate) fn new(
iface: Arc<dyn Iface<E>>,
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
socket: Box<T::RawSocket>,
observer: Weak<dyn SocketEventObserver>,
) -> Self {
Self(Arc::new(AnyBoundSocketInner {
Self(Arc::new(BoundSocketInner {
iface,
handle,
port,
socket_family,
socket: T::new(socket),
observer: RwLock::new(observer),
next_poll_at_ms: AtomicU64::new(u64::MAX),
has_new_events: AtomicBool::new(false),
}))
}
pub(crate) fn inner(&self) -> &Arc<AnyBoundSocketInner<E>> {
pub(crate) fn inner(&self) -> &Arc<BoundSocketInner<T, E>> {
&self.0
}
}
impl<T: AnySocket, E> BoundSocket<T, E> {
/// Sets the observer whose `on_events` will be called when certain iface events happen. After
/// setting, the new observer will fire once immediately to avoid missing any events.
///
@ -51,6 +195,15 @@ impl<E> AnyBoundSocket<E> {
self.0.on_iface_events();
}
/// Returns the observer.
///
/// See also [`Self::set_observer`].
pub fn observer(&self) -> Weak<dyn SocketEventObserver> {
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
// get the read lock.
self.0.observer.read().clone()
}
pub fn local_endpoint(&self) -> Option<IpEndpoint> {
let ip_addr = {
let ipv4_addr = self.0.iface.ipv4_addr()?;
@ -59,58 +212,155 @@ impl<E> AnyBoundSocket<E> {
Some(IpEndpoint::new(ip_addr, self.0.port))
}
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
&self,
f: F,
) -> R {
self.0.raw_with(f)
}
/// Connects to a remote endpoint.
///
/// # Panics
///
/// This method will panic if the socket is not a TCP socket.
pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<(), ConnectError> {
let common = self.iface().common();
let mut sockets = common.sockets();
let socket = sockets.get_mut::<RawTcpSocket>(self.0.handle);
let mut iface = common.interface();
let cx = iface.context();
socket.connect(cx, remote_endpoint, self.0.port)
}
pub fn iface(&self) -> &Arc<dyn Iface<E>> {
&self.0.iface
}
}
impl<E> Drop for AnyBoundSocket<E> {
fn drop(&mut self) {
if self.0.start_closing() {
self.0.iface.common().remove_bound_socket_now(&self.0);
} else {
impl<E> BoundTcpSocket<E> {
/// Connects to a remote endpoint.
pub fn connect(
&self,
remote_endpoint: IpEndpoint,
) -> Result<(), smoltcp::socket::tcp::ConnectError> {
let common = self.iface().common();
let mut iface = common.interface();
let mut socket = self.0.socket.lock();
let result = socket.connect(iface.context(), remote_endpoint, self.0.port);
self.0
.iface
.common()
.remove_bound_socket_when_closed(&self.0);
.update_next_poll_at_ms(socket.poll_at(iface.context()));
result
}
/// Listens at a specified endpoint.
pub fn listen(
&self,
local_endpoint: IpEndpoint,
) -> Result<(), smoltcp::socket::tcp::ListenError> {
let mut socket = self.0.socket.lock();
socket.listen(local_endpoint)
}
pub fn send<F, R>(&self, f: F) -> Result<R, smoltcp::socket::tcp::SendError>
where
F: FnOnce(&mut [u8]) -> (usize, R),
{
let mut socket = self.0.socket.lock();
let result = socket.send(f);
self.0.update_next_poll_at_ms(PollAt::Now);
result
}
pub fn recv<F, R>(&self, f: F) -> Result<R, smoltcp::socket::tcp::RecvError>
where
F: FnOnce(&mut [u8]) -> (usize, R),
{
let mut socket = self.0.socket.lock();
let result = socket.recv(f);
self.0.update_next_poll_at_ms(PollAt::Now);
result
}
pub fn close(&self) {
let mut socket = self.0.socket.lock();
socket.close();
self.0.update_next_poll_at_ms(PollAt::Now);
}
/// Calls `f` with an immutable reference to the associated [`RawTcpSocket`].
//
// NOTE: If a mutable reference is required, add a method above that correctly updates the next
// polling time.
pub fn raw_with<F, R>(&self, f: F) -> R
where
F: FnOnce(&RawTcpSocket) -> R,
{
let socket = self.0.socket.lock();
f(&socket)
}
}
pub(crate) struct AnyBoundSocketInner<E> {
iface: Arc<dyn Iface<E>>,
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
observer: RwLock<Weak<dyn SocketEventObserver>>,
impl<E> BoundUdpSocket<E> {
/// Binds to a specified endpoint.
pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> {
let mut socket = self.0.socket.lock();
socket.bind(local_endpoint)
}
pub fn send<F, R>(
&self,
size: usize,
meta: impl Into<UdpMetadata>,
f: F,
) -> Result<R, crate::errors::udp::SendError>
where
F: FnOnce(&mut [u8]) -> R,
{
use smoltcp::socket::udp::SendError as SendErrorInner;
use crate::errors::udp::SendError;
let mut socket = self.0.socket.lock();
if size > socket.packet_send_capacity() {
return Err(SendError::TooLarge);
}
let buffer = match socket.send(size, meta) {
Ok(data) => data,
Err(SendErrorInner::Unaddressable) => return Err(SendError::Unaddressable),
Err(SendErrorInner::BufferFull) => return Err(SendError::BufferFull),
};
let result = f(buffer);
self.0.update_next_poll_at_ms(PollAt::Now);
Ok(result)
}
pub fn recv<F, R>(&self, f: F) -> Result<R, smoltcp::socket::udp::RecvError>
where
F: FnOnce(&[u8], UdpMetadata) -> R,
{
let mut socket = self.0.socket.lock();
let (data, meta) = socket.recv()?;
let result = f(data, meta);
self.0.update_next_poll_at_ms(PollAt::Now);
Ok(result)
}
/// Calls `f` with an immutable reference to the associated [`RawUdpSocket`].
//
// NOTE: If a mutable reference is required, add a method above that correctly updates the next
// polling time.
pub fn raw_with<F, R>(&self, f: F) -> R
where
F: FnOnce(&RawUdpSocket) -> R,
{
let socket = self.0.socket.lock();
f(&socket)
}
}
impl<T, E> BoundSocketInner<T, E> {
pub(crate) fn has_new_events(&self) -> bool {
self.has_new_events.load(Ordering::Relaxed)
}
impl<E> AnyBoundSocketInner<E> {
pub(crate) fn on_iface_events(&self) {
self.has_new_events.store(false, Ordering::Relaxed);
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
// get the read lock.
let observer = Weak::upgrade(&*self.observer.read());
@ -120,51 +370,173 @@ impl<E> AnyBoundSocketInner<E> {
}
}
pub(crate) fn is_closed(&self) -> bool {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
socket.state() == smoltcp::socket::tcp::State::Closed
}),
SocketFamily::Udp => true,
/// Returns the next polling time.
///
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
/// before new network or user events.
pub(crate) fn next_poll_at_ms(&self) -> u64 {
self.next_poll_at_ms.load(Ordering::Relaxed)
}
/// Updates the next polling time according to `poll_at`.
///
/// The update is typically needed after new network or user events have been handled, so this
/// method also marks that there may be new events, so that the event observer provided by
/// [`BoundSocket::set_observer`] can be notified later.
fn update_next_poll_at_ms(&self, poll_at: PollAt) {
self.has_new_events.store(true, Ordering::Relaxed);
match poll_at {
PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed),
PollAt::Time(instant) => self
.next_poll_at_ms
.store(instant.total_millis() as u64, Ordering::Relaxed),
PollAt::Ingress => self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed),
}
}
}
/// Starts closing the socket and returns whether the socket is closed.
impl<T, E> BoundSocketInner<T, E> {
pub(crate) fn port(&self) -> u16 {
self.port
}
}
impl<E> BoundTcpSocketInner<E> {
/// Returns whether the TCP socket is dead.
///
/// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets,
/// this method will always return `true`.
/// A TCP socket is considered dead if and only if the following two conditions are met:
/// 1. The TCP connection is closed, so this socket cannot process any network events.
/// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this
/// [`BoundSocketInner`] is in background and no more user events can reach it.
pub(crate) fn is_dead(&self) -> bool {
self.socket.is_dead()
}
}
impl<T, E> BoundSocketInner<T, E> {
/// Returns whether an incoming packet _may_ be processed by the socket.
///
/// For other sockets, such as TCP connected sockets, they cannot be closed immediately because
/// we at least need to send the FIN packet and wait for the remote end to send an ACK packet.
/// In this case, this method will return `false` and [`Self::is_closed`] can be used to
/// determine if the closing process is complete.
fn start_closing(&self) -> bool {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
socket.state() == smoltcp::socket::tcp::State::Closed
}),
SocketFamily::Udp => {
self.raw_with(|socket: &mut RawUdpSocket| socket.close());
/// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn can_process(&self, dst_port: u16) -> bool {
self.port == dst_port
}
/// Returns whether the socket _may_ generate an outgoing packet.
///
/// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn need_dispatch(&self, now: Instant) -> bool {
now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed)
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub(crate) enum TcpProcessResult {
NotProcessed,
Processed,
ProcessedWithReply(IpRepr, TcpRepr<'static>),
}
impl<E> BoundTcpSocketInner<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
&self,
cx: &mut Context,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> TcpProcessResult {
let mut socket = self.socket.lock();
if !socket.accepts(cx, ip_repr, tcp_repr) {
return TcpProcessResult::NotProcessed;
}
let result = match socket.process(cx, ip_repr, tcp_repr) {
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
self.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket);
result
}
/// Tries to generate an outgoing packet and dispatches the generated packet.
pub(crate) fn dispatch<D>(
&self,
cx: &mut Context,
dispatch: D,
) -> Option<(IpRepr, TcpRepr<'static>)>
where
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
{
let mut socket = self.socket.lock();
let mut reply = None;
socket
.dispatch(cx, |cx, (ip_repr, tcp_repr)| {
reply = dispatch(cx, &ip_repr, &tcp_repr);
Ok::<(), ()>(())
})
.unwrap();
// `dispatch` can return a packet in response to the generated packet. If the socket
// accepts the packet, we can process it directly.
while let Some((ref ip_repr, ref tcp_repr)) = reply {
if !socket.accepts(cx, ip_repr, tcp_repr) {
break;
}
reply = socket.process(cx, ip_repr, tcp_repr);
}
self.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket);
reply
}
}
impl<E> BoundUdpSocketInner<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
&self,
cx: &mut Context,
ip_repr: &IpRepr,
udp_repr: &UdpRepr,
udp_payload: &[u8],
) -> bool {
let mut socket = self.socket.lock();
if !socket.accepts(cx, ip_repr, udp_repr) {
return false;
}
socket.process(
cx,
smoltcp::phy::PacketMeta::default(),
ip_repr,
udp_repr,
udp_payload,
);
self.update_next_poll_at_ms(socket.poll_at(cx));
true
}
}
}
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
&self,
mut f: F,
) -> R {
let mut sockets = self.iface.common().sockets();
let socket = sockets.get_mut::<T>(self.handle);
f(socket)
}
}
/// Tries to generate an outgoing packet and dispatches the generated packet.
pub(crate) fn dispatch<D>(&self, cx: &mut Context, dispatch: D)
where
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
{
let mut socket = self.socket.lock();
impl<E> Drop for AnyBoundSocketInner<E> {
fn drop(&mut self) {
let iface_common = self.iface.common();
iface_common.remove_socket(self.handle);
iface_common.release_port(self.port);
socket
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {
dispatch(cx, &ip_repr, &udp_repr, udp_payload);
Ok::<(), ()>(())
})
.unwrap();
self.update_next_poll_at_ms(socket.poll_at(cx));
}
}

View File

@ -4,12 +4,11 @@ mod bound;
mod event;
mod unbound;
pub use bound::AnyBoundSocket;
pub(crate) use bound::{AnyBoundSocketInner, SocketFamily};
pub use bound::{BoundTcpSocket, BoundUdpSocket};
pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
pub use event::SocketEventObserver;
pub(crate) use unbound::AnyRawSocket;
pub use unbound::{
AnyUnboundSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
UDP_SEND_PAYLOAD_LEN,
};

View File

@ -1,34 +1,33 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{sync::Weak, vec};
use alloc::{boxed::Box, sync::Weak, vec};
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket};
pub struct AnyUnboundSocket {
socket_family: AnyRawSocket,
pub struct UnboundSocket<T> {
socket: Box<T>,
observer: Weak<dyn SocketEventObserver>,
}
#[allow(clippy::large_enum_variant)]
pub(crate) enum AnyRawSocket {
Tcp(RawTcpSocket),
Udp(RawUdpSocket),
}
pub type UnboundTcpSocket = UnboundSocket<RawTcpSocket>;
pub type UnboundUdpSocket = UnboundSocket<RawUdpSocket>;
impl AnyUnboundSocket {
pub fn new_tcp(observer: Weak<dyn SocketEventObserver>) -> Self {
impl UnboundTcpSocket {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
let raw_tcp_socket = {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]);
RawTcpSocket::new(rx_buffer, tx_buffer)
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
Self {
socket: Box::new(raw_tcp_socket),
observer,
}
}
}
pub fn new_udp(observer: Weak<dyn SocketEventObserver>) -> Self {
impl UnboundUdpSocket {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
@ -41,14 +40,16 @@ impl AnyUnboundSocket {
);
RawUdpSocket::new(rx_buffer, tx_buffer)
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Udp(raw_udp_socket),
Self {
socket: Box::new(raw_udp_socket),
observer,
}
}
}
pub(crate) fn into_raw(self) -> (AnyRawSocket, Weak<dyn SocketEventObserver>) {
(self.socket_family, self.observer)
impl<T> UnboundSocket<T> {
pub(crate) fn into_raw(self) -> (Box<T>, Weak<dyn SocketEventObserver>) {
(self.socket, self.observer)
}
}

View File

@ -8,4 +8,5 @@ pub use init::{init, IFACES};
pub use poll::{lazy_init, poll_ifaces};
pub type Iface = dyn aster_bigtcp::iface::Iface<ext::IfaceExt>;
pub type AnyBoundSocket = aster_bigtcp::socket::AnyBoundSocket<ext::IfaceExt>;
pub type BoundTcpSocket = aster_bigtcp::socket::BoundTcpSocket<ext::IfaceExt>;
pub type BoundUdpSocket = aster_bigtcp::socket::BoundUdpSocket<ext::IfaceExt>;

View File

@ -3,12 +3,11 @@
use aster_bigtcp::{
errors::BindError,
iface::BindPortConfig,
socket::AnyUnboundSocket,
wire::{IpAddress, IpEndpoint},
};
use crate::{
net::iface::{AnyBoundSocket, Iface, IFACES},
net::iface::{Iface, IFACES},
prelude::*,
};
@ -46,11 +45,16 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<Iface> {
ifaces[0].clone()
}
pub(super) fn bind_socket(
unbound_socket: Box<AnyUnboundSocket>,
pub(super) fn bind_socket<S, T>(
unbound_socket: Box<S>,
endpoint: &IpEndpoint,
can_reuse: bool,
) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
bind: impl FnOnce(
Arc<Iface>,
Box<S>,
BindPortConfig,
) -> core::result::Result<T, (BindError, Box<S>)>,
) -> core::result::Result<T, (Error, Box<S>)> {
let iface = match get_iface_to_bind(&endpoint.addr) {
Some(iface) => iface,
None => {
@ -64,9 +68,7 @@ pub(super) fn bind_socket(
let bind_port_config = BindPortConfig::new(endpoint.port, can_reuse);
iface
.bind_socket(unbound_socket, bind_port_config)
.map_err(|(err, unbound)| (err.into(), unbound))
bind(iface, unbound_socket, bind_port_config).map_err(|(err, unbound)| (err.into(), unbound))
}
impl From<BindError> for Error {

View File

@ -2,25 +2,24 @@
use aster_bigtcp::{
errors::udp::{RecvError, SendError},
socket::RawUdpSocket,
wire::IpEndpoint,
};
use crate::{
events::IoEvents,
net::{iface::AnyBoundSocket, socket::util::send_recv_flags::SendRecvFlags},
net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags},
prelude::*,
process::signal::Pollee,
util::{MultiRead, MultiWrite},
};
pub struct BoundDatagram {
bound_socket: AnyBoundSocket,
bound_socket: BoundUdpSocket,
remote_endpoint: Option<IpEndpoint>,
}
impl BoundDatagram {
pub fn new(bound_socket: AnyBoundSocket) -> Self {
pub fn new(bound_socket: BoundUdpSocket) -> Self {
Self {
bound_socket,
remote_endpoint: None,
@ -44,12 +43,10 @@ impl BoundDatagram {
writer: &mut dyn MultiWrite,
_flags: SendRecvFlags,
) -> Result<(usize, IpEndpoint)> {
let result = self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket.recv().map(|(packet, udp_metadata)| {
let result = self.bound_socket.recv(|packet, udp_metadata| {
let copied_res = writer.write(&mut VmReader::from(packet));
let endpoint = udp_metadata.endpoint;
(copied_res, endpoint)
})
});
match result {
@ -59,7 +56,7 @@ impl BoundDatagram {
return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty")
}
Err(RecvError::Truncated) => {
unreachable!("`Socket::recv` should never fail with `RecvError::Truncated`")
unreachable!("`recv` should never fail with `RecvError::Truncated`")
}
}
}
@ -70,23 +67,9 @@ impl BoundDatagram {
remote: &IpEndpoint,
_flags: SendRecvFlags,
) -> Result<usize> {
let reader_len = reader.sum_lens();
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
if socket.payload_send_capacity() < reader_len {
return_errno_with_message!(Errno::EMSGSIZE, "the message is too large");
}
let socket_buffer = match socket.send(reader_len, *remote) {
Ok(socket_buffer) => socket_buffer,
Err(SendError::BufferFull) => {
return_errno_with_message!(Errno::EAGAIN, "the send buffer is full")
}
Err(SendError::Unaddressable) => {
return_errno_with_message!(Errno::EINVAL, "the destination address is invalid")
}
};
let result = self
.bound_socket
.send(reader.sum_lens(), *remote, |socket_buffer| {
// FIXME: If copy failed, we should not send any packet.
// But current smoltcp API seems not to support this behavior.
reader
@ -95,7 +78,20 @@ impl BoundDatagram {
warn!("unexpected UDP packet will be sent");
e
})
})
});
match result {
Ok(inner) => inner,
Err(SendError::TooLarge) => {
return_errno_with_message!(Errno::EMSGSIZE, "the message is too large");
}
Err(SendError::Unaddressable) => {
return_errno_with_message!(Errno::EINVAL, "the destination address is invalid");
}
Err(SendError::BufferFull) => {
return_errno_with_message!(Errno::EAGAIN, "the send buffer is full");
}
}
}
pub(super) fn init_pollee(&self, pollee: &Pollee) {
@ -104,7 +100,7 @@ impl BoundDatagram {
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
self.bound_socket.raw_with(|socket| {
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {

View File

@ -154,13 +154,9 @@ impl DatagramSocket {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound");
};
let received =
bound_datagram
let received = bound_datagram
.try_recv(writer, flags)
.map(|(recv_bytes, remote_endpoint)| {
bound_datagram.update_io_events(&self.pollee);
(recv_bytes, remote_endpoint.into())
});
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()));
drop(inner);
poll_ifaces();
@ -192,12 +188,7 @@ impl DatagramSocket {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound")
};
let sent_bytes = bound_datagram
.try_send(reader, remote, flags)
.map(|sent_bytes| {
bound_datagram.update_io_events(&self.pollee);
sent_bytes
});
let sent_bytes = bound_datagram.try_send(reader, remote, flags);
drop(inner);
poll_ifaces();

View File

@ -3,7 +3,7 @@
use alloc::sync::Weak;
use aster_bigtcp::{
socket::{AnyUnboundSocket, RawUdpSocket, SocketEventObserver},
socket::{SocketEventObserver, UnboundUdpSocket},
wire::IpEndpoint,
};
@ -13,13 +13,13 @@ use crate::{
};
pub struct UnboundDatagram {
unbound_socket: Box<AnyUnboundSocket>,
unbound_socket: Box<UnboundUdpSocket>,
}
impl UnboundDatagram {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
Self {
unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)),
unbound_socket: Box::new(UnboundUdpSocket::new(observer)),
}
}
@ -28,15 +28,18 @@ impl UnboundDatagram {
endpoint: &IpEndpoint,
can_reuse: bool,
) -> core::result::Result<BoundDatagram, (Error, Self)> {
let bound_socket = match bind_socket(self.unbound_socket, endpoint, can_reuse) {
let bound_socket = match bind_socket(
self.unbound_socket,
endpoint,
can_reuse,
|iface, socket, config| iface.bind_udp(socket, config),
) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),
};
let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket.bind(bound_endpoint).unwrap();
});
bound_socket.bind(bound_endpoint).unwrap();
Ok(BoundDatagram::new(bound_socket))
}

View File

@ -4,14 +4,14 @@ use alloc::sync::Weak;
use aster_bigtcp::{
errors::tcp::{RecvError, SendError},
socket::{RawTcpSocket, SocketEventObserver},
socket::SocketEventObserver,
wire::IpEndpoint,
};
use crate::{
events::IoEvents,
net::{
iface::AnyBoundSocket,
iface::BoundTcpSocket,
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
},
prelude::*,
@ -20,7 +20,7 @@ use crate::{
};
pub struct ConnectedStream {
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
remote_endpoint: IpEndpoint,
/// Indicates whether this connection is "new" in a `connect()` system call.
///
@ -37,7 +37,7 @@ pub struct ConnectedStream {
impl ConnectedStream {
pub fn new(
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
remote_endpoint: IpEndpoint,
is_new_connection: bool,
) -> Self {
@ -50,20 +50,16 @@ impl ConnectedStream {
pub fn shutdown(&self, _cmd: SockShutdownCmd) -> Result<()> {
// TODO: deal with cmd
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
});
self.bound_socket.close();
Ok(())
}
pub fn try_recv(&self, writer: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result<usize> {
let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.recv(
|socket_buffer| match writer.write(&mut VmReader::from(&*socket_buffer)) {
let result = self.bound_socket.recv(|socket_buffer| {
match writer.write(&mut VmReader::from(&*socket_buffer)) {
Ok(len) => (len, Ok(len)),
Err(e) => (0, Err(e)),
},
)
}
});
match result {
@ -78,13 +74,11 @@ impl ConnectedStream {
}
pub fn try_send(&self, reader: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result<usize> {
let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.send(
|socket_buffer| match reader.read(&mut VmWriter::from(socket_buffer)) {
let result = self.bound_socket.send(|socket_buffer| {
match reader.read(&mut VmWriter::from(socket_buffer)) {
Ok(len) => (len, Ok(len)),
Err(e) => (0, Err(e)),
},
)
}
});
match result {
@ -123,7 +117,7 @@ impl ConnectedStream {
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
self.bound_socket.raw_with(|socket| {
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {

View File

@ -1,12 +1,12 @@
// SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::{socket::RawTcpSocket, wire::IpEndpoint};
use aster_bigtcp::wire::IpEndpoint;
use super::{connected::ConnectedStream, init::InitStream};
use crate::{net::iface::AnyBoundSocket, prelude::*, process::signal::Pollee};
use crate::{net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
pub struct ConnectingStream {
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>,
}
@ -24,15 +24,15 @@ pub enum NonConnectedStream {
impl ConnectingStream {
pub fn new(
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
remote_endpoint: IpEndpoint,
) -> core::result::Result<Self, (Error, AnyBoundSocket)> {
) -> core::result::Result<Self, (Error, BoundTcpSocket)> {
// The only reason this method might fail is because we're trying to connect to an
// unspecified address (i.e. 0.0.0.0). We currently have no support for binding to,
// listening on, or connecting to the unspecified address.
//
// We assume the remote will just refuse to connect, so we return `ECONNREFUSED`.
if bound_socket.do_connect(remote_endpoint).is_err() {
if bound_socket.connect(remote_endpoint).is_err() {
return Err((
Error::with_message(
Errno::ECONNREFUSED,
@ -91,7 +91,7 @@ impl ConnectingStream {
return false;
}
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
self.bound_socket.raw_with(|socket| {
let mut result = self.conn_result.write();
if result.is_some() {
return false;

View File

@ -3,7 +3,7 @@
use alloc::sync::Weak;
use aster_bigtcp::{
socket::{AnyUnboundSocket, SocketEventObserver},
socket::{SocketEventObserver, UnboundTcpSocket},
wire::IpEndpoint,
};
@ -11,7 +11,7 @@ use super::{connecting::ConnectingStream, listen::ListenStream};
use crate::{
events::IoEvents,
net::{
iface::AnyBoundSocket,
iface::BoundTcpSocket,
socket::ip::common::{bind_socket, get_ephemeral_endpoint},
},
prelude::*,
@ -19,16 +19,16 @@ use crate::{
};
pub enum InitStream {
Unbound(Box<AnyUnboundSocket>),
Bound(AnyBoundSocket),
Unbound(Box<UnboundTcpSocket>),
Bound(BoundTcpSocket),
}
impl InitStream {
pub fn new(observer: Weak<dyn SocketEventObserver>) -> Self {
InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer)))
InitStream::Unbound(Box::new(UnboundTcpSocket::new(observer)))
}
pub fn new_bound(bound_socket: AnyBoundSocket) -> Self {
pub fn new_bound(bound_socket: BoundTcpSocket) -> Self {
InitStream::Bound(bound_socket)
}
@ -36,7 +36,7 @@ impl InitStream {
self,
endpoint: &IpEndpoint,
can_reuse: bool,
) -> core::result::Result<AnyBoundSocket, (Error, Self)> {
) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
let unbound_socket = match self {
InitStream::Unbound(unbound_socket) => unbound_socket,
InitStream::Bound(bound_socket) => {
@ -46,7 +46,12 @@ impl InitStream {
));
}
};
let bound_socket = match bind_socket(unbound_socket, endpoint, can_reuse) {
let bound_socket = match bind_socket(
unbound_socket,
endpoint,
can_reuse,
|iface, socket, config| iface.bind_tcp(socket, config),
) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
};
@ -56,7 +61,7 @@ impl InitStream {
fn bind_to_ephemeral_endpoint(
self,
remote_endpoint: &IpEndpoint,
) -> core::result::Result<AnyBoundSocket, (Error, Self)> {
) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, false)
}

View File

@ -1,28 +1,25 @@
// SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::{
errors::tcp::ListenError,
iface::BindPortConfig,
socket::{AnyUnboundSocket, RawTcpSocket},
wire::IpEndpoint,
errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint,
};
use super::connected::ConnectedStream;
use crate::{events::IoEvents, net::iface::AnyBoundSocket, prelude::*, process::signal::Pollee};
use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee};
pub struct ListenStream {
backlog: usize,
/// A bound socket held to ensure the TCP port cannot be released
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
/// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>,
}
impl ListenStream {
pub fn new(
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
backlog: usize,
) -> core::result::Result<Self, (Error, AnyBoundSocket)> {
) -> core::result::Result<Self, (Error, BoundTcpSocket)> {
const SOMAXCONN: usize = 4096;
let somaxconn = SOMAXCONN.min(backlog);
@ -102,30 +99,28 @@ impl ListenStream {
}
struct BacklogSocket {
bound_socket: AnyBoundSocket,
bound_socket: BoundTcpSocket,
}
impl BacklogSocket {
// FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason
// why the error may occur. Perhaps it is better to call `unwrap()` directly?
fn new(bound_socket: &AnyBoundSocket) -> Result<Self> {
fn new(bound_socket: &BoundTcpSocket) -> Result<Self> {
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new()));
let unbound_socket = Box::new(UnboundTcpSocket::new(bound_socket.observer()));
let bound_socket = {
let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true);
iface
.bind_socket(unbound_socket, bind_port_config)
.bind_tcp(unbound_socket, bind_port_config)
.map_err(|(err, _)| err)?
};
let result = bound_socket
.raw_with(|raw_tcp_socket: &mut RawTcpSocket| raw_tcp_socket.listen(local_endpoint));
match result {
match bound_socket.listen(local_endpoint) {
Ok(()) => Ok(Self { bound_socket }),
Err(ListenError::Unaddressable) => {
return_errno_with_message!(Errno::EINVAL, "the listening address is invalid")
@ -137,16 +132,15 @@ impl BacklogSocket {
}
fn is_active(&self) -> bool {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.is_active())
self.bound_socket.raw_with(|socket| socket.is_active())
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
.raw_with(|socket| socket.remote_endpoint())
}
fn into_bound_socket(self) -> AnyBoundSocket {
fn into_bound_socket(self) -> BoundTcpSocket {
self.bound_socket
}
}

View File

@ -230,18 +230,13 @@ impl StreamSocket {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
};
let accepted = listen_stream.try_accept().map(|connected_stream| {
listen_stream.try_accept().map(|connected_stream| {
listen_stream.update_io_events(&self.pollee);
let remote_endpoint = connected_stream.remote_endpoint();
let accepted_socket = Self::new_connected(connected_stream);
(accepted_socket as _, remote_endpoint.into())
});
drop(state);
poll_ifaces();
accepted
})
}
fn try_recv(
@ -262,8 +257,6 @@ impl StreamSocket {
};
let received = connected_stream.try_recv(writer, flags).map(|recv_bytes| {
connected_stream.update_io_events(&self.pollee);
let remote_endpoint = connected_stream.remote_endpoint();
(recv_bytes, remote_endpoint.into())
});
@ -302,10 +295,7 @@ impl StreamSocket {
}
};
let sent_bytes = connected_stream.try_send(reader, flags).map(|sent_bytes| {
connected_stream.update_io_events(&self.pollee);
sent_bytes
});
let sent_bytes = connected_stream.try_send(reader, flags);
drop(state);
poll_ifaces();
@ -658,3 +648,11 @@ impl SocketEventObserver for StreamSocket {
}
}
}
impl Drop for StreamSocket {
fn drop(&mut self) {
self.state.write().take();
poll_ifaces();
}
}