Create backlog sockets on demand

This commit is contained in:
Ruihan Li
2024-12-02 23:11:43 +08:00
committed by Tate, Hongliang Tian
parent a739848464
commit 776fd6a892
24 changed files with 947 additions and 781 deletions

2
Cargo.lock generated
View File

@ -74,6 +74,8 @@ dependencies = [
"keyable-arc", "keyable-arc",
"ostd", "ostd",
"smoltcp", "smoltcp",
"spin 0.9.8",
"takeable",
] ]
[[package]] [[package]]

View File

@ -18,3 +18,5 @@ smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f0
"socket-udp", "socket-udp",
"socket-tcp", "socket-tcp",
] } ] }
spin = "0.9.4"
takeable = "0.2.2"

View File

@ -15,7 +15,7 @@ pub trait Ext {
type ScheduleNextPoll: ScheduleNextPoll; type ScheduleNextPoll: ScheduleNextPoll;
/// The type for TCP sockets to observe events. /// The type for TCP sockets to observe events.
type TcpEventObserver: SocketEventObserver; type TcpEventObserver: SocketEventObserver + Clone;
/// The type for UDP sockets to observe events. /// The type for UDP sockets to observe events.
type UdpEventObserver: SocketEventObserver; type UdpEventObserver: SocketEventObserver;

View File

@ -1,13 +1,13 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::{ use alloc::{
boxed::Box,
collections::{ collections::{
btree_map::{BTreeMap, Entry}, btree_map::{BTreeMap, Entry},
btree_set::BTreeSet, btree_set::BTreeSet,
}, },
string::String, string::String,
sync::Arc, sync::Arc,
vec::Vec,
}; };
use keyable_arc::KeyableArc; use keyable_arc::KeyableArc;
@ -15,7 +15,7 @@ use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
use smoltcp::{ use smoltcp::{
iface::{packet::Packet, Context}, iface::{packet::Packet, Context},
phy::Device, phy::Device,
wire::{Ipv4Address, Ipv4Packet}, wire::{IpAddress, IpEndpoint, Ipv4Address, Ipv4Packet},
}; };
use super::{ use super::{
@ -27,33 +27,40 @@ use super::{
use crate::{ use crate::{
errors::BindError, errors::BindError,
ext::Ext, ext::Ext,
socket::{ socket::{TcpConnectionBg, TcpListenerBg, UdpSocketBg},
BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket,
UnboundUdpSocket,
},
}; };
pub struct IfaceCommon<E: Ext> { pub struct IfaceCommon<E: Ext> {
name: String, name: String,
interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>, interface: SpinLock<smoltcp::iface::Interface, LocalIrqDisabled>,
used_ports: SpinLock<BTreeMap<u16, usize>, LocalIrqDisabled>, used_ports: SpinLock<BTreeMap<u16, usize>, LocalIrqDisabled>,
tcp_sockets: SpinLock<BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, LocalIrqDisabled>, sockets: SpinLock<SocketSet<E>, LocalIrqDisabled>,
udp_sockets: SpinLock<BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, LocalIrqDisabled>,
sched_poll: E::ScheduleNextPoll, sched_poll: E::ScheduleNextPoll,
} }
pub(super) struct SocketSet<E: Ext> {
pub(super) tcp_conn: BTreeSet<KeyableArc<TcpConnectionBg<E>>>,
pub(super) tcp_listen: BTreeSet<KeyableArc<TcpListenerBg<E>>>,
pub(super) udp: BTreeSet<KeyableArc<UdpSocketBg<E>>>,
}
impl<E: Ext> IfaceCommon<E> { impl<E: Ext> IfaceCommon<E> {
pub(super) fn new( pub(super) fn new(
name: String, name: String,
interface: smoltcp::iface::Interface, interface: smoltcp::iface::Interface,
sched_poll: E::ScheduleNextPoll, sched_poll: E::ScheduleNextPoll,
) -> Self { ) -> Self {
let sockets = SocketSet {
tcp_conn: BTreeSet::new(),
tcp_listen: BTreeSet::new(),
udp: BTreeSet::new(),
};
Self { Self {
name, name,
interface: SpinLock::new(interface), interface: SpinLock::new(interface),
used_ports: SpinLock::new(BTreeMap::new()), used_ports: SpinLock::new(BTreeMap::new()),
tcp_sockets: SpinLock::new(BTreeSet::new()), sockets: SpinLock::new(sockets),
udp_sockets: SpinLock::new(BTreeSet::new()),
sched_poll, sched_poll,
} }
} }
@ -82,52 +89,13 @@ const IP_LOCAL_PORT_START: u16 = 32768;
const IP_LOCAL_PORT_END: u16 = 60999; const IP_LOCAL_PORT_END: u16 = 60999;
impl<E: Ext> IfaceCommon<E> { impl<E: Ext> IfaceCommon<E> {
pub(super) fn bind_tcp( pub(super) fn bind(
&self, &self,
iface: Arc<dyn Iface<E>>, iface: Arc<dyn Iface<E>>,
socket: Box<UnboundTcpSocket>,
observer: E::TcpEventObserver,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> { ) -> core::result::Result<BoundPort<E>, BindError> {
let port = match self.bind_port(config) { let port = self.bind_port(config)?;
Ok(port) => port, Ok(BoundPort { iface, port })
Err(err) => return Err((err, socket)),
};
let raw_socket = socket.into_raw();
let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer);
let inserted = self
.tcp_sockets
.lock()
.insert(KeyableArc::from(bound_socket.inner().clone()));
assert!(inserted);
Ok(bound_socket)
}
pub(super) fn bind_udp(
&self,
iface: Arc<dyn Iface<E>>,
socket: Box<UnboundUdpSocket>,
observer: E::UdpEventObserver,
config: BindPortConfig,
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
let port = match self.bind_port(config) {
Ok(port) => port,
Err(err) => return Err((err, socket)),
};
let raw_socket = socket.into_raw();
let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer);
let inserted = self
.udp_sockets
.lock()
.insert(KeyableArc::from(bound_socket.inner().clone()));
assert!(inserted);
Ok(bound_socket)
} }
/// Allocates an unused ephemeral port. /// Allocates an unused ephemeral port.
@ -171,29 +139,6 @@ impl<E: Ext> IfaceCommon<E> {
Ok(port) Ok(port)
} }
}
impl<E: Ext> IfaceCommon<E> {
#[allow(clippy::mutable_key_type)]
fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>) {
sockets.retain(|socket| {
if socket.is_dead() {
self.release_port(socket.port());
false
} else {
true
}
});
}
pub(crate) fn remove_udp_socket(&self, socket: &Arc<BoundUdpSocketInner<E>>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self.udp_sockets.lock().remove(&keyable_socket);
assert!(removed);
self.release_port(keyable_socket.port());
}
/// Releases the port so that it can be used again (if it is not being reused). /// Releases the port so that it can be used again (if it is not being reused).
fn release_port(&self, port: u16) { fn release_port(&self, port: u16) {
@ -206,11 +151,50 @@ impl<E: Ext> IfaceCommon<E> {
} }
} }
impl<E: Ext> IfaceCommon<E> {
pub(crate) fn register_tcp_connection(&self, socket: KeyableArc<TcpConnectionBg<E>>) {
let mut sockets = self.sockets.lock();
let inserted = sockets.tcp_conn.insert(socket);
debug_assert!(inserted);
}
pub(crate) fn register_tcp_listener(&self, socket: KeyableArc<TcpListenerBg<E>>) {
let mut sockets = self.sockets.lock();
let inserted = sockets.tcp_listen.insert(socket);
debug_assert!(inserted);
}
pub(crate) fn register_udp_socket(&self, socket: KeyableArc<UdpSocketBg<E>>) {
let mut sockets = self.sockets.lock();
let inserted = sockets.udp.insert(socket);
debug_assert!(inserted);
}
#[allow(clippy::mutable_key_type)]
fn remove_dead_tcp_connections(sockets: &mut BTreeSet<KeyableArc<TcpConnectionBg<E>>>) {
for socket in sockets.extract_if(|socket| socket.is_dead()) {
TcpConnectionBg::on_dead_events(socket);
}
}
pub(crate) fn remove_tcp_listener(&self, socket: &KeyableArc<TcpListenerBg<E>>) {
let mut sockets = self.sockets.lock();
let removed = sockets.tcp_listen.remove(socket);
debug_assert!(removed);
}
pub(crate) fn remove_udp_socket(&self, socket: &KeyableArc<UdpSocketBg<E>>) {
let mut sockets = self.sockets.lock();
let removed = sockets.udp.remove(socket);
debug_assert!(removed);
}
}
impl<E: Ext> IfaceCommon<E> { impl<E: Ext> IfaceCommon<E> {
pub(super) fn poll<D, P, Q>( pub(super) fn poll<D, P, Q>(
&self, &self,
device: &mut D, device: &mut D,
process_phy: P, mut process_phy: P,
mut dispatch_phy: Q, mut dispatch_phy: Q,
) -> Option<u64> ) -> Option<u64>
where where
@ -226,41 +210,85 @@ impl<E: Ext> IfaceCommon<E> {
let mut interface = self.interface(); let mut interface = self.interface();
interface.context().now = get_network_timestamp(); interface.context().now = get_network_timestamp();
let mut tcp_sockets = self.tcp_sockets.lock(); let mut sockets = self.sockets.lock();
let udp_sockets = self.udp_sockets.lock();
let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets); loop {
context.poll_ingress(device, process_phy, &mut dispatch_phy); let mut new_tcp_conns = Vec::new();
context.poll_egress(device, dispatch_phy);
tcp_sockets.iter().for_each(|socket| { let mut context = PollContext::new(interface.context(), &sockets, &mut new_tcp_conns);
if socket.has_events() { context.poll_ingress(device, &mut process_phy, &mut dispatch_phy);
socket.on_events(); context.poll_egress(device, &mut dispatch_phy);
// New packets sent by new connections are not handled. So if there are new
// connections, try again.
if new_tcp_conns.is_empty() {
break;
} else {
sockets.tcp_conn.extend(new_tcp_conns);
} }
});
udp_sockets.iter().for_each(|socket| {
if socket.has_events() {
socket.on_events();
}
});
self.remove_dead_tcp_sockets(&mut tcp_sockets);
match (
tcp_sockets
.iter()
.map(|socket| socket.next_poll_at_ms())
.min(),
udp_sockets
.iter()
.map(|socket| socket.next_poll_at_ms())
.min(),
) {
(Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => {
Some(tcp_poll_at)
}
(tcp_poll_at, None) => tcp_poll_at,
(_, udp_poll_at) => udp_poll_at,
} }
Self::remove_dead_tcp_connections(&mut sockets.tcp_conn);
sockets.tcp_conn.iter().for_each(|socket| {
if socket.has_events() {
socket.on_events();
}
});
sockets.tcp_listen.iter().for_each(|socket| {
if socket.has_events() {
socket.on_events();
}
});
sockets.udp.iter().for_each(|socket| {
if socket.has_events() {
socket.on_events();
}
});
// Note that only TCP connections can have timers set, so as far as the time to poll is
// concerned, we only need to consider TCP connections.
sockets
.tcp_conn
.iter()
.map(|socket| socket.next_poll_at_ms())
.min()
}
}
/// A port bound to an iface.
///
/// When dropped, the port is automatically released.
//
// FIXME: TCP and UDP ports are independent. Find a way to track the protocol here.
pub struct BoundPort<E: Ext> {
iface: Arc<dyn Iface<E>>,
port: u16,
}
impl<E: Ext> BoundPort<E> {
/// Returns a reference to the iface.
pub fn iface(&self) -> &Arc<dyn Iface<E>> {
&self.iface
}
/// Returns the port number.
pub fn port(&self) -> u16 {
self.port
}
/// Returns the bound endpoint.
pub fn endpoint(&self) -> Option<IpEndpoint> {
let ip_addr = {
let ipv4_addr = self.iface().ipv4_addr()?;
IpAddress::Ipv4(ipv4_addr)
};
Some(IpEndpoint::new(ip_addr, self.port))
}
}
impl<E: Ext> Drop for BoundPort<E> {
fn drop(&mut self) {
self.iface.common().release_port(self.port);
} }
} }

View File

@ -1,15 +1,11 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc}; use alloc::sync::Arc;
use smoltcp::wire::Ipv4Address; use smoltcp::wire::Ipv4Address;
use super::port::BindPortConfig; use super::{port::BindPortConfig, BoundPort};
use crate::{ use crate::{errors::BindError, ext::Ext};
errors::BindError,
ext::Ext,
socket::{BoundTcpSocket, BoundUdpSocket, UnboundTcpSocket, UnboundUdpSocket},
};
/// A network interface. /// A network interface.
/// ///
@ -34,24 +30,12 @@ impl<E: Ext> dyn Iface<E> {
/// FIXME: The reason for binding the socket and the iface together is because there are /// FIXME: The reason for binding the socket and the iface together is because there are
/// limitations inside smoltcp. See discussion at /// limitations inside smoltcp. See discussion at
/// <https://github.com/smoltcp-rs/smoltcp/issues/779>. /// <https://github.com/smoltcp-rs/smoltcp/issues/779>.
pub fn bind_tcp( pub fn bind(
self: &Arc<Self>, self: &Arc<Self>,
socket: Box<UnboundTcpSocket>,
observer: E::TcpEventObserver,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<BoundTcpSocket<E>, (BindError, Box<UnboundTcpSocket>)> { ) -> core::result::Result<BoundPort<E>, BindError> {
let common = self.common(); let common = self.common();
common.bind_tcp(self.clone(), socket, observer, config) common.bind(self.clone(), config)
}
pub fn bind_udp(
self: &Arc<Self>,
socket: Box<UnboundUdpSocket>,
observer: E::UdpEventObserver,
config: BindPortConfig,
) -> core::result::Result<BoundUdpSocket<E>, (BindError, Box<UnboundUdpSocket>)> {
let common = self.common();
common.bind_udp(self.clone(), socket, observer, config)
} }
/// Gets the name of the iface. /// Gets the name of the iface.

View File

@ -9,6 +9,7 @@ mod port;
mod sched; mod sched;
mod time; mod time;
pub use common::BoundPort;
pub use iface::Iface; pub use iface::Iface;
pub use phy::{EtherIface, IpIface}; pub use phy::{EtherIface, IpIface};
pub use port::BindPortConfig; pub use port::BindPortConfig;

View File

@ -1,6 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::{collections::btree_set::BTreeSet, vec}; use alloc::{vec, vec::Vec};
use keyable_arc::KeyableArc; use keyable_arc::KeyableArc;
use smoltcp::{ use smoltcp::{
@ -16,28 +16,28 @@ use smoltcp::{
}, },
}; };
use super::common::SocketSet;
use crate::{ use crate::{
ext::Ext, ext::Ext,
socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}, socket::{TcpConnectionBg, TcpListenerBg, TcpProcessResult},
}; };
pub(super) struct PollContext<'a, E: Ext> { pub(super) struct PollContext<'a, E: Ext> {
iface_cx: &'a mut Context, iface_cx: &'a mut Context,
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, sockets: &'a SocketSet<E>,
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, new_tcp_conns: &'a mut Vec<KeyableArc<TcpConnectionBg<E>>>,
} }
impl<'a, E: Ext> PollContext<'a, E> { impl<'a, E: Ext> PollContext<'a, E> {
#[allow(clippy::mutable_key_type)]
pub(super) fn new( pub(super) fn new(
iface_cx: &'a mut Context, iface_cx: &'a mut Context,
tcp_sockets: &'a BTreeSet<KeyableArc<BoundTcpSocketInner<E>>>, sockets: &'a SocketSet<E>,
udp_sockets: &'a BTreeSet<KeyableArc<BoundUdpSocketInner<E>>>, new_tcp_conns: &'a mut Vec<KeyableArc<TcpConnectionBg<E>>>,
) -> Self { ) -> Self {
Self { Self {
iface_cx, iface_cx,
tcp_sockets, sockets,
udp_sockets, new_tcp_conns,
} }
} }
} }
@ -51,7 +51,7 @@ impl<E: Ext> PollContext<'_, E> {
pub(super) fn poll_ingress<D, P, Q>( pub(super) fn poll_ingress<D, P, Q>(
&mut self, &mut self,
device: &mut D, device: &mut D,
mut process_phy: P, process_phy: &mut P,
dispatch_phy: &mut Q, dispatch_phy: &mut Q,
) where ) where
D: Device + ?Sized, D: Device + ?Sized,
@ -158,12 +158,17 @@ impl<E: Ext> PollContext<'_, E> {
ip_repr: &IpRepr, ip_repr: &IpRepr,
tcp_repr: &TcpRepr, tcp_repr: &TcpRepr,
) -> Option<(IpRepr, TcpRepr<'static>)> { ) -> Option<(IpRepr, TcpRepr<'static>)> {
for socket in self.tcp_sockets.iter() { for socket in self
.sockets
.tcp_conn
.iter()
.chain(self.new_tcp_conns.iter())
{
if !socket.can_process(tcp_repr.dst_port) { if !socket.can_process(tcp_repr.dst_port) {
continue; continue;
} }
match socket.process(self.iface_cx, ip_repr, tcp_repr) { match TcpConnectionBg::process(socket, self.iface_cx, ip_repr, tcp_repr) {
TcpProcessResult::NotProcessed => continue, TcpProcessResult::NotProcessed => continue,
TcpProcessResult::Processed => return None, TcpProcessResult::Processed => return None,
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
@ -172,6 +177,29 @@ impl<E: Ext> PollContext<'_, E> {
} }
} }
if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() {
for socket in self.sockets.tcp_listen.iter() {
if !socket.can_process(tcp_repr.dst_port) {
continue;
}
let (processed, new_tcp_conn) =
TcpListenerBg::process(socket, self.iface_cx, ip_repr, tcp_repr);
if let Some(tcp_conn) = new_tcp_conn {
self.new_tcp_conns.push(tcp_conn);
}
match processed {
TcpProcessResult::NotProcessed => continue,
TcpProcessResult::Processed => return None,
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
return Some((ip_repr, tcp_repr))
}
}
}
}
// "In no case does receipt of a segment containing RST give rise to a RST in response." // "In no case does receipt of a segment containing RST give rise to a RST in response."
// See <https://datatracker.ietf.org/doc/html/rfc9293#section-4-1.64>. // See <https://datatracker.ietf.org/doc/html/rfc9293#section-4-1.64>.
if tcp_repr.control == TcpControl::Rst { if tcp_repr.control == TcpControl::Rst {
@ -211,7 +239,7 @@ impl<E: Ext> PollContext<'_, E> {
fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool { fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool {
let mut processed = false; let mut processed = false;
for socket in self.udp_sockets.iter() { for socket in self.sockets.udp.iter() {
if !socket.can_process(udp_repr.dst_port) { if !socket.can_process(udp_repr.dst_port) {
continue; continue;
} }
@ -284,13 +312,13 @@ impl<E: Ext> PollContext<'_, E> {
} }
impl<E: Ext> PollContext<'_, E> { impl<E: Ext> PollContext<'_, E> {
pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, mut dispatch_phy: Q) pub(super) fn poll_egress<D, Q>(&mut self, device: &mut D, dispatch_phy: &mut Q)
where where
D: Device + ?Sized, D: Device + ?Sized,
Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), Q: FnMut(&Packet, &mut Context, D::TxToken<'_>),
{ {
while let Some(tx_token) = device.transmit(self.iface_cx.now()) { while let Some(tx_token) = device.transmit(self.iface_cx.now()) {
if !self.dispatch_ipv4(tx_token, &mut dispatch_phy) { if !self.dispatch_ipv4(tx_token, dispatch_phy) {
break; break;
} }
} }
@ -320,7 +348,9 @@ impl<E: Ext> PollContext<'_, E> {
let mut tx_token = Some(tx_token); let mut tx_token = Some(tx_token);
let mut did_something = false; let mut did_something = false;
for socket in self.tcp_sockets.iter() { // We cannot dispatch packets from `new_tcp_conns` because we cannot borrow an immutable
// reference at this point. Instead, we will retry after the entire poll is complete.
for socket in self.sockets.tcp_conn.iter() {
if !socket.need_dispatch(self.iface_cx.now()) { if !socket.need_dispatch(self.iface_cx.now()) {
continue; continue;
} }
@ -331,37 +361,38 @@ impl<E: Ext> PollContext<'_, E> {
let mut deferred = None; let mut deferred = None;
let reply = socket.dispatch(self.iface_cx, |cx, ip_repr, tcp_repr| { let reply =
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets); TcpConnectionBg::dispatch(socket, self.iface_cx, |cx, ip_repr, tcp_repr| {
let mut this = PollContext::new(cx, self.sockets, self.new_tcp_conns);
if !this.is_unicast_local(ip_repr.dst_addr()) { if !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy( dispatch_phy(
&Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)), &Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)),
this.iface_cx, this.iface_cx,
tx_token.take().unwrap(), tx_token.take().unwrap(),
); );
return None; return None;
} }
if !socket.can_process(tcp_repr.dst_port) { if !socket.can_process(tcp_repr.dst_port) {
return this.process_tcp(ip_repr, tcp_repr); return this.process_tcp(ip_repr, tcp_repr);
} }
// We cannot call `process_tcp` now because it may cause deadlocks. We will copy // We cannot call `process_tcp` now because it may cause deadlocks. We will copy
// the packet and call `process_tcp` after releasing the socket lock. // the packet and call `process_tcp` after releasing the socket lock.
deferred = Some((ip_repr.clone(), { deferred = Some((ip_repr.clone(), {
let mut data = vec![0; tcp_repr.buffer_len()]; let mut data = vec![0; tcp_repr.buffer_len()];
tcp_repr.emit( tcp_repr.emit(
&mut TcpPacket::new_unchecked(data.as_mut_slice()), &mut TcpPacket::new_unchecked(data.as_mut_slice()),
&ip_repr.src_addr(), &ip_repr.src_addr(),
&ip_repr.dst_addr(), &ip_repr.dst_addr(),
&ChecksumCapabilities::ignored(), &ChecksumCapabilities::ignored(),
); );
data data
})); }));
None None
}); });
match (deferred, reply) { match (deferred, reply) {
(None, None) => (), (None, None) => (),
@ -411,7 +442,7 @@ impl<E: Ext> PollContext<'_, E> {
let mut tx_token = Some(tx_token); let mut tx_token = Some(tx_token);
let mut did_something = false; let mut did_something = false;
for socket in self.udp_sockets.iter() { for socket in self.sockets.udp.iter() {
if !socket.need_dispatch(self.iface_cx.now()) { if !socket.need_dispatch(self.iface_cx.now()) {
continue; continue;
} }
@ -423,7 +454,7 @@ impl<E: Ext> PollContext<'_, E> {
let mut deferred = None; let mut deferred = None;
socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| { socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| {
let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets); let mut this = PollContext::new(cx, self.sockets, self.new_tcp_conns);
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) { if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy( dispatch_phy(

View File

@ -1,75 +1,99 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::{boxed::Box, sync::Arc}; use alloc::{boxed::Box, collections::btree_set::BTreeSet, sync::Arc, vec::Vec};
use core::{ use core::{
borrow::Borrow,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
}; };
use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard, WriteIrqDisabled}; use keyable_arc::KeyableArc;
use ostd::sync::{LocalIrqDisabled, SpinLock, SpinLockGuard};
use smoltcp::{ use smoltcp::{
iface::Context, iface::Context,
socket::{tcp::State, udp::UdpMetadata, PollAt}, socket::{tcp::State, udp::UdpMetadata, PollAt},
time::{Duration, Instant}, time::{Duration, Instant},
wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, wire::{IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr},
}; };
use spin::Once;
use takeable::Takeable;
use super::{ use super::{
event::{SocketEventObserver, SocketEvents}, event::{SocketEventObserver, SocketEvents},
option::RawTcpSetOption, option::{RawTcpOption, RawTcpSetOption},
unbound::{new_tcp_socket, new_udp_socket},
RawTcpSocket, RawUdpSocket, TcpStateCheck, RawTcpSocket, RawUdpSocket, TcpStateCheck,
}; };
use crate::{ext::Ext, iface::Iface}; use crate::{
ext::Ext,
iface::{BindPortConfig, BoundPort, Iface},
};
pub struct BoundSocket<T: AnySocket<E>, E: Ext>(Arc<BoundSocketInner<T, E>>); pub struct Socket<T: Inner<E>, E: Ext>(Takeable<KeyableArc<SocketBg<T, E>>>);
/// [`TcpSocket`] or [`UdpSocket`]. impl<T: Inner<E>, E: Ext> PartialEq for Socket<T, E> {
pub trait AnySocket<E> { fn eq(&self, other: &Self) -> bool {
type RawSocket; self.0.eq(&other.0)
}
}
impl<T: Inner<E>, E: Ext> Eq for Socket<T, E> {}
impl<T: Inner<E>, E: Ext> PartialOrd for Socket<T, E> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: Inner<E>, E: Ext> Ord for Socket<T, E> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl<T: Inner<E>, E: Ext> Borrow<KeyableArc<SocketBg<T, E>>> for Socket<T, E> {
fn borrow(&self) -> &KeyableArc<SocketBg<T, E>> {
self.0.as_ref()
}
}
/// [`TcpConnectionInner`] or [`UdpSocketInner`].
pub trait Inner<E: Ext> {
type Observer: SocketEventObserver; type Observer: SocketEventObserver;
/// Called by [`BoundSocket::new`]. /// Called by [`Socket::drop`].
fn new(socket: Box<Self::RawSocket>) -> Self; fn on_drop(this: &KeyableArc<SocketBg<Self, E>>)
/// Called by [`BoundSocket::drop`].
fn on_drop(this: &Arc<BoundSocketInner<Self, E>>)
where where
E: Ext, E: Ext,
Self: Sized; Self: Sized;
} }
pub type BoundTcpSocket<E> = BoundSocket<TcpSocket, E>; pub type TcpConnection<E> = Socket<TcpConnectionInner<E>, E>;
pub type BoundUdpSocket<E> = BoundSocket<UdpSocket, E>; pub type TcpListener<E> = Socket<TcpListenerInner<E>, E>;
pub type UdpSocket<E> = Socket<UdpSocketInner, E>;
/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`]. /// Common states shared by [`TcpConnectionBg`] and [`UdpSocketBg`].
pub struct BoundSocketInner<T: AnySocket<E>, E> { ///
iface: Arc<dyn Iface<E>>, /// In the type name, `Bg` means "background". Its meaning is described below:
port: u16, /// - A foreground socket (e.g., [`TcpConnection`]) handles system calls from the user program.
socket: T, /// - A background socket (e.g., [`TcpConnectionBg`]) handles packets from the network.
observer: RwLock<T::Observer, WriteIrqDisabled>, pub struct SocketBg<T: Inner<E>, E: Ext> {
bound: BoundPort<E>,
inner: T,
observer: Once<T::Observer>,
events: AtomicU8, events: AtomicU8,
next_poll_at_ms: AtomicU64, next_poll_at_ms: AtomicU64,
} }
/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`]. /// States needed by [`TcpConnectionBg`] but not [`UdpSocketBg`].
pub struct TcpSocket { pub struct TcpConnectionInner<E: Ext> {
socket: SpinLock<RawTcpSocketExt, LocalIrqDisabled>, socket: SpinLock<RawTcpSocketExt<E>, LocalIrqDisabled>,
is_dead: AtomicBool, is_dead: AtomicBool,
} }
struct RawTcpSocketExt { struct RawTcpSocketExt<E: Ext> {
socket: Box<RawTcpSocket>, socket: Box<RawTcpSocket>,
listener: Option<Arc<TcpListenerBg<E>>>,
has_connected: bool, has_connected: bool,
/// Whether the socket is in the background.
///
/// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means
/// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a
/// state of waiting for certain network events (e.g., remote FIN/ACK packets), so
/// [`BoundSocketInner`] may still be alive for a while.
in_background: bool,
} }
impl Deref for RawTcpSocketExt { impl<E: Ext> Deref for RawTcpSocketExt<E> {
type Target = RawTcpSocket; type Target = RawTcpSocket;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -77,18 +101,28 @@ impl Deref for RawTcpSocketExt {
} }
} }
impl DerefMut for RawTcpSocketExt { impl<E: Ext> DerefMut for RawTcpSocketExt<E> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.socket &mut self.socket
} }
} }
impl RawTcpSocketExt { impl<E: Ext> RawTcpSocketExt<E> {
fn on_new_state(&mut self) -> SocketEvents { fn on_new_state(&mut self, this: &KeyableArc<TcpConnectionBg<E>>) -> SocketEvents {
if self.may_send() { if self.may_send() && !self.has_connected {
self.has_connected = true; self.has_connected = true;
if let Some(ref listener) = self.listener {
let mut backlog = listener.inner.lock();
if let Some(value) = backlog.connecting.take(this) {
backlog.connected.push(value);
}
listener.add_events(SocketEvents::CAN_RECV);
}
} }
self.update_dead(this);
if self.is_peer_closed() { if self.is_peer_closed() {
SocketEvents::PEER_CLOSED SocketEvents::PEER_CLOSED
} else if self.is_closed() { } else if self.is_closed() {
@ -97,148 +131,178 @@ impl RawTcpSocketExt {
SocketEvents::empty() SocketEvents::empty()
} }
} }
}
impl TcpSocket { /// Updates whether the TCP connection is dead.
fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> {
self.socket.lock()
}
/// Returns whether the TCP socket is dead.
/// ///
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. /// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections.
fn is_dead(&self) -> bool {
self.is_dead.load(Ordering::Relaxed)
}
/// Updates whether the TCP socket is dead.
///
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets.
/// ///
/// This method must be called after handling network events. However, it is not necessary to /// This method must be called after handling network events. However, it is not necessary to
/// call this method after handling non-closing user events, because the socket can never be /// call this method after handling non-closing user events, because the socket can never be
/// dead if user events can reach the socket. /// dead if it is not closed.
fn update_dead(&self, socket: &RawTcpSocketExt) { fn update_dead(&self, this: &KeyableArc<TcpConnectionBg<E>>) {
if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed { if self.state() == smoltcp::socket::tcp::State::Closed {
self.is_dead.store(true, Ordering::Relaxed); this.inner.is_dead.store(true, Ordering::Relaxed);
} }
}
/// Sets the TCP socket in [`TimeWait`] state as dead. // According to the current smoltcp implementation, a backlog socket will return back to
/// // the `Listen` state if the connection is RSTed before its establishment.
/// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. if self.state() == smoltcp::socket::tcp::State::Listen {
/// this.inner.is_dead.store(true, Ordering::Relaxed);
/// [`TimeWait`]: smoltcp::socket::tcp::State::TimeWait
fn set_dead_timewait(&self, socket: &RawTcpSocketExt) { if let Some(ref listener) = self.listener {
debug_assert!( let mut backlog = listener.inner.lock();
socket.in_background && socket.state() == smoltcp::socket::tcp::State::TimeWait // This may fail due to race conditions, but it's fine.
); let _ = backlog.connecting.remove(this);
self.is_dead.store(true, Ordering::Relaxed); }
}
} }
} }
impl<E: Ext> AnySocket<E> for TcpSocket { impl<E: Ext> TcpConnectionInner<E> {
type RawSocket = RawTcpSocket; fn new(socket: Box<RawTcpSocket>, listener: Option<Arc<TcpListenerBg<E>>>) -> Self {
type Observer = E::TcpEventObserver;
fn new(socket: Box<Self::RawSocket>) -> Self {
let socket_ext = RawTcpSocketExt { let socket_ext = RawTcpSocketExt {
socket, socket,
listener,
has_connected: false, has_connected: false,
in_background: false,
}; };
Self { TcpConnectionInner {
socket: SpinLock::new(socket_ext), socket: SpinLock::new(socket_ext),
is_dead: AtomicBool::new(false), is_dead: AtomicBool::new(false),
} }
} }
fn on_drop(this: &Arc<BoundSocketInner<Self, E>>) { fn lock(&self) -> SpinLockGuard<RawTcpSocketExt<E>, LocalIrqDisabled> {
let mut socket = this.socket.lock(); self.socket.lock()
}
socket.in_background = true; /// Returns whether the TCP connection is dead.
///
/// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections.
fn is_dead(&self) -> bool {
self.is_dead.load(Ordering::Relaxed)
}
/// Sets the TCP connection in [`TimeWait`] state as dead.
///
/// See [`TcpConnectionBg::is_dead`] for the definition of dead TCP connections.
///
/// [`TimeWait`]: smoltcp::socket::tcp::State::TimeWait
fn set_dead_timewait(&self, socket: &RawTcpSocketExt<E>) {
debug_assert!(socket.state() == smoltcp::socket::tcp::State::TimeWait);
self.is_dead.store(true, Ordering::Relaxed);
}
}
impl<E: Ext> Inner<E> for TcpConnectionInner<E> {
type Observer = E::TcpEventObserver;
fn on_drop(this: &KeyableArc<SocketBg<Self, E>>) {
let mut socket = this.inner.lock();
// FIXME: Send RSTs when there is unread data.
socket.close(); socket.close();
// A TCP socket may not be appropriate for immediate removal. We leave the removal decision // A TCP connection may not be appropriate for immediate removal. We leave the removal
// to the polling logic. // decision to the polling logic.
this.update_next_poll_at_ms(PollAt::Now); this.update_next_poll_at_ms(PollAt::Now);
this.socket.update_dead(&socket); socket.update_dead(this);
} }
} }
/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`]. pub struct TcpBacklog<E: Ext> {
type UdpSocket = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>; socket: Box<RawTcpSocket>,
max_conn: usize,
connecting: BTreeSet<TcpConnection<E>>,
connected: Vec<TcpConnection<E>>,
}
impl<E: Ext> AnySocket<E> for UdpSocket { pub type TcpListenerInner<E> = SpinLock<TcpBacklog<E>, LocalIrqDisabled>;
type RawSocket = RawUdpSocket;
impl<E: Ext> Inner<E> for TcpListenerInner<E> {
type Observer = E::TcpEventObserver;
fn on_drop(this: &KeyableArc<SocketBg<Self, E>>) {
// A TCP listener can be removed immediately.
this.bound.iface().common().remove_tcp_listener(this);
let (connecting, connected) = {
let mut socket = this.inner.lock();
(
core::mem::take(&mut socket.connecting),
core::mem::take(&mut socket.connected),
)
};
// The lock on `connecting`/`connected` cannot be locked after locking `self`, otherwise we
// might get a deadlock. due to inconsistent lock order problems.
//
// FIXME: Send RSTs instead of going through the normal socket close process.
drop(connecting);
drop(connected);
}
}
/// States needed by [`UdpSocketBg`] but not [`TcpConnectionBg`].
type UdpSocketInner = SpinLock<Box<RawUdpSocket>, LocalIrqDisabled>;
impl<E: Ext> Inner<E> for UdpSocketInner {
type Observer = E::UdpEventObserver; type Observer = E::UdpEventObserver;
fn new(socket: Box<Self::RawSocket>) -> Self { fn on_drop(this: &KeyableArc<SocketBg<Self, E>>) {
Self::new(socket) this.inner.lock().close();
}
fn on_drop(this: &Arc<BoundSocketInner<Self, E>>)
where
E: Ext,
{
this.socket.lock().close();
// A UDP socket can be removed immediately. // A UDP socket can be removed immediately.
this.iface.common().remove_udp_socket(this); this.bound.iface().common().remove_udp_socket(this);
} }
} }
impl<T: AnySocket<E>, E: Ext> Drop for BoundSocket<T, E> { impl<T: Inner<E>, E: Ext> Drop for Socket<T, E> {
fn drop(&mut self) { fn drop(&mut self) {
T::on_drop(&self.0); if self.0.is_usable() {
T::on_drop(&self.0);
}
} }
} }
pub(crate) type BoundTcpSocketInner<E> = BoundSocketInner<TcpSocket, E>; pub(crate) type TcpConnectionBg<E> = SocketBg<TcpConnectionInner<E>, E>;
pub(crate) type BoundUdpSocketInner<E> = BoundSocketInner<UdpSocket, E>; pub(crate) type TcpListenerBg<E> = SocketBg<TcpListenerInner<E>, E>;
pub(crate) type UdpSocketBg<E> = SocketBg<UdpSocketInner, E>;
impl<T: AnySocket<E>, E: Ext> BoundSocket<T, E> { impl<T: Inner<E>, E: Ext> Socket<T, E> {
pub(crate) fn new( pub(crate) fn new(bound: BoundPort<E>, inner: T) -> Self {
iface: Arc<dyn Iface<E>>, Self(Takeable::new(KeyableArc::new(SocketBg {
port: u16, bound,
socket: Box<T::RawSocket>, inner,
observer: T::Observer, observer: Once::new(),
) -> Self {
Self(Arc::new(BoundSocketInner {
iface,
port,
socket: T::new(socket),
observer: RwLock::new(observer),
events: AtomicU8::new(0), events: AtomicU8::new(0),
next_poll_at_ms: AtomicU64::new(u64::MAX), next_poll_at_ms: AtomicU64::new(u64::MAX),
})) })))
} }
pub(crate) fn inner(&self) -> &Arc<BoundSocketInner<T, E>> { pub(crate) fn inner(&self) -> &KeyableArc<SocketBg<T, E>> {
&self.0 &self.0
} }
} }
impl<T: AnySocket<E>, E: Ext> BoundSocket<T, E> { impl<T: Inner<E>, E: Ext> Socket<T, E> {
/// Sets the observer whose `on_events` will be called when certain iface events happen. /// Initializes the observer whose `on_events` will be called when certain iface events happen.
/// ///
/// The caller needs to be responsible for race conditions if network events can occur /// The caller needs to be responsible for race conditions if network events can occur
/// simultaneously. /// simultaneously.
pub fn set_observer(&self, new_observer: T::Observer) { ///
*self.0.observer.write() = new_observer; /// Calling this method on a socket whose observer has already been initialized will have no
/// effect.
pub fn init_observer(&self, new_observer: T::Observer) {
self.0.observer.call_once(|| new_observer);
} }
pub fn local_endpoint(&self) -> Option<IpEndpoint> { pub fn local_endpoint(&self) -> Option<IpEndpoint> {
let ip_addr = { self.0.bound.endpoint()
let ipv4_addr = self.0.iface.ipv4_addr()?;
IpAddress::Ipv4(ipv4_addr)
};
Some(IpEndpoint::new(ip_addr, self.0.port))
} }
pub fn iface(&self) -> &Arc<dyn Iface<E>> { pub fn iface(&self) -> &Arc<dyn Iface<E>> {
&self.0.iface self.0.bound.iface()
} }
} }
@ -264,50 +328,76 @@ impl Deref for NeedIfacePoll {
} }
} }
impl<E: Ext> BoundTcpSocket<E> { impl<E: Ext> TcpConnection<E> {
/// Connects to a remote endpoint. /// Connects to a remote endpoint.
/// ///
/// Polling the iface is _always_ required after this method succeeds. /// Polling the iface is _always_ required after this method succeeds.
pub fn connect( pub fn new_connect(
&self, bound: BoundPort<E>,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
) -> Result<(), smoltcp::socket::tcp::ConnectError> { option: &RawTcpOption,
let common = self.iface().common(); observer: E::TcpEventObserver,
let mut iface = common.interface(); ) -> Result<Self, (BoundPort<E>, smoltcp::socket::tcp::ConnectError)> {
let socket = {
let mut socket = new_tcp_socket();
let mut socket = self.0.socket.lock(); option.apply(&mut socket);
socket.connect(iface.context(), remote_endpoint, self.0.port)?; let common = bound.iface().common();
let mut iface = common.interface();
socket.has_connected = false; if let Err(err) = socket.connect(iface.context(), remote_endpoint, bound.port()) {
self.0.update_next_poll_at_ms(PollAt::Now); drop(iface);
return Err((bound, err));
}
Ok(()) socket
};
let inner = TcpConnectionInner::new(socket, None);
let connection = Self::new(bound, inner);
connection.0.update_next_poll_at_ms(PollAt::Now);
connection.init_observer(observer);
connection
.iface()
.common()
.register_tcp_connection(connection.inner().clone());
Ok(connection)
} }
/// Returns the state of the connecting procedure. /// Returns the state of the connecting procedure.
pub fn connect_state(&self) -> ConnectState { pub fn connect_state(&self) -> ConnectState {
let socket = self.0.socket.lock(); let socket = self.0.inner.lock();
if socket.state() == State::SynSent || socket.state() == State::SynReceived { if socket.state() == State::SynSent || socket.state() == State::SynReceived {
ConnectState::Connecting ConnectState::Connecting
} else if socket.has_connected { } else if socket.has_connected {
ConnectState::Connected ConnectState::Connected
} else if KeyableArc::strong_count(self.0.as_ref()) > 1 {
// Now we should return `ConnectState::Refused`. However, when we do this, we must
// guarantee that `into_bound_port` can succeed (see the method's doc comments). We can
// only guarantee this after we have removed all `Arc<TcpConnectionBg>` in the iface's
// socket set.
//
// This branch serves to avoid a race condition: if the removal process hasn't
// finished, we will return `Connecting` so that the caller won't try to call
// `into_bound_port` (which may fail immediately).
ConnectState::Connecting
} else { } else {
ConnectState::Refused ConnectState::Refused
} }
} }
/// Listens at a specified endpoint. /// Converts back to the [`BoundPort`].
/// ///
/// Polling the iface is _not_ required after this method succeeds. /// This method will succeed if the connection is fully closed and no network events can reach
pub fn listen( /// this connection. We guarantee that this method will always succeed if
&self, /// [`Self::connect_state`] returns [`ConnectState::Refused`].
local_endpoint: IpEndpoint, pub fn into_bound_port(mut self) -> Option<BoundPort<E>> {
) -> Result<(), smoltcp::socket::tcp::ListenError> { let this: TcpConnectionBg<E> = Arc::into_inner(self.0.take().into())?;
let mut socket = self.0.socket.lock(); Some(this.bound)
socket.listen(local_endpoint)
} }
/// Sends some data. /// Sends some data.
@ -320,7 +410,7 @@ impl<E: Ext> BoundTcpSocket<E> {
let common = self.iface().common(); let common = self.iface().common();
let mut iface = common.interface(); let mut iface = common.interface();
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
let result = socket.send(f)?; let result = socket.send(f)?;
let need_poll = self let need_poll = self
@ -340,7 +430,7 @@ impl<E: Ext> BoundTcpSocket<E> {
let common = self.iface().common(); let common = self.iface().common();
let mut iface = common.interface(); let mut iface = common.interface();
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
let result = socket.recv(f)?; let result = socket.recv(f)?;
let need_poll = self let need_poll = self
@ -354,8 +444,9 @@ impl<E: Ext> BoundTcpSocket<E> {
/// ///
/// Polling the iface is _always_ required after this method succeeds. /// Polling the iface is _always_ required after this method succeeds.
pub fn close(&self) { pub fn close(&self) {
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
socket.listener = None;
socket.close(); socket.close();
self.0.update_next_poll_at_ms(PollAt::Now); self.0.update_next_poll_at_ms(PollAt::Now);
} }
@ -368,14 +459,14 @@ impl<E: Ext> BoundTcpSocket<E> {
where where
F: FnOnce(&RawTcpSocket) -> R, F: FnOnce(&RawTcpSocket) -> R,
{ {
let socket = self.0.socket.lock(); let socket = self.0.inner.lock();
f(&socket) f(&socket)
} }
} }
impl<E: Ext> RawTcpSetOption for BoundTcpSocket<E> { impl<E: Ext> RawTcpSetOption for TcpConnection<E> {
fn set_keep_alive(&mut self, interval: Option<Duration>) -> NeedIfacePoll { fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
socket.set_keep_alive(interval); socket.set_keep_alive(interval);
if interval.is_some() { if interval.is_some() {
@ -386,20 +477,130 @@ impl<E: Ext> RawTcpSetOption for BoundTcpSocket<E> {
} }
} }
fn set_nagle_enabled(&mut self, enabled: bool) { fn set_nagle_enabled(&self, enabled: bool) {
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
socket.set_nagle_enabled(enabled); socket.set_nagle_enabled(enabled);
} }
} }
impl<E: Ext> BoundUdpSocket<E> { impl<E: Ext> TcpListener<E> {
/// Listens at a specified endpoint.
///
/// Polling the iface is _not_ required after this method succeeds.
pub fn new_listen(
bound: BoundPort<E>,
max_conn: usize,
option: &RawTcpOption,
observer: E::TcpEventObserver,
) -> Result<Self, (BoundPort<E>, smoltcp::socket::tcp::ListenError)> {
let Some(local_endpoint) = bound.endpoint() else {
return Err((bound, smoltcp::socket::tcp::ListenError::Unaddressable));
};
let socket = {
let mut socket = new_tcp_socket();
option.apply(&mut socket);
if let Err(err) = socket.listen(local_endpoint) {
return Err((bound, err));
}
socket
};
let inner = TcpListenerInner::new(TcpBacklog {
socket,
max_conn,
connecting: BTreeSet::new(),
connected: Vec::new(),
});
let listener = Self::new(bound, inner);
listener.init_observer(observer);
listener
.iface()
.common()
.register_tcp_listener(listener.inner().clone());
Ok(listener)
}
/// Accepts a TCP connection.
///
/// Polling the iface is _not_ required after this method succeeds.
pub fn accept(&self) -> Option<(TcpConnection<E>, IpEndpoint)> {
let accepted = {
let mut backlog = self.0.inner.lock();
backlog.connected.pop()?
};
let remote_endpoint = {
// The lock on `accepted` cannot be locked after locking `self`, otherwise we might get
// a deadlock. due to inconsistent lock order problems.
let mut socket = accepted.0.inner.lock();
socket.listener = None;
socket.remote_endpoint()
};
Some((accepted, remote_endpoint.unwrap()))
}
/// Returns whether there is a TCP connection to accept.
///
/// It's the caller's responsibility to deal with race conditions when using this method.
pub fn can_accept(&self) -> bool {
!self.0.inner.lock().connected.is_empty()
}
}
impl<E: Ext> RawTcpSetOption for TcpListener<E> {
fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll {
let mut backlog = self.0.inner.lock();
backlog.socket.set_keep_alive(interval);
NeedIfacePoll::FALSE
}
fn set_nagle_enabled(&self, enabled: bool) {
let mut backlog = self.0.inner.lock();
backlog.socket.set_nagle_enabled(enabled);
}
}
impl<E: Ext> UdpSocket<E> {
/// Binds to a specified endpoint. /// Binds to a specified endpoint.
/// ///
/// Polling the iface is _not_ required after this method succeeds. /// Polling the iface is _not_ required after this method succeeds.
pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> { pub fn new_bind(
let mut socket = self.0.socket.lock(); bound: BoundPort<E>,
observer: E::UdpEventObserver,
) -> Result<Self, (BoundPort<E>, smoltcp::socket::udp::BindError)> {
let Some(local_endpoint) = bound.endpoint() else {
return Err((bound, smoltcp::socket::udp::BindError::Unaddressable));
};
socket.bind(local_endpoint) let socket = {
let mut socket = new_udp_socket();
if let Err(err) = socket.bind(local_endpoint) {
return Err((bound, err));
}
socket
};
let inner = UdpSocketInner::new(socket);
let socket = Self::new(bound, inner);
socket.init_observer(observer);
socket
.iface()
.common()
.register_udp_socket(socket.inner().clone());
Ok(socket)
} }
/// Sends some data. /// Sends some data.
@ -418,7 +619,7 @@ impl<E: Ext> BoundUdpSocket<E> {
use crate::errors::udp::SendError; use crate::errors::udp::SendError;
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
if size > socket.packet_send_capacity() { if size > socket.packet_send_capacity() {
return Err(SendError::TooLarge); return Err(SendError::TooLarge);
@ -442,7 +643,7 @@ impl<E: Ext> BoundUdpSocket<E> {
where where
F: FnOnce(&[u8], UdpMetadata) -> R, F: FnOnce(&[u8], UdpMetadata) -> R,
{ {
let mut socket = self.0.socket.lock(); let mut socket = self.0.inner.lock();
let (data, meta) = socket.recv()?; let (data, meta) = socket.recv()?;
let result = f(data, meta); let result = f(data, meta);
@ -458,12 +659,12 @@ impl<E: Ext> BoundUdpSocket<E> {
where where
F: FnOnce(&RawUdpSocket) -> R, F: FnOnce(&RawUdpSocket) -> R,
{ {
let socket = self.0.socket.lock(); let socket = self.0.inner.lock();
f(&socket) f(&socket)
} }
} }
impl<T: AnySocket<E>, E> BoundSocketInner<T, E> { impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
pub(crate) fn has_events(&self) -> bool { pub(crate) fn has_events(&self) -> bool {
self.events.load(Ordering::Relaxed) != 0 self.events.load(Ordering::Relaxed) != 0
} }
@ -474,8 +675,28 @@ impl<T: AnySocket<E>, E> BoundSocketInner<T, E> {
let events = self.events.load(Ordering::Relaxed); let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed); self.events.store(0, Ordering::Relaxed);
let observer = self.observer.read(); if let Some(observer) = self.observer.get() {
observer.on_events(SocketEvents::from_bits_truncate(events)); observer.on_events(SocketEvents::from_bits_truncate(events));
}
}
pub(crate) fn on_dead_events(this: KeyableArc<Self>)
where
T::Observer: Clone,
{
// This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = this.events.load(Ordering::Relaxed);
this.events.store(0, Ordering::Relaxed);
let observer = this.observer.get().cloned();
drop(this);
// Notify dead events after the `Arc` is dropped to ensure the observer sees this event
// with the expected reference count. See `TcpConnection::connect_state` for an example.
if let Some(ref observer) = observer {
observer.on_events(SocketEvents::from_bits_truncate(events));
}
} }
fn add_events(&self, new_events: SocketEvents) { fn add_events(&self, new_events: SocketEvents) {
@ -498,7 +719,7 @@ impl<T: AnySocket<E>, E> BoundSocketInner<T, E> {
/// ///
/// The update is typically needed after new network or user events have been handled, so this /// The update is typically needed after new network or user events have been handled, so this
/// method also marks that there may be new events, so that the event observer provided by /// method also marks that there may be new events, so that the event observer provided by
/// [`BoundSocket::set_observer`] can be notified later. /// [`Socket::init_observer`] can be notified later.
fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll { fn update_next_poll_at_ms(&self, poll_at: PollAt) -> NeedIfacePoll {
match poll_at { match poll_at {
PollAt::Now => { PollAt::Now => {
@ -522,30 +743,23 @@ impl<T: AnySocket<E>, E> BoundSocketInner<T, E> {
} }
} }
impl<T: AnySocket<E>, E> BoundSocketInner<T, E> { impl<E: Ext> TcpConnectionBg<E> {
pub(crate) fn port(&self) -> u16 { /// Returns whether the TCP connection is dead.
self.port
}
}
impl<E: Ext> BoundTcpSocketInner<E> {
/// Returns whether the TCP socket is dead.
/// ///
/// A TCP socket is considered dead if and only if the following two conditions are met: /// A TCP connection is considered dead when and only when the TCP socket is in the closed
/// 1. The TCP connection is closed, so this socket cannot process any network events. /// state, meaning it's no longer accepting packets from the network. This is different from
/// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this /// the socket file being closed, which only initiates the socket close process.
/// [`BoundSocketInner`] is in background and no more user events can reach it.
pub(crate) fn is_dead(&self) -> bool { pub(crate) fn is_dead(&self) -> bool {
self.socket.is_dead() self.inner.is_dead()
} }
} }
impl<T: AnySocket<E>, E> BoundSocketInner<T, E> { impl<T: Inner<E>, E: Ext> SocketBg<T, E> {
/// Returns whether an incoming packet _may_ be processed by the socket. /// Returns whether an incoming packet _may_ be processed by the socket.
/// ///
/// The check is intended to be lock-free and fast, but may have false positives. /// The check is intended to be lock-free and fast, but may have false positives.
pub(crate) fn can_process(&self, dst_port: u16) -> bool { pub(crate) fn can_process(&self, dst_port: u16) -> bool {
self.port == dst_port self.bound.port() == dst_port
} }
/// Returns whether the socket _may_ generate an outgoing packet. /// Returns whether the socket _may_ generate an outgoing packet.
@ -563,15 +777,15 @@ pub(crate) enum TcpProcessResult {
ProcessedWithReply(IpRepr, TcpRepr<'static>), ProcessedWithReply(IpRepr, TcpRepr<'static>),
} }
impl<E: Ext> BoundTcpSocketInner<E> { impl<E: Ext> TcpConnectionBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed. /// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process( pub(crate) fn process(
&self, this: &KeyableArc<Self>,
cx: &mut Context, cx: &mut Context,
ip_repr: &IpRepr, ip_repr: &IpRepr,
tcp_repr: &TcpRepr, tcp_repr: &TcpRepr,
) -> TcpProcessResult { ) -> TcpProcessResult {
let mut socket = self.socket.lock(); let mut socket = this.inner.lock();
if !socket.accepts(cx, ip_repr, tcp_repr) { if !socket.accepts(cx, ip_repr, tcp_repr) {
return TcpProcessResult::NotProcessed; return TcpProcessResult::NotProcessed;
@ -594,7 +808,7 @@ impl<E: Ext> BoundTcpSocketInner<E> {
&& tcp_repr.control == TcpControl::Syn && tcp_repr.control == TcpControl::Syn
&& tcp_repr.ack_number.is_none() && tcp_repr.ack_number.is_none()
{ {
self.socket.set_dead_timewait(&socket); this.inner.set_dead_timewait(&socket);
return TcpProcessResult::NotProcessed; return TcpProcessResult::NotProcessed;
} }
@ -609,26 +823,25 @@ impl<E: Ext> BoundTcpSocketInner<E> {
}; };
if socket.state() != old_state { if socket.state() != old_state {
events |= socket.on_new_state(); events |= socket.on_new_state(this);
} }
self.add_events(events); this.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx)); this.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket);
result result
} }
/// Tries to generate an outgoing packet and dispatches the generated packet. /// Tries to generate an outgoing packet and dispatches the generated packet.
pub(crate) fn dispatch<D>( pub(crate) fn dispatch<D>(
&self, this: &KeyableArc<Self>,
cx: &mut Context, cx: &mut Context,
dispatch: D, dispatch: D,
) -> Option<(IpRepr, TcpRepr<'static>)> ) -> Option<(IpRepr, TcpRepr<'static>)>
where where
D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>,
{ {
let mut socket = self.socket.lock(); let mut socket = this.inner.lock();
let old_state = socket.state(); let old_state = socket.state();
let mut events = SocketEvents::empty(); let mut events = SocketEvents::empty();
@ -652,18 +865,76 @@ impl<E: Ext> BoundTcpSocketInner<E> {
} }
if socket.state() != old_state { if socket.state() != old_state {
events |= socket.on_new_state(); events |= socket.on_new_state(this);
} }
self.add_events(events); this.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx)); this.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket);
reply reply
} }
} }
impl<E: Ext> BoundUdpSocketInner<E> { impl<E: Ext> TcpListenerBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process(
this: &KeyableArc<Self>,
cx: &mut Context,
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> (TcpProcessResult, Option<KeyableArc<TcpConnectionBg<E>>>) {
let mut backlog = this.inner.lock();
if !backlog.socket.accepts(cx, ip_repr, tcp_repr) {
return (TcpProcessResult::NotProcessed, None);
}
// FIXME: According to the Linux implementation, `max_conn` is the upper bound of
// `connected.len()`. We currently limit it to `connected.len() + connecting.len()` for
// simplicity.
if backlog.connected.len() + backlog.connecting.len() >= backlog.max_conn {
return (TcpProcessResult::Processed, None);
}
let result = match backlog.socket.process(cx, ip_repr, tcp_repr) {
None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
};
if backlog.socket.state() == smoltcp::socket::tcp::State::Listen {
return (result, None);
}
let new_socket = {
let mut socket = new_tcp_socket();
RawTcpOption::inherit(&backlog.socket, &mut socket);
socket.listen(backlog.socket.listen_endpoint()).unwrap();
socket
};
let inner = TcpConnectionInner::new(
core::mem::replace(&mut backlog.socket, new_socket),
Some(this.clone().into()),
);
let conn = TcpConnection::new(
this.bound
.iface()
.bind(BindPortConfig::CanReuse(this.bound.port()))
.unwrap(),
inner,
);
let conn_bg = conn.inner().clone();
let inserted = backlog.connecting.insert(conn);
assert!(inserted);
conn_bg.update_next_poll_at_ms(PollAt::Now);
(result, Some(conn_bg))
}
}
impl<E: Ext> UdpSocketBg<E> {
/// Tries to process an incoming packet and returns whether the packet is processed. /// Tries to process an incoming packet and returns whether the packet is processed.
pub(crate) fn process( pub(crate) fn process(
&self, &self,
@ -672,7 +943,7 @@ impl<E: Ext> BoundUdpSocketInner<E> {
udp_repr: &UdpRepr, udp_repr: &UdpRepr,
udp_payload: &[u8], udp_payload: &[u8],
) -> bool { ) -> bool {
let mut socket = self.socket.lock(); let mut socket = self.inner.lock();
if !socket.accepts(cx, ip_repr, udp_repr) { if !socket.accepts(cx, ip_repr, udp_repr) {
return false; return false;
@ -697,7 +968,7 @@ impl<E: Ext> BoundUdpSocketInner<E> {
where where
D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]), D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]),
{ {
let mut socket = self.socket.lock(); let mut socket = self.inner.lock();
socket socket
.dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| { .dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| {

View File

@ -6,15 +6,12 @@ mod option;
mod state; mod state;
mod unbound; mod unbound;
pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState, NeedIfacePoll}; pub use bound::{ConnectState, NeedIfacePoll, TcpConnection, TcpListener, UdpSocket};
pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; pub(crate) use bound::{TcpConnectionBg, TcpListenerBg, TcpProcessResult, UdpSocketBg};
pub use event::{SocketEventObserver, SocketEvents}; pub use event::{SocketEventObserver, SocketEvents};
pub use option::RawTcpSetOption; pub use option::{RawTcpOption, RawTcpSetOption};
pub use state::{TcpState, TcpStateCheck}; pub use state::TcpStateCheck;
pub use unbound::{ pub use unbound::{TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN};
UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
UDP_SEND_PAYLOAD_LEN,
};
pub type RawTcpSocket = smoltcp::socket::tcp::Socket<'static>; pub type RawTcpSocket = smoltcp::socket::tcp::Socket<'static>;
pub type RawUdpSocket = smoltcp::socket::udp::Socket<'static>; pub type RawUdpSocket = smoltcp::socket::udp::Socket<'static>;

View File

@ -2,20 +2,37 @@
use smoltcp::time::Duration; use smoltcp::time::Duration;
use super::NeedIfacePoll; use super::{NeedIfacePoll, RawTcpSocket};
/// A trait defines setting socket options on a raw socket. /// A trait defines setting socket options on a raw socket.
///
/// TODO: When `UnboundSocket` is removed, all methods in this trait can accept
/// `&self` instead of `&mut self` as parameter.
pub trait RawTcpSetOption { pub trait RawTcpSetOption {
/// Sets the keep alive interval. /// Sets the keep alive interval.
/// ///
/// Polling the iface _may_ be required after this method succeeds. /// Polling the iface _may_ be required after this method succeeds.
fn set_keep_alive(&mut self, interval: Option<Duration>) -> NeedIfacePoll; fn set_keep_alive(&self, interval: Option<Duration>) -> NeedIfacePoll;
/// Enables or disables Nagles Algorithm. /// Enables or disables Nagles Algorithm.
/// ///
/// Polling the iface is not required after this method succeeds. /// Polling the iface is _not_ required after this method succeeds.
fn set_nagle_enabled(&mut self, enabled: bool); fn set_nagle_enabled(&self, enabled: bool);
}
/// Socket options on a raw socket.
pub struct RawTcpOption {
/// The keep alive interval.
pub keep_alive: Option<Duration>,
/// Whether Nagle's algorithm is enabled.
pub is_nagle_enabled: bool,
}
impl RawTcpOption {
pub(super) fn apply(&self, socket: &mut RawTcpSocket) {
socket.set_keep_alive(self.keep_alive);
socket.set_nagle_enabled(self.is_nagle_enabled);
}
pub(super) fn inherit(from: &RawTcpSocket, to: &mut RawTcpSocket) {
to.set_keep_alive(from.keep_alive());
to.set_nagle_enabled(from.nagle_enabled());
}
} }

View File

@ -2,75 +2,31 @@
use alloc::{boxed::Box, vec}; use alloc::{boxed::Box, vec};
use super::{option::RawTcpSetOption, NeedIfacePoll, RawTcpSocket, RawUdpSocket}; use super::{RawTcpSocket, RawUdpSocket};
pub struct UnboundSocket<T> { pub(super) fn new_tcp_socket() -> Box<RawTcpSocket> {
socket: Box<T>, let raw_tcp_socket = {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]);
RawTcpSocket::new(rx_buffer, tx_buffer)
};
Box::new(raw_tcp_socket)
} }
pub type UnboundTcpSocket = UnboundSocket<RawTcpSocket>; pub(super) fn new_udp_socket() -> Box<RawUdpSocket> {
pub type UnboundUdpSocket = UnboundSocket<RawUdpSocket>; let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
impl UnboundTcpSocket { let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
pub fn new() -> Self { vec![metadata; UDP_METADATA_LEN],
let raw_tcp_socket = { vec![0u8; UDP_RECV_PAYLOAD_LEN],
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]); );
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]); let tx_buffer = smoltcp::socket::udp::PacketBuffer::new(
RawTcpSocket::new(rx_buffer, tx_buffer) vec![metadata; UDP_METADATA_LEN],
}; vec![0u8; UDP_SEND_PAYLOAD_LEN],
Self { );
socket: Box::new(raw_tcp_socket), RawUdpSocket::new(rx_buffer, tx_buffer)
} };
} Box::new(raw_udp_socket)
}
impl Default for UnboundTcpSocket {
fn default() -> Self {
Self::new()
}
}
impl RawTcpSetOption for UnboundTcpSocket {
fn set_keep_alive(&mut self, interval: Option<smoltcp::time::Duration>) -> NeedIfacePoll {
self.socket.set_keep_alive(interval);
NeedIfacePoll::FALSE
}
fn set_nagle_enabled(&mut self, enabled: bool) {
self.socket.set_nagle_enabled(enabled);
}
}
impl UnboundUdpSocket {
pub fn new() -> Self {
let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
vec![metadata; UDP_METADATA_LEN],
vec![0u8; UDP_RECV_PAYLOAD_LEN],
);
let tx_buffer = smoltcp::socket::udp::PacketBuffer::new(
vec![metadata; UDP_METADATA_LEN],
vec![0u8; UDP_SEND_PAYLOAD_LEN],
);
RawUdpSocket::new(rx_buffer, tx_buffer)
};
Self {
socket: Box::new(raw_udp_socket),
}
}
}
impl Default for UnboundUdpSocket {
fn default() -> Self {
Self::new()
}
}
impl<T> UnboundSocket<T> {
pub(crate) fn into_raw(self) -> Box<T> {
self.socket
}
} }
// TCP socket buffer sizes: // TCP socket buffer sizes:

View File

@ -138,9 +138,22 @@ impl<T: ?Sized> KeyableArc<T> {
} }
/// Creates a new `KeyableWeak` pointer to this allocation. /// Creates a new `KeyableWeak` pointer to this allocation.
#[inline]
pub fn downgrade(this: &Self) -> KeyableWeak<T> { pub fn downgrade(this: &Self) -> KeyableWeak<T> {
Arc::downgrade(&this.0).into() Arc::downgrade(&this.0).into()
} }
/// Gets the number of strong pointers pointing to this allocation.
#[inline]
pub fn strong_count(this: &Self) -> usize {
Arc::strong_count(&this.0)
}
/// Gets the number of weak pointers pointing to this allocation.
#[inline]
pub fn weak_count(this: &Self) -> usize {
Arc::weak_count(&this.0)
}
} }
impl<T: ?Sized> Deref for KeyableArc<T> { impl<T: ?Sized> Deref for KeyableArc<T> {

View File

@ -9,5 +9,8 @@ pub use init::{init, IFACES};
pub use poll::lazy_init; pub use poll::lazy_init;
pub type Iface = dyn aster_bigtcp::iface::Iface<ext::BigtcpExt>; pub type Iface = dyn aster_bigtcp::iface::Iface<ext::BigtcpExt>;
pub type BoundTcpSocket = aster_bigtcp::socket::BoundTcpSocket<ext::BigtcpExt>; pub type BoundPort = aster_bigtcp::iface::BoundPort<ext::BigtcpExt>;
pub type BoundUdpSocket = aster_bigtcp::socket::BoundUdpSocket<ext::BigtcpExt>;
pub type TcpConnection = aster_bigtcp::socket::TcpConnection<ext::BigtcpExt>;
pub type TcpListener = aster_bigtcp::socket::TcpListener<ext::BigtcpExt>;
pub type UdpSocket = aster_bigtcp::socket::UdpSocket<ext::BigtcpExt>;

View File

@ -7,7 +7,7 @@ use aster_bigtcp::{
}; };
use crate::{ use crate::{
net::iface::{Iface, IFACES}, net::iface::{BoundPort, Iface, IFACES},
prelude::*, prelude::*,
}; };
@ -45,30 +45,20 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<Iface> {
ifaces[0].clone() ifaces[0].clone()
} }
pub(super) fn bind_socket<S, T>( pub(super) fn bind_port(endpoint: &IpEndpoint, can_reuse: bool) -> Result<BoundPort> {
unbound_socket: Box<S>,
endpoint: &IpEndpoint,
can_reuse: bool,
bind: impl FnOnce(
Arc<Iface>,
Box<S>,
BindPortConfig,
) -> core::result::Result<T, (BindError, Box<S>)>,
) -> core::result::Result<T, (Error, Box<S>)> {
let iface = match get_iface_to_bind(&endpoint.addr) { let iface = match get_iface_to_bind(&endpoint.addr) {
Some(iface) => iface, Some(iface) => iface,
None => { None => {
let err = Error::with_message( return_errno_with_message!(
Errno::EADDRNOTAVAIL, Errno::EADDRNOTAVAIL,
"the address is not available from the local machine", "the address is not available from the local machine"
); );
return Err((err, unbound_socket));
} }
}; };
let bind_port_config = BindPortConfig::new(endpoint.port, can_reuse); let bind_port_config = BindPortConfig::new(endpoint.port, can_reuse);
bind(iface, unbound_socket, bind_port_config).map_err(|(err, unbound)| (err.into(), unbound)) Ok(iface.bind(bind_port_config)?)
} }
impl From<BindError> for Error { impl From<BindError> for Error {

View File

@ -8,7 +8,7 @@ use aster_bigtcp::{
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::{ net::{
iface::{BoundUdpSocket, Iface}, iface::{Iface, UdpSocket},
socket::util::send_recv_flags::SendRecvFlags, socket::util::send_recv_flags::SendRecvFlags,
}, },
prelude::*, prelude::*,
@ -16,12 +16,12 @@ use crate::{
}; };
pub struct BoundDatagram { pub struct BoundDatagram {
bound_socket: BoundUdpSocket, bound_socket: UdpSocket,
remote_endpoint: Option<IpEndpoint>, remote_endpoint: Option<IpEndpoint>,
} }
impl BoundDatagram { impl BoundDatagram {
pub fn new(bound_socket: BoundUdpSocket) -> Self { pub fn new(bound_socket: UdpSocket) -> Self {
Self { Self {
bound_socket, bound_socket,
remote_endpoint: None, remote_endpoint: None,

View File

@ -1,19 +1,17 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::{socket::UnboundUdpSocket, wire::IpEndpoint}; use aster_bigtcp::{socket::UdpSocket, wire::IpEndpoint};
use super::{bound::BoundDatagram, DatagramObserver}; use super::{bound::BoundDatagram, DatagramObserver};
use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*}; use crate::{events::IoEvents, net::socket::ip::common::bind_port, prelude::*};
pub struct UnboundDatagram { pub struct UnboundDatagram {
unbound_socket: Box<UnboundUdpSocket>, _private: (),
} }
impl UnboundDatagram { impl UnboundDatagram {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self { _private: () }
unbound_socket: Box::new(UnboundUdpSocket::new()),
}
} }
pub fn bind( pub fn bind(
@ -22,18 +20,17 @@ impl UnboundDatagram {
can_reuse: bool, can_reuse: bool,
observer: DatagramObserver, observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> { ) -> core::result::Result<BoundDatagram, (Error, Self)> {
let bound_socket = match bind_socket( let bound_port = match bind_port(endpoint, can_reuse) {
self.unbound_socket, Ok(bound_port) => bound_port,
endpoint, Err(err) => return Err((err, self)),
can_reuse,
|iface, socket, config| iface.bind_udp(socket, observer, config),
) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),
}; };
let bound_endpoint = bound_socket.local_endpoint().unwrap(); let bound_socket = match UdpSocket::new_bind(bound_port, observer) {
bound_socket.bind(bound_endpoint).unwrap(); Ok(bound_socket) => bound_socket,
Err((_, err)) => {
unreachable!("`new_bind fails with {:?}, which should not happen", err)
}
};
Ok(BoundDatagram::new(bound_socket)) Ok(BoundDatagram::new(bound_socket))
} }

