Properly close sockets

This commit is contained in:
Ruihan Li
2024-07-23 23:03:10 +08:00
committed by Tate, Hongliang Tian
parent 94b23e077d
commit 328ce9e92c
11 changed files with 162 additions and 95 deletions

View File

@ -59,14 +59,7 @@ impl AnyUnboundSocket {
} }
} }
pub struct AnyBoundSocket { pub struct AnyBoundSocket(Arc<AnyBoundSocketInner>);
iface: Arc<dyn Iface>,
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
observer: RwLock<Weak<dyn Observer<()>>>,
weak_self: Weak<Self>,
}
impl AnyBoundSocket { impl AnyBoundSocket {
pub(super) fn new( pub(super) fn new(
@ -75,21 +68,18 @@ impl AnyBoundSocket {
port: u16, port: u16,
socket_family: SocketFamily, socket_family: SocketFamily,
observer: Weak<dyn Observer<()>>, observer: Weak<dyn Observer<()>>,
) -> Arc<Self> { ) -> Self {
Arc::new_cyclic(|weak_self| Self { Self(Arc::new(AnyBoundSocketInner {
iface, iface,
handle, handle,
port, port,
socket_family, socket_family,
observer: RwLock::new(observer), observer: RwLock::new(observer),
weak_self: weak_self.clone(), }))
})
} }
pub(super) fn on_iface_events(&self) { pub(super) fn inner(&self) -> &Arc<AnyBoundSocketInner> {
if let Some(observer) = Weak::upgrade(&*self.observer.read()) { &self.0
observer.on_events(&())
}
} }
/// Set the observer whose `on_events` will be called when certain iface events happen. After /// Set the observer whose `on_events` will be called when certain iface events happen. After
@ -99,17 +89,101 @@ impl AnyBoundSocket {
/// that the old observer will never be called after the setting. Users should be aware of this /// that the old observer will never be called after the setting. Users should be aware of this
/// and proactively handle the race conditions if necessary. /// and proactively handle the race conditions if necessary.
pub fn set_observer(&self, handler: Weak<dyn Observer<()>>) { pub fn set_observer(&self, handler: Weak<dyn Observer<()>>) {
*self.observer.write() = handler; *self.0.observer.write() = handler;
self.on_iface_events(); self.0.on_iface_events();
} }
pub fn local_endpoint(&self) -> Option<IpEndpoint> { pub fn local_endpoint(&self) -> Option<IpEndpoint> {
let ip_addr = { let ip_addr = {
let ipv4_addr = self.iface.ipv4_addr()?; let ipv4_addr = self.0.iface.ipv4_addr()?;
IpAddress::Ipv4(ipv4_addr) IpAddress::Ipv4(ipv4_addr)
}; };
Some(IpEndpoint::new(ip_addr, self.port)) Some(IpEndpoint::new(ip_addr, self.0.port))
}
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
&self,
f: F,
) -> R {
self.0.raw_with(f)
}
/// Try to connect to a remote endpoint. Tcp socket only.
pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<()> {
let mut sockets = self.0.iface.sockets();
let socket = sockets.get_mut::<RawTcpSocket>(self.0.handle);
let port = self.0.port;
let mut iface_inner = self.0.iface.iface_inner();
let cx = iface_inner.context();
socket
.connect(cx, remote_endpoint, port)
.map_err(|_| Error::with_message(Errno::ENOBUFS, "send connection request failed"))?;
Ok(())
}
pub fn iface(&self) -> &Arc<dyn Iface> {
&self.0.iface
}
}
impl Drop for AnyBoundSocket {
fn drop(&mut self) {
if self.0.start_closing() {
self.0.iface.common().remove_bound_socket_now(&self.0);
} else {
self.0
.iface
.common()
.remove_bound_socket_when_closed(&self.0);
}
}
}
pub(super) struct AnyBoundSocketInner {
iface: Arc<dyn Iface>,
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
observer: RwLock<Weak<dyn Observer<()>>>,
}
impl AnyBoundSocketInner {
pub(super) fn on_iface_events(&self) {
if let Some(observer) = Weak::upgrade(&*self.observer.read()) {
observer.on_events(&())
}
}
pub(super) fn is_closed(&self) -> bool {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
socket.state() == smoltcp::socket::tcp::State::Closed
}),
SocketFamily::Udp => true,
}
}
/// Starts closing the socket and returns whether the socket is closed.
///
/// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets,
/// this method will always return `true`.
///
/// For other sockets, such as TCP connected sockets, they cannot be closed immediately because
/// we at least need to send the FIN packet and wait for the remote end to send an ACK packet.
/// In this case, this method will return `false` and [`Self::is_closed`] can be used to
/// determine if the closing process is complete.
fn start_closing(&self) -> bool {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
socket.state() == smoltcp::socket::tcp::State::Closed
}),
SocketFamily::Udp => {
self.raw_with(|socket: &mut RawUdpSocket| socket.close());
true
}
}
} }
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>( pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
@ -120,43 +194,13 @@ impl AnyBoundSocket {
let socket = sockets.get_mut::<T>(self.handle); let socket = sockets.get_mut::<T>(self.handle);
f(socket) f(socket)
} }
/// Try to connect to a remote endpoint. Tcp socket only.
pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<()> {
let mut sockets = self.iface.sockets();
let socket = sockets.get_mut::<RawTcpSocket>(self.handle);
let port = self.port;
let mut iface_inner = self.iface.iface_inner();
let cx = iface_inner.context();
socket
.connect(cx, remote_endpoint, port)
.map_err(|_| Error::with_message(Errno::ENOBUFS, "send connection request failed"))?;
Ok(())
} }
pub fn iface(&self) -> &Arc<dyn Iface> { impl Drop for AnyBoundSocketInner {
&self.iface
}
pub(super) fn weak_ref(&self) -> Weak<Self> {
self.weak_self.clone()
}
fn close(&self) {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| socket.close()),
SocketFamily::Udp => self.raw_with(|socket: &mut RawUdpSocket| socket.close()),
}
}
}
impl Drop for AnyBoundSocket {
fn drop(&mut self) { fn drop(&mut self) {
self.close(); let iface_common = self.iface.common();
self.iface.poll(); iface_common.remove_socket(self.handle);
self.iface.common().remove_socket(self.handle); iface_common.release_port(self.port);
self.iface.common().release_port(self.port);
self.iface.common().remove_bound_socket(self.weak_ref());
} }
} }

