Make connection addition and deletion in order

This commit is contained in:
Ruihan Li 2025-03-18 20:03:28 +08:00 committed by Tate, Hongliang Tian
parent a7e718e812
commit 240192f735
3 changed files with 94 additions and 72 deletions

View File

@ -15,7 +15,7 @@ use smoltcp::{
};
use super::{
poll::{FnHelper, PollContext},
poll::{FnHelper, PollContext, SocketTableAction},
poll_iface::PollableIface,
port::BindPortConfig,
time::get_network_timestamp,
@ -183,26 +183,23 @@ impl<E: Ext> IfaceCommon<E> {
interface.context_mut().now = get_network_timestamp();
let mut sockets = self.sockets.lock();
let mut dead_tcp_conns = Vec::new();
let mut socket_actions = Vec::new();
let mut new_tcp_conns = Vec::new();
let mut context = PollContext::new(
interface.as_mut(),
&sockets,
&mut new_tcp_conns,
&mut dead_tcp_conns,
);
let mut context = PollContext::new(interface.as_mut(), &sockets, &mut socket_actions);
context.poll_ingress(device, &mut process_phy, &mut dispatch_phy);
context.poll_egress(device, &mut dispatch_phy);
// Insert new connections and remove dead connections.
for new_tcp_conn in new_tcp_conns.into_iter() {
let res = sockets.insert_connection(new_tcp_conn);
debug_assert!(res.is_ok());
}
for dead_conn_key in dead_tcp_conns.into_iter() {
sockets.remove_dead_tcp_connection(&dead_conn_key);
for action in socket_actions.into_iter() {
match action {
SocketTableAction::AddTcpConn(new_tcp_conn) => {
let res = sockets.insert_connection(new_tcp_conn);
debug_assert!(res.is_ok());
}
SocketTableAction::DelTcpConn(dead_conn_key) => {
sockets.remove_dead_tcp_connection(&dead_conn_key);
}
}
}
// Notify all socket events.

View File

@ -25,22 +25,28 @@ use crate::{
pub(super) struct PollContext<'a, E: Ext> {
iface: PollableIfaceMut<'a, E>,
sockets: &'a SocketTable<E>,
new_tcp_conns: &'a mut Vec<Arc<TcpConnectionBg<E>>>,
dead_tcp_conns: &'a mut Vec<ConnectionKey>,
actions: &'a mut Vec<SocketTableAction<E>>,
}
/// Socket table actions such as adding or removing TCP connections.
///
/// Note that they must be performed in order. This is because the same connection key can occur
/// multiple times, but with different types of operations (e.g., add or remove).
pub(super) enum SocketTableAction<E: Ext> {
AddTcpConn(Arc<TcpConnectionBg<E>>),
DelTcpConn(ConnectionKey),
}
impl<'a, E: Ext> PollContext<'a, E> {
pub(super) fn new(
iface: PollableIfaceMut<'a, E>,
sockets: &'a SocketTable<E>,
new_tcp_conns: &'a mut Vec<Arc<TcpConnectionBg<E>>>,
dead_tcp_conns: &'a mut Vec<ConnectionKey>,
actions: &'a mut Vec<SocketTableAction<E>>,
) -> Self {
Self {
iface,
sockets,
new_tcp_conns,
dead_tcp_conns,
actions,
}
}
}
@ -159,7 +165,60 @@ impl<E: Ext> PollContext<'_, E> {
ip_repr: &IpRepr,
tcp_repr: &TcpRepr,
) -> Option<(IpRepr, TcpRepr<'static>)> {
// Process packets that request to create new connections first.
// Process packets belonging to existing connections first.
// Note that we must do this first because SYN packets may match existing TIME-WAIT
// sockets. See comments in `TcpConnectionBg::process` for details.
let connection_key = ConnectionKey::new(
ip_repr.dst_addr(),
tcp_repr.dst_port,
ip_repr.src_addr(),
tcp_repr.src_port,
);
let mut connection_in_table = self.sockets.lookup_connection(&connection_key);
loop {
// First try the connection in the socket table, as this is the most common. If it
// fails, it might mean that the connection is dead, the next step is to try the new
// connections instead.
let (should_break, connection) = if let Some(conn) = connection_in_table.take() {
(false, Some(conn))
} else {
// Find in reverse order because old connections must have been dead.
(
true,
self.actions
.iter()
.rev()
.flat_map(|action| match action {
SocketTableAction::AddTcpConn(conn) => Some(conn),
SocketTableAction::DelTcpConn(_) => None,
})
.find(|conn| conn.connection_key() == &connection_key),
)
};
if let Some(connection) = connection {
let (process_result, became_dead) =
connection.process(&mut self.iface, ip_repr, tcp_repr);
if *became_dead {
self.actions
.push(SocketTableAction::DelTcpConn(*connection.connection_key()));
}
match process_result {
TcpProcessResult::NotProcessed => {}
TcpProcessResult::Processed => return None,
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
return Some((ip_repr, tcp_repr))
}
}
}
if should_break {
break;
}
}
// Process packets that request to create new connections second.
if tcp_repr.control == TcpControl::Syn && tcp_repr.ack_number.is_none() {
let listener_key = ListenerKey::new(ip_repr.dst_addr(), tcp_repr.dst_port);
if let Some(listener) = self.sockets.lookup_listener(&listener_key) {
@ -167,7 +226,7 @@ impl<E: Ext> PollContext<'_, E> {
listener.process(&mut self.iface, ip_repr, tcp_repr);
if let Some(tcp_conn) = new_tcp_conn {
self.new_tcp_conns.push(tcp_conn);
self.actions.push(SocketTableAction::AddTcpConn(tcp_conn));
}
match processed {
@ -180,36 +239,6 @@ impl<E: Ext> PollContext<'_, E> {
}
}
// Process packets belonging to existing connections second.
let connection_key = ConnectionKey::new(
ip_repr.dst_addr(),
tcp_repr.dst_port,
ip_repr.src_addr(),
tcp_repr.src_port,
);
let connection = if let Some(connection) = self.sockets.lookup_connection(&connection_key) {
Some(connection)
} else {
self.new_tcp_conns
.iter()
.find(|tcp_conn| tcp_conn.connection_key() == &connection_key)
};
if let Some(connection) = connection {
let (process_result, became_dead) =
connection.process(&mut self.iface, ip_repr, tcp_repr);
if *became_dead {
self.dead_tcp_conns.push(*connection.connection_key());
}
match process_result {
TcpProcessResult::NotProcessed => {}
TcpProcessResult::Processed => return None,
TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => {
return Some((ip_repr, tcp_repr))
}
}
}
// "In no case does receipt of a segment containing RST give rise to a RST in response."
// See <https://datatracker.ietf.org/doc/html/rfc9293#section-4-1.64>.
if tcp_repr.control == TcpControl::Rst {
@ -359,7 +388,6 @@ impl<E: Ext> PollContext<'_, E> {
{
let mut tx_token = Some(tx_token);
let mut did_something = false;
let mut dead_conns = Vec::new();
loop {
let Some(socket) = self.iface.pop_pending_tcp() else {
@ -374,8 +402,7 @@ impl<E: Ext> PollContext<'_, E> {
let (reply, became_dead) =
TcpConnectionBg::dispatch(&socket, &mut self.iface, |iface, ip_repr, tcp_repr| {
let mut this =
PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns);
let mut this = PollContext::new(iface, self.sockets, self.actions);
if !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy(
@ -407,7 +434,8 @@ impl<E: Ext> PollContext<'_, E> {
});
if *became_dead {
self.dead_tcp_conns.push(*socket.connection_key());
self.actions
.push(SocketTableAction::DelTcpConn(*socket.connection_key()));
}
match (deferred, reply) {
@ -447,8 +475,6 @@ impl<E: Ext> PollContext<'_, E> {
}
}
self.dead_tcp_conns.append(&mut dead_conns);
(did_something, tx_token)
}
@ -459,7 +485,8 @@ impl<E: Ext> PollContext<'_, E> {
{
let mut tx_token = Some(tx_token);
let mut did_something = false;
let mut dead_conns = Vec::new();
let mut actions = Vec::new();
for socket in self.sockets.udp_socket_iter() {
if !socket.need_dispatch() {
@ -475,8 +502,7 @@ impl<E: Ext> PollContext<'_, E> {
let (cx, pending) = self.iface.inner_mut();
socket.dispatch(cx, |cx, ip_repr, udp_repr, udp_payload| {
let iface = PollableIfaceMut::new(cx, pending);
let mut this =
PollContext::new(iface, self.sockets, self.new_tcp_conns, &mut dead_conns);
let mut this = PollContext::new(iface, self.sockets, &mut actions);
if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) {
dispatch_phy(
@ -527,10 +553,10 @@ impl<E: Ext> PollContext<'_, E> {
}
}
// `dead_conns` should be empty,
// because we are using UDP sockets,
// and the `dead_conns` contains only dead TCP connections.
debug_assert!(dead_conns.is_empty());
// `actions` should be empty,
// because we are dealing with UDP sockets,
// and the `actions` contains only TCP actions.
debug_assert!(actions.is_empty());
(did_something, tx_token)
}

View File

@ -286,14 +286,13 @@ impl<E: Ext> SocketTable<E> {
&mut self.connection_buckets[bucket_index as usize]
};
if let Some(index) = bucket
let index = bucket
.connections
.iter()
.position(|tcp_connection| tcp_connection.connection_key() == key)
{
let connection = bucket.connections.swap_remove(index);
connection.on_dead_events();
}
.unwrap();
let connection = bucket.connections.swap_remove(index);
connection.on_dead_events();
}
pub(crate) fn remove_udp_socket(