@ -1,5 +1,6 @@
#![ allow(dead_code) ]
use alloc ::{ boxed ::Box , sync ::Arc , vec ::Vec } ;
use hashbrown ::HashMap ;
use smoltcp ::{
iface ::{ SocketHandle , SocketSet } ,
socket ::{ raw , tcp , udp } ,
@ -25,6 +26,100 @@ lazy_static! {
/// TODO: 优化这里, 自己实现SocketSet! ! ! 现在这样的话, 不管全局有多少个网卡, 每个时间点都只会有1个进程能够访问socket
pub static ref SOCKET_SET : SpinLock < SocketSet < 'static > > = SpinLock ::new ( SocketSet ::new ( vec! [ ] ) ) ;
pub static ref SOCKET_WAITQUEUE : WaitQueue = WaitQueue ::INIT ;
/// 端口管理器
pub static ref PORT_MANAGER : PortManager = PortManager ::new ( ) ;
}
/// @brief TCP 和 UDP 的端口管理器。
/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
pub struct PortManager {
// TCP 端口记录表
tcp_port_table : SpinLock < HashMap < u16 , Arc < GlobalSocketHandle > > > ,
// UDP 端口记录表
udp_port_table : SpinLock < HashMap < u16 , Arc < GlobalSocketHandle > > > ,
}
impl PortManager {
pub fn new ( ) -> Self {
return Self {
tcp_port_table : SpinLock ::new ( HashMap ::new ( ) ) ,
udp_port_table : SpinLock ::new ( HashMap ::new ( ) ) ,
} ;
}
/// @brief 自动分配一个相对应协议中未被使用的PORT, 如果动态端口均已被占用, 返回错误码 EADDRINUSE
pub fn get_ephemeral_port ( & self , socket_type : SocketType ) -> Result < u16 , SystemError > {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT : u16 = 0 ;
unsafe {
if EPHEMERAL_PORT = = 0 {
EPHEMERAL_PORT = ( 49152 + rand ( ) % ( 65536 - 49152 ) ) as u16 ;
}
}
let mut remaining = 65536 - 49152 ; // 剩余尝试分配端口次数
let mut port : u16 ;
while remaining > 0 {
unsafe {
if EPHEMERAL_PORT = = 65535 {
EPHEMERAL_PORT = 49152 ;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1 ;
}
port = EPHEMERAL_PORT ;
}
// 使用 ListenTable 检查端口是否被占用
let listen_table_guard = match socket_type {
SocketType ::UdpSocket = > self . udp_port_table . lock ( ) ,
SocketType ::TcpSocket = > self . tcp_port_table . lock ( ) ,
SocketType ::RawSocket = > todo! ( ) ,
} ;
if let None = listen_table_guard . get ( & port ) {
drop ( listen_table_guard ) ;
return Ok ( port ) ;
}
remaining - = 1 ;
}
return Err ( SystemError ::EADDRINUSE ) ;
}
/// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
///
/// TODO: 增加支持端口复用的逻辑
pub fn get_port (
& self ,
socket_type : SocketType ,
port : u16 ,
handle : Arc < GlobalSocketHandle > ,
) -> Result < ( ) , SystemError > {
if port > 0 {
let mut listen_table_guard = match socket_type {
SocketType ::UdpSocket = > self . udp_port_table . lock ( ) ,
SocketType ::TcpSocket = > self . tcp_port_table . lock ( ) ,
SocketType ::RawSocket = > panic! ( " RawSocket cann't bind a port " ) ,
} ;
match listen_table_guard . get ( & port ) {
Some ( _ ) = > return Err ( SystemError ::EADDRINUSE ) ,
None = > listen_table_guard . insert ( port , handle ) ,
} ;
drop ( listen_table_guard ) ;
}
return Ok ( ( ) ) ;
}
/// @brief 在对应的端口记录表中将端口和 socket 解绑
pub fn unbind_port ( & self , socket_type : SocketType , port : u16 ) -> Result < ( ) , SystemError > {
let mut listen_table_guard = match socket_type {
SocketType ::UdpSocket = > self . udp_port_table . lock ( ) ,
SocketType ::TcpSocket = > self . tcp_port_table . lock ( ) ,
SocketType ::RawSocket = > return Ok ( ( ) ) ,
} ;
listen_table_guard . remove ( & port ) ;
drop ( listen_table_guard ) ;
return Ok ( ( ) ) ;
}
}
/* For setsockopt(2) */
@ -38,8 +133,8 @@ pub const SOL_SOCKET: u8 = 1;
pub struct GlobalSocketHandle ( SocketHandle ) ;
impl GlobalSocketHandle {
pub fn new ( handle : SocketHandle ) -> Self {
Self ( handle )
pub fn new ( handle : SocketHandle ) -> Arc < Self> {
return Arc ::new ( Self ( handle ) ) ;
}
}
@ -59,7 +154,7 @@ impl Drop for GlobalSocketHandle {
}
/// @brief socket的类型
#[ derive(Debug) ]
#[ derive(Debug, Clone, Copy ) ]
pub enum SocketType {
/// 原始的socket
RawSocket ,
@ -86,7 +181,7 @@ bitflags! {
}
}
#[ derive(Debug) ]
#[ derive(Debug, Clone ) ]
/// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
pub struct SocketMetadata {
/// socket的类型
@ -101,18 +196,36 @@ pub struct SocketMetadata {
pub options : SocketOptions ,
}
impl SocketMetadata {
fn new (
socket_type : SocketType ,
send_buf_size : usize ,
recv_buf_size : usize ,
metadata_buf_size : usize ,
options : SocketOptions ,
) -> Self {
Self {
socket_type ,
send_buf_size ,
recv_buf_size ,
metadata_buf_size ,
options ,
}
}
}
/// @brief 表示原始的socket。原始套接字绕过传输层协议( 如 TCP 或 UDP) 并提供对网络层协议( 如 IP) 的直接访问。
///
/// ref: https://man7.org/linux/man-pages/man7/raw.7.html
#[ derive(Debug, Clone) ]
pub struct RawSocket {
handle : GlobalSocketHandle ,
handle : Arc < GlobalSocketHandle> ,
/// 用户发送的数据包是否包含了IP头.
/// 如果是true, 用户发送的数据包, 必须包含IP头。( 即用户要自行设置IP头+数据)
/// 如果是false, 用户发送的数据包, 不包含IP头。( 即用户只要设置数据)
header_included : bool ,
/// socket的选项
options : SocketOptions ,
/// socket的metadata
metadata : SocketMetadata ,
}
impl RawSocket {
@ -147,12 +260,21 @@ impl RawSocket {
) ;
// 把socket添加到socket集合中, 并得到socket的句柄
let handle : GlobalSocketHandle = GlobalSocketHandle ::new ( SOCKET_SET . lock ( ) . add ( socket ) ) ;
let handle : Arc < GlobalSocketHandle > =
GlobalSocketHandle ::new ( SOCKET_SET . lock ( ) . add ( socket ) ) ;
let metadata = SocketMetadata ::new (
SocketType ::RawSocket ,
Self ::DEFAULT_RX_BUF_SIZE ,
Self ::DEFAULT_TX_BUF_SIZE ,
Self ::DEFAULT_METADATA_BUF_SIZE ,
options ,
) ;
return Self {
handle ,
header_included : false ,
options ,
metadata ,
} ;
}
}
@ -177,7 +299,7 @@ impl Socket for RawSocket {
) ;
}
Err ( smoltcp ::socket ::raw ::RecvError ::Exhausted ) = > {
if ! self . options . contains ( SocketOptions ::BLOCK ) {
if ! self . metadata . options. contains ( SocketOptions ::BLOCK ) {
// 如果是非阻塞的socket, 就返回错误
return ( Err ( SystemError ::EAGAIN_OR_EWOULDBLOCK ) , Endpoint ::Ip ( None ) ) ;
}
@ -271,7 +393,7 @@ impl Socket for RawSocket {
}
fn metadata ( & self ) -> Result < SocketMetadata , SystemError > {
todo! ( )
Ok ( self . metadata . clone ( ) )
}
fn box_clone ( & self ) -> alloc ::boxed ::Box < dyn Socket > {
@ -284,9 +406,9 @@ impl Socket for RawSocket {
/// https://man7.org/linux/man-pages/man7/udp.7.html
#[ derive(Debug, Clone) ]
pub struct UdpSocket {
pub handle : GlobalSocketHandle ,
pub handle : Arc < GlobalSocketHandle> ,
remote_endpoint : Option < Endpoint > , // 记录远程endpoint提供给connect(), 应该使用IP地址。
options : SocketOptions ,
metadata : SocketMetadata ,
}
impl UdpSocket {
@ -315,17 +437,29 @@ impl UdpSocket {
let socket = udp ::Socket ::new ( tx_buffer , rx_buffer ) ;
// 把socket添加到socket集合中, 并得到socket的句柄
let handle : GlobalSocketHandle = GlobalSocketHandle ::new ( SOCKET_SET . lock ( ) . add ( socket ) ) ;
let handle : Arc < GlobalSocketHandle > =
GlobalSocketHandle ::new ( SOCKET_SET . lock ( ) . add ( socket ) ) ;
let metadata = SocketMetadata ::new (
SocketType ::UdpSocket ,
Self ::DEFAULT_RX_BUF_SIZE ,
Self ::DEFAULT_TX_BUF_SIZE ,
Self ::DEFAULT_METADATA_BUF_SIZE ,
options ,
) ;
return Self {
handle ,
remote_endpoint : None ,
options ,
metadata ,
} ;
}
fn do_bind ( & self , socket : & mut udp ::Socket , endpoint : Endpoint ) -> Result < ( ) , SystemError > {
if let Endpoint ::Ip ( Some ( ip ) ) = endpoint {
// 检测端口是否已被占用
PORT_MANAGER . get_port ( self . metadata . socket_type , ip . port , self . handle . clone ( ) ) ? ;
let bind_res = if ip . addr . is_unspecified ( ) {
socket . bind ( ip . port )
} else {
@ -388,7 +522,7 @@ impl Socket for UdpSocket {
// kdebug!("is open()={}", socket.is_open());
// kdebug!("socket endpoint={:?}", socket.endpoint());
if socket . endpoint ( ) . port = = 0 {
let temp_port = get_ephemeral_port ( ) ;
let temp_port = PORT_MANAGER . get_ephemeral_port ( self . metadata . socket_type ) ? ;
let local_ep = match remote_endpoint . addr {
// 远程remote endpoint使用什么协议, 发送的时候使用的协议是一样的吧
@ -461,7 +595,7 @@ impl Socket for UdpSocket {
todo! ( )
}
fn metadata ( & self ) -> Result < SocketMetadata , SystemError > {
todo! ( )
Ok ( self . metadata . clone ( ) )
}
fn box_clone ( & self ) -> alloc ::boxed ::Box < dyn Socket > {
@ -499,10 +633,10 @@ impl Socket for UdpSocket {
/// https://man7.org/linux/man-pages/man7/tcp.7.html
#[ derive(Debug, Clone) ]
pub struct TcpSocket {
handle : GlobalSocketHandle ,
handle : Arc < GlobalSocketHandle> ,
local_endpoint : Option < wire ::IpEndpoint > , // save local endpoint for bind()
is_listening : bool ,
options : SocketOptions ,
metadata : SocketMetadata ,
}
impl TcpSocket {
@ -525,13 +659,22 @@ impl TcpSocket {
let socket = tcp ::Socket ::new ( tx_buffer , rx_buffer ) ;
// 把socket添加到socket集合中, 并得到socket的句柄
let handle : GlobalSocketHandle = GlobalSocketHandle ::new ( SOCKET_SET . lock ( ) . add ( socket ) ) ;
let handle : Arc < GlobalSocketHandle > =
GlobalSocketHandle ::new ( SOCKET_SET . lock ( ) . add ( socket ) ) ;
let metadata = SocketMetadata ::new (
SocketType ::TcpSocket ,
Self ::DEFAULT_RX_BUF_SIZE ,
Self ::DEFAULT_TX_BUF_SIZE ,
Self ::DEFAULT_METADATA_BUF_SIZE ,
options ,
) ;
return Self {
handle ,
local_endpoint : None ,
is_listening : false ,
options ,
metadata ,
} ;
}
fn do_listen (
@ -546,7 +689,7 @@ impl TcpSocket {
// kdebug!("Tcp Socket Listen on {local_endpoint}");
socket . listen ( local_endpoint )
} ;
// todo : 增加端口占用检查
// TODO : 增加端口占用检查
return match listen_result {
Ok ( ( ) ) = > {
// kdebug!(
@ -668,7 +811,7 @@ impl Socket for TcpSocket {
let socket = sockets . get_mut ::< tcp ::Socket > ( self . handle . 0 ) ;
if let Endpoint ::Ip ( Some ( ip ) ) = endpoint {
let temp_port = get_ephemeral_port ( ) ;
let temp_port = PORT_MANAGER . get_ephemeral_port ( self . metadata . socket_type ) ? ;
// kdebug!("temp_port: {}", temp_port);
let iface : Arc < dyn NetDriver > = NET_DRIVERS . write ( ) . get ( & 0 ) . unwrap ( ) . clone ( ) ;
let mut inner_iface = iface . inner_iface ( ) . lock ( ) ;
@ -737,9 +880,12 @@ impl Socket for TcpSocket {
fn bind ( & mut self , endpoint : Endpoint ) -> Result < ( ) , SystemError > {
if let Endpoint ::Ip ( Some ( mut ip ) ) = endpoint {
if ip . port = = 0 {
ip . port = get_ephemeral_port ( ) ;
ip . port = PORT_MANAGER . get_ephemeral_port ( self . metadata . socket_type ) ? ;
}
// 检测端口是否已被占用
PORT_MANAGER . get_port ( self . metadata . socket_type , ip . port , self . handle . clone ( ) ) ? ;
self . local_endpoint = Some ( ip ) ;
self . is_listening = false ;
return Ok ( ( ) ) ;
@ -785,11 +931,19 @@ impl Socket for TcpSocket {
let new_handle = GlobalSocketHandle ::new ( sockets . add ( tcp_socket ) ) ;
let old_handle = ::core ::mem ::replace ( & mut self . handle , new_handle ) ;
let metadata = SocketMetadata {
socket_type : SocketType ::TcpSocket ,
send_buf_size : Self ::DEFAULT_RX_BUF_SIZE ,
recv_buf_size : Self ::DEFAULT_TX_BUF_SIZE ,
metadata_buf_size : Self ::DEFAULT_METADATA_BUF_SIZE ,
options : self . metadata . options ,
} ;
Box ::new ( TcpSocket {
handle : old_handle ,
local_endpoint : self . local_endpoint ,
is_listening : false ,
options : self . options ,
metadata ,
} )
} ;
// kdebug!("tcp accept: new socket: {:?}", new_socket);
@ -825,7 +979,7 @@ impl Socket for TcpSocket {
}
fn metadata ( & self ) -> Result < SocketMetadata , SystemError > {
todo! ( )
Ok ( self . metadata . clone ( ) )
}
fn box_clone ( & self ) -> alloc ::boxed ::Box < dyn Socket > {
@ -833,26 +987,6 @@ impl Socket for TcpSocket {
}
}
/// @breif 自动分配一个未被使用的PORT
///
/// TODO: 增加ListenTable, 用于检查端口是否被占用
pub fn get_ephemeral_port ( ) -> u16 {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT : u16 = 0 ;
unsafe {
if EPHEMERAL_PORT = = 0 {
EPHEMERAL_PORT = ( 49152 + rand ( ) % ( 65536 - 49152 ) ) as u16 ;
}
if EPHEMERAL_PORT = = 65535 {
EPHEMERAL_PORT = 49152 ;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1 ;
}
EPHEMERAL_PORT
}
}
/// @brief 地址族的枚举
///
/// 参考: https://opengrok.ringotek.cn/xref/linux-5.19.10/include/linux/socket.h#180
@ -1012,6 +1146,10 @@ impl IndexNode for SocketInode {
& self ,
_data : & mut crate ::filesystem ::vfs ::FilePrivateData ,
) -> Result < ( ) , SystemError > {
let socket = self . 0. lock ( ) ;
if let Some ( Endpoint ::Ip ( Some ( ip ) ) ) = socket . endpoint ( ) {
PORT_MANAGER . unbind_port ( socket . metadata ( ) . unwrap ( ) . socket_type , ip . port ) ? ;
}
return Ok ( ( ) ) ;
}