Merge pull request #974 from Samuka007:feat-network-rebuild

clean format, enable ctrl-c in accept blocking
This commit is contained in:
Samuel Dai 2024-10-14 20:23:46 +08:00 committed by GitHub
commit 9a1fe0f989
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 275 additions and 180 deletions

View File

@ -133,7 +133,7 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
show &= false; show &= false;
} }
} }
show = false; show &= false;
if show { if show {
debug!("[SYS] [Pid: {:?}] [Call: {:?}]", pid, to_print); debug!("[SYS] [Pid: {:?}] [Call: {:?}]", pid, to_print);
} }

View File

@ -1,8 +1,10 @@
/// 引入Module /// 引入Module
use crate::{driver::{ use crate::{
driver::{
base::{ base::{
device::{ device::{
device_number::{DeviceNumber, Major}, Device, DeviceError, IdTable, BLOCKDEVS device_number::{DeviceNumber, Major},
Device, DeviceError, IdTable, BLOCKDEVS,
}, },
map::{ map::{
DeviceStruct, DEV_MAJOR_DYN_END, DEV_MAJOR_DYN_EXT_END, DEV_MAJOR_DYN_EXT_START, DeviceStruct, DEV_MAJOR_DYN_END, DEV_MAJOR_DYN_EXT_END, DEV_MAJOR_DYN_EXT_START,
@ -10,7 +12,9 @@ use crate::{driver::{
}, },
}, },
block::cache::{cached_block_device::BlockCache, BlockCacheError, BLOCK_SIZE}, block::cache::{cached_block_device::BlockCache, BlockCacheError, BLOCK_SIZE},
}, filesystem::sysfs::AttributeGroup}; },
filesystem::sysfs::AttributeGroup,
};
use alloc::{string::String, sync::Arc, vec::Vec}; use alloc::{string::String, sync::Arc, vec::Vec};
use core::{any::Any, fmt::Display, ops::Deref}; use core::{any::Any, fmt::Display, ops::Deref};

View File

@ -157,7 +157,11 @@ impl Attribute for UeventAttr {
writeln!(&mut uevent_content, "DEVTYPE=char").unwrap(); writeln!(&mut uevent_content, "DEVTYPE=char").unwrap();
} }
DeviceType::Net => { DeviceType::Net => {
let net_device = device.clone().cast::<dyn Iface>().map_err(|e: Arc<dyn Device>| { let net_device =
device
.clone()
.cast::<dyn Iface>()
.map_err(|e: Arc<dyn Device>| {
warn!("device:{:?} is not a net device!", e); warn!("device:{:?} is not a net device!", e);
SystemError::EINVAL SystemError::EINVAL
})?; })?;
@ -200,7 +204,6 @@ impl Attribute for UeventAttr {
} }
} }
/// 将设备的基本信息写入 uevent 文件 /// 将设备的基本信息写入 uevent 文件
fn sysfs_emit_str(buf: &mut [u8], content: &str) -> Result<usize, SystemError> { fn sysfs_emit_str(buf: &mut [u8], content: &str) -> Result<usize, SystemError> {
log::info!("sysfs_emit_str"); log::info!("sysfs_emit_str");

View File

@ -241,15 +241,18 @@ impl IfaceCommon {
self.poll_at_ms.store(0, Ordering::Relaxed); self.poll_at_ms.store(0, Ordering::Relaxed);
} }
if has_events { // if has_events {
// log::debug!("IfaceCommon::poll: has_events"); // log::debug!("IfaceCommon::poll: has_events");
// We never try to hold the write lock in the IRQ context, and we disable IRQ when // We never try to hold the write lock in the IRQ context, and we disable IRQ when
// holding the write lock. So we don't need to disable IRQ when holding the read lock. // holding the write lock. So we don't need to disable IRQ when holding the read lock.
self.bounds.read().iter().for_each(|bound_socket| { self.bounds.read().iter().for_each(|bound_socket| {
bound_socket.on_iface_events(); bound_socket.on_iface_events();
if has_events {
bound_socket bound_socket
.wait_queue() .wait_queue()
.wakeup(Some(ProcessState::Blocked(true))); .wakeup(Some(ProcessState::Blocked(true)));
}
}); });
// let closed_sockets = self // let closed_sockets = self
@ -258,7 +261,7 @@ impl IfaceCommon {
// .extract_if(|closing_socket| closing_socket.is_closed()) // .extract_if(|closing_socket| closing_socket.is_closed())
// .collect::<Vec<_>>(); // .collect::<Vec<_>>();
// drop(closed_sockets); // drop(closed_sockets);
} // }
} }
pub fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> { pub fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> {

View File

@ -257,7 +257,7 @@ impl Device for VirtIONetDevice {
impl VirtIODevice for VirtIONetDevice { impl VirtIODevice for VirtIONetDevice {
fn handle_irq(&self, _irq: IrqNumber) -> Result<IrqReturn, SystemError> { fn handle_irq(&self, _irq: IrqNumber) -> Result<IrqReturn, SystemError> {
log::warn!("VirtioInterface: poll_ifaces_try_lock_onetime -> poll_ifaces"); // log::warn!("VirtioInterface: poll_ifaces_try_lock_onetime -> poll_ifaces");
poll_ifaces(); poll_ifaces();
return Ok(IrqReturn::Handled); return Ok(IrqReturn::Handled);
} }

View File

@ -1,7 +1,6 @@
use alloc::vec::Vec; use alloc::vec::Vec;
use alloc::{string::String, sync::Arc}; use alloc::{string::String, sync::Arc};
use log::debug;
use system_error::SystemError; use system_error::SystemError;
use crate::libs::spinlock::SpinLock; use crate::libs::spinlock::SpinLock;
@ -43,14 +42,14 @@ impl Buffer {
let len = core::cmp::min(buf.len(), read_buffer.len()); let len = core::cmp::min(buf.len(), read_buffer.len());
buf[..len].copy_from_slice(&read_buffer[..len]); buf[..len].copy_from_slice(&read_buffer[..len]);
let _ = read_buffer.split_off(len); let _ = read_buffer.split_off(len);
log::debug!("recv buf {}", String::from_utf8_lossy(buf)); // log::debug!("recv buf {}", String::from_utf8_lossy(buf));
return Ok(len); return Ok(len);
} }
pub fn write_read_buffer(&self, buf: &[u8]) -> Result<usize, SystemError> { pub fn write_read_buffer(&self, buf: &[u8]) -> Result<usize, SystemError> {
let mut buffer = self.read_buffer.lock_irqsave(); let mut buffer = self.read_buffer.lock_irqsave();
log::debug!("send buf {}", String::from_utf8_lossy(buf)); // log::debug!("send buf {}", String::from_utf8_lossy(buf));
let len = buf.len(); let len = buf.len();
if self.metadata.buf_size - buffer.len() < len { if self.metadata.buf_size - buffer.len() < len {
return Err(SystemError::ENOBUFS); return Err(SystemError::ENOBUFS);

View File

@ -45,7 +45,7 @@ impl BoundInner {
// iface // iface
// } // }
// 强绑VirtualIO // 强绑VirtualIO
log::debug!("Not bind to any iface, bind to virtIO"); // log::debug!("Not bind to any iface, bind to virtIO");
let iface = NET_DEVICES let iface = NET_DEVICES
.read_irqsave() .read_irqsave()
.get(&0) .get(&0)

View File

@ -3,6 +3,7 @@ use core::sync::atomic::{AtomicU32, AtomicUsize};
use crate::libs::rwlock::RwLock; use crate::libs::rwlock::RwLock;
use crate::net::socket::EPollEventType; use crate::net::socket::EPollEventType;
use crate::net::socket::{self, inet::Types}; use crate::net::socket::{self, inet::Types};
use alloc::boxed::Box;
use alloc::vec::Vec; use alloc::vec::Vec;
use smoltcp; use smoltcp;
use system_error::SystemError::{self, *}; use system_error::SystemError::{self, *};
@ -30,13 +31,13 @@ where
#[derive(Debug)] #[derive(Debug)]
pub enum Init { pub enum Init {
Unbound(smoltcp::socket::tcp::Socket<'static>), Unbound(Box<smoltcp::socket::tcp::Socket<'static>>),
Bound((socket::inet::BoundInner, smoltcp::wire::IpEndpoint)), Bound((socket::inet::BoundInner, smoltcp::wire::IpEndpoint)),
} }
impl Init { impl Init {
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Init::Unbound(new_smoltcp_socket()) Init::Unbound(Box::new(new_smoltcp_socket()))
} }
/// 传入一个已经绑定的socket /// 传入一个已经绑定的socket
@ -55,7 +56,7 @@ impl Init {
) -> Result<Self, SystemError> { ) -> Result<Self, SystemError> {
match self { match self {
Init::Unbound(socket) => { Init::Unbound(socket) => {
let bound = socket::inet::BoundInner::bind(socket, &local_endpoint.addr)?; let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr)?;
bound bound
.port_manager() .port_manager()
.bind_port(Types::Tcp, local_endpoint.port)?; .bind_port(Types::Tcp, local_endpoint.port)?;
@ -73,7 +74,7 @@ impl Init {
match self { match self {
Init::Unbound(socket) => { Init::Unbound(socket) => {
let (bound, address) = let (bound, address) =
socket::inet::BoundInner::bind_ephemeral(socket, remote_endpoint.addr) socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr)
.map_err(|err| (Self::new(), err))?; .map_err(|err| (Self::new(), err))?;
let bound_port = bound let bound_port = bound
.port_manager() .port_manager()
@ -125,7 +126,7 @@ impl Init {
} else { } else {
smoltcp::wire::IpListenEndpoint::from(local) smoltcp::wire::IpListenEndpoint::from(local)
}; };
log::debug!("listen at {:?}", listen_addr); // log::debug!("listen at {:?}", listen_addr);
let mut inners = Vec::new(); let mut inners = Vec::new();
if let Err(err) = || -> Result<(), SystemError> { if let Err(err) = || -> Result<(), SystemError> {
for _ in 0..(backlog - 1) { for _ in 0..(backlog - 1) {
@ -440,4 +441,13 @@ impl Inner {
Inner::Established(est) => est.with_mut(|socket| socket.recv_capacity()), Inner::Established(est) => est.with_mut(|socket| socket.recv_capacity()),
} }
} }
pub fn iface(&self) -> Option<&alloc::sync::Arc<dyn crate::driver::net::Iface>> {
match self {
Inner::Init(_) => None,
Inner::Connecting(conn) => Some(conn.inner.iface()),
Inner::Listening(listen) => Some(listen.inners[0].iface()),
Inner::Established(est) => Some(est.inner.iface()),
}
}
} }

View File

@ -185,11 +185,19 @@ impl TcpSocket {
} }
pub fn try_recv(&self, buf: &mut [u8]) -> Result<usize, SystemError> { pub fn try_recv(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
poll_ifaces(); self.inner
match self.inner.read().as_ref().expect("Tcp Inner is None") { .read()
.as_ref()
.map(|inner| {
inner.iface().unwrap().poll();
let result = match inner {
Inner::Established(inner) => inner.recv_slice(buf), Inner::Established(inner) => inner.recv_slice(buf),
_ => Err(EINVAL), _ => Err(EINVAL),
} };
inner.iface().unwrap().poll();
result
})
.unwrap()
} }
pub fn try_send(&self, buf: &[u8]) -> Result<usize, SystemError> { pub fn try_send(&self, buf: &[u8]) -> Result<usize, SystemError> {
@ -221,6 +229,7 @@ impl TcpSocket {
// should only call on accept // should only call on accept
fn is_acceptable(&self) -> bool { fn is_acceptable(&self) -> bool {
// (self.poll() & EP::EPOLLIN.bits() as usize) != 0 // (self.poll() & EP::EPOLLIN.bits() as usize) != 0
self.inner.read().as_ref().unwrap().iface().unwrap().poll();
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN)
} }
} }
@ -233,7 +242,7 @@ impl Socket for TcpSocket {
fn get_name(&self) -> Result<Endpoint, SystemError> { fn get_name(&self) -> Result<Endpoint, SystemError> {
match self.inner.read().as_ref().expect("Tcp Inner is None") { match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Init(Init::Unbound(_)) => Ok(Endpoint::Ip(UNSPECIFIED_LOCAL_ENDPOINT)), Inner::Init(Init::Unbound(_)) => Ok(Endpoint::Ip(UNSPECIFIED_LOCAL_ENDPOINT)),
Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(local.clone())), Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)),
Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())), Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())), Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())),
Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())), Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())),
@ -255,7 +264,7 @@ impl Socket for TcpSocket {
} }
fn poll(&self) -> usize { fn poll(&self) -> usize {
self.pollee.load(core::sync::atomic::Ordering::Relaxed) self.pollee.load(core::sync::atomic::Ordering::SeqCst)
} }
fn listen(&self, backlog: usize) -> Result<(), SystemError> { fn listen(&self, backlog: usize) -> Result<(), SystemError> {

View File

@ -11,7 +11,7 @@ fn create_inet_socket(
socket_type: Type, socket_type: Type,
protocol: smoltcp::wire::IpProtocol, protocol: smoltcp::wire::IpProtocol,
) -> Result<Arc<dyn Socket>, SystemError> { ) -> Result<Arc<dyn Socket>, SystemError> {
log::debug!("type: {:?}, protocol: {:?}", socket_type, protocol); // log::debug!("type: {:?}, protocol: {:?}", socket_type, protocol);
use smoltcp::wire::IpProtocol::*; use smoltcp::wire::IpProtocol::*;
use Type::*; use Type::*;
match socket_type { match socket_type {

View File

@ -27,7 +27,7 @@ impl family::Family for Unix {
impl Unix { impl Unix {
pub fn new_pairs(socket_type: Type) -> Result<(Arc<Inode>, Arc<Inode>), SystemError> { pub fn new_pairs(socket_type: Type) -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
log::debug!("socket_type {:?}", socket_type); // log::debug!("socket_type {:?}", socket_type);
match socket_type { match socket_type {
Type::SeqPacket => seqpacket::SeqpacketSocket::new_pairs(), Type::SeqPacket => seqpacket::SeqpacketSocket::new_pairs(),
Type::Stream | Type::Datagram => stream::StreamSocket::new_pairs(), Type::Stream | Type::Datagram => stream::StreamSocket::new_pairs(),

View File

@ -62,11 +62,7 @@ pub(super) struct Listener {
impl Listener { impl Listener {
pub(super) fn new(inode: Endpoint, backlog: usize) -> Self { pub(super) fn new(inode: Endpoint, backlog: usize) -> Self {
log::debug!("backlog {}", backlog); log::debug!("backlog {}", backlog);
let back = if backlog > 1024 { let back = if backlog > 1024 { 1024_usize } else { backlog };
1024 as usize
} else {
backlog
};
return Self { return Self {
inode, inode,
backlog: AtomicUsize::new(back), backlog: AtomicUsize::new(back),
@ -82,7 +78,7 @@ impl Listener {
log::debug!(" incom len {}", incoming_conns.len()); log::debug!(" incom len {}", incoming_conns.len());
let conn = incoming_conns let conn = incoming_conns
.pop_front() .pop_front()
.ok_or_else(|| SystemError::EAGAIN_OR_EWOULDBLOCK)?; .ok_or(SystemError::EAGAIN_OR_EWOULDBLOCK)?;
let socket = let socket =
Arc::downcast::<SeqpacketSocket>(conn.inner()).map_err(|_| SystemError::EINVAL)?; Arc::downcast::<SeqpacketSocket>(conn.inner()).map_err(|_| SystemError::EINVAL)?;
let peer = match &*socket.inner.read() { let peer = match &*socket.inner.read() {
@ -190,7 +186,7 @@ impl Connected {
if self.can_send()? { if self.can_send()? {
return self.send_slice(buf); return self.send_slice(buf);
} else { } else {
log::debug!("can not send {:?}", String::from_utf8_lossy(&buf[..])); log::debug!("can not send {:?}", String::from_utf8_lossy(buf));
return Err(SystemError::ENOBUFS); return Err(SystemError::ENOBUFS);
} }
} }

View File

@ -230,11 +230,7 @@ impl Socket for SeqpacketSocket {
if !self.is_nonblocking() { if !self.is_nonblocking() {
loop { loop {
wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?; wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
match self match self.try_accept() {
.try_accept()
.map(|(seqpacket_socket, remote_endpoint)| {
(seqpacket_socket, Endpoint::from(remote_endpoint))
}) {
Ok((socket, epoint)) => return Ok((socket, epoint)), Ok((socket, epoint)) => return Ok((socket, epoint)),
Err(_) => continue, Err(_) => continue,
} }
@ -260,7 +256,7 @@ impl Socket for SeqpacketSocket {
} }
fn close(&self) -> Result<(), SystemError> { fn close(&self) -> Result<(), SystemError> {
log::debug!("seqpacket close"); // log::debug!("seqpacket close");
self.shutdown.recv_shutdown(); self.shutdown.recv_shutdown();
self.shutdown.send_shutdown(); self.shutdown.send_shutdown();
Ok(()) Ok(())
@ -274,7 +270,7 @@ impl Socket for SeqpacketSocket {
}; };
if let Some(endpoint) = endpoint { if let Some(endpoint) = endpoint {
return Ok(Endpoint::from(endpoint)); return Ok(endpoint);
} else { } else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
} }
@ -289,7 +285,7 @@ impl Socket for SeqpacketSocket {
}; };
if let Some(endpoint) = endpoint { if let Some(endpoint) = endpoint {
return Ok(Endpoint::from(endpoint)); return Ok(endpoint);
} else { } else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
} }
@ -402,7 +398,7 @@ impl Socket for SeqpacketSocket {
flags: MessageFlag, flags: MessageFlag,
_address: Option<Endpoint>, _address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> { ) -> Result<(usize, Endpoint), SystemError> {
log::debug!("recvfrom flags {:?}", flags); // log::debug!("recvfrom flags {:?}", flags);
if flags.contains(MessageFlag::OOB) { if flags.contains(MessageFlag::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
} }
@ -417,7 +413,7 @@ impl Socket for SeqpacketSocket {
match &*self.inner.write() { match &*self.inner.write() {
Inner::Connected(connected) => match connected.recv_slice(buffer) { Inner::Connected(connected) => match connected.recv_slice(buffer) {
Ok(usize) => { Ok(usize) => {
log::debug!("recvs from successfully"); // log::debug!("recvs from successfully");
return Ok((usize, connected.peer_endpoint().unwrap().clone())); return Ok((usize, connected.peer_endpoint().unwrap().clone()));
} }
Err(_) => continue, Err(_) => continue,

View File

@ -231,10 +231,7 @@ impl Socket for StreamSocket {
//目前只实现了阻塞式实现 //目前只实现了阻塞式实现
loop { loop {
wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?; wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
match self match self.try_accept() {
.try_accept()
.map(|(stream_socket, remote_endpoint)| (stream_socket, remote_endpoint))
{
Ok((socket, endpoint)) => { Ok((socket, endpoint)) => {
debug!("server accept!:{:?}", endpoint); debug!("server accept!:{:?}", endpoint);
return Ok((socket, endpoint)); return Ok((socket, endpoint));

View File

@ -41,18 +41,18 @@ impl Syscall {
protocol: usize, protocol: usize,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
// 打印收到的参数 // 打印收到的参数
log::debug!( // log::debug!(
"socket: address_family={:?}, socket_type={:?}, protocol={:?}", // "socket: address_family={:?}, socket_type={:?}, protocol={:?}",
address_family, // address_family,
socket_type, // socket_type,
protocol // protocol
); // );
let address_family = socket::AddressFamily::try_from(address_family as u16)?; let address_family = socket::AddressFamily::try_from(address_family as u16)?;
let type_arg = SysArgSocketType::from_bits_truncate(socket_type as u32); let type_arg = SysArgSocketType::from_bits_truncate(socket_type as u32);
let is_nonblock = type_arg.is_nonblock(); let is_nonblock = type_arg.is_nonblock();
let is_close_on_exec = type_arg.is_cloexec(); let is_close_on_exec = type_arg.is_cloexec();
let stype = socket::Type::try_from(type_arg)?; let stype = socket::Type::try_from(type_arg)?;
log::debug!("type_arg {:?} stype {:?}", type_arg, stype); // log::debug!("type_arg {:?} stype {:?}", type_arg, stype);
let inode = socket::create_socket( let inode = socket::create_socket(
address_family, address_family,
@ -256,7 +256,7 @@ impl Syscall {
let socket: Arc<socket::Inode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
log::debug!("bind: socket={:?}", socket); // log::debug!("bind: socket={:?}", socket);
socket.bind(endpoint)?; socket.bind(endpoint)?;
Ok(0) Ok(0)
} }

View File

@ -312,7 +312,7 @@ impl From<Endpoint> for SockAddr {
} }
let addr_un = SockAddrUn { let addr_un = SockAddrUn {
sun_family: AddressFamily::Unix as u16, sun_family: AddressFamily::Unix as u16,
sun_path: sun_path, sun_path,
}; };
return SockAddr { addr_un }; return SockAddr { addr_un };
} }

View File

@ -101,7 +101,7 @@ impl Ping {
for i in 0..this.config.count { for i in 0..this.config.count {
let _this = this.clone(); let _this = this.clone();
let handle = thread::spawn(move||{ let handle = thread::spawn(move || {
_this.ping(i).unwrap(); _this.ping(i).unwrap();
}); });
_send.fetch_add(1, Ordering::SeqCst); _send.fetch_add(1, Ordering::SeqCst);

View File

@ -1,7 +1,10 @@
use libc::{sockaddr, sockaddr_storage, recvfrom, bind, sendto, socket, AF_NETLINK, SOCK_DGRAM, SOCK_CLOEXEC, getpid, c_void}; use libc::{
bind, c_void, getpid, recvfrom, sendto, sockaddr, sockaddr_storage, socket, AF_NETLINK,
SOCK_CLOEXEC, SOCK_DGRAM,
};
use nix::libc; use nix::libc;
use std::os::unix::io::RawFd; use std::os::unix::io::RawFd;
use std::{ mem, io}; use std::{io, mem};
#[repr(C)] #[repr(C)]
struct Nlmsghdr { struct Nlmsghdr {
@ -14,7 +17,11 @@ struct Nlmsghdr {
fn create_netlink_socket() -> io::Result<RawFd> { fn create_netlink_socket() -> io::Result<RawFd> {
let sockfd = unsafe { let sockfd = unsafe {
socket(AF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, libc::NETLINK_KOBJECT_UEVENT) socket(
AF_NETLINK,
SOCK_DGRAM | SOCK_CLOEXEC,
libc::NETLINK_KOBJECT_UEVENT,
)
}; };
if sockfd < 0 { if sockfd < 0 {
@ -33,7 +40,11 @@ fn bind_netlink_socket(sock: RawFd) -> io::Result<()> {
addr.nl_groups = 0; addr.nl_groups = 0;
let ret = unsafe { let ret = unsafe {
bind(sock, &addr as *const _ as *const sockaddr, mem::size_of::<libc::sockaddr_nl>() as u32) bind(
sock,
&addr as *const _ as *const sockaddr,
mem::size_of::<libc::sockaddr_nl>() as u32,
)
}; };
if ret < 0 { if ret < 0 {
@ -90,7 +101,10 @@ fn receive_uevent(sock: RawFd) -> io::Result<String> {
// 检查套接字文件描述符是否有效 // 检查套接字文件描述符是否有效
if sock < 0 { if sock < 0 {
println!("Invalid socket file descriptor: {}", sock); println!("Invalid socket file descriptor: {}", sock);
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid socket file descriptor")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid socket file descriptor",
));
} }
let mut buf = [0u8; 1024]; let mut buf = [0u8; 1024];
@ -100,7 +114,10 @@ fn receive_uevent(sock: RawFd) -> io::Result<String> {
// 检查缓冲区指针和长度是否有效 // 检查缓冲区指针和长度是否有效
if buf.is_empty() { if buf.is_empty() {
println!("Buffer is empty"); println!("Buffer is empty");
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Buffer is empty")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Buffer is empty",
));
} }
let len = unsafe { let len = unsafe {
recvfrom( recvfrom(
@ -122,13 +139,19 @@ fn receive_uevent(sock: RawFd) -> io::Result<String> {
let nlmsghdr_size = mem::size_of::<Nlmsghdr>(); let nlmsghdr_size = mem::size_of::<Nlmsghdr>();
if (len as usize) < nlmsghdr_size { if (len as usize) < nlmsghdr_size {
println!("Received message is too short"); println!("Received message is too short");
return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is too short")); return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Received message is too short",
));
} }
let nlmsghdr = unsafe { &*(buf.as_ptr() as *const Nlmsghdr) }; let nlmsghdr = unsafe { &*(buf.as_ptr() as *const Nlmsghdr) };
if nlmsghdr.nlmsg_len as isize > len { if nlmsghdr.nlmsg_len as isize > len {
println!("Received message is incomplete"); println!("Received message is incomplete");
return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is incomplete")); return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Received message is incomplete",
));
} }
let message_data = &buf[nlmsghdr_size..nlmsghdr.nlmsg_len as usize]; let message_data = &buf[nlmsghdr_size..nlmsghdr.nlmsg_len as usize];

View File

@ -1,8 +1,8 @@
mod seq_socket;
mod seq_pair; mod seq_pair;
mod seq_socket;
use seq_socket::test_seq_socket;
use seq_pair::test_seq_pair; use seq_pair::test_seq_pair;
use seq_socket::test_seq_socket;
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
if let Err(e) = test_seq_socket() { if let Err(e) = test_seq_socket() {

View File

@ -1,16 +1,17 @@
use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType}; use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType};
use std::fs::File; use std::fs::File;
use std::io::{Read, Write,Error}; use std::io::{Error, Read, Write};
use std::os::fd::FromRawFd; use std::os::fd::FromRawFd;
pub fn test_seq_pair()->Result<(),Error>{ pub fn test_seq_pair() -> Result<(), Error> {
// 创建 socket pair // 创建 socket pair
let (sock1, sock2) = socketpair( let (sock1, sock2) = socketpair(
AddressFamily::Unix, AddressFamily::Unix,
SockType::SeqPacket, // 使用 SeqPacket 类型 SockType::SeqPacket, // 使用 SeqPacket 类型
None, // 协议默认 None, // 协议默认
SockFlag::empty(), SockFlag::empty(),
).expect("Failed to create socket pair"); )
.expect("Failed to create socket pair");
let mut socket1 = unsafe { File::from_raw_fd(sock1) }; let mut socket1 = unsafe { File::from_raw_fd(sock1) };
let mut socket2 = unsafe { File::from_raw_fd(sock2) }; let mut socket2 = unsafe { File::from_raw_fd(sock2) };

View File

@ -1,16 +1,14 @@
use libc::*; use libc::*;
use std::{fs, str};
use std::ffi::CString; use std::ffi::CString;
use std::io::Error; use std::io::Error;
use std::mem; use std::mem;
use std::os::unix::io::RawFd; use std::os::unix::io::RawFd;
use std::{fs, str};
const SOCKET_PATH: &str = "/test.seqpacket"; const SOCKET_PATH: &str = "/test.seqpacket";
const MSG1: &str = "Hello, Unix SEQPACKET socket from Client!"; const MSG1: &str = "Hello, Unix SEQPACKET socket from Client!";
const MSG2: &str = "Hello, Unix SEQPACKET socket from Server!"; const MSG2: &str = "Hello, Unix SEQPACKET socket from Server!";
fn create_seqpacket_socket() -> Result<RawFd, Error> { fn create_seqpacket_socket() -> Result<RawFd, Error> {
unsafe { unsafe {
let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
@ -33,7 +31,12 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
addr.sun_path[i] = byte as i8; addr.sun_path[i] = byte as i8;
} }
if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { if bind(
fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
} }
@ -68,7 +71,13 @@ fn accept_connection(fd: RawFd) -> Result<RawFd, Error> {
fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
unsafe { unsafe {
let msg_bytes = msg.as_bytes(); let msg_bytes = msg.as_bytes();
if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0) == -1 { if send(
fd,
msg_bytes.as_ptr() as *const libc::c_void,
msg_bytes.len(),
0,
) == -1
{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
} }
@ -78,7 +87,12 @@ fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
fn receive_message(fd: RawFd) -> Result<String, Error> { fn receive_message(fd: RawFd) -> Result<String, Error> {
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
unsafe { unsafe {
let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(), 0); let len = recv(
fd,
buffer.as_mut_ptr() as *mut libc::c_void,
buffer.len(),
0,
);
if len == -1 { if len == -1 {
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
@ -86,7 +100,7 @@ fn receive_message(fd: RawFd) -> Result<String, Error> {
} }
} }
pub fn test_seq_socket() ->Result<(), Error>{ pub fn test_seq_socket() -> Result<(), Error> {
// Create and bind the server socket // Create and bind the server socket
fs::remove_file(&SOCKET_PATH).ok(); fs::remove_file(&SOCKET_PATH).ok();
@ -121,7 +135,12 @@ pub fn test_seq_socket() ->Result<(), Error>{
for (i, &byte) in path_bytes.iter().enumerate() { for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i] = byte as i8; addr.sun_path[i] = byte as i8;
} }
if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { if connect(
client_fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
} }
@ -140,9 +159,16 @@ pub fn test_seq_socket() ->Result<(), Error>{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
let sun_path = addrss.sun_path.clone(); let sun_path = addrss.sun_path.clone();
let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::<Vec<u8>>().try_into().unwrap(); let peer_path: [u8; 108] = sun_path
println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path)); .iter()
.map(|&x| x as u8)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
println!(
"Client: Connected to server at path: {}",
String::from_utf8_lossy(&peer_path)
);
} }
server_thread.join().expect("Server thread panicked"); server_thread.join().expect("Server thread panicked");

View File

@ -1,19 +1,19 @@
use std::io::Error;
use std::os::fd::RawFd;
use std::fs;
use libc::*; use libc::*;
use std::ffi::CString; use std::ffi::CString;
use std::fs;
use std::io::Error;
use std::mem; use std::mem;
use std::os::fd::RawFd;
const SOCKET_PATH: &str = "/test.stream"; const SOCKET_PATH: &str = "/test.stream";
const MSG1: &str = "Hello, unix stream socket from Client!"; const MSG1: &str = "Hello, unix stream socket from Client!";
const MSG2: &str = "Hello, unix stream socket from Server!"; const MSG2: &str = "Hello, unix stream socket from Server!";
fn create_stream_socket() -> Result<RawFd, Error>{ fn create_stream_socket() -> Result<RawFd, Error> {
unsafe { unsafe {
let fd = socket(AF_UNIX, SOCK_STREAM, 0); let fd = socket(AF_UNIX, SOCK_STREAM, 0);
if fd == -1 { if fd == -1 {
return Err(Error::last_os_error()) return Err(Error::last_os_error());
} }
Ok(fd) Ok(fd)
} }
@ -31,7 +31,12 @@ fn bind_socket(fd: RawFd) -> Result<(), Error> {
addr.sun_path[i] = byte as i8; addr.sun_path[i] = byte as i8;
} }
if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { if bind(
fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
} }
@ -61,7 +66,13 @@ fn accept_conn(fd: RawFd) -> Result<RawFd, Error> {
fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
unsafe { unsafe {
let msg_bytes = msg.as_bytes(); let msg_bytes = msg.as_bytes();
if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0)== -1 { if send(
fd,
msg_bytes.as_ptr() as *const libc::c_void,
msg_bytes.len(),
0,
) == -1
{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
} }
@ -71,7 +82,12 @@ fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
fn recv_message(fd: RawFd) -> Result<String, Error> { fn recv_message(fd: RawFd) -> Result<String, Error> {
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
unsafe { unsafe {
let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(),0); let len = recv(
fd,
buffer.as_mut_ptr() as *mut libc::c_void,
buffer.len(),
0,
);
if len == -1 { if len == -1 {
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
@ -95,7 +111,7 @@ fn test_stream() -> Result<(), Error> {
send_message(client_fd, MSG2).expect("Failed to send message"); send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish"); println!("Server send finish");
unsafe {close(client_fd)}; unsafe { close(client_fd) };
}); });
let client_fd = create_stream_socket()?; let client_fd = create_stream_socket()?;
@ -111,7 +127,12 @@ fn test_stream() -> Result<(), Error> {
addr.sun_path[i] = byte as i8; addr.sun_path[i] = byte as i8;
} }
if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { if connect(
client_fd,
&addr as *const _ as *const sockaddr,
mem::size_of_val(&addr) as socklen_t,
) == -1
{
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
} }
@ -129,9 +150,16 @@ fn test_stream() -> Result<(), Error> {
return Err(Error::last_os_error()); return Err(Error::last_os_error());
} }
let sun_path = addrss.sun_path.clone(); let sun_path = addrss.sun_path.clone();
let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::<Vec<u8>>().try_into().unwrap(); let peer_path: [u8; 108] = sun_path
println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path)); .iter()
.map(|&x| x as u8)
.collect::<Vec<u8>>()
.try_into()
.unwrap();
println!(
"Client: Connected to server at path: {}",
String::from_utf8_lossy(&peer_path)
);
} }
server_thread.join().expect("Server thread panicked"); server_thread.join().expect("Server thread panicked");
@ -139,7 +167,7 @@ fn test_stream() -> Result<(), Error> {
let recv_msg = recv_message(client_fd).expect("Failed to receive message from server"); let recv_msg = recv_message(client_fd).expect("Failed to receive message from server");
println!("Client Received message: {}", recv_msg); println!("Client Received message: {}", recv_msg);
unsafe {close(client_fd)}; unsafe { close(client_fd) };
fs::remove_file(&SOCKET_PATH).ok(); fs::remove_file(&SOCKET_PATH).ok();
Ok(()) Ok(())
@ -148,6 +176,6 @@ fn test_stream() -> Result<(), Error> {
fn main() { fn main() {
match test_stream() { match test_stream() {
Ok(_) => println!("test for unix stream success"), Ok(_) => println!("test for unix stream success"),
Err(_) => println!("test for unix stream failed") Err(_) => println!("test for unix stream failed"),
} }
} }