feat(net): 实现unix抽象地址空间 (#1017)

This commit is contained in:
Cai Junyuan 2024-10-28 20:29:08 +08:00 committed by GitHub
parent 8189cb1771
commit fad1c09757
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 455 additions and 8 deletions

View File

@ -35,8 +35,10 @@ impl PosixArgsSocketType {
} }
} }
use alloc::string::String;
use alloc::sync::Arc; use alloc::sync::Arc;
use core::ffi::CStr; use core::ffi::CStr;
use unix::ns::abs::{alloc_abs_addr, look_up_abs_addr};
use crate::{ use crate::{
filesystem::vfs::{FileType, IndexNode, ROOT_INODE, VFS_MAX_FOLLOW_SYMLINK_TIMES}, filesystem::vfs::{FileType, IndexNode, ROOT_INODE, VFS_MAX_FOLLOW_SYMLINK_TIMES},
@ -45,7 +47,7 @@ use crate::{
process::ProcessManager, process::ProcessManager,
}; };
use smoltcp; use smoltcp;
use system_error::SystemError; use system_error::SystemError::{self};
// 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32 // 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32
#[repr(C)] #[repr(C)]
@ -146,6 +148,42 @@ impl SockAddr {
AddressFamily::Unix => { AddressFamily::Unix => {
let addr_un: SockAddrUn = addr.addr_un; let addr_un: SockAddrUn = addr.addr_un;
if addr_un.sun_path[0] == 0 {
// 抽象地址空间,与文件系统没有关系
let path = CStr::from_bytes_until_nul(&addr_un.sun_path[1..])
.map_err(|_| {
log::error!("CStr::from_bytes_until_nul fail");
SystemError::EINVAL
})?
.to_str()
.map_err(|_| {
log::error!("CStr::to_str fail");
SystemError::EINVAL
})?;
// 向抽象地址管理器申请或查找抽象地址
let spath = String::from(path);
log::debug!("abs path: {}", spath);
let abs_find = match look_up_abs_addr(&spath) {
Ok(result) => result,
Err(_) => {
//未找到尝试分配abs
match alloc_abs_addr(spath.clone()) {
Ok(result) => {
log::debug!("alloc abs addr success!");
return Ok(result);
}
Err(e) => {
log::debug!("alloc abs addr failed!");
return Err(e);
}
};
}
};
log::debug!("find alloc abs addr success!");
return Ok(abs_find);
}
let path = CStr::from_bytes_until_nul(&addr_un.sun_path) let path = CStr::from_bytes_until_nul(&addr_un.sun_path)
.map_err(|_| { .map_err(|_| {
log::error!("CStr::from_bytes_until_nul fail"); log::error!("CStr::from_bytes_until_nul fail");

View File

@ -3,6 +3,8 @@ use alloc::{string::String, sync::Arc};
pub use smoltcp::wire::IpEndpoint; pub use smoltcp::wire::IpEndpoint;
use super::unix::ns::abs::AbsHandle;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Endpoint { pub enum Endpoint {
/// 链路层端点 /// 链路层端点
@ -13,6 +15,8 @@ pub enum Endpoint {
Inode((Arc<socket::Inode>, String)), Inode((Arc<socket::Inode>, String)),
/// Unix传递id索引和path所用的端点 /// Unix传递id索引和path所用的端点
Unixpath((InodeId, String)), Unixpath((InodeId, String)),
/// Unix抽象端点
Abspath((AbsHandle, String)),
} }
/// @brief 链路层端点 /// @brief 链路层端点

View File

@ -1,5 +1,6 @@
pub mod ns;
pub(crate) mod seqpacket; pub(crate) mod seqpacket;
mod stream; pub mod stream;
use crate::{filesystem::vfs::InodeId, libs::rwlock::RwLock, net::socket::*}; use crate::{filesystem::vfs::InodeId, libs::rwlock::RwLock, net::socket::*};
use alloc::sync::Arc; use alloc::sync::Arc;
use hashbrown::HashMap; use hashbrown::HashMap;

View File

@ -0,0 +1,172 @@
use core::fmt;
use crate::libs::spinlock::SpinLock;
use crate::net::socket::Endpoint;
use alloc::string::String;
use hashbrown::HashMap;
use ida::IdAllocator;
use system_error::SystemError;
lazy_static! {
pub static ref ABSHANDLE_MAP: AbsHandleMap = AbsHandleMap::new();
}
lazy_static! {
pub static ref ABS_INODE_MAP: SpinLock<HashMap<usize, Endpoint>> =
SpinLock::new(HashMap::new());
}
static ABS_ADDRESS_ALLOCATOR: SpinLock<IdAllocator> =
SpinLock::new(IdAllocator::new(0, (1 << 20) as usize).unwrap());
#[derive(Debug, Clone)]
pub struct AbsHandle(usize);
impl AbsHandle {
pub fn new(name: usize) -> Self {
Self(name)
}
pub fn name(&self) -> usize {
self.0
}
}
impl fmt::Display for AbsHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:05x}", self.0)
}
}
/// 抽象地址映射表
///
/// 负责管理抽象命名空间内的地址
pub struct AbsHandleMap {
abs_handle_map: SpinLock<HashMap<String, Endpoint>>,
}
impl AbsHandleMap {
pub fn new() -> Self {
Self {
abs_handle_map: SpinLock::new(HashMap::new()),
}
}
/// 插入新的地址映射
pub fn insert(&self, name: String) -> Result<Endpoint, SystemError> {
let mut guard = self.abs_handle_map.lock();
//检查name是否被占用
if guard.contains_key(&name) {
return Err(SystemError::ENOMEM);
}
let ads_addr = match self.alloc(name.clone()) {
Some(addr) => addr.clone(),
None => return Err(SystemError::ENOMEM),
};
guard.insert(name, ads_addr.clone());
return Ok(ads_addr);
}
/// 抽象空间地址分配器
///
/// ## 返回
///
/// 分配到的可用的抽象端点
pub fn alloc(&self, name: String) -> Option<Endpoint> {
let abs_addr = match ABS_ADDRESS_ALLOCATOR.lock().alloc() {
Some(addr) => addr,
//地址被分配
None => return None,
};
let result = Some(Endpoint::Abspath((AbsHandle::new(abs_addr), name)));
return result;
}
/// 进行地址映射
///
/// ## 参数
///
/// name用户定义的地址
pub fn look_up(&self, name: &String) -> Option<Endpoint> {
let guard = self.abs_handle_map.lock();
return guard.get(name).cloned();
}
/// 移除绑定的地址
///
/// ## 参数
///
/// name待删除的地址
pub fn remove(&self, name: &String) -> Result<(), SystemError> {
let abs_addr = match look_up_abs_addr(name) {
Ok(result) => match result {
Endpoint::Abspath((abshandle, _)) => abshandle.name(),
_ => return Err(SystemError::EINVAL),
},
Err(_) => return Err(SystemError::EINVAL),
};
//释放abs地址分配实例
ABS_ADDRESS_ALLOCATOR.lock().free(abs_addr);
//释放entry
let mut guard = self.abs_handle_map.lock();
guard.remove(name);
Ok(())
}
}
/// 分配抽象地址
///
/// ## 返回
///
/// 分配到的抽象地址
pub fn alloc_abs_addr(name: String) -> Result<Endpoint, SystemError> {
ABSHANDLE_MAP.insert(name)
}
/// 查找抽象地址
///
/// ## 参数
///
/// name用户socket字符地址
///
/// ## 返回
///
/// 查询到的抽象地址
pub fn look_up_abs_addr(name: &String) -> Result<Endpoint, SystemError> {
match ABSHANDLE_MAP.look_up(name) {
Some(result) => return Ok(result),
None => return Err(SystemError::EINVAL),
}
}
/// 删除抽象地址
///
/// ## 参数
/// name待删除的地址
///
/// ## 返回
/// 删除的抽象地址
pub fn remove_abs_addr(name: &String) -> Result<(), SystemError> {
let abs_addr = match look_up_abs_addr(name) {
Ok(addr) => match addr {
Endpoint::Abspath((addr, _)) => addr,
_ => return Err(SystemError::EINVAL),
},
Err(_) => return Err(SystemError::EINVAL),
};
match ABS_INODE_MAP.lock_irqsave().remove(&abs_addr.name()) {
Some(_) => log::debug!("free abs inode"),
None => log::debug!("not free abs inode"),
}
ABSHANDLE_MAP.remove(name)?;
log::debug!("free abs!");
Ok(())
}

View File

@ -0,0 +1 @@
pub mod abs;

View File

@ -4,6 +4,7 @@ use alloc::{
sync::{Arc, Weak}, sync::{Arc, Weak},
}; };
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use unix::ns::abs::{remove_abs_addr, ABS_INODE_MAP};
use crate::sched::SchedMode; use crate::sched::SchedMode;
use crate::{libs::rwlock::RwLock, net::socket::*}; use crate::{libs::rwlock::RwLock, net::socket::*};
@ -136,6 +137,23 @@ impl Socket for SeqpacketSocket {
_ => return Err(SystemError::EINVAL), _ => return Err(SystemError::EINVAL),
} }
} }
Endpoint::Abspath((abs_addr, _)) => {
let inode_guard = ABS_INODE_MAP.lock_irqsave();
let inode = match inode_guard.get(&abs_addr.name()) {
Some(inode) => inode,
None => {
log::debug!("can not find inode from absInodeMap");
return Err(SystemError::EINVAL);
}
};
match inode {
Endpoint::Inode((inode, _)) => inode.clone(),
_ => {
log::debug!("when connect, find inode failed!");
return Err(SystemError::EINVAL);
}
}
}
_ => return Err(SystemError::EINVAL), _ => return Err(SystemError::EINVAL),
}; };
// 远端为服务端 // 远端为服务端
@ -197,6 +215,17 @@ impl Socket for SeqpacketSocket {
INODE_MAP.write_irqsave().insert(inodeid, inode); INODE_MAP.write_irqsave().insert(inodeid, inode);
Ok(()) Ok(())
} }
Endpoint::Abspath((abshandle, path)) => {
let inode = match &mut *self.inner.write() {
Inner::Init(init) => init.bind_path(path)?,
_ => {
log::error!("socket has listen or connected");
return Err(SystemError::EINVAL);
}
};
ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode);
Ok(())
}
_ => return Err(SystemError::EINVAL), _ => return Err(SystemError::EINVAL),
} }
} }
@ -260,7 +289,21 @@ impl Socket for SeqpacketSocket {
// log::debug!("seqpacket close"); // log::debug!("seqpacket close");
self.shutdown.recv_shutdown(); self.shutdown.recv_shutdown();
self.shutdown.send_shutdown(); self.shutdown.send_shutdown();
Ok(())
let path = match self.get_name()? {
Endpoint::Inode((_, path)) => path,
_ => return Err(SystemError::EINVAL),
};
//如果path是空的说明没有bind不用释放相关映射资源
if path.is_empty() {
return Ok(());
}
// TODO: 释放INODE_MAP相关资源
// 尝试释放相关抽象地址资源
let _ = remove_abs_addr(&path);
return Ok(());
} }
fn get_peer_name(&self) -> Result<Endpoint, SystemError> { fn get_peer_name(&self) -> Result<Endpoint, SystemError> {

View File

@ -49,8 +49,9 @@ impl Init {
} }
if let Some(Endpoint::Inode((inode, mut path))) = self.addr.take() { if let Some(Endpoint::Inode((inode, mut path))) = self.addr.take() {
path = sun_path; path = sun_path;
let epoint = Endpoint::Inode((inode, path)); let epoint = Endpoint::Inode((inode, path.clone()));
self.addr.replace(epoint.clone()); self.addr.replace(epoint.clone());
log::debug!("bind path in inode : {:?}", path);
return Ok(epoint); return Ok(epoint);
}; };

View File

@ -6,7 +6,10 @@ use alloc::{
use inner::{Connected, Init, Inner, Listener}; use inner::{Connected, Init, Inner, Listener};
use log::debug; use log::debug;
use system_error::SystemError; use system_error::SystemError;
use unix::INODE_MAP; use unix::{
ns::abs::{remove_abs_addr, ABSHANDLE_MAP, ABS_INODE_MAP},
INODE_MAP,
};
use crate::{ use crate::{
libs::rwlock::RwLock, libs::rwlock::RwLock,
@ -157,6 +160,23 @@ impl Socket for StreamSocket {
_ => return Err(SystemError::EINVAL), _ => return Err(SystemError::EINVAL),
} }
} }
Endpoint::Abspath((abs_addr, path)) => {
let inode_guard = ABS_INODE_MAP.lock_irqsave();
let inode = match inode_guard.get(&abs_addr.name()) {
Some(inode) => inode,
None => {
log::debug!("can not find inode from absInodeMap");
return Err(SystemError::EINVAL);
}
};
match inode {
Endpoint::Inode((inode, _)) => (inode.clone(), path),
_ => {
debug!("when connect, find inode failed!");
return Err(SystemError::EINVAL);
}
}
}
_ => return Err(SystemError::EINVAL), _ => return Err(SystemError::EINVAL),
}; };
@ -200,6 +220,17 @@ impl Socket for StreamSocket {
INODE_MAP.write_irqsave().insert(inodeid, inode); INODE_MAP.write_irqsave().insert(inodeid, inode);
Ok(()) Ok(())
} }
Endpoint::Abspath((abshandle, path)) => {
let inode = match &mut *self.inner.write() {
Inner::Init(init) => init.bind_path(path)?,
_ => {
log::error!("socket has listen or connected");
return Err(SystemError::EINVAL);
}
};
ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode);
Ok(())
}
_ => return Err(SystemError::EINVAL), _ => return Err(SystemError::EINVAL),
} }
} }
@ -290,7 +321,21 @@ impl Socket for StreamSocket {
fn close(&self) -> Result<(), SystemError> { fn close(&self) -> Result<(), SystemError> {
self.shutdown.recv_shutdown(); self.shutdown.recv_shutdown();
self.shutdown.send_shutdown(); self.shutdown.send_shutdown();
Ok(())
let path = match self.get_name()? {
Endpoint::Inode((_, path)) => path,
_ => return Err(SystemError::EINVAL),
};
//如果path是空的说明没有bind不用释放相关映射资源
if path.is_empty() {
return Ok(());
}
// TODO: 释放INODE_MAP相关资源
// 尝试释放相关抽象地址资源
let _ = remove_abs_addr(&path);
return Ok(());
} }
fn get_peer_name(&self) -> Result<Endpoint, SystemError> { fn get_peer_name(&self) -> Result<Endpoint, SystemError> {

View File

@ -5,7 +5,8 @@ use std::io::Error;
use std::mem; use std::mem;
use std::os::fd::RawFd; use std::os::fd::RawFd;
const SOCKET_PATH: &str = "/test.stream"; const SOCKET_PATH: &str = "./test.stream";
const SOCKET_ABSTRUCT_PATH: &str = "/abs.stream";
const MSG1: &str = "Hello, unix stream socket from Client!"; const MSG1: &str = "Hello, unix stream socket from Client!";
const MSG2: &str = "Hello, unix stream socket from Server!"; const MSG2: &str = "Hello, unix stream socket from Server!";
@ -44,6 +45,32 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn bind_abstruct_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
addr.sun_path[0] = 0;
let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i + 1] = byte as i8;
}
if bind(
fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error());
}
}
Ok(())
}
fn listen_socket(fd: RawFd) -> Result<(), Error> { fn listen_socket(fd: RawFd) -> Result<(), Error> {
unsafe { unsafe {
if listen(fd, 5) == -1 { if listen(fd, 5) == -1 {
@ -111,7 +138,7 @@ fn test_stream() -> Result<(), Error> {
send_message(client_fd, MSG2).expect("Failed to send message"); send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish"); println!("Server send finish");
unsafe { close(client_fd) }; unsafe { close(server_fd) };
}); });
let client_fd = create_stream_socket()?; let client_fd = create_stream_socket()?;
@ -173,9 +200,124 @@ fn test_stream() -> Result<(), Error> {
Ok(()) Ok(())
} }
fn test_abstruct_namespace() -> Result<(), Error> {
let server_fd = create_stream_socket()?;
bind_abstruct_socket(server_fd)?;
listen_socket(server_fd)?;
let server_thread = std::thread::spawn(move || {
let client_fd = accept_conn(server_fd).expect("Failed to accept connection");
println!("accept success!");
let recv_msg = recv_message(client_fd).expect("Failed to receive message");
println!("Server: Received message: {}", recv_msg);
send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish");
unsafe { close(server_fd) }
});
let client_fd = create_stream_socket()?;
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
addr.sun_path[0] = 0;
let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i + 1] = byte as i8;
}
if connect(
client_fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error());
}
}
send_message(client_fd, MSG1)?;
// get peer_name
unsafe {
let mut addrss = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let mut len = mem::size_of_val(&addrss) as socklen_t;
let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
if res == -1 {
return Err(Error::last_os_error());
}
let sun_path = addrss.sun_path.clone();
let peer_path: [u8; 108] = sun_path
.iter()
.map(|&x| x as u8)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
println!(
"Client: Connected to server at path: {}",
String::from_utf8_lossy(&peer_path)
);
}
server_thread.join().expect("Server thread panicked");
println!("Client try recv!");
let recv_msg = recv_message(client_fd).expect("Failed to receive message from server");
println!("Client Received message: {}", recv_msg);
unsafe { close(client_fd) };
Ok(())
}
fn test_recourse_free() -> Result<(), Error> {
let client_fd = create_stream_socket()?;
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
addr.sun_path[0] = 0;
let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i + 1] = byte as i8;
}
if connect(
client_fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error());
}
}
send_message(client_fd, MSG1)?;
unsafe { close(client_fd) };
Ok(())
}
fn main() { fn main() {
match test_stream() { match test_stream() {
Ok(_) => println!("test for unix stream success"), Ok(_) => println!("test for unix stream success"),
Err(_) => println!("test for unix stream failed"), Err(_) => println!("test for unix stream failed"),
} }
match test_abstruct_namespace() {
Ok(_) => println!("test for unix abstruct namespace success"),
Err(_) => println!("test for unix abstruct namespace failed"),
}
match test_recourse_free() {
Ok(_) => println!("not free!"),
Err(_) => println!("free!"),
}
} }