From 37cef00bb404c9cc01509c12df57548029967dc2 Mon Sep 17 00:00:00 2001 From: Samuel Dai Date: Sat, 11 May 2024 17:17:43 +0800 Subject: [PATCH] fix(net): Fix TCP Unresponsiveness and Inability to Close Connections (#791) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(net): Improve stability. 为RawSocket与UdpSocket实现close时调用close方法,符合smoltcp的行为。为SocketInode实现drop,保证程序任何情况下退出时都能正确close对应socket, 释放被占用的端口。 * fix(net): Correct socket close behavior. --- kernel/src/net/net_core.rs | 7 +- kernel/src/net/socket/inet.rs | 180 +++++++++++++++++----------------- kernel/src/net/socket/mod.rs | 88 +++++++++-------- user/apps/http_server/main.c | 2 + 4 files changed, 143 insertions(+), 134 deletions(-) diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index eb7efe7b..b555a098 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -217,6 +217,9 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { if inner_socket.state() == smoltcp::socket::tcp::State::Established { events |= TcpSocket::CAN_CONNECT; } + if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait { + events |= EPollEventType::EPOLLHUP.bits() as u64; + } handle_guard .get(&global_handle) .unwrap() @@ -226,13 +229,11 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { smoltcp::socket::Socket::Dhcpv4(_) => {} smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"), } - drop(handle_guard); - let mut handle_guard = HANDLE_MAP.write_irqsave(); - let handle_item = handle_guard.get_mut(&global_handle).unwrap(); EventPoll::wakeup_epoll( &handle_item.epitems, EPollEventType::from_bits_truncate(events as u32), )?; + drop(handle_guard); // crate::kdebug!( // "{} send_event {:?}", // handle, diff --git a/kernel/src/net/socket/inet.rs b/kernel/src/net/socket/inet.rs index 6c4580a6..09ceea96 100644 --- a/kernel/src/net/socket/inet.rs +++ b/kernel/src/net/socket/inet.rs @@ -1,12 +1,11 @@ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use smoltcp::{ - socket::{raw, tcp, udp, AnySocket}, + socket::{raw, tcp, udp}, wire, }; use system_error::SystemError; use crate::{ - arch::rand::rand, driver::net::NetDevice, kerror, kwarn, libs::rwlock::RwLock, @@ -88,7 +87,11 @@ impl RawSocket { impl Socket for RawSocket { fn close(&mut self) { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? + if let smoltcp::socket::Socket::Udp(mut sock) = + socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) + { + sock.close(); + } drop(socket_set_guard); poll_ifaces(); } @@ -289,7 +292,7 @@ impl UdpSocket { ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; } // 检测端口是否已被占用 - PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?; + PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; let bind_res = if ip.addr.is_unspecified() { socket.bind(ip.port) @@ -310,7 +313,11 @@ impl UdpSocket { impl Socket for UdpSocket { fn close(&mut self) { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? + if let smoltcp::socket::Socket::Udp(mut sock) = + socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) + { + sock.close(); + } drop(socket_set_guard); poll_ifaces(); } @@ -559,11 +566,20 @@ impl TcpSocket { impl Socket for TcpSocket { fn close(&mut self) { for handle in self.handles.iter() { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? - drop(socket_set_guard); + { + let mut socket_set_guard = SOCKET_SET.lock_irqsave(); + let smoltcp_handle = handle.smoltcp_handle().unwrap(); + socket_set_guard + .get_mut::(smoltcp_handle) + .close(); + drop(socket_set_guard); + } + poll_ifaces(); + SOCKET_SET + .lock_irqsave() + .remove(handle.smoltcp_handle().unwrap()); + // kdebug!("[Socket] [TCP] Close: {:?}", handle); } - poll_ifaces(); } fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { @@ -627,7 +643,7 @@ impl Socket for TcpSocket { drop(socket_set_guard); SocketHandleItem::sleep( self.socket_handle(), - EPollEventType::EPOLLIN.bits() as u64, + (EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64, HANDLE_MAP.read_irqsave(), ); } @@ -697,7 +713,7 @@ impl Socket for TcpSocket { if let Endpoint::Ip(Some(ip)) = endpoint { let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; // 检测端口是否被占用 - PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.clone())?; + PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?; // kdebug!("temp_port: {}", temp_port); let iface: Arc = NET_DEVICES.write_irqsave().get(&0).unwrap().clone(); @@ -750,7 +766,7 @@ impl Socket for TcpSocket { /// @brief tcp socket 监听 local_endpoint 端口 /// - /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效 + /// @param backlog 未处理的连接队列的最大长度 fn listen(&mut self, backlog: usize) -> Result<(), SystemError> { if self.is_listening { return Ok(()); @@ -763,12 +779,14 @@ impl Socket for TcpSocket { let backlog = handlen.max(backlog); // 添加剩余需要构建的socket - // kdebug!("tcp socket:before listen, socket'len={}",self.handle.len()); + // kdebug!("tcp socket:before listen, socket'len={}", self.handle_list.len()); let mut handle_guard = HANDLE_MAP.write_irqsave(); + let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue); + self.handles.extend((handlen..backlog).map(|_| { let socket = Self::create_new_socket(); let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); - let handle_item = SocketHandleItem::new(); + let handle_item = SocketHandleItem::new(Some(wait_queue.clone())); handle_guard.insert(handle, handle_item); handle })); @@ -797,7 +815,7 @@ impl Socket for TcpSocket { } // 检测端口是否已被占用 - PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?; + PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; // kdebug!("tcp socket:bind, socket'len={}",self.handle.len()); self.local_endpoint = Some(ip); @@ -818,100 +836,78 @@ impl Socket for TcpSocket { } fn accept(&mut self) -> Result<(Box, Endpoint), SystemError> { + if !self.is_listening { + return Err(SystemError::EINVAL); + } let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; loop { // kdebug!("tcp accept: poll_ifaces()"); poll_ifaces(); - // kdebug!("tcp socket:accept, socket'len={}",self.handle.len()); + // kdebug!("tcp socket:accept, socket'len={}", self.handle_list.len()); - let mut sockets = SOCKET_SET.lock_irqsave(); + let mut sockset = SOCKET_SET.lock_irqsave(); + // Get the corresponding activated handler + let global_handle_index = self.handles.iter().position(|handle| { + let con_smol_sock = sockset.get::(handle.smoltcp_handle().unwrap()); + con_smol_sock.is_active() + }); - // 随机获取访问的socket的handle - let index: usize = rand() % self.handles.len(); - let handle = self.handles.get(index).unwrap(); + if let Some(handle_index) = global_handle_index { + let con_smol_sock = sockset + .get::(self.handles[handle_index].smoltcp_handle().unwrap()); - let socket = sockets - .iter_mut() - .find(|y| { - tcp::Socket::downcast(y.1) - .map(|y| y.is_active()) - .unwrap_or(false) - }) - .map(|y| tcp::Socket::downcast_mut(y.1).unwrap()); - if let Some(socket) = socket { - if socket.is_active() { - // kdebug!("tcp accept: socket.is_active()"); - let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; + // kdebug!("[Socket] [TCP] Accept: {:?}", handle); + // handle is connected socket's handle + let remote_ep = con_smol_sock + .remote_endpoint() + .ok_or(SystemError::ENOTCONN)?; - let new_socket = { - // The new TCP socket used for sending and receiving data. - let mut tcp_socket = Self::create_new_socket(); - self.do_listen(&mut tcp_socket, endpoint) - .expect("do_listen failed"); + let mut tcp_socket = Self::create_new_socket(); + self.do_listen(&mut tcp_socket, endpoint)?; - // tcp_socket.listen(endpoint).unwrap(); + let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket)); - // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 - // 因此需要再为当前的socket分配一个新的handle - let new_handle = - GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket)); - let old_handle = ::core::mem::replace( - &mut *self.handles.get_mut(index).unwrap(), - new_handle, - ); + // let handle in TcpSock be the new empty handle, and return the old connected handle + let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle); - let metadata = SocketMetadata::new( - SocketType::Tcp, - Self::DEFAULT_TX_BUF_SIZE, - Self::DEFAULT_RX_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - self.metadata.options, - ); + let metadata = SocketMetadata::new( + SocketType::Tcp, + Self::DEFAULT_TX_BUF_SIZE, + Self::DEFAULT_RX_BUF_SIZE, + Self::DEFAULT_METADATA_BUF_SIZE, + self.metadata.options, + ); - let new_socket = Box::new(TcpSocket { - handles: vec![old_handle], - local_endpoint: self.local_endpoint, - is_listening: false, - metadata, - }); - // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len()); + let sock_ret = Box::new(TcpSocket { + handles: vec![old_handle], + local_endpoint: self.local_endpoint, + is_listening: false, + metadata, + }); - // 更新端口与 socket 的绑定 - if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() { - PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?; - PORT_MANAGER.bind_port( - self.metadata.socket_type, - ip.port, - *new_socket.clone(), - )?; - } + { + let mut handle_guard = HANDLE_MAP.write_irqsave(); + // 先删除原来的 + let item = handle_guard.remove(&old_handle).unwrap(); - // 更新handle表 - let mut handle_guard = HANDLE_MAP.write_irqsave(); - // 先删除原来的 - - let item = handle_guard.remove(&old_handle).unwrap(); - - // 按照smoltcp行为,将新的handle绑定到原来的item - handle_guard.insert(new_handle, item); - let new_item = SocketHandleItem::new(); - - // 插入新的item - handle_guard.insert(old_handle, new_item); - - new_socket - }; - // kdebug!("tcp accept: new socket: {:?}", new_socket); - drop(sockets); - poll_ifaces(); - - return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); + // 按照smoltcp行为,将新的handle绑定到原来的item + let new_item = SocketHandleItem::new(None); + handle_guard.insert(old_handle, new_item); + // 插入新的item + handle_guard.insert(new_handle, item); + drop(handle_guard); } + return Ok((sock_ret, Endpoint::Ip(Some(remote_ep)))); } - // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); - drop(sockets); - SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave()); + drop(sockset); + + // kdebug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap()); + SocketHandleItem::sleep( + self.socket_handle(), // NOTICE + Self::CAN_ACCPET, + HANDLE_MAP.read_irqsave(), + ); // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); } } diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index 70fb5414..c9864273 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -25,6 +25,7 @@ use crate::{ spinlock::{SpinLock, SpinLockGuard}, wait_queue::EventWaitQueue, }, + process::{Pid, ProcessManager}, sched::{schedule, SchedMode}, }; @@ -85,7 +86,7 @@ pub(super) fn new_socket( } }; - let handle_item = SocketHandleItem::new(); + let handle_item = SocketHandleItem::new(None); HANDLE_MAP .write_irqsave() .insert(socket.socket_handle(), handle_item); @@ -303,6 +304,40 @@ impl SocketInode { pub unsafe fn inner_no_preempt(&self) -> SpinLockGuard> { self.0.lock_no_preempt() } + + fn do_close(&self) -> Result<(), SystemError> { + let prev_ref_count = self.1.fetch_sub(1, core::sync::atomic::Ordering::SeqCst); + if prev_ref_count == 1 { + // 最后一次关闭,需要释放 + let mut socket = self.0.lock_irqsave(); + + if socket.metadata().socket_type == SocketType::Unix { + return Ok(()); + } + + if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() { + PORT_MANAGER.unbind_port(socket.metadata().socket_type, ip.port); + } + + socket.clear_epoll()?; + + HANDLE_MAP + .write_irqsave() + .remove(&socket.socket_handle()) + .unwrap(); + socket.close(); + } + + Ok(()) + } +} + +impl Drop for SocketInode { + fn drop(&mut self) { + for _ in 0..self.1.load(core::sync::atomic::Ordering::SeqCst) { + let _ = self.do_close(); + } + } } impl IndexNode for SocketInode { @@ -316,29 +351,7 @@ impl IndexNode for SocketInode { } fn close(&self, _data: SpinLockGuard) -> Result<(), SystemError> { - let prev_ref_count = self.1.fetch_sub(1, core::sync::atomic::Ordering::SeqCst); - if prev_ref_count == 1 { - // 最后一次关闭,需要释放 - let mut socket = self.0.lock_irqsave(); - - if socket.metadata().socket_type == SocketType::Unix { - return Ok(()); - } - - if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() { - PORT_MANAGER.unbind_port(socket.metadata().socket_type, ip.port)?; - } - - socket.clear_epoll()?; - - HANDLE_MAP - .write_irqsave() - .remove(&socket.socket_handle()) - .unwrap(); - socket.close(); - } - - Ok(()) + self.do_close() } fn read_at( @@ -400,16 +413,16 @@ pub struct SocketHandleItem { /// shutdown状态 pub shutdown_type: RwLock, /// socket的waitqueue - pub wait_queue: EventWaitQueue, + pub wait_queue: Arc, /// epitems,考虑写在这是否是最优解? pub epitems: SpinLock>>, } impl SocketHandleItem { - pub fn new() -> Self { + pub fn new(wait_queue: Option>) -> Self { Self { shutdown_type: RwLock::new(ShutdownType::empty()), - wait_queue: EventWaitQueue::new(), + wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())), epitems: SpinLock::new(LinkedList::new()), } } @@ -463,9 +476,9 @@ impl SocketHandleItem { /// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。 pub struct PortManager { // TCP 端口记录表 - tcp_port_table: SpinLock>>, + tcp_port_table: SpinLock>, // UDP 端口记录表 - udp_port_table: SpinLock>>, + udp_port_table: SpinLock>, } impl PortManager { @@ -517,12 +530,7 @@ impl PortManager { /// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录 /// /// TODO: 增加支持端口复用的逻辑 - pub fn bind_port( - &self, - socket_type: SocketType, - port: u16, - socket: impl Socket, - ) -> Result<(), SystemError> { + pub fn bind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> { if port > 0 { let mut listen_table_guard = match socket_type { SocketType::Udp => self.udp_port_table.lock(), @@ -531,7 +539,7 @@ impl PortManager { }; match listen_table_guard.get(&port) { Some(_) => return Err(SystemError::EADDRINUSE), - None => listen_table_guard.insert(port, Arc::new(socket)), + None => listen_table_guard.insert(port, ProcessManager::current_pid()), }; drop(listen_table_guard); } @@ -539,15 +547,17 @@ impl PortManager { } /// @brief 在对应的端口记录表中将端口和 socket 解绑 - pub fn unbind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> { + /// should call this function when socket is closed or aborted + pub fn unbind_port(&self, socket_type: SocketType, port: u16) { let mut listen_table_guard = match socket_type { SocketType::Udp => self.udp_port_table.lock(), SocketType::Tcp => self.tcp_port_table.lock(), - _ => return Ok(()), + _ => { + return; + } }; listen_table_guard.remove(&port); drop(listen_table_guard); - return Ok(()); } } diff --git a/user/apps/http_server/main.c b/user/apps/http_server/main.c index 95dbb2d1..76fb3354 100644 --- a/user/apps/http_server/main.c +++ b/user/apps/http_server/main.c @@ -233,6 +233,8 @@ int main(int argc, char const *argv[]) // 关闭客户端连接 close(new_socket); } + // 关闭tcp socket + close(server_fd); return 0; } \ No newline at end of file