From bbea79ec193ce9382b25846545902814d9cae74d Mon Sep 17 00:00:00 2001 From: xiaolin2004 <1553367438@qq.com> Date: Thu, 28 Nov 2024 19:20:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(socket):=20=E6=B7=BB=E5=8A=A0shutdown?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E5=B9=B6=E5=AE=9E=E7=8E=B0ShutdownTemp?= =?UTF-8?q?=E7=9A=84TryFrom=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kernel/src/net/socket/common/shutdown.rs | 21 ++++++++++++++++++--- kernel/src/net/socket/inet/stream/mod.rs | 22 ++++++++++++++++++++++ kernel/src/net/syscall.rs | 2 +- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/kernel/src/net/socket/common/shutdown.rs b/kernel/src/net/socket/common/shutdown.rs index 69c90ad2..096cb43d 100644 --- a/kernel/src/net/socket/common/shutdown.rs +++ b/kernel/src/net/socket/common/shutdown.rs @@ -1,4 +1,6 @@ -use core::sync::atomic::AtomicU8; +use core::{default, sync::atomic::AtomicU8}; + +use system_error::SystemError; bitflags! { /// @brief 用于指定socket的关闭类型 @@ -101,8 +103,8 @@ impl ShutdownTemp { self.bit == 0 } - pub fn from_how(how: usize) -> Self { - Self { bit: how as u8 + 1 } + pub fn bits(&self) -> ShutdownBit { + ShutdownBit { bits: self.bit } } } @@ -116,3 +118,16 @@ impl From for ShutdownTemp { } } } + +impl TryFrom for ShutdownTemp { + type Error = SystemError; + + fn try_from(value: usize) -> Result { + match value { + 0 | 1 | 2 => Ok(ShutdownTemp { + bit: value as u8 + 1, + }), + _ => Err(SystemError::EINVAL), + } + } +} diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index ac00eb90..4751a917 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -338,6 +338,28 @@ impl Socket for TcpSocket { .recv_buffer_size() } + + fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + let self_shutdown = self.shutdown.get().bits(); + let diff = how.bits().difference(self_shutdown); + match diff.is_empty(){ + true => { + return Ok(()) + }, + false => { + if diff.contains(ShutdownBit::SHUT_RD){ + self.shutdown.recv_shutdown(); + // TODO 协议栈处理 + } + if diff.contains(ShutdownBit::SHUT_WR){ + self.shutdown.send_shutdown(); + // TODO 协议栈处理 + } + }, + } + Ok(()) + } + fn close(&self) -> Result<(), SystemError> { let inner = self.inner .write() diff --git a/kernel/src/net/syscall.rs b/kernel/src/net/syscall.rs index bea15e20..f5d50922 100644 --- a/kernel/src/net/syscall.rs +++ b/kernel/src/net/syscall.rs @@ -367,7 +367,7 @@ impl Syscall { let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - socket.shutdown(socket::ShutdownTemp::from_how(how))?; + socket.shutdown(how.try_into()?)?; return Ok(0); }