socket统一改用GlobalSocketHandle,并且修复fcntl SETFD的错误 (#730)

* socket统一改用`GlobalSocketHandle`,并且修复fcntl SETFD的错误

---------

Co-authored-by: longjin <longjin@DragonOS.org>
This commit is contained in:
GnoCiYeH 2024-04-15 22:01:32 +08:00 committed by GitHub
parent 7162a8358d
commit d623e90231
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 247 additions and 148 deletions

View File

@ -88,7 +88,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
mfence(); mfence();
let pid = ProcessManager::current_pcb().pid(); let pid = ProcessManager::current_pcb().pid();
let show = false; let show = false;
// let show = if syscall_num != SYS_SCHED && pid.data() > 3 { // let show = if syscall_num != SYS_SCHED && pid.data() >= 7 {
// true // true
// } else { // } else {
// false // false

View File

@ -1024,6 +1024,15 @@ impl Syscall {
oldfd: i32, oldfd: i32,
newfd: i32, newfd: i32,
fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>, fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>,
) -> Result<usize, SystemError> {
Self::do_dup3(oldfd, newfd, FileMode::empty(), fd_table_guard)
}
fn do_dup3(
oldfd: i32,
newfd: i32,
flags: FileMode,
fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
// 确认oldfd, newid是否有效 // 确认oldfd, newid是否有效
if !(FileDescriptorVec::validate_fd(oldfd) && FileDescriptorVec::validate_fd(newfd)) { if !(FileDescriptorVec::validate_fd(oldfd) && FileDescriptorVec::validate_fd(newfd)) {
@ -1047,8 +1056,12 @@ impl Syscall {
.get_file_by_fd(oldfd) .get_file_by_fd(oldfd)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let new_file = old_file.try_clone().ok_or(SystemError::EBADF)?; let new_file = old_file.try_clone().ok_or(SystemError::EBADF)?;
// dup2默认非cloexec
new_file.set_close_on_exec(false); if flags.contains(FileMode::O_CLOEXEC) {
new_file.set_close_on_exec(true);
} else {
new_file.set_close_on_exec(false);
}
// 申请文件描述符,并把文件对象存入其中 // 申请文件描述符,并把文件对象存入其中
let res = fd_table_guard let res = fd_table_guard
.alloc_fd(new_file, Some(newfd)) .alloc_fd(new_file, Some(newfd))
@ -1064,8 +1077,9 @@ impl Syscall {
/// - `cmd`:命令 /// - `cmd`:命令
/// - `arg`:参数 /// - `arg`:参数
pub fn fcntl(fd: i32, cmd: FcntlCommand, arg: i32) -> Result<usize, SystemError> { pub fn fcntl(fd: i32, cmd: FcntlCommand, arg: i32) -> Result<usize, SystemError> {
// kdebug!("fcntl ({cmd:?}) fd: {fd}, arg={arg}");
match cmd { match cmd {
FcntlCommand::DupFd => { FcntlCommand::DupFd | FcntlCommand::DupFdCloexec => {
if arg < 0 || arg as usize >= FileDescriptorVec::PROCESS_MAX_FD { if arg < 0 || arg as usize >= FileDescriptorVec::PROCESS_MAX_FD {
return Err(SystemError::EBADF); return Err(SystemError::EBADF);
} }
@ -1074,7 +1088,16 @@ impl Syscall {
let binding = ProcessManager::current_pcb().fd_table(); let binding = ProcessManager::current_pcb().fd_table();
let mut fd_table_guard = binding.write(); let mut fd_table_guard = binding.write();
if fd_table_guard.get_file_by_fd(i as i32).is_none() { if fd_table_guard.get_file_by_fd(i as i32).is_none() {
return Self::do_dup2(fd, i as i32, &mut fd_table_guard); if cmd == FcntlCommand::DupFd {
return Self::do_dup2(fd, i as i32, &mut fd_table_guard);
} else {
return Self::do_dup3(
fd,
i as i32,
FileMode::O_CLOEXEC,
&mut fd_table_guard,
);
}
} }
} }
return Err(SystemError::EMFILE); return Err(SystemError::EMFILE);
@ -1083,12 +1106,15 @@ impl Syscall {
// Get file descriptor flags. // Get file descriptor flags.
let binding = ProcessManager::current_pcb().fd_table(); let binding = ProcessManager::current_pcb().fd_table();
let fd_table_guard = binding.read(); let fd_table_guard = binding.read();
if let Some(file) = fd_table_guard.get_file_by_fd(fd) { if let Some(file) = fd_table_guard.get_file_by_fd(fd) {
// drop guard 以避免无法调度的问题 // drop guard 以避免无法调度的问题
drop(fd_table_guard); drop(fd_table_guard);
if file.close_on_exec() { if file.close_on_exec() {
return Ok(FD_CLOEXEC as usize); return Ok(FD_CLOEXEC as usize);
} else {
return Ok(0);
} }
} }
return Err(SystemError::EBADF); return Err(SystemError::EBADF);
@ -1145,8 +1171,8 @@ impl Syscall {
// TODO: unimplemented // TODO: unimplemented
// 未实现的命令返回0不报错。 // 未实现的命令返回0不报错。
// kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd); kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd);
return Ok(0); return Err(SystemError::ENOSYS);
} }
} }
} }

View File

@ -12,7 +12,7 @@ use crate::{
use super::{ use super::{
event_poll::{EPollEventType, EventPoll}, event_poll::{EPollEventType, EventPoll},
socket::{inet::TcpSocket, HANDLE_MAP, SOCKET_SET}, socket::{handle::GlobalSocketHandle, inet::TcpSocket, HANDLE_MAP, SOCKET_SET},
}; };
/// The network poll function, which will be called by timer. /// The network poll function, which will be called by timer.
@ -188,7 +188,8 @@ pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> {
fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
for (handle, socket_type) in sockets.iter() { for (handle, socket_type) in sockets.iter() {
let handle_guard = HANDLE_MAP.read_irqsave(); let handle_guard = HANDLE_MAP.read_irqsave();
let item = handle_guard.get(&handle); let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
let item = handle_guard.get(&global_handle);
if item.is_none() { if item.is_none() {
continue; continue;
} }
@ -203,7 +204,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
match socket_type { match socket_type {
smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => { smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
handle_guard handle_guard
.get(&handle) .get(&global_handle)
.unwrap() .unwrap()
.wait_queue .wait_queue
.wakeup_any(events); .wakeup_any(events);
@ -217,7 +218,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
events |= TcpSocket::CAN_CONNECT; events |= TcpSocket::CAN_CONNECT;
} }
handle_guard handle_guard
.get(&handle) .get(&global_handle)
.unwrap() .unwrap()
.wait_queue .wait_queue
.wakeup_any(events); .wakeup_any(events);
@ -227,7 +228,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
} }
drop(handle_guard); drop(handle_guard);
let mut handle_guard = HANDLE_MAP.write_irqsave(); let mut handle_guard = HANDLE_MAP.write_irqsave();
let handle_item = handle_guard.get_mut(&handle).unwrap(); let handle_item = handle_guard.get_mut(&global_handle).unwrap();
EventPoll::wakeup_epoll( EventPoll::wakeup_epoll(
&handle_item.epitems, &handle_item.epitems,
EPollEventType::from_bits_truncate(events as u32), EPollEventType::from_bits_truncate(events as u32),

View File

@ -0,0 +1,39 @@
use ida::IdAllocator;
use smoltcp::iface::SocketHandle;
int_like!(KernelHandle, usize);
/// # socket的句柄管理组件
/// 它在smoltcp的SocketHandle上封装了一层增加更多的功能。
/// 比如在socket被关闭时自动释放socket的资源通知系统的其他组件。
#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
pub enum GlobalSocketHandle {
Smoltcp(SocketHandle),
Kernel(KernelHandle),
}
static KERNEL_HANDLE_IDA: IdAllocator = IdAllocator::new(0, usize::MAX);
impl GlobalSocketHandle {
pub fn new_smoltcp_handle(handle: SocketHandle) -> Self {
return Self::Smoltcp(handle);
}
pub fn new_kernel_handle() -> Self {
return Self::Kernel(KernelHandle::new(KERNEL_HANDLE_IDA.alloc().unwrap()));
}
pub fn smoltcp_handle(&self) -> Option<SocketHandle> {
if let Self::Smoltcp(sh) = *self {
return Some(sh);
}
None
}
pub fn kernel_handle(&self) -> Option<KernelHandle> {
if let Self::Kernel(kh) = *self {
return Some(kh);
}
None
}
}

View File

@ -1,7 +1,6 @@
use alloc::{boxed::Box, sync::Arc, vec::Vec}; use alloc::{boxed::Box, sync::Arc, vec::Vec};
use smoltcp::{ use smoltcp::{
iface::SocketHandle, socket::{raw, tcp, udp, AnySocket},
socket::{raw, tcp, udp},
wire, wire,
}; };
use system_error::SystemError; use system_error::SystemError;
@ -18,8 +17,8 @@ use crate::{
}; };
use super::{ use super::{
GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod, handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions,
SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
}; };
/// @brief 表示原始的socket。原始套接字绕过传输层协议如 TCP 或 UDP并提供对网络层协议如 IP的直接访问。 /// @brief 表示原始的socket。原始套接字绕过传输层协议如 TCP 或 UDP并提供对网络层协议如 IP的直接访问。
@ -27,7 +26,7 @@ use super::{
/// ref: https://man7.org/linux/man-pages/man7/raw.7.html /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RawSocket { pub struct RawSocket {
handle: Arc<GlobalSocketHandle>, handle: GlobalSocketHandle,
/// 用户发送的数据包是否包含了IP头. /// 用户发送的数据包是否包含了IP头.
/// 如果是true用户发送的数据包必须包含IP头。即用户要自行设置IP头+数据) /// 如果是true用户发送的数据包必须包含IP头。即用户要自行设置IP头+数据)
/// 如果是false用户发送的数据包不包含IP头。即用户只要设置数据 /// 如果是false用户发送的数据包不包含IP头。即用户只要设置数据
@ -68,8 +67,7 @@ impl RawSocket {
); );
// 把socket添加到socket集合中并得到socket的句柄 // 把socket添加到socket集合中并得到socket的句柄
let handle: Arc<GlobalSocketHandle> = let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
let metadata = SocketMetadata::new( let metadata = SocketMetadata::new(
SocketType::Raw, SocketType::Raw,
@ -88,12 +86,20 @@ impl RawSocket {
} }
impl Socket for RawSocket { impl Socket for RawSocket {
fn close(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
poll_ifaces();
}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
poll_ifaces(); poll_ifaces();
loop { loop {
// 如何优化这里? // 如何优化这里?
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0); let socket =
socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
match socket.recv_slice(buf) { match socket.recv_slice(buf) {
Ok(len) => { Ok(len) => {
@ -126,7 +132,8 @@ impl Socket for RawSocket {
// 如果用户发送的数据包包含IP头则直接发送 // 如果用户发送的数据包包含IP头则直接发送
if self.header_included { if self.header_included {
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0); let socket =
socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
match socket.send_slice(buf) { match socket.send_slice(buf) {
Ok(_) => { Ok(_) => {
return Ok(buf.len()); return Ok(buf.len());
@ -141,7 +148,7 @@ impl Socket for RawSocket {
if let Some(Endpoint::Ip(Some(endpoint))) = to { if let Some(Endpoint::Ip(Some(endpoint))) = to {
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket: &mut raw::Socket = let socket: &mut raw::Socket =
socket_set_guard.get_mut::<raw::Socket>(self.handle.0); socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
// 暴力解决方案只考虑0号网卡。 TODO考虑多网卡的情况 // 暴力解决方案只考虑0号网卡。 TODO考虑多网卡的情况
let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone(); let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
@ -209,8 +216,8 @@ impl Socket for RawSocket {
Box::new(self.clone()) Box::new(self.clone())
} }
fn socket_handle(&self) -> SocketHandle { fn socket_handle(&self) -> GlobalSocketHandle {
self.handle.0 self.handle
} }
fn as_any_ref(&self) -> &dyn core::any::Any { fn as_any_ref(&self) -> &dyn core::any::Any {
@ -227,7 +234,7 @@ impl Socket for RawSocket {
/// https://man7.org/linux/man-pages/man7/udp.7.html /// https://man7.org/linux/man-pages/man7/udp.7.html
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct UdpSocket { pub struct UdpSocket {
pub handle: Arc<GlobalSocketHandle>, pub handle: GlobalSocketHandle,
remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect() 应该使用IP地址。 remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect() 应该使用IP地址。
metadata: SocketMetadata, metadata: SocketMetadata,
} }
@ -257,8 +264,8 @@ impl UdpSocket {
let socket = udp::Socket::new(rx_buffer, tx_buffer); let socket = udp::Socket::new(rx_buffer, tx_buffer);
// 把socket添加到socket集合中并得到socket的句柄 // 把socket添加到socket集合中并得到socket的句柄
let handle: Arc<GlobalSocketHandle> = let handle: GlobalSocketHandle =
GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket)); GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
let metadata = SocketMetadata::new( let metadata = SocketMetadata::new(
SocketType::Udp, SocketType::Udp,
@ -301,13 +308,21 @@ impl UdpSocket {
} }
impl Socket for UdpSocket { impl Socket for UdpSocket {
fn close(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
poll_ifaces();
}
/// @brief 在read函数执行之前请先bind到本地的指定端口 /// @brief 在read函数执行之前请先bind到本地的指定端口
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
loop { loop {
// kdebug!("Wait22 to Read"); // kdebug!("Wait22 to Read");
poll_ifaces(); poll_ifaces();
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0); let socket =
socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
// kdebug!("Wait to Read"); // kdebug!("Wait to Read");
@ -344,7 +359,7 @@ impl Socket for UdpSocket {
// kdebug!("udp write: remote = {:?}", remote_endpoint); // kdebug!("udp write: remote = {:?}", remote_endpoint);
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0); let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
// kdebug!("is open()={}", socket.is_open()); // kdebug!("is open()={}", socket.is_open());
// kdebug!("socket endpoint={:?}", socket.endpoint()); // kdebug!("socket endpoint={:?}", socket.endpoint());
if socket.can_send() { if socket.can_send() {
@ -369,14 +384,14 @@ impl Socket for UdpSocket {
fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
let mut sockets = SOCKET_SET.lock_irqsave(); let mut sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get_mut::<udp::Socket>(self.handle.0); let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
// kdebug!("UDP Bind to {:?}", endpoint); // kdebug!("UDP Bind to {:?}", endpoint);
return self.do_bind(socket, endpoint); return self.do_bind(socket, endpoint);
} }
fn poll(&self) -> EPollEventType { fn poll(&self) -> EPollEventType {
let sockets = SOCKET_SET.lock_irqsave(); let sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get::<udp::Socket>(self.handle.0); let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
return SocketPollMethod::udp_poll( return SocketPollMethod::udp_poll(
socket, socket,
@ -417,7 +432,7 @@ impl Socket for UdpSocket {
fn endpoint(&self) -> Option<Endpoint> { fn endpoint(&self) -> Option<Endpoint> {
let sockets = SOCKET_SET.lock_irqsave(); let sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get::<udp::Socket>(self.handle.0); let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
let listen_endpoint = socket.endpoint(); let listen_endpoint = socket.endpoint();
if listen_endpoint.port == 0 { if listen_endpoint.port == 0 {
@ -440,8 +455,8 @@ impl Socket for UdpSocket {
return self.remote_endpoint.clone(); return self.remote_endpoint.clone();
} }
fn socket_handle(&self) -> SocketHandle { fn socket_handle(&self) -> GlobalSocketHandle {
self.handle.0 self.handle
} }
fn as_any_ref(&self) -> &dyn core::any::Any { fn as_any_ref(&self) -> &dyn core::any::Any {
@ -458,7 +473,7 @@ impl Socket for UdpSocket {
/// https://man7.org/linux/man-pages/man7/tcp.7.html /// https://man7.org/linux/man-pages/man7/tcp.7.html
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TcpSocket { pub struct TcpSocket {
handles: Vec<Arc<GlobalSocketHandle>>, handles: Vec<GlobalSocketHandle>,
local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind() local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
is_listening: bool, is_listening: bool,
metadata: SocketMetadata, metadata: SocketMetadata,
@ -483,7 +498,7 @@ impl TcpSocket {
/// @return 返回创建的tcp的socket /// @return 返回创建的tcp的socket
pub fn new(options: SocketOptions) -> Self { pub fn new(options: SocketOptions) -> Self {
// 创建handles数组并把socket添加到socket集合中并得到socket的句柄 // 创建handles数组并把socket添加到socket集合中并得到socket的句柄
let handles: Vec<Arc<GlobalSocketHandle>> = vec![GlobalSocketHandle::new( let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle(
SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
)]; )];
@ -542,6 +557,15 @@ impl TcpSocket {
} }
impl Socket for TcpSocket { impl Socket for TcpSocket {
fn close(&mut self) {
for handle in self.handles.iter() {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
}
poll_ifaces();
}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
if HANDLE_MAP if HANDLE_MAP
.read_irqsave() .read_irqsave()
@ -558,7 +582,8 @@ impl Socket for TcpSocket {
poll_ifaces(); poll_ifaces();
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
// 如果socket已经关闭返回错误 // 如果socket已经关闭返回错误
if !socket.is_active() { if !socket.is_active() {
@ -626,7 +651,8 @@ impl Socket for TcpSocket {
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
if socket.is_open() { if socket.is_open() {
if socket.can_send() { if socket.can_send() {
@ -653,7 +679,8 @@ impl Socket for TcpSocket {
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:poll, socket'len={}",self.handle.len()); // kdebug!("tcp socket:poll, socket'len={}",self.handle.len());
let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
return SocketPollMethod::tcp_poll( return SocketPollMethod::tcp_poll(
socket, socket,
HANDLE_MAP HANDLE_MAP
@ -668,7 +695,8 @@ impl Socket for TcpSocket {
let mut sockets = SOCKET_SET.lock_irqsave(); let mut sockets = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:connect, socket'len={}",self.handle.len()); // kdebug!("tcp socket:connect, socket'len={}",self.handle.len());
let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket =
sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
if let Endpoint::Ip(Some(ip)) = endpoint { if let Endpoint::Ip(Some(ip)) = endpoint {
let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
@ -689,7 +717,9 @@ impl Socket for TcpSocket {
loop { loop {
poll_ifaces(); poll_ifaces();
let mut sockets = SOCKET_SET.lock_irqsave(); let mut sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket = sockets.get_mut::<tcp::Socket>(
self.handles.get(0).unwrap().smoltcp_handle().unwrap(),
);
match socket.state() { match socket.state() {
tcp::State::Established => { tcp::State::Established => {
@ -741,9 +771,9 @@ impl Socket for TcpSocket {
let mut handle_guard = HANDLE_MAP.write_irqsave(); let mut handle_guard = HANDLE_MAP.write_irqsave();
self.handles.extend((handlen..backlog).map(|_| { self.handles.extend((handlen..backlog).map(|_| {
let socket = Self::create_new_socket(); let socket = Self::create_new_socket();
let handle = GlobalSocketHandle::new(sockets.add(socket)); let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
let handle_item = SocketHandleItem::new(); let handle_item = SocketHandleItem::new();
handle_guard.insert(handle.0, handle_item); handle_guard.insert(handle, handle_item);
handle handle
})); }));
// kdebug!("tcp socket:listen, socket'len={}",self.handle.len()); // kdebug!("tcp socket:listen, socket'len={}",self.handle.len());
@ -753,7 +783,7 @@ impl Socket for TcpSocket {
for i in 0..backlog { for i in 0..backlog {
let handle = self.handles.get(i).unwrap(); let handle = self.handles.get(i).unwrap();
let socket = sockets.get_mut::<tcp::Socket>(handle.0); let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap());
if !socket.is_listening() { if !socket.is_listening() {
// kdebug!("Tcp Socket is already listening on {local_endpoint}"); // kdebug!("Tcp Socket is already listening on {local_endpoint}");
@ -803,79 +833,89 @@ impl Socket for TcpSocket {
// 随机获取访问的socket的handle // 随机获取访问的socket的handle
let index: usize = rand() % self.handles.len(); let index: usize = rand() % self.handles.len();
let handle = self.handles.get(index).unwrap(); let handle = self.handles.get(index).unwrap();
let socket = sockets.get_mut::<tcp::Socket>(handle.0);
if socket.is_active() { let socket = sockets
// kdebug!("tcp accept: socket.is_active()"); .iter_mut()
let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; .find(|y| {
tcp::Socket::downcast(y.1)
.map(|y| y.is_active())
.unwrap_or(false)
})
.map(|y| tcp::Socket::downcast_mut(y.1).unwrap());
if let Some(socket) = socket {
if socket.is_active() {
// kdebug!("tcp accept: socket.is_active()");
let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
let new_socket = { let new_socket = {
// The new TCP socket used for sending and receiving data. // The new TCP socket used for sending and receiving data.
let mut tcp_socket = Self::create_new_socket(); let mut tcp_socket = Self::create_new_socket();
self.do_listen(&mut tcp_socket, endpoint) self.do_listen(&mut tcp_socket, endpoint)
.expect("do_listen failed"); .expect("do_listen failed");
// tcp_socket.listen(endpoint).unwrap(); // tcp_socket.listen(endpoint).unwrap();
// 之所以把old_handle存入new_socket, 是因为当前时刻smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 // 之所以把old_handle存入new_socket, 是因为当前时刻smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
// 因此需要再为当前的socket分配一个新的handle // 因此需要再为当前的socket分配一个新的handle
let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket)); let new_handle =
let old_handle = ::core::mem::replace( GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket));
&mut *self.handles.get_mut(index).unwrap(), let old_handle = ::core::mem::replace(
new_handle.clone(), &mut *self.handles.get_mut(index).unwrap(),
); new_handle,
);
let metadata = SocketMetadata::new( let metadata = SocketMetadata::new(
SocketType::Tcp, SocketType::Tcp,
Self::DEFAULT_TX_BUF_SIZE, Self::DEFAULT_TX_BUF_SIZE,
Self::DEFAULT_RX_BUF_SIZE, Self::DEFAULT_RX_BUF_SIZE,
Self::DEFAULT_METADATA_BUF_SIZE, Self::DEFAULT_METADATA_BUF_SIZE,
self.metadata.options, self.metadata.options,
); );
let new_socket = Box::new(TcpSocket { let new_socket = Box::new(TcpSocket {
handles: vec![old_handle.clone()], handles: vec![old_handle],
local_endpoint: self.local_endpoint, local_endpoint: self.local_endpoint,
is_listening: false, is_listening: false,
metadata, metadata,
}); });
// kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len()); // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len());
// 更新端口与 socket 的绑定 // 更新端口与 socket 的绑定
if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() { if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?; PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
PORT_MANAGER.bind_port( PORT_MANAGER.bind_port(
self.metadata.socket_type, self.metadata.socket_type,
ip.port, ip.port,
*new_socket.clone(), *new_socket.clone(),
)?; )?;
} }
// 更新handle表 // 更新handle表
let mut handle_guard = HANDLE_MAP.write_irqsave(); let mut handle_guard = HANDLE_MAP.write_irqsave();
// 先删除原来的 // 先删除原来的
let item = handle_guard.remove(&old_handle.0).unwrap(); let item = handle_guard.remove(&old_handle).unwrap();
// 按照smoltcp行为将新的handle绑定到原来的item // 按照smoltcp行为将新的handle绑定到原来的item
handle_guard.insert(new_handle.0, item); handle_guard.insert(new_handle, item);
let new_item = SocketHandleItem::new(); let new_item = SocketHandleItem::new();
// 插入新的item // 插入新的item
handle_guard.insert(old_handle.0, new_item); handle_guard.insert(old_handle, new_item);
new_socket new_socket
}; };
// kdebug!("tcp accept: new socket: {:?}", new_socket); // kdebug!("tcp accept: new socket: {:?}", new_socket);
drop(sockets); drop(sockets);
poll_ifaces(); poll_ifaces();
return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
}
} }
// kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
drop(sockets); drop(sockets);
SocketHandleItem::sleep(handle.0, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave()); SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
// kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
} }
} }
@ -887,7 +927,8 @@ impl Socket for TcpSocket {
let sockets = SOCKET_SET.lock_irqsave(); let sockets = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len()); // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len());
let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket =
sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
if let Some(ep) = socket.local_endpoint() { if let Some(ep) = socket.local_endpoint() {
result = Some(Endpoint::Ip(Some(ep))); result = Some(Endpoint::Ip(Some(ep)));
} }
@ -899,7 +940,8 @@ impl Socket for TcpSocket {
let sockets = SOCKET_SET.lock_irqsave(); let sockets = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0); let socket =
sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
} }
@ -911,10 +953,10 @@ impl Socket for TcpSocket {
Box::new(self.clone()) Box::new(self.clone())
} }
fn socket_handle(&self) -> SocketHandle { fn socket_handle(&self) -> GlobalSocketHandle {
// kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
self.handles.get(0).unwrap().0 *self.handles.get(0).unwrap()
} }
fn as_any_ref(&self) -> &dyn core::any::Any { fn as_any_ref(&self) -> &dyn core::any::Any {

View File

@ -9,7 +9,7 @@ use alloc::{
}; };
use hashbrown::HashMap; use hashbrown::HashMap;
use smoltcp::{ use smoltcp::{
iface::{SocketHandle, SocketSet}, iface::SocketSet,
socket::{self, tcp, udp}, socket::{self, tcp, udp},
}; };
use system_error::SystemError; use system_error::SystemError;
@ -29,16 +29,17 @@ use crate::{
}; };
use self::{ use self::{
handle::GlobalSocketHandle,
inet::{RawSocket, TcpSocket, UdpSocket}, inet::{RawSocket, TcpSocket, UdpSocket},
unix::{SeqpacketSocket, StreamSocket}, unix::{SeqpacketSocket, StreamSocket},
}; };
use super::{ use super::{
event_poll::{EPollEventType, EPollItem, EventPoll}, event_poll::{EPollEventType, EPollItem, EventPoll},
net_core::poll_ifaces,
Endpoint, Protocol, ShutdownType, Endpoint, Protocol, ShutdownType,
}; };
pub mod handle;
pub mod inet; pub mod inet;
pub mod unix; pub mod unix;
@ -48,7 +49,7 @@ lazy_static! {
pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![])); pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
/// SocketHandle表每个SocketHandle对应一个SocketHandleItem /// SocketHandle表每个SocketHandle对应一个SocketHandleItem
/// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁 /// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁
pub static ref HANDLE_MAP: RwLock<HashMap<SocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new()); pub static ref HANDLE_MAP: RwLock<HashMap<GlobalSocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new());
/// 端口管理器 /// 端口管理器
pub static ref PORT_MANAGER: PortManager = PortManager::new(); pub static ref PORT_MANAGER: PortManager = PortManager::new();
} }
@ -83,6 +84,11 @@ pub(super) fn new_socket(
return Err(SystemError::EAFNOSUPPORT); return Err(SystemError::EAFNOSUPPORT);
} }
}; };
let handle_item = SocketHandleItem::new();
HANDLE_MAP
.write_irqsave()
.insert(socket.socket_handle(), handle_item);
Ok(socket) Ok(socket)
} }
@ -224,9 +230,7 @@ pub trait Socket: Sync + Send + Debug + Any {
Ok(()) Ok(())
} }
fn socket_handle(&self) -> SocketHandle { fn socket_handle(&self) -> GlobalSocketHandle;
todo!()
}
fn write_buffer(&self, _buf: &[u8]) -> Result<usize, SystemError> { fn write_buffer(&self, _buf: &[u8]) -> Result<usize, SystemError> {
todo!() todo!()
@ -272,6 +276,8 @@ pub trait Socket: Sync + Send + Debug + Any {
Ok(()) Ok(())
} }
fn close(&mut self);
} }
impl Clone for Box<dyn Socket> { impl Clone for Box<dyn Socket> {
@ -329,6 +335,7 @@ impl IndexNode for SocketInode {
.write_irqsave() .write_irqsave()
.remove(&socket.socket_handle()) .remove(&socket.socket_handle())
.unwrap(); .unwrap();
socket.close();
} }
Ok(()) Ok(())
@ -409,9 +416,9 @@ impl SocketHandleItem {
/// ## 在socket的等待队列上睡眠 /// ## 在socket的等待队列上睡眠
pub fn sleep( pub fn sleep(
socket_handle: SocketHandle, socket_handle: GlobalSocketHandle,
events: u64, events: u64,
handle_map_guard: RwLockReadGuard<'_, HashMap<SocketHandle, SocketHandleItem>>, handle_map_guard: RwLockReadGuard<'_, HashMap<GlobalSocketHandle, SocketHandleItem>>,
) { ) {
unsafe { unsafe {
handle_map_guard handle_map_guard
@ -544,33 +551,6 @@ impl PortManager {
} }
} }
/// # socket的句柄管理组件
/// 它在smoltcp的SocketHandle上封装了一层增加更多的功能。
/// 比如在socket被关闭时自动释放socket的资源通知系统的其他组件。
#[derive(Debug)]
pub struct GlobalSocketHandle(SocketHandle);
impl GlobalSocketHandle {
pub fn new(handle: SocketHandle) -> Arc<Self> {
return Arc::new(Self(handle));
}
}
impl Clone for GlobalSocketHandle {
fn clone(&self) -> Self {
Self(self.0)
}
}
impl Drop for GlobalSocketHandle {
fn drop(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(self.0); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
poll_ifaces();
}
}
/// @brief socket的类型 /// @brief socket的类型
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub enum SocketType { pub enum SocketType {

View File

@ -3,13 +3,16 @@ use system_error::SystemError;
use crate::{libs::spinlock::SpinLock, net::Endpoint}; use crate::{libs::spinlock::SpinLock, net::Endpoint};
use super::{Socket, SocketInode, SocketMetadata, SocketOptions, SocketType}; use super::{
handle::GlobalSocketHandle, Socket, SocketInode, SocketMetadata, SocketOptions, SocketType,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct StreamSocket { pub struct StreamSocket {
metadata: SocketMetadata, metadata: SocketMetadata,
buffer: Arc<SpinLock<Vec<u8>>>, buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>, peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle,
} }
impl StreamSocket { impl StreamSocket {
@ -37,11 +40,18 @@ impl StreamSocket {
metadata, metadata,
buffer, buffer,
peer_inode: None, peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(),
} }
} }
} }
impl Socket for StreamSocket { impl Socket for StreamSocket {
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
let mut buffer = self.buffer.lock_irqsave(); let mut buffer = self.buffer.lock_irqsave();
@ -110,6 +120,7 @@ pub struct SeqpacketSocket {
metadata: SocketMetadata, metadata: SocketMetadata,
buffer: Arc<SpinLock<Vec<u8>>>, buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>, peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle,
} }
impl SeqpacketSocket { impl SeqpacketSocket {
@ -137,11 +148,14 @@ impl SeqpacketSocket {
metadata, metadata,
buffer, buffer,
peer_inode: None, peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(),
} }
} }
} }
impl Socket for SeqpacketSocket { impl Socket for SeqpacketSocket {
fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
let mut buffer = self.buffer.lock_irqsave(); let mut buffer = self.buffer.lock_irqsave();
@ -188,6 +202,10 @@ impl Socket for SeqpacketSocket {
Ok(len) Ok(len)
} }
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn metadata(&self) -> SocketMetadata { fn metadata(&self) -> SocketMetadata {
self.metadata.clone() self.metadata.clone()
} }

View File

@ -19,7 +19,7 @@ use crate::{
}; };
use super::{ use super::{
socket::{new_socket, PosixSocketType, Socket, SocketHandleItem, SocketInode, HANDLE_MAP}, socket::{new_socket, PosixSocketType, Socket, SocketInode},
Endpoint, Protocol, ShutdownType, Endpoint, Protocol, ShutdownType,
}; };
@ -44,13 +44,6 @@ impl Syscall {
let socket = new_socket(address_family, socket_type, protocol)?; let socket = new_socket(address_family, socket_type, protocol)?;
if address_family != AddressFamily::Unix {
let handle_item = SocketHandleItem::new();
HANDLE_MAP
.write_irqsave()
.insert(socket.socket_handle(), handle_item);
}
let socketinode: Arc<SocketInode> = SocketInode::new(socket); let socketinode: Arc<SocketInode> = SocketInode::new(socket);
let f = File::new(socketinode, FileMode::O_RDWR)?; let f = File::new(socketinode, FileMode::O_RDWR)?;
// 把socket添加到当前进程的文件描述符表中 // 把socket添加到当前进程的文件描述符表中