View File

@ -12,7 +12,7 @@ use super::StreamObserver;
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::{ net::{
iface::{BoundTcpSocket, Iface}, iface::{Iface, TcpConnection},
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
}, },
prelude::*, prelude::*,
@ -21,7 +21,7 @@ use crate::{
}; };
pub struct ConnectedStream { pub struct ConnectedStream {
bound_socket: BoundTcpSocket, tcp_conn: TcpConnection,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
/// Indicates whether this connection is "new" in a `connect()` system call. /// Indicates whether this connection is "new" in a `connect()` system call.
/// ///
@ -47,12 +47,12 @@ pub struct ConnectedStream {
impl ConnectedStream { impl ConnectedStream {
pub fn new( pub fn new(
bound_socket: BoundTcpSocket, tcp_conn: TcpConnection,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
is_new_connection: bool, is_new_connection: bool,
) -> Self { ) -> Self {
Self { Self {
bound_socket, tcp_conn,
remote_endpoint, remote_endpoint,
is_new_connection, is_new_connection,
is_receiving_closed: AtomicBool::new(false), is_receiving_closed: AtomicBool::new(false),
@ -70,7 +70,7 @@ impl ConnectedStream {
if cmd.shut_write() { if cmd.shut_write() {
self.is_sending_closed.store(true, Ordering::Relaxed); self.is_sending_closed.store(true, Ordering::Relaxed);
self.bound_socket.close(); self.tcp_conn.close();
events |= IoEvents::OUT | IoEvents::HUP; events |= IoEvents::OUT | IoEvents::HUP;
} }
@ -84,7 +84,7 @@ impl ConnectedStream {
writer: &mut dyn MultiWrite, writer: &mut dyn MultiWrite,
_flags: SendRecvFlags, _flags: SendRecvFlags,
) -> Result<(usize, NeedIfacePoll)> { ) -> Result<(usize, NeedIfacePoll)> {
let result = self.bound_socket.recv(|socket_buffer| { let result = self.tcp_conn.recv(|socket_buffer| {
match writer.write(&mut VmReader::from(&*socket_buffer)) { match writer.write(&mut VmReader::from(&*socket_buffer)) {
Ok(len) => (len, Ok(len)), Ok(len) => (len, Ok(len)),
Err(e) => (0, Err(e)), Err(e) => (0, Err(e)),
@ -116,7 +116,7 @@ impl ConnectedStream {
reader: &mut dyn MultiRead, reader: &mut dyn MultiRead,
_flags: SendRecvFlags, _flags: SendRecvFlags,
) -> Result<(usize, NeedIfacePoll)> { ) -> Result<(usize, NeedIfacePoll)> {
let result = self.bound_socket.send(|socket_buffer| { let result = self.tcp_conn.send(|socket_buffer| {
match reader.read(&mut VmWriter::from(socket_buffer)) { match reader.read(&mut VmWriter::from(socket_buffer)) {
Ok(len) => (len, Ok(len)), Ok(len) => (len, Ok(len)),
Err(e) => (0, Err(e)), Err(e) => (0, Err(e)),
@ -143,7 +143,7 @@ impl ConnectedStream {
} }
pub fn local_endpoint(&self) -> IpEndpoint { pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap() self.tcp_conn.local_endpoint().unwrap()
} }
pub fn remote_endpoint(&self) -> IpEndpoint { pub fn remote_endpoint(&self) -> IpEndpoint {
@ -151,7 +151,7 @@ impl ConnectedStream {
} }
pub fn iface(&self) -> &Arc<Iface> { pub fn iface(&self) -> &Arc<Iface> {
self.bound_socket.iface() self.tcp_conn.iface()
} }
pub fn check_new(&mut self) -> Result<()> { pub fn check_new(&mut self) -> Result<()> {
@ -163,8 +163,12 @@ impl ConnectedStream {
Ok(()) Ok(())
} }
pub(super) fn init_observer(&self, observer: StreamObserver) {
self.tcp_conn.init_observer(observer);
}
pub(super) fn check_io_events(&self) -> IoEvents { pub(super) fn check_io_events(&self) -> IoEvents {
self.bound_socket.raw_with(|socket| { self.tcp_conn.raw_with(|socket| {
if socket.is_peer_closed() { if socket.is_peer_closed() {
// Only the sending side of peer socket is closed // Only the sending side of peer socket is closed
self.is_receiving_closed.store(true, Ordering::Relaxed); self.is_receiving_closed.store(true, Ordering::Relaxed);
@ -202,18 +206,14 @@ impl ConnectedStream {
}) })
} }
pub(super) fn set_observer(&self, observer: StreamObserver) {
self.bound_socket.set_observer(observer)
}
pub(super) fn set_raw_option<R>( pub(super) fn set_raw_option<R>(
&mut self, &self,
set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R,
) -> R { ) -> R {
set_option(&mut self.bound_socket) set_option(&self.tcp_conn)
} }
pub(super) fn raw_with<R>(&self, f: impl FnOnce(&RawTcpSocket) -> R) -> R { pub(super) fn raw_with<R>(&self, f: impl FnOnce(&RawTcpSocket) -> R) -> R {
self.bound_socket.raw_with(f) self.tcp_conn.raw_with(f)
} }
} }

View File

@ -1,19 +1,19 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::{ use aster_bigtcp::{
socket::{ConnectState, RawTcpSetOption}, socket::{ConnectState, RawTcpOption, RawTcpSetOption},
wire::IpEndpoint, wire::IpEndpoint,
}; };
use super::{connected::ConnectedStream, init::InitStream}; use super::{connected::ConnectedStream, init::InitStream, StreamObserver};
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::iface::{BoundTcpSocket, Iface}, net::iface::{BoundPort, Iface, TcpConnection},
prelude::*, prelude::*,
}; };
pub struct ConnectingStream { pub struct ConnectingStream {
bound_socket: BoundTcpSocket, tcp_conn: TcpConnection,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
} }
@ -25,32 +25,38 @@ pub enum ConnResult {
impl ConnectingStream { impl ConnectingStream {
pub fn new( pub fn new(
bound_socket: BoundTcpSocket, bound_port: BoundPort,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
) -> core::result::Result<Self, (Error, BoundTcpSocket)> { option: &RawTcpOption,
observer: StreamObserver,
) -> core::result::Result<Self, (Error, BoundPort)> {
// The only reason this method might fail is because we're trying to connect to an // The only reason this method might fail is because we're trying to connect to an
// unspecified address (i.e. 0.0.0.0). We currently have no support for binding to, // unspecified address (i.e. 0.0.0.0). We currently have no support for binding to,
// listening on, or connecting to the unspecified address. // listening on, or connecting to the unspecified address.
// //
// We assume the remote will just refuse to connect, so we return `ECONNREFUSED`. // We assume the remote will just refuse to connect, so we return `ECONNREFUSED`.
if bound_socket.connect(remote_endpoint).is_err() { let tcp_conn =
return Err(( match TcpConnection::new_connect(bound_port, remote_endpoint, option, observer) {
Error::with_message( Ok(tcp_conn) => tcp_conn,
Errno::ECONNREFUSED, Err((bound_port, _)) => {
"connecting to an unspecified address is not supported", return Err((
), Error::with_message(
bound_socket, Errno::ECONNREFUSED,
)); "connecting to an unspecified address is not supported",
} ),
bound_port,
))
}
};
Ok(Self { Ok(Self {
bound_socket, tcp_conn,
remote_endpoint, remote_endpoint,
}) })
} }
pub fn has_result(&self) -> bool { pub fn has_result(&self) -> bool {
match self.bound_socket.connect_state() { match self.tcp_conn.connect_state() {
ConnectState::Connecting => false, ConnectState::Connecting => false,
ConnectState::Connected => true, ConnectState::Connected => true,
ConnectState::Refused => true, ConnectState::Refused => true,
@ -58,21 +64,23 @@ impl ConnectingStream {
} }
pub fn into_result(self) -> ConnResult { pub fn into_result(self) -> ConnResult {
let next_state = self.bound_socket.connect_state(); let next_state = self.tcp_conn.connect_state();
match next_state { match next_state {
ConnectState::Connecting => ConnResult::Connecting(self), ConnectState::Connecting => ConnResult::Connecting(self),
ConnectState::Connected => ConnResult::Connected(ConnectedStream::new( ConnectState::Connected => ConnResult::Connected(ConnectedStream::new(
self.bound_socket, self.tcp_conn,
self.remote_endpoint, self.remote_endpoint,
true, true,
)), )),
ConnectState::Refused => ConnResult::Refused(InitStream::new_bound(self.bound_socket)), ConnectState::Refused => ConnResult::Refused(InitStream::new_bound(
self.tcp_conn.into_bound_port().unwrap(),
)),
} }
} }
pub fn local_endpoint(&self) -> IpEndpoint { pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap() self.tcp_conn.local_endpoint().unwrap()
} }
pub fn remote_endpoint(&self) -> IpEndpoint { pub fn remote_endpoint(&self) -> IpEndpoint {
@ -80,7 +88,7 @@ impl ConnectingStream {
} }
pub fn iface(&self) -> &Arc<Iface> { pub fn iface(&self) -> &Arc<Iface> {
self.bound_socket.iface() self.tcp_conn.iface()
} }
pub(super) fn check_io_events(&self) -> IoEvents { pub(super) fn check_io_events(&self) -> IoEvents {
@ -88,9 +96,9 @@ impl ConnectingStream {
} }
pub(super) fn set_raw_option<R>( pub(super) fn set_raw_option<R>(
&mut self, &self,
set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R,
) -> R { ) -> R {
set_option(&mut self.bound_socket) set_option(&self.tcp_conn)
} }
} }

View File

@ -1,43 +1,38 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::{ use aster_bigtcp::{socket::RawTcpOption, wire::IpEndpoint};
socket::{RawTcpSetOption, UnboundTcpSocket},
wire::IpEndpoint,
};
use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver}; use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver};
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::{ net::{
iface::BoundTcpSocket, iface::BoundPort,
socket::ip::common::{bind_socket, get_ephemeral_endpoint}, socket::ip::common::{bind_port, get_ephemeral_endpoint},
}, },
prelude::*, prelude::*,
process::signal::Pollee,
}; };
pub enum InitStream { pub enum InitStream {
Unbound(Box<UnboundTcpSocket>), Unbound,
Bound(BoundTcpSocket), Bound(BoundPort),
} }
impl InitStream { impl InitStream {
pub fn new() -> Self { pub fn new() -> Self {
InitStream::Unbound(Box::new(UnboundTcpSocket::new())) InitStream::Unbound
} }
pub fn new_bound(bound_socket: BoundTcpSocket) -> Self { pub fn new_bound(bound_port: BoundPort) -> Self {
InitStream::Bound(bound_socket) InitStream::Bound(bound_port)
} }
pub fn bind( pub fn bind(
self, self,
endpoint: &IpEndpoint, endpoint: &IpEndpoint,
can_reuse: bool, can_reuse: bool,
observer: StreamObserver, ) -> core::result::Result<BoundPort, (Error, Self)> {
) -> core::result::Result<BoundTcpSocket, (Error, Self)> { match self {
let unbound_socket = match self { InitStream::Unbound => (),
InitStream::Unbound(unbound_socket) => unbound_socket,
InitStream::Bound(bound_socket) => { InitStream::Bound(bound_socket) => {
return Err(( return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"), Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
@ -45,48 +40,45 @@ impl InitStream {
)); ));
} }
}; };
let bound_socket = match bind_socket(
unbound_socket, let bound_port = match bind_port(endpoint, can_reuse) {
endpoint, Ok(bound_port) => bound_port,
can_reuse, Err(err) => return Err((err, Self::Unbound)),
|iface, socket, config| iface.bind_tcp(socket, observer, config),
) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
}; };
Ok(bound_socket)
Ok(bound_port)
} }
fn bind_to_ephemeral_endpoint( fn bind_to_ephemeral_endpoint(
self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
observer: StreamObserver, ) -> core::result::Result<BoundPort, (Error, Self)> {
) -> core::result::Result<BoundTcpSocket, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint); let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, false, observer) self.bind(&endpoint, false)
} }
pub fn connect( pub fn connect(
self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
pollee: &Pollee, option: &RawTcpOption,
observer: StreamObserver,
) -> core::result::Result<ConnectingStream, (Error, Self)> { ) -> core::result::Result<ConnectingStream, (Error, Self)> {
let bound_socket = match self { let bound_port = match self {
InitStream::Bound(bound_socket) => bound_socket, InitStream::Bound(bound_port) => bound_port,
InitStream::Unbound(_) => self InitStream::Unbound => self.bind_to_ephemeral_endpoint(remote_endpoint)?,
.bind_to_ephemeral_endpoint(remote_endpoint, StreamObserver::new(pollee.clone()))?,
}; };
ConnectingStream::new(bound_socket, *remote_endpoint) ConnectingStream::new(bound_port, *remote_endpoint, option, observer)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket))) .map_err(|(err, bound_port)| (err, InitStream::Bound(bound_port)))
} }
pub fn listen( pub fn listen(
self, self,
backlog: usize, backlog: usize,
pollee: &Pollee, option: &RawTcpOption,
observer: StreamObserver,
) -> core::result::Result<ListenStream, (Error, Self)> { ) -> core::result::Result<ListenStream, (Error, Self)> {
let InitStream::Bound(bound_socket) = self else { let InitStream::Bound(bound_port) = self else {
// FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral // FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral
// port. However, INADDR_ANY is not yet supported, so we need to return an error first. // port. However, INADDR_ANY is not yet supported, so we need to return an error first.
debug_assert!(false, "listen() without bind() is not implemented"); debug_assert!(false, "listen() without bind() is not implemented");
@ -96,14 +88,13 @@ impl InitStream {
)); ));
}; };
ListenStream::new(bound_socket, backlog, pollee) Ok(ListenStream::new(bound_port, backlog, option, observer))
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
} }
pub fn local_endpoint(&self) -> Option<IpEndpoint> { pub fn local_endpoint(&self) -> Option<IpEndpoint> {
match self { match self {
InitStream::Unbound(_) => None, InitStream::Unbound => None,
InitStream::Bound(bound_socket) => Some(bound_socket.local_endpoint().unwrap()), InitStream::Bound(bound_port) => Some(bound_port.endpoint().unwrap()),
} }
} }
@ -111,14 +102,4 @@ impl InitStream {
// Linux adds OUT and HUP events for a newly created socket // Linux adds OUT and HUP events for a newly created socket
IoEvents::OUT | IoEvents::HUP IoEvents::OUT | IoEvents::HUP
} }
pub(super) fn set_raw_option<R>(
&mut self,
set_option: impl Fn(&mut dyn RawTcpSetOption) -> R,
) -> R {
match self {
InitStream::Unbound(unbound_socket) => set_option(unbound_socket.as_mut()),
InitStream::Bound(bound_socket) => set_option(bound_socket),
}
}
} }

View File

@ -1,103 +1,59 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::{ use aster_bigtcp::{
errors::tcp::ListenError, socket::{RawTcpOption, RawTcpSetOption},
iface::BindPortConfig,
socket::{RawTcpSetOption, TcpState, UnboundTcpSocket},
wire::IpEndpoint, wire::IpEndpoint,
}; };
use ostd::sync::PreemptDisabled;
use super::{connected::ConnectedStream, StreamObserver}; use super::{connected::ConnectedStream, StreamObserver};
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::iface::{BoundTcpSocket, Iface}, net::iface::{BoundPort, Iface, TcpListener},
prelude::*, prelude::*,
process::signal::Pollee,
}; };
pub struct ListenStream { pub struct ListenStream {
backlog: usize, tcp_listener: TcpListener,
/// A bound socket held to ensure the TCP port cannot be released
bound_socket: BoundTcpSocket,
/// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>, PreemptDisabled>,
} }
impl ListenStream { impl ListenStream {
pub fn new( pub fn new(
bound_socket: BoundTcpSocket, bound_port: BoundPort,
backlog: usize, backlog: usize,
pollee: &Pollee, option: &RawTcpOption,
) -> core::result::Result<Self, (Error, BoundTcpSocket)> { observer: StreamObserver,
) -> Self {
const SOMAXCONN: usize = 4096; const SOMAXCONN: usize = 4096;
let somaxconn = SOMAXCONN.min(backlog); let max_conn = SOMAXCONN.min(backlog);
let listen_stream = Self { let tcp_listener = match TcpListener::new_listen(bound_port, max_conn, option, observer) {
backlog: somaxconn, Ok(tcp_listener) => tcp_listener,
bound_socket, Err((_, err)) => {
backlog_sockets: RwLock::new(Vec::new()), unreachable!("`new_listen` fails with {:?}, which should not happen", err)
}
}; };
if let Err(err) = listen_stream.fill_backlog_sockets(pollee) {
return Err((err, listen_stream.bound_socket)); Self { tcp_listener }
}
Ok(listen_stream)
} }
/// Append sockets listening at LocalEndPoint to support backlog pub fn try_accept(&self) -> Result<ConnectedStream> {
fn fill_backlog_sockets(&self, pollee: &Pollee) -> Result<()> { let (new_conn, remote_endpoint) = self.tcp_listener.accept().ok_or_else(|| {
let mut backlog_sockets = self.backlog_sockets.write(); Error::with_message(Errno::EAGAIN, "no pending connection is available")
})?;
let backlog = self.backlog; Ok(ConnectedStream::new(new_conn, remote_endpoint, false))
let current_backlog_len = backlog_sockets.len();
debug_assert!(backlog >= current_backlog_len);
if backlog == current_backlog_len {
return Ok(());
}
for _ in current_backlog_len..backlog {
let backlog_socket = BacklogSocket::new(&self.bound_socket, pollee)?;
backlog_sockets.push(backlog_socket);
}
Ok(())
}
pub fn try_accept(&self, pollee: &Pollee) -> Result<ConnectedStream> {
let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets
.iter()
.position(|backlog_socket| backlog_socket.can_accept())
.ok_or_else(|| {
Error::with_message(Errno::EAGAIN, "no pending connection is available")
})?;
let active_backlog_socket = backlog_sockets.remove(index);
if let Ok(backlog_socket) = BacklogSocket::new(&self.bound_socket, pollee) {
backlog_sockets.push(backlog_socket);
}
let remote_endpoint = active_backlog_socket.remote_endpoint().unwrap();
Ok(ConnectedStream::new(
active_backlog_socket.into_bound_socket(),
remote_endpoint,
false,
))
} }
pub fn local_endpoint(&self) -> IpEndpoint { pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap() self.tcp_listener.local_endpoint().unwrap()
} }
pub fn iface(&self) -> &Arc<Iface> { pub fn iface(&self) -> &Arc<Iface> {
self.bound_socket.iface() self.tcp_listener.iface()
} }
pub(super) fn check_io_events(&self) -> IoEvents { pub(super) fn check_io_events(&self) -> IoEvents {
let backlog_sockets = self.backlog_sockets.read(); let can_accept = self.tcp_listener.can_accept();
let can_accept = backlog_sockets.iter().any(|socket| socket.can_accept());
// If network packets come in simultaneously, the socket state may change in the middle. // If network packets come in simultaneously, the socket state may change in the middle.
// However, the current pollee implementation should be able to handle this race condition. // However, the current pollee implementation should be able to handle this race condition.
@ -108,97 +64,10 @@ impl ListenStream {
} }
} }
/// Calls `f` to set socket option on raw socket.
///
/// This method will call `f` on the bound socket and each backlog socket that is in `Listen` state .
pub(super) fn set_raw_option<R>( pub(super) fn set_raw_option<R>(
&mut self, &self,
set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R,
) -> R { ) -> R {
self.backlog_sockets.write().iter_mut().for_each(|socket| { set_option(&self.tcp_listener)
if socket
.bound_socket
.raw_with(|raw_tcp_socket| raw_tcp_socket.state() != TcpState::Listen)
{
return;
}
// If the socket receives SYN after above check,
// we will also set keep alive on the socket that is not in `Listen` state.
// But such a race doesn't matter, we just let it happen.
set_option(&mut socket.bound_socket);
});
set_option(&mut self.bound_socket)
}
}
struct BacklogSocket {
bound_socket: BoundTcpSocket,
}
impl BacklogSocket {
// FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason
// why the error may occur. Perhaps it is better to call `unwrap()` directly?
fn new(bound_socket: &BoundTcpSocket, pollee: &Pollee) -> Result<Self> {
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
let unbound_socket = {
let mut unbound = UnboundTcpSocket::new();
unbound.set_keep_alive(bound_socket.raw_with(|socket| socket.keep_alive()));
unbound.set_nagle_enabled(bound_socket.raw_with(|socket| socket.nagle_enabled()));
// TODO: Inherit other options that can be set via `setsockopt` from bound socket
Box::new(unbound)
};
let bound_socket = {
let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true);
iface
.bind_tcp(
unbound_socket,
StreamObserver::new(pollee.clone()),
bind_port_config,
)
.map_err(|(err, _)| err)?
};
match bound_socket.listen(local_endpoint) {
Ok(()) => Ok(Self { bound_socket }),
Err(ListenError::Unaddressable) => {
return_errno_with_message!(Errno::EINVAL, "the listening address is invalid")
}
Err(ListenError::InvalidState) => {
return_errno_with_message!(Errno::EINVAL, "the listening socket is invalid")
}
}
}
/// Returns whether the backlog socket can be `accept`ed.
///
/// According to the Linux implementation, assuming the TCP Fast Open mechanism is off, a
/// backlog socket becomes ready to be returned in the `accept` system call when the 3-way
/// handshake is complete (i.e., when it enters the ESTABLISHED state).
///
/// The Linux kernel implementation can be found at
/// <https://elixir.bootlin.com/linux/v6.11.8/source/net/ipv4/tcp_input.c#L7304>.
//
// FIMXE: Some sockets may be dead (e.g., RSTed), and such sockets can never become alive
// again. We need to remove them from the backlog sockets.
fn can_accept(&self) -> bool {
self.bound_socket.raw_with(|socket| socket.may_send())
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket
.raw_with(|socket| socket.remote_endpoint())
}
fn into_bound_socket(self) -> BoundTcpSocket {
self.bound_socket
} }
} }

View File

@ -3,14 +3,14 @@
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{ use aster_bigtcp::{
socket::{NeedIfacePoll, RawTcpSetOption}, socket::{NeedIfacePoll, RawTcpOption, RawTcpSetOption},
wire::IpEndpoint, wire::IpEndpoint,
}; };
use connected::ConnectedStream; use connected::ConnectedStream;
use connecting::{ConnResult, ConnectingStream}; use connecting::{ConnResult, ConnectingStream};
use init::InitStream; use init::InitStream;
use listen::ListenStream; use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; use options::{Congestion, MaxSegment, NoDelay, WindowClamp, KEEPALIVE_INTERVAL};
use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard};
use takeable::Takeable; use takeable::Takeable;
use util::TcpOptionSet; use util::TcpOptionSet;
@ -83,6 +83,13 @@ impl OptionSet {
let tcp = TcpOptionSet::new(); let tcp = TcpOptionSet::new();
OptionSet { socket, tcp } OptionSet { socket, tcp }
} }
fn raw(&self) -> RawTcpOption {
RawTcpOption {
keep_alive: self.socket.keep_alive().then_some(KEEPALIVE_INTERVAL),
is_nagle_enabled: !self.tcp.no_delay(),
}
}
} }
impl StreamSocket { impl StreamSocket {
@ -114,7 +121,7 @@ impl StreamSocket {
}); });
let pollee = Pollee::new(); let pollee = Pollee::new();
connected_stream.set_observer(StreamObserver::new(pollee.clone())); connected_stream.init_observer(StreamObserver::new(pollee.clone()));
Arc::new(Self { Arc::new(Self {
options: RwLock::new(options), options: RwLock::new(options),
@ -207,7 +214,9 @@ impl StreamSocket {
// `Some(_)` if blocking is not necessary or not allowed. // `Some(_)` if blocking is not necessary or not allowed.
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> { fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> {
let is_nonblocking = self.is_nonblocking(); let is_nonblocking = self.is_nonblocking();
let mut state = self.write_updated_state(); let (options, mut state) = self.update_connecting();
let raw_option = options.raw();
let (result_or_block, iface_to_poll) = state.borrow_result(|mut owned_state| { let (result_or_block, iface_to_poll) = state.borrow_result(|mut owned_state| {
let init_stream = match owned_state { let init_stream = match owned_state {
@ -243,7 +252,11 @@ impl StreamSocket {
} }
}; };
let connecting_stream = match init_stream.connect(remote_endpoint, &self.pollee) { let connecting_stream = match init_stream.connect(
remote_endpoint,
&raw_option,
StreamObserver::new(self.pollee.clone()),
) {
Ok(connecting_stream) => connecting_stream, Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => { Err((err, init_stream)) => {
return (State::Init(init_stream), (Some(Err(err)), None)); return (State::Init(init_stream), (Some(Err(err)), None));
@ -298,13 +311,11 @@ impl StreamSocket {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
}; };
let accepted = listen_stream let accepted = listen_stream.try_accept().map(|connected_stream| {
.try_accept(&self.pollee) let remote_endpoint = connected_stream.remote_endpoint();
.map(|connected_stream| { let accepted_socket = Self::new_accepted(connected_stream);
let remote_endpoint = connected_stream.remote_endpoint(); (accepted_socket as _, remote_endpoint.into())
let accepted_socket = Self::new_accepted(connected_stream); });
(accepted_socket as _, remote_endpoint.into())
});
let iface_to_poll = listen_stream.iface().clone(); let iface_to_poll = listen_stream.iface().clone();
drop(state); drop(state);
@ -475,18 +486,14 @@ impl Socket for StreamSocket {
); );
}; };
let bound_socket = match init_stream.bind( let bound_port = match init_stream.bind(&endpoint, can_reuse) {
&endpoint, Ok(bound_port) => bound_port,
can_reuse,
StreamObserver::new(self.pollee.clone()),
) {
Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => { Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err)); return (State::Init(init_stream), Err(err));
} }
}; };
(State::Init(InitStream::new_bound(bound_socket)), Ok(())) (State::Init(InitStream::new_bound(bound_port)), Ok(()))
}) })
} }
@ -501,7 +508,9 @@ impl Socket for StreamSocket {
} }
fn listen(&self, backlog: usize) -> Result<()> { fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.write_updated_state(); let (options, mut state) = self.update_connecting();
let raw_option = options.raw();
state.borrow_result(|owned_state| { state.borrow_result(|owned_state| {
let init_stream = match owned_state { let init_stream = match owned_state {
@ -520,7 +529,11 @@ impl Socket for StreamSocket {
} }
}; };
let listen_stream = match init_stream.listen(backlog, &self.pollee) { let listen_stream = match init_stream.listen(
backlog,
&raw_option,
StreamObserver::new(self.pollee.clone()),
) {
Ok(listen_stream) => listen_stream, Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => { Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err)); return (State::Init(init_stream), Err(err));
@ -701,7 +714,7 @@ impl Socket for StreamSocket {
tcp_no_delay: NoDelay => { tcp_no_delay: NoDelay => {
let no_delay = tcp_no_delay.get().unwrap(); let no_delay = tcp_no_delay.get().unwrap();
options.tcp.set_no_delay(*no_delay); options.tcp.set_no_delay(*no_delay);
state.set_raw_option(|raw_socket: &mut dyn RawTcpSetOption| raw_socket.set_nagle_enabled(!no_delay)); state.set_raw_option(|raw_socket: &dyn RawTcpSetOption| raw_socket.set_nagle_enabled(!no_delay));
}, },
tcp_congestion: Congestion => { tcp_congestion: Congestion => {
let congestion = tcp_congestion.get().unwrap(); let congestion = tcp_congestion.get().unwrap();
@ -736,14 +749,16 @@ impl Socket for StreamSocket {
impl State { impl State {
/// Calls `f` to set raw socket option. /// Calls `f` to set raw socket option.
/// ///
/// Note that for listening socket, `f` is called on all backlog sockets in `Listen` State. /// For listening sockets, socket options are inherited by new connections. However, they are
/// That is to say, `f` won't be called on backlog sockets in `SynReceived` or `Established` state. /// not updated for connections in the backlog queue.
fn set_raw_option<R>(&mut self, set_option: impl Fn(&mut dyn RawTcpSetOption) -> R) -> R { fn set_raw_option<R>(&self, set_option: impl FnOnce(&dyn RawTcpSetOption) -> R) -> Option<R> {
match self { match self {
State::Init(init_stream) => init_stream.set_raw_option(set_option), State::Init(_) => None,
State::Connecting(connecting_stream) => connecting_stream.set_raw_option(set_option), State::Connecting(connecting_stream) => {
State::Connected(connected_stream) => connected_stream.set_raw_option(set_option), Some(connecting_stream.set_raw_option(set_option))
State::Listen(listen_stream) => listen_stream.set_raw_option(set_option), }
State::Connected(connected_stream) => Some(connected_stream.set_raw_option(set_option)),
State::Listen(listen_stream) => Some(listen_stream.set_raw_option(set_option)),
} }
} }
@ -758,24 +773,17 @@ impl State {
} }
impl SetSocketLevelOption for State { impl SetSocketLevelOption for State {
fn set_keep_alive(&mut self, keep_alive: bool) -> NeedIfacePoll { fn set_keep_alive(&self, keep_alive: bool) -> NeedIfacePoll {
/// The keepalive interval.
///
/// The linux value can be found at `/proc/sys/net/ipv4/tcp_keepalive_intvl`,
/// which is by default 75 seconds for most Linux distributions.
const KEEPALIVE_INTERVAL: aster_bigtcp::time::Duration =
aster_bigtcp::time::Duration::from_secs(75);
let interval = if keep_alive { let interval = if keep_alive {
Some(KEEPALIVE_INTERVAL) Some(KEEPALIVE_INTERVAL)
} else { } else {
None None
}; };
let set_keepalive = let set_keepalive = |raw_socket: &dyn RawTcpSetOption| raw_socket.set_keep_alive(interval);
|raw_socket: &mut dyn RawTcpSetOption| raw_socket.set_keep_alive(interval);
self.set_raw_option(set_keepalive) self.set_raw_option(set_keepalive)
.unwrap_or(NeedIfacePoll::FALSE)
} }
} }

View File

@ -4,6 +4,7 @@ use aster_bigtcp::socket::{SocketEventObserver, SocketEvents};
use crate::{events::IoEvents, process::signal::Pollee}; use crate::{events::IoEvents, process::signal::Pollee};
#[derive(Clone)]
pub struct StreamObserver(Pollee); pub struct StreamObserver(Pollee);
impl StreamObserver { impl StreamObserver {

View File

@ -9,3 +9,10 @@ impl_socket_options!(
pub struct MaxSegment(u32); pub struct MaxSegment(u32);
pub struct WindowClamp(u32); pub struct WindowClamp(u32);
); );
/// The keepalive interval.
///
/// The linux value can be found at `/proc/sys/net/ipv4/tcp_keepalive_intvl`,
/// which is by default 75 seconds for most Linux distributions.
pub(super) const KEEPALIVE_INTERVAL: aster_bigtcp::time::Duration =
aster_bigtcp::time::Duration::from_secs(75);

View File

@ -173,7 +173,7 @@ impl LingerOption {
/// A trait used for setting socket level options on actual sockets. /// A trait used for setting socket level options on actual sockets.
pub(in crate::net) trait SetSocketLevelOption { pub(in crate::net) trait SetSocketLevelOption {
/// Sets whether keepalive messages are enabled. /// Sets whether keepalive messages are enabled.
fn set_keep_alive(&mut self, _keep_alive: bool) -> NeedIfacePoll { fn set_keep_alive(&self, _keep_alive: bool) -> NeedIfacePoll {
NeedIfacePoll::FALSE NeedIfacePoll::FALSE
} }
} }