socket统一改用GlobalSocketHandle,并且修复fcntl SETFD的错误 (#730)

* socket统一改用`GlobalSocketHandle`,并且修复fcntl SETFD的错误

---------

Co-authored-by: longjin <longjin@DragonOS.org>
This commit is contained in:
GnoCiYeH 2024-04-15 22:01:32 +08:00 committed by GitHub
parent 7162a8358d
commit d623e90231
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 247 additions and 148 deletions

View File

@ -88,7 +88,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
mfence();
let pid = ProcessManager::current_pcb().pid();
let show = false;
// let show = if syscall_num != SYS_SCHED && pid.data() > 3 {
// let show = if syscall_num != SYS_SCHED && pid.data() >= 7 {
// true
// } else {
// false

View File

@ -1024,6 +1024,15 @@ impl Syscall {
oldfd: i32,
newfd: i32,
fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>,
) -> Result<usize, SystemError> {
Self::do_dup3(oldfd, newfd, FileMode::empty(), fd_table_guard)
}
fn do_dup3(
oldfd: i32,
newfd: i32,
flags: FileMode,
fd_table_guard: &mut RwLockWriteGuard<'_, FileDescriptorVec>,
) -> Result<usize, SystemError> {
// 确认oldfd, newid是否有效
if !(FileDescriptorVec::validate_fd(oldfd) && FileDescriptorVec::validate_fd(newfd)) {
@ -1047,8 +1056,12 @@ impl Syscall {
.get_file_by_fd(oldfd)
.ok_or(SystemError::EBADF)?;
let new_file = old_file.try_clone().ok_or(SystemError::EBADF)?;
// dup2默认非cloexec
if flags.contains(FileMode::O_CLOEXEC) {
new_file.set_close_on_exec(true);
} else {
new_file.set_close_on_exec(false);
}
// 申请文件描述符,并把文件对象存入其中
let res = fd_table_guard
.alloc_fd(new_file, Some(newfd))
@ -1064,8 +1077,9 @@ impl Syscall {
/// - `cmd`:命令
/// - `arg`:参数
pub fn fcntl(fd: i32, cmd: FcntlCommand, arg: i32) -> Result<usize, SystemError> {
// kdebug!("fcntl ({cmd:?}) fd: {fd}, arg={arg}");
match cmd {
FcntlCommand::DupFd => {
FcntlCommand::DupFd | FcntlCommand::DupFdCloexec => {
if arg < 0 || arg as usize >= FileDescriptorVec::PROCESS_MAX_FD {
return Err(SystemError::EBADF);
}
@ -1074,7 +1088,16 @@ impl Syscall {
let binding = ProcessManager::current_pcb().fd_table();
let mut fd_table_guard = binding.write();
if fd_table_guard.get_file_by_fd(i as i32).is_none() {
if cmd == FcntlCommand::DupFd {
return Self::do_dup2(fd, i as i32, &mut fd_table_guard);
} else {
return Self::do_dup3(
fd,
i as i32,
FileMode::O_CLOEXEC,
&mut fd_table_guard,
);
}
}
}
return Err(SystemError::EMFILE);
@ -1083,12 +1106,15 @@ impl Syscall {
// Get file descriptor flags.
let binding = ProcessManager::current_pcb().fd_table();
let fd_table_guard = binding.read();
if let Some(file) = fd_table_guard.get_file_by_fd(fd) {
// drop guard 以避免无法调度的问题
drop(fd_table_guard);
if file.close_on_exec() {
return Ok(FD_CLOEXEC as usize);
} else {
return Ok(0);
}
}
return Err(SystemError::EBADF);
@ -1145,8 +1171,8 @@ impl Syscall {
// TODO: unimplemented
// 未实现的命令返回0不报错。
// kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd);
return Ok(0);
kwarn!("fcntl: unimplemented command: {:?}, defaults to 0.", cmd);
return Err(SystemError::ENOSYS);
}
}
}

View File

@ -12,7 +12,7 @@ use crate::{
use super::{
event_poll::{EPollEventType, EventPoll},
socket::{inet::TcpSocket, HANDLE_MAP, SOCKET_SET},
socket::{handle::GlobalSocketHandle, inet::TcpSocket, HANDLE_MAP, SOCKET_SET},
};
/// The network poll function, which will be called by timer.
@ -188,7 +188,8 @@ pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> {
fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
for (handle, socket_type) in sockets.iter() {
let handle_guard = HANDLE_MAP.read_irqsave();
let item = handle_guard.get(&handle);
let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
let item = handle_guard.get(&global_handle);
if item.is_none() {
continue;
}
@ -203,7 +204,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
match socket_type {
smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
handle_guard
.get(&handle)
.get(&global_handle)
.unwrap()
.wait_queue
.wakeup_any(events);
@ -217,7 +218,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
events |= TcpSocket::CAN_CONNECT;
}
handle_guard
.get(&handle)
.get(&global_handle)
.unwrap()
.wait_queue
.wakeup_any(events);
@ -227,7 +228,7 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
}
drop(handle_guard);
let mut handle_guard = HANDLE_MAP.write_irqsave();
let handle_item = handle_guard.get_mut(&handle).unwrap();
let handle_item = handle_guard.get_mut(&global_handle).unwrap();
EventPoll::wakeup_epoll(
&handle_item.epitems,
EPollEventType::from_bits_truncate(events as u32),

View File

@ -0,0 +1,39 @@
use ida::IdAllocator;
use smoltcp::iface::SocketHandle;
int_like!(KernelHandle, usize);
/// # socket的句柄管理组件
/// 它在smoltcp的SocketHandle上封装了一层增加更多的功能。
/// 比如在socket被关闭时自动释放socket的资源通知系统的其他组件。
#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
pub enum GlobalSocketHandle {
Smoltcp(SocketHandle),
Kernel(KernelHandle),
}
static KERNEL_HANDLE_IDA: IdAllocator = IdAllocator::new(0, usize::MAX);
impl GlobalSocketHandle {
pub fn new_smoltcp_handle(handle: SocketHandle) -> Self {
return Self::Smoltcp(handle);
}
pub fn new_kernel_handle() -> Self {
return Self::Kernel(KernelHandle::new(KERNEL_HANDLE_IDA.alloc().unwrap()));
}
pub fn smoltcp_handle(&self) -> Option<SocketHandle> {
if let Self::Smoltcp(sh) = *self {
return Some(sh);
}
None
}
pub fn kernel_handle(&self) -> Option<KernelHandle> {
if let Self::Kernel(kh) = *self {
return Some(kh);
}
None
}
}

View File

@ -1,7 +1,6 @@
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use smoltcp::{
iface::SocketHandle,
socket::{raw, tcp, udp},
socket::{raw, tcp, udp, AnySocket},
wire,
};
use system_error::SystemError;
@ -18,8 +17,8 @@ use crate::{
};
use super::{
GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod,
SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions,
SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
};
/// @brief 表示原始的socket。原始套接字绕过传输层协议如 TCP 或 UDP并提供对网络层协议如 IP的直接访问。
@ -27,7 +26,7 @@ use super::{
/// ref: https://man7.org/linux/man-pages/man7/raw.7.html
#[derive(Debug, Clone)]
pub struct RawSocket {
handle: Arc<GlobalSocketHandle>,
handle: GlobalSocketHandle,
/// 用户发送的数据包是否包含了IP头.
/// 如果是true用户发送的数据包必须包含IP头。即用户要自行设置IP头+数据)
/// 如果是false用户发送的数据包不包含IP头。即用户只要设置数据
@ -68,8 +67,7 @@ impl RawSocket {
);
// 把socket添加到socket集合中并得到socket的句柄
let handle: Arc<GlobalSocketHandle> =
GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
let metadata = SocketMetadata::new(
SocketType::Raw,
@ -88,12 +86,20 @@ impl RawSocket {
}
impl Socket for RawSocket {
fn close(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
poll_ifaces();
}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
poll_ifaces();
loop {
// 如何优化这里?
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
let socket =
socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
match socket.recv_slice(buf) {
Ok(len) => {
@ -126,7 +132,8 @@ impl Socket for RawSocket {
// 如果用户发送的数据包包含IP头则直接发送
if self.header_included {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
let socket =
socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
match socket.send_slice(buf) {
Ok(_) => {
return Ok(buf.len());
@ -141,7 +148,7 @@ impl Socket for RawSocket {
if let Some(Endpoint::Ip(Some(endpoint))) = to {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket: &mut raw::Socket =
socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
// 暴力解决方案只考虑0号网卡。 TODO考虑多网卡的情况
let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
@ -209,8 +216,8 @@ impl Socket for RawSocket {
Box::new(self.clone())
}
fn socket_handle(&self) -> SocketHandle {
self.handle.0
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn as_any_ref(&self) -> &dyn core::any::Any {
@ -227,7 +234,7 @@ impl Socket for RawSocket {
/// https://man7.org/linux/man-pages/man7/udp.7.html
#[derive(Debug, Clone)]
pub struct UdpSocket {
pub handle: Arc<GlobalSocketHandle>,
pub handle: GlobalSocketHandle,
remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect() 应该使用IP地址。
metadata: SocketMetadata,
}
@ -257,8 +264,8 @@ impl UdpSocket {
let socket = udp::Socket::new(rx_buffer, tx_buffer);
// 把socket添加到socket集合中并得到socket的句柄
let handle: Arc<GlobalSocketHandle> =
GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
let handle: GlobalSocketHandle =
GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
let metadata = SocketMetadata::new(
SocketType::Udp,
@ -301,13 +308,21 @@ impl UdpSocket {
}
impl Socket for UdpSocket {
fn close(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
poll_ifaces();
}
/// @brief 在read函数执行之前请先bind到本地的指定端口
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
loop {
// kdebug!("Wait22 to Read");
poll_ifaces();
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
let socket =
socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
// kdebug!("Wait to Read");
@ -344,7 +359,7 @@ impl Socket for UdpSocket {
// kdebug!("udp write: remote = {:?}", remote_endpoint);
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
// kdebug!("is open()={}", socket.is_open());
// kdebug!("socket endpoint={:?}", socket.endpoint());
if socket.can_send() {
@ -369,14 +384,14 @@ impl Socket for UdpSocket {
fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
let mut sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
// kdebug!("UDP Bind to {:?}", endpoint);
return self.do_bind(socket, endpoint);
}
fn poll(&self) -> EPollEventType {
let sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get::<udp::Socket>(self.handle.0);
let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
return SocketPollMethod::udp_poll(
socket,
@ -417,7 +432,7 @@ impl Socket for UdpSocket {
fn endpoint(&self) -> Option<Endpoint> {
let sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get::<udp::Socket>(self.handle.0);
let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
let listen_endpoint = socket.endpoint();
if listen_endpoint.port == 0 {
@ -440,8 +455,8 @@ impl Socket for UdpSocket {
return self.remote_endpoint.clone();
}
fn socket_handle(&self) -> SocketHandle {
self.handle.0
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn as_any_ref(&self) -> &dyn core::any::Any {
@ -458,7 +473,7 @@ impl Socket for UdpSocket {
/// https://man7.org/linux/man-pages/man7/tcp.7.html
#[derive(Debug, Clone)]
pub struct TcpSocket {
handles: Vec<Arc<GlobalSocketHandle>>,
handles: Vec<GlobalSocketHandle>,
local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
is_listening: bool,
metadata: SocketMetadata,
@ -483,7 +498,7 @@ impl TcpSocket {
/// @return 返回创建的tcp的socket
pub fn new(options: SocketOptions) -> Self {
// 创建handles数组并把socket添加到socket集合中并得到socket的句柄
let handles: Vec<Arc<GlobalSocketHandle>> = vec![GlobalSocketHandle::new(
let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle(
SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
)];
@ -542,6 +557,15 @@ impl TcpSocket {
}
impl Socket for TcpSocket {
fn close(&mut self) {
for handle in self.handles.iter() {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
}
poll_ifaces();
}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
if HANDLE_MAP
.read_irqsave()
@ -558,7 +582,8 @@ impl Socket for TcpSocket {
poll_ifaces();
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
// 如果socket已经关闭返回错误
if !socket.is_active() {
@ -626,7 +651,8 @@ impl Socket for TcpSocket {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
if socket.is_open() {
if socket.can_send() {
@ -653,7 +679,8 @@ impl Socket for TcpSocket {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:poll, socket'len={}",self.handle.len());
let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
return SocketPollMethod::tcp_poll(
socket,
HANDLE_MAP
@ -668,7 +695,8 @@ impl Socket for TcpSocket {
let mut sockets = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:connect, socket'len={}",self.handle.len());
let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket =
sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
if let Endpoint::Ip(Some(ip)) = endpoint {
let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
@ -689,7 +717,9 @@ impl Socket for TcpSocket {
loop {
poll_ifaces();
let mut sockets = SOCKET_SET.lock_irqsave();
let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket = sockets.get_mut::<tcp::Socket>(
self.handles.get(0).unwrap().smoltcp_handle().unwrap(),
);
match socket.state() {
tcp::State::Established => {
@ -741,9 +771,9 @@ impl Socket for TcpSocket {
let mut handle_guard = HANDLE_MAP.write_irqsave();
self.handles.extend((handlen..backlog).map(|_| {
let socket = Self::create_new_socket();
let handle = GlobalSocketHandle::new(sockets.add(socket));
let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
let handle_item = SocketHandleItem::new();
handle_guard.insert(handle.0, handle_item);
handle_guard.insert(handle, handle_item);
handle
}));
// kdebug!("tcp socket:listen, socket'len={}",self.handle.len());
@ -753,7 +783,7 @@ impl Socket for TcpSocket {
for i in 0..backlog {
let handle = self.handles.get(i).unwrap();
let socket = sockets.get_mut::<tcp::Socket>(handle.0);
let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap());
if !socket.is_listening() {
// kdebug!("Tcp Socket is already listening on {local_endpoint}");
@ -803,8 +833,16 @@ impl Socket for TcpSocket {
// 随机获取访问的socket的handle
let index: usize = rand() % self.handles.len();
let handle = self.handles.get(index).unwrap();
let socket = sockets.get_mut::<tcp::Socket>(handle.0);
let socket = sockets
.iter_mut()
.find(|y| {
tcp::Socket::downcast(y.1)
.map(|y| y.is_active())
.unwrap_or(false)
})
.map(|y| tcp::Socket::downcast_mut(y.1).unwrap());
if let Some(socket) = socket {
if socket.is_active() {
// kdebug!("tcp accept: socket.is_active()");
let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
@ -819,10 +857,11 @@ impl Socket for TcpSocket {
// 之所以把old_handle存入new_socket, 是因为当前时刻smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
// 因此需要再为当前的socket分配一个新的handle
let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
let new_handle =
GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket));
let old_handle = ::core::mem::replace(
&mut *self.handles.get_mut(index).unwrap(),
new_handle.clone(),
new_handle,
);
let metadata = SocketMetadata::new(
@ -834,7 +873,7 @@ impl Socket for TcpSocket {
);
let new_socket = Box::new(TcpSocket {
handles: vec![old_handle.clone()],
handles: vec![old_handle],
local_endpoint: self.local_endpoint,
is_listening: false,
metadata,
@ -855,14 +894,14 @@ impl Socket for TcpSocket {
let mut handle_guard = HANDLE_MAP.write_irqsave();
// 先删除原来的
let item = handle_guard.remove(&old_handle.0).unwrap();
let item = handle_guard.remove(&old_handle).unwrap();
// 按照smoltcp行为将新的handle绑定到原来的item
handle_guard.insert(new_handle.0, item);
handle_guard.insert(new_handle, item);
let new_item = SocketHandleItem::new();
// 插入新的item
handle_guard.insert(old_handle.0, new_item);
handle_guard.insert(old_handle, new_item);
new_socket
};
@ -872,10 +911,11 @@ impl Socket for TcpSocket {
return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
}
}
// kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
drop(sockets);
SocketHandleItem::sleep(handle.0, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
// kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
}
}
@ -887,7 +927,8 @@ impl Socket for TcpSocket {
let sockets = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len());
let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket =
sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
if let Some(ep) = socket.local_endpoint() {
result = Some(Endpoint::Ip(Some(ep)));
}
@ -899,7 +940,8 @@ impl Socket for TcpSocket {
let sockets = SOCKET_SET.lock_irqsave();
// kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0);
let socket =
sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
}
@ -911,10 +953,10 @@ impl Socket for TcpSocket {
Box::new(self.clone())
}
fn socket_handle(&self) -> SocketHandle {
fn socket_handle(&self) -> GlobalSocketHandle {
// kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
self.handles.get(0).unwrap().0
*self.handles.get(0).unwrap()
}
fn as_any_ref(&self) -> &dyn core::any::Any {

View File

@ -9,7 +9,7 @@ use alloc::{
};
use hashbrown::HashMap;
use smoltcp::{
iface::{SocketHandle, SocketSet},
iface::SocketSet,
socket::{self, tcp, udp},
};
use system_error::SystemError;
@ -29,16 +29,17 @@ use crate::{
};
use self::{
handle::GlobalSocketHandle,
inet::{RawSocket, TcpSocket, UdpSocket},
unix::{SeqpacketSocket, StreamSocket},
};
use super::{
event_poll::{EPollEventType, EPollItem, EventPoll},
net_core::poll_ifaces,
Endpoint, Protocol, ShutdownType,
};
pub mod handle;
pub mod inet;
pub mod unix;
@ -48,7 +49,7 @@ lazy_static! {
pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
/// SocketHandle表每个SocketHandle对应一个SocketHandleItem
/// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁
pub static ref HANDLE_MAP: RwLock<HashMap<SocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new());
pub static ref HANDLE_MAP: RwLock<HashMap<GlobalSocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new());
/// 端口管理器
pub static ref PORT_MANAGER: PortManager = PortManager::new();
}
@ -83,6 +84,11 @@ pub(super) fn new_socket(
return Err(SystemError::EAFNOSUPPORT);
}
};
let handle_item = SocketHandleItem::new();
HANDLE_MAP
.write_irqsave()
.insert(socket.socket_handle(), handle_item);
Ok(socket)
}
@ -224,9 +230,7 @@ pub trait Socket: Sync + Send + Debug + Any {
Ok(())
}
fn socket_handle(&self) -> SocketHandle {
todo!()
}
fn socket_handle(&self) -> GlobalSocketHandle;
fn write_buffer(&self, _buf: &[u8]) -> Result<usize, SystemError> {
todo!()
@ -272,6 +276,8 @@ pub trait Socket: Sync + Send + Debug + Any {
Ok(())
}
fn close(&mut self);
}
impl Clone for Box<dyn Socket> {
@ -329,6 +335,7 @@ impl IndexNode for SocketInode {
.write_irqsave()
.remove(&socket.socket_handle())
.unwrap();
socket.close();
}
Ok(())
@ -409,9 +416,9 @@ impl SocketHandleItem {
/// ## 在socket的等待队列上睡眠
pub fn sleep(
socket_handle: SocketHandle,
socket_handle: GlobalSocketHandle,
events: u64,
handle_map_guard: RwLockReadGuard<'_, HashMap<SocketHandle, SocketHandleItem>>,
handle_map_guard: RwLockReadGuard<'_, HashMap<GlobalSocketHandle, SocketHandleItem>>,
) {
unsafe {
handle_map_guard
@ -544,33 +551,6 @@ impl PortManager {
}
}
/// # socket的句柄管理组件
/// 它在smoltcp的SocketHandle上封装了一层增加更多的功能。
/// 比如在socket被关闭时自动释放socket的资源通知系统的其他组件。
#[derive(Debug)]
pub struct GlobalSocketHandle(SocketHandle);
impl GlobalSocketHandle {
pub fn new(handle: SocketHandle) -> Arc<Self> {
return Arc::new(Self(handle));
}
}
impl Clone for GlobalSocketHandle {
fn clone(&self) -> Self {
Self(self.0)
}
}
impl Drop for GlobalSocketHandle {
fn drop(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave();
socket_set_guard.remove(self.0); // 删除的时候会发送一条FINISH的信息
drop(socket_set_guard);
poll_ifaces();
}
}
/// @brief socket的类型
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SocketType {

View File

@ -3,13 +3,16 @@ use system_error::SystemError;
use crate::{libs::spinlock::SpinLock, net::Endpoint};
use super::{Socket, SocketInode, SocketMetadata, SocketOptions, SocketType};
use super::{
handle::GlobalSocketHandle, Socket, SocketInode, SocketMetadata, SocketOptions, SocketType,
};
#[derive(Debug, Clone)]
pub struct StreamSocket {
metadata: SocketMetadata,
buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle,
}
impl StreamSocket {
@ -37,11 +40,18 @@ impl StreamSocket {
metadata,
buffer,
peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(),
}
}
}
impl Socket for StreamSocket {
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
let mut buffer = self.buffer.lock_irqsave();
@ -110,6 +120,7 @@ pub struct SeqpacketSocket {
metadata: SocketMetadata,
buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle,
}
impl SeqpacketSocket {
@ -137,11 +148,14 @@ impl SeqpacketSocket {
metadata,
buffer,
peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(),
}
}
}
impl Socket for SeqpacketSocket {
fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
let mut buffer = self.buffer.lock_irqsave();
@ -188,6 +202,10 @@ impl Socket for SeqpacketSocket {
Ok(len)
}
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn metadata(&self) -> SocketMetadata {
self.metadata.clone()
}

View File

@ -19,7 +19,7 @@ use crate::{
};
use super::{
socket::{new_socket, PosixSocketType, Socket, SocketHandleItem, SocketInode, HANDLE_MAP},
socket::{new_socket, PosixSocketType, Socket, SocketInode},
Endpoint, Protocol, ShutdownType,
};
@ -44,13 +44,6 @@ impl Syscall {
let socket = new_socket(address_family, socket_type, protocol)?;
if address_family != AddressFamily::Unix {
let handle_item = SocketHandleItem::new();
HANDLE_MAP
.write_irqsave()
.insert(socket.socket_handle(), handle_item);
}
let socketinode: Arc<SocketInode> = SocketInode::new(socket);
let f = File::new(socketinode, FileMode::O_RDWR)?;
// 把socket添加到当前进程的文件描述符表中