增加 ListenTable 来检测端口占用 (#291)

* 增加 ListenTable 来检测端口占用


* 使用Arc封装GlobalSocketHandle

* 删除 listen 处的端口检测逻辑,延至实现端口复用时完成

* 设立两张表,分别记录TCP和UDP的端口占用

* 实现 meatadata 相关逻辑

* 实现socket关闭时,端口在表中移除

* 使用端口管理器重构端口记录表

* 修正与RawSocket相关的端口管理逻辑

* 补充测试文件

* 修正 unbind_port 在逻辑错误

* 修正格式问题

---------

Co-authored-by: longjin <longjin@RinGoTek.cn>
This commit is contained in:
Xshine
2023-07-28 17:51:05 +08:00
committed by GitHub
parent 7cc4a02c7f
commit 821bb9a2dc
9 changed files with 716 additions and 51 deletions

View File

@ -228,7 +228,7 @@ pub fn do_mkdir(path: &str, _mode: FileMode) -> Result<u64, SystemError> {
return Ok(0);
}
/// @breif 删除文件夹
/// @brief 删除文件夹
pub fn do_remove_dir(path: &str) -> Result<u64, SystemError> {
// 文件名过长
if path.len() > PAGE_4K_SIZE as usize {

View File

@ -81,7 +81,7 @@ pub trait BlockDevice: Any + Send + Sync + Debug {
/// @brief: 同步磁盘信息把所有的dirty数据写回硬盘 - 待实现
fn sync(&self) -> Result<(), SystemError>;
/// @breif: 每个块设备都必须固定自己块大小而且该块大小必须是2的幂次
/// @brief: 每个块设备都必须固定自己块大小而且该块大小必须是2的幂次
/// @return: 返回一个固定量,硬编码(编程的时候固定的常量).
fn blk_size_log2(&self) -> u8;

View File

@ -70,9 +70,9 @@ pub struct BuddyAllocator<A> {
impl<A: MemoryManagementArch> BuddyAllocator<A> {
const BUDDY_ENTRIES: usize =
// 定义一个变量记录buddy表的大小
// 定义一个变量记录buddy表的大小
(A::PAGE_SIZE - mem::size_of::<PageList<A>>()) / mem::size_of::<PhysAddr>();
pub unsafe fn new(mut bump_allocator: BumpAllocator<A>) -> Option<Self> {
let initial_free_pages = bump_allocator.usage().free();
kdebug!("Free pages before init buddy: {:?}", initial_free_pages);

View File

@ -1,5 +1,6 @@
#![allow(dead_code)]
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use hashbrown::HashMap;
use smoltcp::{
iface::{SocketHandle, SocketSet},
socket::{raw, tcp, udp},
@ -25,6 +26,100 @@ lazy_static! {
/// TODO: 优化这里自己实现SocketSet现在这样的话不管全局有多少个网卡每个时间点都只会有1个进程能够访问socket
pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
pub static ref SOCKET_WAITQUEUE: WaitQueue = WaitQueue::INIT;
/// 端口管理器
pub static ref PORT_MANAGER: PortManager = PortManager::new();
}
/// @brief TCP 和 UDP 的端口管理器。
/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
pub struct PortManager {
// TCP 端口记录表
tcp_port_table: SpinLock<HashMap<u16, Arc<GlobalSocketHandle>>>,
// UDP 端口记录表
udp_port_table: SpinLock<HashMap<u16, Arc<GlobalSocketHandle>>>,
}
impl PortManager {
pub fn new() -> Self {
return Self {
tcp_port_table: SpinLock::new(HashMap::new()),
udp_port_table: SpinLock::new(HashMap::new()),
};
}
/// @brief 自动分配一个相对应协议中未被使用的PORT如果动态端口均已被占用返回错误码 EADDRINUSE
pub fn get_ephemeral_port(&self, socket_type: SocketType) -> Result<u16, SystemError> {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 0;
unsafe {
if EPHEMERAL_PORT == 0 {
EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
}
}
let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数
let mut port: u16;
while remaining > 0 {
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
}
port = EPHEMERAL_PORT;
}
// 使用 ListenTable 检查端口是否被占用
let listen_table_guard = match socket_type {
SocketType::UdpSocket => self.udp_port_table.lock(),
SocketType::TcpSocket => self.tcp_port_table.lock(),
SocketType::RawSocket => todo!(),
};
if let None = listen_table_guard.get(&port) {
drop(listen_table_guard);
return Ok(port);
}
remaining -= 1;
}
return Err(SystemError::EADDRINUSE);
}
/// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
///
/// TODO: 增加支持端口复用的逻辑
pub fn get_port(
&self,
socket_type: SocketType,
port: u16,
handle: Arc<GlobalSocketHandle>,
) -> Result<(), SystemError> {
if port > 0 {
let mut listen_table_guard = match socket_type {
SocketType::UdpSocket => self.udp_port_table.lock(),
SocketType::TcpSocket => self.tcp_port_table.lock(),
SocketType::RawSocket => panic!("RawSocket cann't bind a port"),
};
match listen_table_guard.get(&port) {
Some(_) => return Err(SystemError::EADDRINUSE),
None => listen_table_guard.insert(port, handle),
};
drop(listen_table_guard);
}
return Ok(());
}
/// @brief 在对应的端口记录表中将端口和 socket 解绑
pub fn unbind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> {
let mut listen_table_guard = match socket_type {
SocketType::UdpSocket => self.udp_port_table.lock(),
SocketType::TcpSocket => self.tcp_port_table.lock(),
SocketType::RawSocket => return Ok(()),
};
listen_table_guard.remove(&port);
drop(listen_table_guard);
return Ok(());
}
}
/* For setsockopt(2) */
@ -38,8 +133,8 @@ pub const SOL_SOCKET: u8 = 1;
pub struct GlobalSocketHandle(SocketHandle);
impl GlobalSocketHandle {
pub fn new(handle: SocketHandle) -> Self {
Self(handle)
pub fn new(handle: SocketHandle) -> Arc<Self> {
return Arc::new(Self(handle));
}
}
@ -59,7 +154,7 @@ impl Drop for GlobalSocketHandle {
}
/// @brief socket的类型
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum SocketType {
/// 原始的socket
RawSocket,
@ -86,7 +181,7 @@ bitflags! {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
/// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
pub struct SocketMetadata {
/// socket的类型
@ -101,18 +196,36 @@ pub struct SocketMetadata {
pub options: SocketOptions,
}
impl SocketMetadata {
fn new(
socket_type: SocketType,
send_buf_size: usize,
recv_buf_size: usize,
metadata_buf_size: usize,
options: SocketOptions,
) -> Self {
Self {
socket_type,
send_buf_size,
recv_buf_size,
metadata_buf_size,
options,
}
}
}
/// @brief 表示原始的socket。原始套接字绕过传输层协议如 TCP 或 UDP并提供对网络层协议如 IP的直接访问。
///
/// ref: https://man7.org/linux/man-pages/man7/raw.7.html
#[derive(Debug, Clone)]
pub struct RawSocket {
handle: GlobalSocketHandle,
handle: Arc<GlobalSocketHandle>,
/// 用户发送的数据包是否包含了IP头.
/// 如果是true用户发送的数据包必须包含IP头。即用户要自行设置IP头+数据)
/// 如果是false用户发送的数据包不包含IP头。即用户只要设置数据
header_included: bool,
/// socket的选项
options: SocketOptions,
/// socket的metadata
metadata: SocketMetadata,
}
impl RawSocket {
@ -147,12 +260,21 @@ impl RawSocket {
);
// 把socket添加到socket集合中并得到socket的句柄
let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
let handle: Arc<GlobalSocketHandle> =
GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
let metadata = SocketMetadata::new(
SocketType::RawSocket,
Self::DEFAULT_RX_BUF_SIZE,
Self::DEFAULT_TX_BUF_SIZE,
Self::DEFAULT_METADATA_BUF_SIZE,
options,
);
return Self {
handle,
header_included: false,
options,
metadata,
};
}
}
@ -177,7 +299,7 @@ impl Socket for RawSocket {
);
}
Err(smoltcp::socket::raw::RecvError::Exhausted) => {
if !self.options.contains(SocketOptions::BLOCK) {
if !self.metadata.options.contains(SocketOptions::BLOCK) {
// 如果是非阻塞的socket就返回错误
return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
}
@ -271,7 +393,7 @@ impl Socket for RawSocket {
}
fn metadata(&self) -> Result<SocketMetadata, SystemError> {
todo!()
Ok(self.metadata.clone())
}
fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
@ -284,9 +406,9 @@ impl Socket for RawSocket {
/// https://man7.org/linux/man-pages/man7/udp.7.html
#[derive(Debug, Clone)]
pub struct UdpSocket {
pub handle: GlobalSocketHandle,
pub handle: Arc<GlobalSocketHandle>,
remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect() 应该使用IP地址。
options: SocketOptions,
metadata: SocketMetadata,
}
impl UdpSocket {
@ -315,17 +437,29 @@ impl UdpSocket {
let socket = udp::Socket::new(tx_buffer, rx_buffer);
// 把socket添加到socket集合中并得到socket的句柄
let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
let handle: Arc<GlobalSocketHandle> =
GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
let metadata = SocketMetadata::new(
SocketType::UdpSocket,
Self::DEFAULT_RX_BUF_SIZE,
Self::DEFAULT_TX_BUF_SIZE,
Self::DEFAULT_METADATA_BUF_SIZE,
options,
);
return Self {
handle,
remote_endpoint: None,
options,
metadata,
};
}
fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(Some(ip)) = endpoint {
// 检测端口是否已被占用
PORT_MANAGER.get_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
let bind_res = if ip.addr.is_unspecified() {
socket.bind(ip.port)
} else {
@ -388,7 +522,7 @@ impl Socket for UdpSocket {
// kdebug!("is open()={}", socket.is_open());
// kdebug!("socket endpoint={:?}", socket.endpoint());
if socket.endpoint().port == 0 {
let temp_port = get_ephemeral_port();
let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
let local_ep = match remote_endpoint.addr {
// 远程remote endpoint使用什么协议发送的时候使用的协议是一样的吧
@ -461,7 +595,7 @@ impl Socket for UdpSocket {
todo!()
}
fn metadata(&self) -> Result<SocketMetadata, SystemError> {
todo!()
Ok(self.metadata.clone())
}
fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
@ -499,10 +633,10 @@ impl Socket for UdpSocket {
/// https://man7.org/linux/man-pages/man7/tcp.7.html
#[derive(Debug, Clone)]
pub struct TcpSocket {
handle: GlobalSocketHandle,
handle: Arc<GlobalSocketHandle>,
local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
is_listening: bool,
options: SocketOptions,
metadata: SocketMetadata,
}
impl TcpSocket {
@ -525,13 +659,22 @@ impl TcpSocket {
let socket = tcp::Socket::new(tx_buffer, rx_buffer);
// 把socket添加到socket集合中并得到socket的句柄
let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
let handle: Arc<GlobalSocketHandle> =
GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
let metadata = SocketMetadata::new(
SocketType::TcpSocket,
Self::DEFAULT_RX_BUF_SIZE,
Self::DEFAULT_TX_BUF_SIZE,
Self::DEFAULT_METADATA_BUF_SIZE,
options,
);
return Self {
handle,
local_endpoint: None,
is_listening: false,
options,
metadata,
};
}
fn do_listen(
@ -546,7 +689,7 @@ impl TcpSocket {
// kdebug!("Tcp Socket Listen on {local_endpoint}");
socket.listen(local_endpoint)
};
// todo: 增加端口占用检查
// TODO: 增加端口占用检查
return match listen_result {
Ok(()) => {
// kdebug!(
@ -668,7 +811,7 @@ impl Socket for TcpSocket {
let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
if let Endpoint::Ip(Some(ip)) = endpoint {
let temp_port = get_ephemeral_port();
let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
// kdebug!("temp_port: {}", temp_port);
let iface: Arc<dyn NetDriver> = NET_DRIVERS.write().get(&0).unwrap().clone();
let mut inner_iface = iface.inner_iface().lock();
@ -737,9 +880,12 @@ impl Socket for TcpSocket {
fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(Some(mut ip)) = endpoint {
if ip.port == 0 {
ip.port = get_ephemeral_port();
ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
}
// 检测端口是否已被占用
PORT_MANAGER.get_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
self.local_endpoint = Some(ip);
self.is_listening = false;
return Ok(());
@ -785,11 +931,19 @@ impl Socket for TcpSocket {
let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
let old_handle = ::core::mem::replace(&mut self.handle, new_handle);
let metadata = SocketMetadata {
socket_type: SocketType::TcpSocket,
send_buf_size: Self::DEFAULT_RX_BUF_SIZE,
recv_buf_size: Self::DEFAULT_TX_BUF_SIZE,
metadata_buf_size: Self::DEFAULT_METADATA_BUF_SIZE,
options: self.metadata.options,
};
Box::new(TcpSocket {
handle: old_handle,
local_endpoint: self.local_endpoint,
is_listening: false,
options: self.options,
metadata,
})
};
// kdebug!("tcp accept: new socket: {:?}", new_socket);
@ -825,7 +979,7 @@ impl Socket for TcpSocket {
}
fn metadata(&self) -> Result<SocketMetadata, SystemError> {
todo!()
Ok(self.metadata.clone())
}
fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
@ -833,26 +987,6 @@ impl Socket for TcpSocket {
}
}
/// @breif 自动分配一个未被使用的PORT
///
/// TODO: 增加ListenTable, 用于检查端口是否被占用
pub fn get_ephemeral_port() -> u16 {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 0;
unsafe {
if EPHEMERAL_PORT == 0 {
EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
}
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
}
EPHEMERAL_PORT
}
}
/// @brief 地址族的枚举
///
/// 参考https://opengrok.ringotek.cn/xref/linux-5.19.10/include/linux/socket.h#180
@ -1012,6 +1146,10 @@ impl IndexNode for SocketInode {
&self,
_data: &mut crate::filesystem::vfs::FilePrivateData,
) -> Result<(), SystemError> {
let socket = self.0.lock();
if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() {
PORT_MANAGER.unbind_port(socket.metadata().unwrap().socket_type, ip.port)?;
}
return Ok(());
}

View File

@ -348,7 +348,6 @@ impl process_control_block {
}
}
/// @brief 初始化pid=1的进程的stdio
pub fn init_stdio() -> Result<(), SystemError> {
if current_pcb().pid != 1 {