diff --git a/kernel/aster-nix/src/net/iface/any_socket.rs b/kernel/aster-nix/src/net/iface/any_socket.rs index 912c7e99e..a95f0c04e 100644 --- a/kernel/aster-nix/src/net/iface/any_socket.rs +++ b/kernel/aster-nix/src/net/iface/any_socket.rs @@ -59,14 +59,7 @@ impl AnyUnboundSocket { } } -pub struct AnyBoundSocket { - iface: Arc, - handle: smoltcp::iface::SocketHandle, - port: u16, - socket_family: SocketFamily, - observer: RwLock>>, - weak_self: Weak, -} +pub struct AnyBoundSocket(Arc); impl AnyBoundSocket { pub(super) fn new( @@ -75,21 +68,18 @@ impl AnyBoundSocket { port: u16, socket_family: SocketFamily, observer: Weak>, - ) -> Arc { - Arc::new_cyclic(|weak_self| Self { + ) -> Self { + Self(Arc::new(AnyBoundSocketInner { iface, handle, port, socket_family, observer: RwLock::new(observer), - weak_self: weak_self.clone(), - }) + })) } - pub(super) fn on_iface_events(&self) { - if let Some(observer) = Weak::upgrade(&*self.observer.read()) { - observer.on_events(&()) - } + pub(super) fn inner(&self) -> &Arc { + &self.0 } /// Set the observer whose `on_events` will be called when certain iface events happen. After @@ -99,17 +89,101 @@ impl AnyBoundSocket { /// that the old observer will never be called after the setting. Users should be aware of this /// and proactively handle the race conditions if necessary. pub fn set_observer(&self, handler: Weak>) { - *self.observer.write() = handler; + *self.0.observer.write() = handler; - self.on_iface_events(); + self.0.on_iface_events(); } pub fn local_endpoint(&self) -> Option { let ip_addr = { - let ipv4_addr = self.iface.ipv4_addr()?; + let ipv4_addr = self.0.iface.ipv4_addr()?; IpAddress::Ipv4(ipv4_addr) }; - Some(IpEndpoint::new(ip_addr, self.port)) + Some(IpEndpoint::new(ip_addr, self.0.port)) + } + + pub fn raw_with, R, F: FnMut(&mut T) -> R>( + &self, + f: F, + ) -> R { + self.0.raw_with(f) + } + + /// Try to connect to a remote endpoint. Tcp socket only. + pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<()> { + let mut sockets = self.0.iface.sockets(); + let socket = sockets.get_mut::(self.0.handle); + let port = self.0.port; + let mut iface_inner = self.0.iface.iface_inner(); + let cx = iface_inner.context(); + socket + .connect(cx, remote_endpoint, port) + .map_err(|_| Error::with_message(Errno::ENOBUFS, "send connection request failed"))?; + Ok(()) + } + + pub fn iface(&self) -> &Arc { + &self.0.iface + } +} + +impl Drop for AnyBoundSocket { + fn drop(&mut self) { + if self.0.start_closing() { + self.0.iface.common().remove_bound_socket_now(&self.0); + } else { + self.0 + .iface + .common() + .remove_bound_socket_when_closed(&self.0); + } + } +} + +pub(super) struct AnyBoundSocketInner { + iface: Arc, + handle: smoltcp::iface::SocketHandle, + port: u16, + socket_family: SocketFamily, + observer: RwLock>>, +} + +impl AnyBoundSocketInner { + pub(super) fn on_iface_events(&self) { + if let Some(observer) = Weak::upgrade(&*self.observer.read()) { + observer.on_events(&()) + } + } + + pub(super) fn is_closed(&self) -> bool { + match self.socket_family { + SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| { + socket.state() == smoltcp::socket::tcp::State::Closed + }), + SocketFamily::Udp => true, + } + } + + /// Starts closing the socket and returns whether the socket is closed. + /// + /// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets, + /// this method will always return `true`. + /// + /// For other sockets, such as TCP connected sockets, they cannot be closed immediately because + /// we at least need to send the FIN packet and wait for the remote end to send an ACK packet. + /// In this case, this method will return `false` and [`Self::is_closed`] can be used to + /// determine if the closing process is complete. + fn start_closing(&self) -> bool { + match self.socket_family { + SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| { + socket.close(); + socket.state() == smoltcp::socket::tcp::State::Closed + }), + SocketFamily::Udp => { + self.raw_with(|socket: &mut RawUdpSocket| socket.close()); + true + } + } } pub fn raw_with, R, F: FnMut(&mut T) -> R>( @@ -120,43 +194,13 @@ impl AnyBoundSocket { let socket = sockets.get_mut::(self.handle); f(socket) } - - /// Try to connect to a remote endpoint. Tcp socket only. - pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<()> { - let mut sockets = self.iface.sockets(); - let socket = sockets.get_mut::(self.handle); - let port = self.port; - let mut iface_inner = self.iface.iface_inner(); - let cx = iface_inner.context(); - socket - .connect(cx, remote_endpoint, port) - .map_err(|_| Error::with_message(Errno::ENOBUFS, "send connection request failed"))?; - Ok(()) - } - - pub fn iface(&self) -> &Arc { - &self.iface - } - - pub(super) fn weak_ref(&self) -> Weak { - self.weak_self.clone() - } - - fn close(&self) { - match self.socket_family { - SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| socket.close()), - SocketFamily::Udp => self.raw_with(|socket: &mut RawUdpSocket| socket.close()), - } - } } -impl Drop for AnyBoundSocket { +impl Drop for AnyBoundSocketInner { fn drop(&mut self) { - self.close(); - self.iface.poll(); - self.iface.common().remove_socket(self.handle); - self.iface.common().release_port(self.port); - self.iface.common().remove_bound_socket(self.weak_ref()); + let iface_common = self.iface.common(); + iface_common.remove_socket(self.handle); + iface_common.release_port(self.port); } } diff --git a/kernel/aster-nix/src/net/iface/common.rs b/kernel/aster-nix/src/net/iface/common.rs index ccfbbb2f1..d058efa0c 100644 --- a/kernel/aster-nix/src/net/iface/common.rs +++ b/kernel/aster-nix/src/net/iface/common.rs @@ -3,7 +3,7 @@ use alloc::collections::btree_map::Entry; use core::sync::atomic::{AtomicU64, Ordering}; -use keyable_arc::KeyableWeak; +use keyable_arc::KeyableArc; use ostd::sync::WaitQueue; use smoltcp::{ iface::{SocketHandle, SocketSet}, @@ -12,10 +12,10 @@ use smoltcp::{ }; use super::{ - any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily}, + any_socket::{AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily}, time::get_network_timestamp, util::BindPortConfig, - Iface, Ipv4Address, + AnyBoundSocket, Iface, Ipv4Address, }; use crate::prelude::*; @@ -25,7 +25,8 @@ pub struct IfaceCommon { used_ports: RwLock>, /// The time should do next poll. We stores the total milliseconds since system boots up. next_poll_at_ms: AtomicU64, - bound_sockets: RwLock>>, + bound_sockets: RwLock>>, + closing_sockets: SpinLock>>, /// The wait queue that background polling thread will sleep on polling_wait_queue: WaitQueue, } @@ -40,6 +41,7 @@ impl IfaceCommon { used_ports: RwLock::new(used_ports), next_poll_at_ms: AtomicU64::new(0), bound_sockets: RwLock::new(BTreeSet::new()), + closing_sockets: SpinLock::new(BTreeSet::new()), polling_wait_queue: WaitQueue::new(), } } @@ -109,7 +111,7 @@ impl IfaceCommon { iface: Arc, socket: Box, config: BindPortConfig, - ) -> core::result::Result, (Error, Box)> { + ) -> core::result::Result)> { let port = if let Some(port) = config.port() { port } else { @@ -135,7 +137,7 @@ impl IfaceCommon { ), }; let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer); - self.insert_bound_socket(&bound_socket).unwrap(); + self.insert_bound_socket(bound_socket.inner()); Ok(bound_socket) } @@ -184,10 +186,15 @@ impl IfaceCommon { if has_events { self.bound_sockets.read().iter().for_each(|bound_socket| { - if let Some(bound_socket) = bound_socket.upgrade() { - bound_socket.on_iface_events(); - } + bound_socket.on_iface_events(); }); + + let closed_sockets = self + .closing_sockets + .lock() + .extract_if(|closing_socket| closing_socket.is_closed()) + .collect::>(); + drop(closed_sockets); } } @@ -200,19 +207,35 @@ impl IfaceCommon { } } - fn insert_bound_socket(&self, socket: &Arc) -> Result<()> { - let weak_ref = KeyableWeak::from(Arc::downgrade(socket)); - let mut bound_sockets = self.bound_sockets.write(); - if bound_sockets.contains(&weak_ref) { - return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); - } - bound_sockets.insert(weak_ref); - Ok(()) + fn insert_bound_socket(&self, socket: &Arc) { + let keyable_socket = KeyableArc::from(socket.clone()); + + let inserted = self.bound_sockets.write().insert(keyable_socket); + assert!(inserted); } - pub(super) fn remove_bound_socket(&self, socket: Weak) { - let weak_ref = KeyableWeak::from(socket); - self.bound_sockets.write().remove(&weak_ref); + pub(super) fn remove_bound_socket_now(&self, socket: &Arc) { + let keyable_socket = KeyableArc::from(socket.clone()); + + let removed = self.bound_sockets.write().remove(&keyable_socket); + assert!(removed); + } + + pub(super) fn remove_bound_socket_when_closed(&self, socket: &Arc) { + let keyable_socket = KeyableArc::from(socket.clone()); + + let removed = self.bound_sockets.write().remove(&keyable_socket); + assert!(removed); + + let mut closing_sockets = self.closing_sockets.lock(); + + // Check `is_closed` after holding the lock to avoid race conditions. + if keyable_socket.is_closed() { + return; + } + + let inserted = closing_sockets.insert(keyable_socket); + assert!(inserted); } } diff --git a/kernel/aster-nix/src/net/iface/mod.rs b/kernel/aster-nix/src/net/iface/mod.rs index d90b746d9..47149a0e3 100644 --- a/kernel/aster-nix/src/net/iface/mod.rs +++ b/kernel/aster-nix/src/net/iface/mod.rs @@ -45,7 +45,7 @@ pub trait Iface: internal::IfaceInternal + Send + Sync { &self, socket: Box, config: BindPortConfig, - ) -> core::result::Result, (Error, Box)> { + ) -> core::result::Result)> { let common = self.common(); common.bind_socket(self.arc_self(), socket, config) } diff --git a/kernel/aster-nix/src/net/socket/ip/common.rs b/kernel/aster-nix/src/net/socket/ip/common.rs index 5edfa2ebe..785bc5a03 100644 --- a/kernel/aster-nix/src/net/socket/ip/common.rs +++ b/kernel/aster-nix/src/net/socket/ip/common.rs @@ -46,7 +46,7 @@ pub(super) fn bind_socket( unbound_socket: Box, endpoint: &IpEndpoint, can_reuse: bool, -) -> core::result::Result, (Error, Box)> { +) -> core::result::Result)> { let iface = match get_iface_to_bind(&endpoint.addr) { Some(iface) => iface, None => { diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs b/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs index 4fc9f4bfd..e8869bcb2 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs @@ -13,12 +13,12 @@ use crate::{ }; pub struct BoundDatagram { - bound_socket: Arc, + bound_socket: AnyBoundSocket, remote_endpoint: Option, } impl BoundDatagram { - pub fn new(bound_socket: Arc) -> Self { + pub fn new(bound_socket: AnyBoundSocket) -> Self { Self { bound_socket, remote_endpoint: None, diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs index afb4527c4..6ac2b9122 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs @@ -15,7 +15,7 @@ use crate::{ }; pub struct ConnectedStream { - bound_socket: Arc, + bound_socket: AnyBoundSocket, remote_endpoint: IpEndpoint, /// Indicates whether this connection is "new" in a `connect()` system call. /// @@ -32,7 +32,7 @@ pub struct ConnectedStream { impl ConnectedStream { pub fn new( - bound_socket: Arc, + bound_socket: AnyBoundSocket, remote_endpoint: IpEndpoint, is_new_connection: bool, ) -> Self { diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs index 3503a9afe..2aca5c52b 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs @@ -8,7 +8,7 @@ use crate::{ }; pub struct ConnectingStream { - bound_socket: Arc, + bound_socket: AnyBoundSocket, remote_endpoint: IpEndpoint, conn_result: RwLock>, } @@ -26,9 +26,9 @@ pub enum NonConnectedStream { impl ConnectingStream { pub fn new( - bound_socket: Arc, + bound_socket: AnyBoundSocket, remote_endpoint: IpEndpoint, - ) -> core::result::Result)> { + ) -> core::result::Result { if let Err(err) = bound_socket.do_connect(remote_endpoint) { return Err((err, bound_socket)); } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/init.rs b/kernel/aster-nix/src/net/socket/ip/stream/init.rs index 8f103a404..81c1df518 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/init.rs @@ -15,7 +15,7 @@ use crate::{ pub enum InitStream { Unbound(Box), - Bound(Arc), + Bound(AnyBoundSocket), } impl InitStream { @@ -23,14 +23,14 @@ impl InitStream { InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer))) } - pub fn new_bound(bound_socket: Arc) -> Self { + pub fn new_bound(bound_socket: AnyBoundSocket) -> Self { InitStream::Bound(bound_socket) } pub fn bind( self, endpoint: &IpEndpoint, - ) -> core::result::Result, (Error, Self)> { + ) -> core::result::Result { let unbound_socket = match self { InitStream::Unbound(unbound_socket) => unbound_socket, InitStream::Bound(bound_socket) => { @@ -50,7 +50,7 @@ impl InitStream { fn bind_to_ephemeral_endpoint( self, remote_endpoint: &IpEndpoint, - ) -> core::result::Result, (Error, Self)> { + ) -> core::result::Result { let endpoint = get_ephemeral_endpoint(remote_endpoint); self.bind(&endpoint) } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs index ef63ce156..d7bdc593c 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs @@ -13,16 +13,16 @@ use crate::{ pub struct ListenStream { backlog: usize, /// A bound socket held to ensure the TCP port cannot be released - bound_socket: Arc, + bound_socket: AnyBoundSocket, /// Backlog sockets listening at the local endpoint backlog_sockets: RwLock>, } impl ListenStream { pub fn new( - bound_socket: Arc, + bound_socket: AnyBoundSocket, backlog: usize, - ) -> core::result::Result)> { + ) -> core::result::Result { let listen_stream = Self { backlog, bound_socket, @@ -99,13 +99,13 @@ impl ListenStream { } struct BacklogSocket { - bound_socket: Arc, + bound_socket: AnyBoundSocket, } impl BacklogSocket { // FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason // why the error may occur. Perhaps it is better to call `unwrap()` directly? - fn new(bound_socket: &Arc) -> Result { + fn new(bound_socket: &AnyBoundSocket) -> Result { let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message( Errno::EINVAL, "the socket is not bound", @@ -143,7 +143,7 @@ impl BacklogSocket { .raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint()) } - fn into_bound_socket(self) -> Arc { + fn into_bound_socket(self) -> AnyBoundSocket { self.bound_socket } } diff --git a/test/apps/network/listen_backlog.c b/test/apps/network/listen_backlog.c index 949189165..301a90bde 100644 --- a/test/apps/network/listen_backlog.c +++ b/test/apps/network/listen_backlog.c @@ -131,7 +131,7 @@ int main(void) for (backlog = 0; backlog <= MAX_TEST_BACKLOG; ++backlog) { // Avoid "bind: Address already in use" - addr.sin_port = htons(8080 + backlog); + addr.sin_port = htons(10000 + backlog); err = test_listen_backlog(&addr, backlog); if (err != 0) diff --git a/test/apps/network/send_buf_full.c b/test/apps/network/send_buf_full.c index d4b8671e9..6889fe1a0 100644 --- a/test/apps/network/send_buf_full.c +++ b/test/apps/network/send_buf_full.c @@ -265,7 +265,7 @@ int main(void) struct sockaddr_in addr; addr.sin_family = AF_INET; - addr.sin_port = htons(8080); + addr.sin_port = htons(9999); if (inet_aton("127.0.0.1", &addr.sin_addr) < 0) { fprintf(stderr, "inet_aton cannot parse 127.0.0.1\n"); return -1;