View File

@ -3,7 +3,7 @@
use alloc::collections::btree_map::Entry; use alloc::collections::btree_map::Entry;
use core::sync::atomic::{AtomicU64, Ordering}; use core::sync::atomic::{AtomicU64, Ordering};
use keyable_arc::KeyableWeak; use keyable_arc::KeyableArc;
use ostd::sync::WaitQueue; use ostd::sync::WaitQueue;
use smoltcp::{ use smoltcp::{
iface::{SocketHandle, SocketSet}, iface::{SocketHandle, SocketSet},
@ -12,10 +12,10 @@ use smoltcp::{
}; };
use super::{ use super::{
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily}, any_socket::{AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
time::get_network_timestamp, time::get_network_timestamp,
util::BindPortConfig, util::BindPortConfig,
Iface, Ipv4Address, AnyBoundSocket, Iface, Ipv4Address,
}; };
use crate::prelude::*; use crate::prelude::*;
@ -25,7 +25,8 @@ pub struct IfaceCommon {
used_ports: RwLock<BTreeMap<u16, usize>>, used_ports: RwLock<BTreeMap<u16, usize>>,
/// The time should do next poll. We stores the total milliseconds since system boots up. /// The time should do next poll. We stores the total milliseconds since system boots up.
next_poll_at_ms: AtomicU64, next_poll_at_ms: AtomicU64,
bound_sockets: RwLock<BTreeSet<KeyableWeak<AnyBoundSocket>>>, bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner>>>,
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner>>>,
/// The wait queue that background polling thread will sleep on /// The wait queue that background polling thread will sleep on
polling_wait_queue: WaitQueue, polling_wait_queue: WaitQueue,
} }
@ -40,6 +41,7 @@ impl IfaceCommon {
used_ports: RwLock::new(used_ports), used_ports: RwLock::new(used_ports),
next_poll_at_ms: AtomicU64::new(0), next_poll_at_ms: AtomicU64::new(0),
bound_sockets: RwLock::new(BTreeSet::new()), bound_sockets: RwLock::new(BTreeSet::new()),
closing_sockets: SpinLock::new(BTreeSet::new()),
polling_wait_queue: WaitQueue::new(), polling_wait_queue: WaitQueue::new(),
} }
} }
@ -109,7 +111,7 @@ impl IfaceCommon {
iface: Arc<dyn Iface>, iface: Arc<dyn Iface>,
socket: Box<AnyUnboundSocket>, socket: Box<AnyUnboundSocket>,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> { ) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
let port = if let Some(port) = config.port() { let port = if let Some(port) = config.port() {
port port
} else { } else {
@ -135,7 +137,7 @@ impl IfaceCommon {
), ),
}; };
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer); let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
self.insert_bound_socket(&bound_socket).unwrap(); self.insert_bound_socket(bound_socket.inner());
Ok(bound_socket) Ok(bound_socket)
} }
@ -184,10 +186,15 @@ impl IfaceCommon {
if has_events { if has_events {
self.bound_sockets.read().iter().for_each(|bound_socket| { self.bound_sockets.read().iter().for_each(|bound_socket| {
if let Some(bound_socket) = bound_socket.upgrade() {
bound_socket.on_iface_events(); bound_socket.on_iface_events();
}
}); });
let closed_sockets = self
.closing_sockets
.lock()
.extract_if(|closing_socket| closing_socket.is_closed())
.collect::<Vec<_>>();
drop(closed_sockets);
} }
} }
@ -200,19 +207,35 @@ impl IfaceCommon {
} }
} }
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocket>) -> Result<()> { fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner>) {
let weak_ref = KeyableWeak::from(Arc::downgrade(socket)); let keyable_socket = KeyableArc::from(socket.clone());
let mut bound_sockets = self.bound_sockets.write();
if bound_sockets.contains(&weak_ref) { let inserted = self.bound_sockets.write().insert(keyable_socket);
return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); assert!(inserted);
}
bound_sockets.insert(weak_ref);
Ok(())
} }
pub(super) fn remove_bound_socket(&self, socket: Weak<AnyBoundSocket>) { pub(super) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner>) {
let weak_ref = KeyableWeak::from(socket); let keyable_socket = KeyableArc::from(socket.clone());
self.bound_sockets.write().remove(&weak_ref);
let removed = self.bound_sockets.write().remove(&keyable_socket);
assert!(removed);
}
pub(super) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self.bound_sockets.write().remove(&keyable_socket);
assert!(removed);
let mut closing_sockets = self.closing_sockets.lock();
// Check `is_closed` after holding the lock to avoid race conditions.
if keyable_socket.is_closed() {
return;
}
let inserted = closing_sockets.insert(keyable_socket);
assert!(inserted);
} }
} }

