mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-27 19:33:23 +00:00
Properly close sockets
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
94b23e077d
commit
328ce9e92c
@ -3,7 +3,7 @@
|
||||
use alloc::collections::btree_map::Entry;
|
||||
use core::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use keyable_arc::KeyableWeak;
|
||||
use keyable_arc::KeyableArc;
|
||||
use ostd::sync::WaitQueue;
|
||||
use smoltcp::{
|
||||
iface::{SocketHandle, SocketSet},
|
||||
@ -12,10 +12,10 @@ use smoltcp::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily},
|
||||
any_socket::{AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
|
||||
time::get_network_timestamp,
|
||||
util::BindPortConfig,
|
||||
Iface, Ipv4Address,
|
||||
AnyBoundSocket, Iface, Ipv4Address,
|
||||
};
|
||||
use crate::prelude::*;
|
||||
|
||||
@ -25,7 +25,8 @@ pub struct IfaceCommon {
|
||||
used_ports: RwLock<BTreeMap<u16, usize>>,
|
||||
/// The time should do next poll. We stores the total milliseconds since system boots up.
|
||||
next_poll_at_ms: AtomicU64,
|
||||
bound_sockets: RwLock<BTreeSet<KeyableWeak<AnyBoundSocket>>>,
|
||||
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner>>>,
|
||||
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner>>>,
|
||||
/// The wait queue that background polling thread will sleep on
|
||||
polling_wait_queue: WaitQueue,
|
||||
}
|
||||
@ -40,6 +41,7 @@ impl IfaceCommon {
|
||||
used_ports: RwLock::new(used_ports),
|
||||
next_poll_at_ms: AtomicU64::new(0),
|
||||
bound_sockets: RwLock::new(BTreeSet::new()),
|
||||
closing_sockets: SpinLock::new(BTreeSet::new()),
|
||||
polling_wait_queue: WaitQueue::new(),
|
||||
}
|
||||
}
|
||||
@ -109,7 +111,7 @@ impl IfaceCommon {
|
||||
iface: Arc<dyn Iface>,
|
||||
socket: Box<AnyUnboundSocket>,
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
|
||||
) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
|
||||
let port = if let Some(port) = config.port() {
|
||||
port
|
||||
} else {
|
||||
@ -135,7 +137,7 @@ impl IfaceCommon {
|
||||
),
|
||||
};
|
||||
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
|
||||
self.insert_bound_socket(&bound_socket).unwrap();
|
||||
self.insert_bound_socket(bound_socket.inner());
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
@ -184,10 +186,15 @@ impl IfaceCommon {
|
||||
|
||||
if has_events {
|
||||
self.bound_sockets.read().iter().for_each(|bound_socket| {
|
||||
if let Some(bound_socket) = bound_socket.upgrade() {
|
||||
bound_socket.on_iface_events();
|
||||
}
|
||||
bound_socket.on_iface_events();
|
||||
});
|
||||
|
||||
let closed_sockets = self
|
||||
.closing_sockets
|
||||
.lock()
|
||||
.extract_if(|closing_socket| closing_socket.is_closed())
|
||||
.collect::<Vec<_>>();
|
||||
drop(closed_sockets);
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,19 +207,35 @@ impl IfaceCommon {
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocket>) -> Result<()> {
|
||||
let weak_ref = KeyableWeak::from(Arc::downgrade(socket));
|
||||
let mut bound_sockets = self.bound_sockets.write();
|
||||
if bound_sockets.contains(&weak_ref) {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
|
||||
}
|
||||
bound_sockets.insert(weak_ref);
|
||||
Ok(())
|
||||
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let inserted = self.bound_sockets.write().insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
}
|
||||
|
||||
pub(super) fn remove_bound_socket(&self, socket: Weak<AnyBoundSocket>) {
|
||||
let weak_ref = KeyableWeak::from(socket);
|
||||
self.bound_sockets.write().remove(&weak_ref);
|
||||
pub(super) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self.bound_sockets.write().remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
}
|
||||
|
||||
pub(super) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner>) {
|
||||
let keyable_socket = KeyableArc::from(socket.clone());
|
||||
|
||||
let removed = self.bound_sockets.write().remove(&keyable_socket);
|
||||
assert!(removed);
|
||||
|
||||
let mut closing_sockets = self.closing_sockets.lock();
|
||||
|
||||
// Check `is_closed` after holding the lock to avoid race conditions.
|
||||
if keyable_socket.is_closed() {
|
||||
return;
|
||||
}
|
||||
|
||||
let inserted = closing_sockets.insert(keyable_socket);
|
||||
assert!(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user