feat(net): 实现unix抽象地址空间 (#1017)

This commit is contained in:
Cai Junyuan
2024-10-28 20:29:08 +08:00
committed by GitHub
parent 8189cb1771
commit fad1c09757
9 changed files with 455 additions and 8 deletions

View File

@ -5,7 +5,8 @@ use std::io::Error;
use std::mem;
use std::os::fd::RawFd;
const SOCKET_PATH: &str = "/test.stream";
const SOCKET_PATH: &str = "./test.stream";
const SOCKET_ABSTRUCT_PATH: &str = "/abs.stream";
const MSG1: &str = "Hello, unix stream socket from Client!";
const MSG2: &str = "Hello, unix stream socket from Server!";
@ -44,6 +45,32 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
Ok(())
}
fn bind_abstruct_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
addr.sun_path[0] = 0;
let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i + 1] = byte as i8;
}
if bind(
fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error());
}
}
Ok(())
}
fn listen_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
if listen(fd, 5) == -1 {
@ -111,7 +138,7 @@ fn test_stream() -> Result<(), Error> {
send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish");
unsafe { close(client_fd) };
unsafe { close(server_fd) };
});
let client_fd = create_stream_socket()?;
@ -173,9 +200,124 @@ fn test_stream() -> Result<(), Error> {
Ok(())
}
fn test_abstruct_namespace() -> Result<(), Error> {
let server_fd = create_stream_socket()?;
bind_abstruct_socket(server_fd)?;
listen_socket(server_fd)?;
let server_thread = std::thread::spawn(move || {
let client_fd = accept_conn(server_fd).expect("Failed to accept connection");
println!("accept success!");
let recv_msg = recv_message(client_fd).expect("Failed to receive message");
println!("Server: Received message: {}", recv_msg);
send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish");
unsafe { close(server_fd) }
});
let client_fd = create_stream_socket()?;
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
addr.sun_path[0] = 0;
let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i + 1] = byte as i8;
}
if connect(
client_fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error());
}
}
send_message(client_fd, MSG1)?;
// get peer_name
unsafe {
let mut addrss = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let mut len = mem::size_of_val(&addrss) as socklen_t;
let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
if res == -1 {
return Err(Error::last_os_error());
}
let sun_path = addrss.sun_path.clone();
let peer_path: [u8; 108] = sun_path
.iter()
.map(|&x| x as u8)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
println!(
"Client: Connected to server at path: {}",
String::from_utf8_lossy(&peer_path)
);
}
server_thread.join().expect("Server thread panicked");
println!("Client try recv!");
let recv_msg = recv_message(client_fd).expect("Failed to receive message from server");
println!("Client Received message: {}", recv_msg);
unsafe { close(client_fd) };
Ok(())
}
fn test_recourse_free() -> Result<(), Error> {
let client_fd = create_stream_socket()?;
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
addr.sun_path[0] = 0;
let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i + 1] = byte as i8;
}
if connect(
client_fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error());
}
}
send_message(client_fd, MSG1)?;
unsafe { close(client_fd) };
Ok(())
}
fn main() {
match test_stream() {
Ok(_) => println!("test for unix stream success"),
Err(_) => println!("test for unix stream failed"),
}
match test_abstruct_namespace() {
Ok(_) => println!("test for unix abstruct namespace success"),
Err(_) => println!("test for unix abstruct namespace failed"),
}
match test_recourse_free() {
Ok(_) => println!("not free!"),
Err(_) => println!("free!"),
}
}