View File

@ -45,7 +45,7 @@ pub trait Iface: internal::IfaceInternal + Send + Sync {
&self, &self,
socket: Box<AnyUnboundSocket>, socket: Box<AnyUnboundSocket>,
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> { ) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
let common = self.common(); let common = self.common();
common.bind_socket(self.arc_self(), socket, config) common.bind_socket(self.arc_self(), socket, config)
} }

View File

@ -46,7 +46,7 @@ pub(super) fn bind_socket(
unbound_socket: Box<AnyUnboundSocket>, unbound_socket: Box<AnyUnboundSocket>,
endpoint: &IpEndpoint, endpoint: &IpEndpoint,
can_reuse: bool, can_reuse: bool,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> { ) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
let iface = match get_iface_to_bind(&endpoint.addr) { let iface = match get_iface_to_bind(&endpoint.addr) {
Some(iface) => iface, Some(iface) => iface,
None => { None => {

View File

@ -13,12 +13,12 @@ use crate::{
}; };
pub struct BoundDatagram { pub struct BoundDatagram {
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
remote_endpoint: Option<IpEndpoint>, remote_endpoint: Option<IpEndpoint>,
} }
impl BoundDatagram { impl BoundDatagram {
pub fn new(bound_socket: Arc<AnyBoundSocket>) -> Self { pub fn new(bound_socket: AnyBoundSocket) -> Self {
Self { Self {
bound_socket, bound_socket,
remote_endpoint: None, remote_endpoint: None,

View File

@ -15,7 +15,7 @@ use crate::{
}; };
pub struct ConnectedStream { pub struct ConnectedStream {
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
/// Indicates whether this connection is "new" in a `connect()` system call. /// Indicates whether this connection is "new" in a `connect()` system call.
/// ///
@ -32,7 +32,7 @@ pub struct ConnectedStream {
impl ConnectedStream { impl ConnectedStream {
pub fn new( pub fn new(
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
is_new_connection: bool, is_new_connection: bool,
) -> Self { ) -> Self {

View File

@ -8,7 +8,7 @@ use crate::{
}; };
pub struct ConnectingStream { pub struct ConnectingStream {
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>, conn_result: RwLock<Option<ConnResult>>,
} }
@ -26,9 +26,9 @@ pub enum NonConnectedStream {
impl ConnectingStream { impl ConnectingStream {
pub fn new( pub fn new(
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> { ) -> core::result::Result<Self, (Error, AnyBoundSocket)> {
if let Err(err) = bound_socket.do_connect(remote_endpoint) { if let Err(err) = bound_socket.do_connect(remote_endpoint) {
return Err((err, bound_socket)); return Err((err, bound_socket));
} }

View File

@ -15,7 +15,7 @@ use crate::{
pub enum InitStream { pub enum InitStream {
Unbound(Box<AnyUnboundSocket>), Unbound(Box<AnyUnboundSocket>),
Bound(Arc<AnyBoundSocket>), Bound(AnyBoundSocket),
} }
impl InitStream { impl InitStream {
@ -23,14 +23,14 @@ impl InitStream {
InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer))) InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer)))
} }
pub fn new_bound(bound_socket: Arc<AnyBoundSocket>) -> Self { pub fn new_bound(bound_socket: AnyBoundSocket) -> Self {
InitStream::Bound(bound_socket) InitStream::Bound(bound_socket)
} }
pub fn bind( pub fn bind(
self, self,
endpoint: &IpEndpoint, endpoint: &IpEndpoint,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> { ) -> core::result::Result<AnyBoundSocket, (Error, Self)> {
let unbound_socket = match self { let unbound_socket = match self {
InitStream::Unbound(unbound_socket) => unbound_socket, InitStream::Unbound(unbound_socket) => unbound_socket,
InitStream::Bound(bound_socket) => { InitStream::Bound(bound_socket) => {
@ -50,7 +50,7 @@ impl InitStream {
fn bind_to_ephemeral_endpoint( fn bind_to_ephemeral_endpoint(
self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> { ) -> core::result::Result<AnyBoundSocket, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint); let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint) self.bind(&endpoint)
} }

View File

@ -13,16 +13,16 @@ use crate::{
pub struct ListenStream { pub struct ListenStream {
backlog: usize, backlog: usize,
/// A bound socket held to ensure the TCP port cannot be released /// A bound socket held to ensure the TCP port cannot be released
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
/// Backlog sockets listening at the local endpoint /// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>, backlog_sockets: RwLock<Vec<BacklogSocket>>,
} }
impl ListenStream { impl ListenStream {
pub fn new( pub fn new(
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
backlog: usize, backlog: usize,
) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> { ) -> core::result::Result<Self, (Error, AnyBoundSocket)> {
let listen_stream = Self { let listen_stream = Self {
backlog, backlog,
bound_socket, bound_socket,
@ -99,13 +99,13 @@ impl ListenStream {
} }
struct BacklogSocket { struct BacklogSocket {
bound_socket: Arc<AnyBoundSocket>, bound_socket: AnyBoundSocket,
} }
impl BacklogSocket { impl BacklogSocket {
// FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason // FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason
// why the error may occur. Perhaps it is better to call `unwrap()` directly? // why the error may occur. Perhaps it is better to call `unwrap()` directly?
fn new(bound_socket: &Arc<AnyBoundSocket>) -> Result<Self> { fn new(bound_socket: &AnyBoundSocket) -> Result<Self> {
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message( let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
Errno::EINVAL, Errno::EINVAL,
"the socket is not bound", "the socket is not bound",
@ -143,7 +143,7 @@ impl BacklogSocket {
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint()) .raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
} }
fn into_bound_socket(self) -> Arc<AnyBoundSocket> { fn into_bound_socket(self) -> AnyBoundSocket {
self.bound_socket self.bound_socket
} }
} }

View File

@ -131,7 +131,7 @@ int main(void)
for (backlog = 0; backlog <= MAX_TEST_BACKLOG; ++backlog) { for (backlog = 0; backlog <= MAX_TEST_BACKLOG; ++backlog) {
// Avoid "bind: Address already in use" // Avoid "bind: Address already in use"
addr.sin_port = htons(8080 + backlog); addr.sin_port = htons(10000 + backlog);
err = test_listen_backlog(&addr, backlog); err = test_listen_backlog(&addr, backlog);
if (err != 0) if (err != 0)

View File

@ -265,7 +265,7 @@ int main(void)
struct sockaddr_in addr; struct sockaddr_in addr;
addr.sin_family = AF_INET; addr.sin_family = AF_INET;
addr.sin_port = htons(8080); addr.sin_port = htons(9999);
if (inet_aton("127.0.0.1", &addr.sin_addr) < 0) { if (inet_aton("127.0.0.1", &addr.sin_addr) < 0) {
fprintf(stderr, "inet_aton cannot parse 127.0.0.1\n"); fprintf(stderr, "inet_aton cannot parse 127.0.0.1\n");
return -1; return -1;