From 8fe49e190e0e15bbb9ca9dd13d5e8f329a875492 Mon Sep 17 00:00:00 2001 From: Samuka007 Date: Mon, 14 Oct 2024 12:11:27 +0000 Subject: [PATCH] make fmt --- kernel/src/arch/x86_64/syscall/mod.rs | 2 +- kernel/src/driver/base/block/block_device.rs | 24 +-- kernel/src/driver/base/uevent/mod.rs | 13 +- kernel/src/driver/net/mod.rs | 12 +- kernel/src/net/socket/inet/stream/inner.rs | 9 +- kernel/src/net/socket/inet/stream/mod.rs | 24 +-- kernel/src/net/socket/unix/seqpacket/inner.rs | 10 +- kernel/src/net/socket/unix/seqpacket/mod.rs | 10 +- kernel/src/net/socket/unix/stream/mod.rs | 5 +- kernel/src/net/syscall_util.rs | 2 +- user/apps/ping/src/ping.rs | 2 +- user/apps/test-uevent/src/main.rs | 39 ++++- user/apps/test_seqpacket/src/main.rs | 6 +- user/apps/test_seqpacket/src/seq_pair.rs | 9 +- user/apps/test_seqpacket/src/seq_socket.rs | 160 ++++++++++-------- user/apps/test_unix_stream_socket/src/main.rs | 64 +++++-- 16 files changed, 235 insertions(+), 156 deletions(-) diff --git a/kernel/src/arch/x86_64/syscall/mod.rs b/kernel/src/arch/x86_64/syscall/mod.rs index e977d0b8..d9468df6 100644 --- a/kernel/src/arch/x86_64/syscall/mod.rs +++ b/kernel/src/arch/x86_64/syscall/mod.rs @@ -133,7 +133,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) { show &= false; } } - show = false; + show &= false; if show { debug!("[SYS] [Pid: {:?}] [Call: {:?}]", pid, to_print); } diff --git a/kernel/src/driver/base/block/block_device.rs b/kernel/src/driver/base/block/block_device.rs index 623b5008..369d9bc1 100644 --- a/kernel/src/driver/base/block/block_device.rs +++ b/kernel/src/driver/base/block/block_device.rs @@ -1,16 +1,20 @@ /// 引入Module -use crate::{driver::{ - base::{ - device::{ - device_number::{DeviceNumber, Major}, Device, DeviceError, IdTable, BLOCKDEVS - }, - map::{ - DeviceStruct, DEV_MAJOR_DYN_END, DEV_MAJOR_DYN_EXT_END, DEV_MAJOR_DYN_EXT_START, - DEV_MAJOR_HASH_SIZE, DEV_MAJOR_MAX, +use crate::{ + driver::{ + base::{ + device::{ + device_number::{DeviceNumber, Major}, + Device, DeviceError, IdTable, BLOCKDEVS, + }, + map::{ + DeviceStruct, DEV_MAJOR_DYN_END, DEV_MAJOR_DYN_EXT_END, DEV_MAJOR_DYN_EXT_START, + DEV_MAJOR_HASH_SIZE, DEV_MAJOR_MAX, + }, }, + block::cache::{cached_block_device::BlockCache, BlockCacheError, BLOCK_SIZE}, }, - block::cache::{cached_block_device::BlockCache, BlockCacheError, BLOCK_SIZE}, -}, filesystem::sysfs::AttributeGroup}; + filesystem::sysfs::AttributeGroup, +}; use alloc::{string::String, sync::Arc, vec::Vec}; use core::{any::Any, fmt::Display, ops::Deref}; diff --git a/kernel/src/driver/base/uevent/mod.rs b/kernel/src/driver/base/uevent/mod.rs index e85c1df1..442551b8 100644 --- a/kernel/src/driver/base/uevent/mod.rs +++ b/kernel/src/driver/base/uevent/mod.rs @@ -157,10 +157,14 @@ impl Attribute for UeventAttr { writeln!(&mut uevent_content, "DEVTYPE=char").unwrap(); } DeviceType::Net => { - let net_device = device.clone().cast::().map_err(|e: Arc| { - warn!("device:{:?} is not a net device!", e); - SystemError::EINVAL - })?; + let net_device = + device + .clone() + .cast::() + .map_err(|e: Arc| { + warn!("device:{:?} is not a net device!", e); + SystemError::EINVAL + })?; let iface_id = net_device.nic_id(); let device_name = device.name(); writeln!(&mut uevent_content, "INTERFACE={}", device_name).unwrap(); @@ -200,7 +204,6 @@ impl Attribute for UeventAttr { } } - /// 将设备的基本信息写入 uevent 文件 fn sysfs_emit_str(buf: &mut [u8], content: &str) -> Result { log::info!("sysfs_emit_str"); diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index f11ffffc..bf57a885 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -255,12 +255,12 @@ impl IfaceCommon { } }); - // let closed_sockets = self - // .closing_sockets - // .lock_irq_disabled() - // .extract_if(|closing_socket| closing_socket.is_closed()) - // .collect::>(); - // drop(closed_sockets); + // let closed_sockets = self + // .closing_sockets + // .lock_irq_disabled() + // .extract_if(|closing_socket| closing_socket.is_closed()) + // .collect::>(); + // drop(closed_sockets); // } } diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs index badf900b..1292e907 100644 --- a/kernel/src/net/socket/inet/stream/inner.rs +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -3,6 +3,7 @@ use core::sync::atomic::{AtomicU32, AtomicUsize}; use crate::libs::rwlock::RwLock; use crate::net::socket::EPollEventType; use crate::net::socket::{self, inet::Types}; +use alloc::boxed::Box; use alloc::vec::Vec; use smoltcp; use system_error::SystemError::{self, *}; @@ -30,13 +31,13 @@ where #[derive(Debug)] pub enum Init { - Unbound(smoltcp::socket::tcp::Socket<'static>), + Unbound(Box>), Bound((socket::inet::BoundInner, smoltcp::wire::IpEndpoint)), } impl Init { pub(super) fn new() -> Self { - Init::Unbound(new_smoltcp_socket()) + Init::Unbound(Box::new(new_smoltcp_socket())) } /// 传入一个已经绑定的socket @@ -55,7 +56,7 @@ impl Init { ) -> Result { match self { Init::Unbound(socket) => { - let bound = socket::inet::BoundInner::bind(socket, &local_endpoint.addr)?; + let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr)?; bound .port_manager() .bind_port(Types::Tcp, local_endpoint.port)?; @@ -73,7 +74,7 @@ impl Init { match self { Init::Unbound(socket) => { let (bound, address) = - socket::inet::BoundInner::bind_ephemeral(socket, remote_endpoint.addr) + socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr) .map_err(|err| (Self::new(), err))?; let bound_port = bound .port_manager() diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index 609be573..8f8cc745 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -185,15 +185,19 @@ impl TcpSocket { } pub fn try_recv(&self, buf: &mut [u8]) -> Result { - self.inner.read().as_ref().map(|inner| { - inner.iface().unwrap().poll(); - let result = match inner { - Inner::Established(inner) => inner.recv_slice(buf), - _ => Err(EINVAL), - }; - inner.iface().unwrap().poll(); - result - }).unwrap() + self.inner + .read() + .as_ref() + .map(|inner| { + inner.iface().unwrap().poll(); + let result = match inner { + Inner::Established(inner) => inner.recv_slice(buf), + _ => Err(EINVAL), + }; + inner.iface().unwrap().poll(); + result + }) + .unwrap() } pub fn try_send(&self, buf: &[u8]) -> Result { @@ -238,7 +242,7 @@ impl Socket for TcpSocket { fn get_name(&self) -> Result { match self.inner.read().as_ref().expect("Tcp Inner is None") { Inner::Init(Init::Unbound(_)) => Ok(Endpoint::Ip(UNSPECIFIED_LOCAL_ENDPOINT)), - Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(local.clone())), + Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)), Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())), Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())), Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())), diff --git a/kernel/src/net/socket/unix/seqpacket/inner.rs b/kernel/src/net/socket/unix/seqpacket/inner.rs index c43921ec..9875d18e 100644 --- a/kernel/src/net/socket/unix/seqpacket/inner.rs +++ b/kernel/src/net/socket/unix/seqpacket/inner.rs @@ -62,11 +62,7 @@ pub(super) struct Listener { impl Listener { pub(super) fn new(inode: Endpoint, backlog: usize) -> Self { log::debug!("backlog {}", backlog); - let back = if backlog > 1024 { - 1024 as usize - } else { - backlog - }; + let back = if backlog > 1024 { 1024_usize } else { backlog }; return Self { inode, backlog: AtomicUsize::new(back), @@ -82,7 +78,7 @@ impl Listener { log::debug!(" incom len {}", incoming_conns.len()); let conn = incoming_conns .pop_front() - .ok_or_else(|| SystemError::EAGAIN_OR_EWOULDBLOCK)?; + .ok_or(SystemError::EAGAIN_OR_EWOULDBLOCK)?; let socket = Arc::downcast::(conn.inner()).map_err(|_| SystemError::EINVAL)?; let peer = match &*socket.inner.read() { @@ -190,7 +186,7 @@ impl Connected { if self.can_send()? { return self.send_slice(buf); } else { - log::debug!("can not send {:?}", String::from_utf8_lossy(&buf[..])); + log::debug!("can not send {:?}", String::from_utf8_lossy(buf)); return Err(SystemError::ENOBUFS); } } diff --git a/kernel/src/net/socket/unix/seqpacket/mod.rs b/kernel/src/net/socket/unix/seqpacket/mod.rs index a1d29788..ada152fa 100644 --- a/kernel/src/net/socket/unix/seqpacket/mod.rs +++ b/kernel/src/net/socket/unix/seqpacket/mod.rs @@ -230,11 +230,7 @@ impl Socket for SeqpacketSocket { if !self.is_nonblocking() { loop { wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?; - match self - .try_accept() - .map(|(seqpacket_socket, remote_endpoint)| { - (seqpacket_socket, Endpoint::from(remote_endpoint)) - }) { + match self.try_accept() { Ok((socket, epoint)) => return Ok((socket, epoint)), Err(_) => continue, } @@ -274,7 +270,7 @@ impl Socket for SeqpacketSocket { }; if let Some(endpoint) = endpoint { - return Ok(Endpoint::from(endpoint)); + return Ok(endpoint); } else { return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); } @@ -289,7 +285,7 @@ impl Socket for SeqpacketSocket { }; if let Some(endpoint) = endpoint { - return Ok(Endpoint::from(endpoint)); + return Ok(endpoint); } else { return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); } diff --git a/kernel/src/net/socket/unix/stream/mod.rs b/kernel/src/net/socket/unix/stream/mod.rs index 2978a0b1..0ee57d1b 100644 --- a/kernel/src/net/socket/unix/stream/mod.rs +++ b/kernel/src/net/socket/unix/stream/mod.rs @@ -231,10 +231,7 @@ impl Socket for StreamSocket { //目前只实现了阻塞式实现 loop { wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?; - match self - .try_accept() - .map(|(stream_socket, remote_endpoint)| (stream_socket, remote_endpoint)) - { + match self.try_accept() { Ok((socket, endpoint)) => { debug!("server accept!:{:?}", endpoint); return Ok((socket, endpoint)); diff --git a/kernel/src/net/syscall_util.rs b/kernel/src/net/syscall_util.rs index 298ec84e..41d521fb 100644 --- a/kernel/src/net/syscall_util.rs +++ b/kernel/src/net/syscall_util.rs @@ -312,7 +312,7 @@ impl From for SockAddr { } let addr_un = SockAddrUn { sun_family: AddressFamily::Unix as u16, - sun_path: sun_path, + sun_path, }; return SockAddr { addr_un }; } diff --git a/user/apps/ping/src/ping.rs b/user/apps/ping/src/ping.rs index 7af719bc..a17881dc 100644 --- a/user/apps/ping/src/ping.rs +++ b/user/apps/ping/src/ping.rs @@ -101,7 +101,7 @@ impl Ping { for i in 0..this.config.count { let _this = this.clone(); - let handle = thread::spawn(move||{ + let handle = thread::spawn(move || { _this.ping(i).unwrap(); }); _send.fetch_add(1, Ordering::SeqCst); diff --git a/user/apps/test-uevent/src/main.rs b/user/apps/test-uevent/src/main.rs index d4831326..4e6b4d21 100644 --- a/user/apps/test-uevent/src/main.rs +++ b/user/apps/test-uevent/src/main.rs @@ -1,7 +1,10 @@ -use libc::{sockaddr, sockaddr_storage, recvfrom, bind, sendto, socket, AF_NETLINK, SOCK_DGRAM, SOCK_CLOEXEC, getpid, c_void}; +use libc::{ + bind, c_void, getpid, recvfrom, sendto, sockaddr, sockaddr_storage, socket, AF_NETLINK, + SOCK_CLOEXEC, SOCK_DGRAM, +}; use nix::libc; use std::os::unix::io::RawFd; -use std::{ mem, io}; +use std::{io, mem}; #[repr(C)] struct Nlmsghdr { @@ -14,7 +17,11 @@ struct Nlmsghdr { fn create_netlink_socket() -> io::Result { let sockfd = unsafe { - socket(AF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, libc::NETLINK_KOBJECT_UEVENT) + socket( + AF_NETLINK, + SOCK_DGRAM | SOCK_CLOEXEC, + libc::NETLINK_KOBJECT_UEVENT, + ) }; if sockfd < 0 { @@ -33,7 +40,11 @@ fn bind_netlink_socket(sock: RawFd) -> io::Result<()> { addr.nl_groups = 0; let ret = unsafe { - bind(sock, &addr as *const _ as *const sockaddr, mem::size_of::() as u32) + bind( + sock, + &addr as *const _ as *const sockaddr, + mem::size_of::() as u32, + ) }; if ret < 0 { @@ -90,7 +101,10 @@ fn receive_uevent(sock: RawFd) -> io::Result { // 检查套接字文件描述符是否有效 if sock < 0 { println!("Invalid socket file descriptor: {}", sock); - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid socket file descriptor")); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid socket file descriptor", + )); } let mut buf = [0u8; 1024]; @@ -100,7 +114,10 @@ fn receive_uevent(sock: RawFd) -> io::Result { // 检查缓冲区指针和长度是否有效 if buf.is_empty() { println!("Buffer is empty"); - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Buffer is empty")); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Buffer is empty", + )); } let len = unsafe { recvfrom( @@ -122,13 +139,19 @@ fn receive_uevent(sock: RawFd) -> io::Result { let nlmsghdr_size = mem::size_of::(); if (len as usize) < nlmsghdr_size { println!("Received message is too short"); - return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is too short")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Received message is too short", + )); } let nlmsghdr = unsafe { &*(buf.as_ptr() as *const Nlmsghdr) }; if nlmsghdr.nlmsg_len as isize > len { println!("Received message is incomplete"); - return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is incomplete")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Received message is incomplete", + )); } let message_data = &buf[nlmsghdr_size..nlmsghdr.nlmsg_len as usize]; diff --git a/user/apps/test_seqpacket/src/main.rs b/user/apps/test_seqpacket/src/main.rs index dc1953e2..9657b36a 100644 --- a/user/apps/test_seqpacket/src/main.rs +++ b/user/apps/test_seqpacket/src/main.rs @@ -1,8 +1,8 @@ -mod seq_socket; mod seq_pair; +mod seq_socket; -use seq_socket::test_seq_socket; use seq_pair::test_seq_pair; +use seq_socket::test_seq_socket; fn main() -> Result<(), std::io::Error> { if let Err(e) = test_seq_socket() { @@ -187,4 +187,4 @@ fn main() -> Result<(), std::io::Error> { // let len = socket1.read(&mut buf)?; // println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len])); // Ok(()) -// } \ No newline at end of file +// } diff --git a/user/apps/test_seqpacket/src/seq_pair.rs b/user/apps/test_seqpacket/src/seq_pair.rs index f474b200..3c9c3818 100644 --- a/user/apps/test_seqpacket/src/seq_pair.rs +++ b/user/apps/test_seqpacket/src/seq_pair.rs @@ -1,16 +1,17 @@ use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType}; use std::fs::File; -use std::io::{Read, Write,Error}; +use std::io::{Error, Read, Write}; use std::os::fd::FromRawFd; -pub fn test_seq_pair()->Result<(),Error>{ +pub fn test_seq_pair() -> Result<(), Error> { // 创建 socket pair let (sock1, sock2) = socketpair( AddressFamily::Unix, SockType::SeqPacket, // 使用 SeqPacket 类型 None, // 协议默认 SockFlag::empty(), - ).expect("Failed to create socket pair"); + ) + .expect("Failed to create socket pair"); let mut socket1 = unsafe { File::from_raw_fd(sock1) }; let mut socket2 = unsafe { File::from_raw_fd(sock2) }; @@ -36,4 +37,4 @@ pub fn test_seq_pair()->Result<(),Error>{ let len = socket1.read(&mut buf)?; println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len])); Ok(()) -} \ No newline at end of file +} diff --git a/user/apps/test_seqpacket/src/seq_socket.rs b/user/apps/test_seqpacket/src/seq_socket.rs index a2f08c11..81b3db5b 100644 --- a/user/apps/test_seqpacket/src/seq_socket.rs +++ b/user/apps/test_seqpacket/src/seq_socket.rs @@ -1,16 +1,14 @@ - use libc::*; -use std::{fs, str}; use std::ffi::CString; use std::io::Error; use std::mem; use std::os::unix::io::RawFd; +use std::{fs, str}; const SOCKET_PATH: &str = "/test.seqpacket"; const MSG1: &str = "Hello, Unix SEQPACKET socket from Client!"; const MSG2: &str = "Hello, Unix SEQPACKET socket from Server!"; - fn create_seqpacket_socket() -> Result { unsafe { let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); @@ -33,7 +31,12 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> { addr.sun_path[i] = byte as i8; } - if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { + if bind( + fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { return Err(Error::last_os_error()); } } @@ -68,7 +71,13 @@ fn accept_connection(fd: RawFd) -> Result { fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { unsafe { let msg_bytes = msg.as_bytes(); - if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0) == -1 { + if send( + fd, + msg_bytes.as_ptr() as *const libc::c_void, + msg_bytes.len(), + 0, + ) == -1 + { return Err(Error::last_os_error()); } } @@ -78,7 +87,12 @@ fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { fn receive_message(fd: RawFd) -> Result { let mut buffer = [0; 1024]; unsafe { - let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(), 0); + let len = recv( + fd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); if len == -1 { return Err(Error::last_os_error()); } @@ -86,70 +100,82 @@ fn receive_message(fd: RawFd) -> Result { } } -pub fn test_seq_socket() ->Result<(), Error>{ - // Create and bind the server socket - fs::remove_file(&SOCKET_PATH).ok(); +pub fn test_seq_socket() -> Result<(), Error> { + // Create and bind the server socket + fs::remove_file(&SOCKET_PATH).ok(); - let server_fd = create_seqpacket_socket()?; - bind_socket(server_fd)?; - listen_socket(server_fd)?; + let server_fd = create_seqpacket_socket()?; + bind_socket(server_fd)?; + listen_socket(server_fd)?; - // Accept connection in a separate thread - let server_thread = std::thread::spawn(move || { - let client_fd = accept_connection(server_fd).expect("Failed to accept connection"); - - // Receive and print message - let received_msg = receive_message(client_fd).expect("Failed to receive message"); - println!("Server: Received message: {}", received_msg); - - send_message(client_fd, MSG2).expect("Failed to send message"); - - // Close client connection - unsafe { close(client_fd) }; - }); - - // Create and connect the client socket - let client_fd = create_seqpacket_socket()?; - unsafe { - let mut addr = sockaddr_un { - sun_family: AF_UNIX as u16, - sun_path: [0; 108], - }; - let path_cstr = CString::new(SOCKET_PATH).unwrap(); - let path_bytes = path_cstr.as_bytes(); - // Convert u8 to i8 - for (i, &byte) in path_bytes.iter().enumerate() { - addr.sun_path[i] = byte as i8; - } - if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { - return Err(Error::last_os_error()); - } - } - send_message(client_fd, MSG1)?; + // Accept connection in a separate thread + let server_thread = std::thread::spawn(move || { + let client_fd = accept_connection(server_fd).expect("Failed to accept connection"); + + // Receive and print message let received_msg = receive_message(client_fd).expect("Failed to receive message"); - println!("Client: Received message: {}", received_msg); - // get peer_name - unsafe { - let mut addrss = sockaddr_un { - sun_family: AF_UNIX as u16, - sun_path: [0; 108], - }; - let mut len = mem::size_of_val(&addrss) as socklen_t; - let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len); - if res == -1 { - return Err(Error::last_os_error()); - } - let sun_path = addrss.sun_path.clone(); - let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::>().try_into().unwrap(); - println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path)); + println!("Server: Received message: {}", received_msg); + + send_message(client_fd, MSG2).expect("Failed to send message"); - } - - server_thread.join().expect("Server thread panicked"); - let received_msg = receive_message(client_fd).expect("Failed to receive message"); - println!("Client: Received message: {}", received_msg); // Close client connection unsafe { close(client_fd) }; - fs::remove_file(&SOCKET_PATH).ok(); - Ok(()) -} \ No newline at end of file + }); + + // Create and connect the client socket + let client_fd = create_seqpacket_socket()?; + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let path_cstr = CString::new(SOCKET_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + // Convert u8 to i8 + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i] = byte as i8; + } + if connect( + client_fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + send_message(client_fd, MSG1)?; + let received_msg = receive_message(client_fd).expect("Failed to receive message"); + println!("Client: Received message: {}", received_msg); + // get peer_name + unsafe { + let mut addrss = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let mut len = mem::size_of_val(&addrss) as socklen_t; + let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len); + if res == -1 { + return Err(Error::last_os_error()); + } + let sun_path = addrss.sun_path.clone(); + let peer_path: [u8; 108] = sun_path + .iter() + .map(|&x| x as u8) + .collect::>() + .try_into() + .unwrap(); + println!( + "Client: Connected to server at path: {}", + String::from_utf8_lossy(&peer_path) + ); + } + + server_thread.join().expect("Server thread panicked"); + let received_msg = receive_message(client_fd).expect("Failed to receive message"); + println!("Client: Received message: {}", received_msg); + // Close client connection + unsafe { close(client_fd) }; + fs::remove_file(&SOCKET_PATH).ok(); + Ok(()) +} diff --git a/user/apps/test_unix_stream_socket/src/main.rs b/user/apps/test_unix_stream_socket/src/main.rs index 1e52748b..27b1cc8c 100644 --- a/user/apps/test_unix_stream_socket/src/main.rs +++ b/user/apps/test_unix_stream_socket/src/main.rs @@ -1,19 +1,19 @@ -use std::io::Error; -use std::os::fd::RawFd; -use std::fs; use libc::*; use std::ffi::CString; +use std::fs; +use std::io::Error; use std::mem; +use std::os::fd::RawFd; const SOCKET_PATH: &str = "/test.stream"; const MSG1: &str = "Hello, unix stream socket from Client!"; const MSG2: &str = "Hello, unix stream socket from Server!"; -fn create_stream_socket() -> Result{ +fn create_stream_socket() -> Result { unsafe { let fd = socket(AF_UNIX, SOCK_STREAM, 0); if fd == -1 { - return Err(Error::last_os_error()) + return Err(Error::last_os_error()); } Ok(fd) } @@ -31,7 +31,12 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> { addr.sun_path[i] = byte as i8; } - if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { + if bind( + fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { return Err(Error::last_os_error()); } } @@ -61,7 +66,13 @@ fn accept_conn(fd: RawFd) -> Result { fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { unsafe { let msg_bytes = msg.as_bytes(); - if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0)== -1 { + if send( + fd, + msg_bytes.as_ptr() as *const libc::c_void, + msg_bytes.len(), + 0, + ) == -1 + { return Err(Error::last_os_error()); } } @@ -71,7 +82,12 @@ fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { fn recv_message(fd: RawFd) -> Result { let mut buffer = [0; 1024]; unsafe { - let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(),0); + let len = recv( + fd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); if len == -1 { return Err(Error::last_os_error()); } @@ -82,7 +98,7 @@ fn recv_message(fd: RawFd) -> Result { fn test_stream() -> Result<(), Error> { fs::remove_file(&SOCKET_PATH).ok(); - let server_fd = create_stream_socket()?; + let server_fd = create_stream_socket()?; bind_socket(server_fd)?; listen_socket(server_fd)?; @@ -95,7 +111,7 @@ fn test_stream() -> Result<(), Error> { send_message(client_fd, MSG2).expect("Failed to send message"); println!("Server send finish"); - unsafe {close(client_fd)}; + unsafe { close(client_fd) }; }); let client_fd = create_stream_socket()?; @@ -111,9 +127,14 @@ fn test_stream() -> Result<(), Error> { addr.sun_path[i] = byte as i8; } - if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { + if connect( + client_fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { return Err(Error::last_os_error()); - } + } } send_message(client_fd, MSG1)?; @@ -129,9 +150,16 @@ fn test_stream() -> Result<(), Error> { return Err(Error::last_os_error()); } let sun_path = addrss.sun_path.clone(); - let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::>().try_into().unwrap(); - println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path)); - + let peer_path: [u8; 108] = sun_path + .iter() + .map(|&x| x as u8) + .collect::>() + .try_into() + .unwrap(); + println!( + "Client: Connected to server at path: {}", + String::from_utf8_lossy(&peer_path) + ); } server_thread.join().expect("Server thread panicked"); @@ -139,7 +167,7 @@ fn test_stream() -> Result<(), Error> { let recv_msg = recv_message(client_fd).expect("Failed to receive message from server"); println!("Client Received message: {}", recv_msg); - unsafe {close(client_fd)}; + unsafe { close(client_fd) }; fs::remove_file(&SOCKET_PATH).ok(); Ok(()) @@ -148,6 +176,6 @@ fn test_stream() -> Result<(), Error> { fn main() { match test_stream() { Ok(_) => println!("test for unix stream success"), - Err(_) => println!("test for unix stream failed") + Err(_) => println!("test for unix stream failed"), } -} \ No newline at end of file +}