ready for merge in master (#964)

uevent should be format

Enum of smoltcp socket should be optimized.

need to add interface for routing subsys

actix is still not abled to run.

clean some casual added code to other places
This commit is contained in:
Samuel Dai 2024-10-10 17:53:39 +08:00 committed by GitHub
parent 79eda4bcf9
commit 40d9375b6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
102 changed files with 10497 additions and 3303 deletions

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
arch::{ arch::{
ipc::signal::X86_64SignalArch, ipc::signal::X86_64SignalArch,
syscall::nr::{SYS_ARCH_PRCTL, SYS_RT_SIGRETURN}, syscall::nr::{SysCall, SYS_ARCH_PRCTL, SYS_RT_SIGRETURN},
CurrentIrqArch, CurrentIrqArch,
}, },
exception::InterruptArch, exception::InterruptArch,
@ -53,7 +53,7 @@ macro_rules! syscall_return {
if $show { if $show {
let pid = ProcessManager::current_pcb().pid(); let pid = ProcessManager::current_pcb().pid();
debug!("syscall return:pid={:?},ret= {:?}\n", pid, ret as isize); debug!("[SYS] [Pid: {:?}] [Retn: {:?}]", pid, ret as i64);
} }
unsafe { unsafe {
@ -63,6 +63,24 @@ macro_rules! syscall_return {
}}; }};
} }
macro_rules! normal_syscall_return {
($val:expr, $regs:expr, $show:expr) => {{
let ret = $val;
if $show {
let pid = ProcessManager::current_pcb().pid();
debug!("[SYS] [Pid: {:?}] [Retn: {:?}]", pid, ret);
}
$regs.rax = ret.unwrap_or_else(|e| e.to_posix_errno() as usize) as u64;
unsafe {
CurrentIrqArch::interrupt_disable();
}
return;
}};
}
#[no_mangle] #[no_mangle]
pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) { pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
let syscall_num = frame.rax as usize; let syscall_num = frame.rax as usize;
@ -87,15 +105,38 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
]; ];
mfence(); mfence();
let pid = ProcessManager::current_pcb().pid(); let pid = ProcessManager::current_pcb().pid();
let show = false; let mut show = (syscall_num != SYS_SCHED) && (pid.data() >= 7);
// let show = if syscall_num != SYS_SCHED && pid.data() >= 7 { // let mut show = true;
// true
// } else { let to_print = SysCall::try_from(syscall_num);
// false if let Ok(to_print) = to_print {
// }; use SysCall::*;
match to_print {
SYS_ACCEPT | SYS_ACCEPT4 | SYS_BIND | SYS_CONNECT | SYS_SHUTDOWN | SYS_LISTEN => {
show &= true;
}
SYS_RECVFROM | SYS_SENDTO | SYS_SENDMSG | SYS_RECVMSG => {
show &= true;
}
SYS_SOCKET | SYS_GETSOCKNAME | SYS_GETPEERNAME | SYS_SOCKETPAIR | SYS_SETSOCKOPT
| SYS_GETSOCKOPT => {
show &= true;
}
SYS_OPEN | SYS_OPENAT | SYS_CREAT | SYS_CLOSE => {
show &= true;
}
SYS_READ | SYS_WRITE | SYS_READV | SYS_WRITEV | SYS_PREAD64 | SYS_PWRITE64
| SYS_PREADV | SYS_PWRITEV | SYS_PREADV2 => {
show &= true;
}
_ => {
show &= false;
}
}
if show { if show {
debug!("syscall: pid: {:?}, num={:?}\n", pid, syscall_num); debug!("[SYS] [Pid: {:?}] [Call: {:?}]", pid, to_print);
}
} }
// Arch specific syscall // Arch specific syscall
@ -108,21 +149,11 @@ pub extern "sysv64" fn syscall_handler(frame: &mut TrapFrame) {
); );
} }
SYS_ARCH_PRCTL => { SYS_ARCH_PRCTL => {
syscall_return!( normal_syscall_return!(Syscall::arch_prctl(args[0], args[1]), frame, show);
Syscall::arch_prctl(args[0], args[1])
.unwrap_or_else(|e| e.to_posix_errno() as usize),
frame,
show
);
} }
_ => {} _ => {}
} }
syscall_return!( normal_syscall_return!(Syscall::handle(syscall_num, &args, frame), frame, show);
Syscall::handle(syscall_num, &args, frame).unwrap_or_else(|e| e.to_posix_errno() as usize)
as u64,
frame,
show
);
} }
/// 系统调用初始化 /// 系统调用初始化

View File

@ -355,3 +355,381 @@ pub const SYS_WAIT4: usize = 61;
pub const SYS_WAITID: usize = 247; pub const SYS_WAITID: usize = 247;
pub const SYS_WRITE: usize = 1; pub const SYS_WRITE: usize = 1;
pub const SYS_WRITEV: usize = 20; pub const SYS_WRITEV: usize = 20;
use num_traits::{FromPrimitive, ToPrimitive};
use system_error::SystemError;
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)]
pub enum SysCall {
SYS__SYSCTL = 156,
SYS_ACCEPT = 43,
SYS_ACCEPT4 = 288,
SYS_ACCESS = 21,
SYS_ACCT = 163,
SYS_ADD_KEY = 248,
SYS_ADJTIMEX = 159,
SYS_AFS_SYSCALL = 183,
SYS_ALARM = 37,
SYS_ARCH_PRCTL = 158,
SYS_BIND = 49,
SYS_BPF = 321,
SYS_BRK = 12,
SYS_CAPGET = 125,
SYS_CAPSET = 126,
SYS_CHDIR = 80,
SYS_CHMOD = 90,
SYS_CHOWN = 92,
SYS_CHROOT = 161,
SYS_CLOCK_ADJTIME = 305,
SYS_CLOCK_GETRES = 229,
SYS_CLOCK_GETTIME = 228,
SYS_CLOCK_NANOSLEEP = 230,
SYS_CLOCK_SETTIME = 227,
SYS_CLONE = 56,
SYS_CLONE3 = 435,
SYS_CLOSE = 3,
SYS_CLOSE_RANGE = 436,
SYS_CONNECT = 42,
SYS_COPY_FILE_RANGE = 326,
SYS_CREAT = 85,
SYS_CREATE_MODULE = 174,
SYS_DELETE_MODULE = 176,
SYS_DUP = 32,
SYS_DUP2 = 33,
SYS_DUP3 = 292,
SYS_EPOLL_CREATE = 213,
SYS_EPOLL_CREATE1 = 291,
SYS_EPOLL_CTL = 233,
SYS_EPOLL_CTL_OLD = 214,
SYS_EPOLL_PWAIT = 281,
SYS_EPOLL_PWAIT2 = 441,
SYS_EPOLL_WAIT = 232,
SYS_EPOLL_WAIT_OLD = 215,
SYS_EVENTFD = 284,
SYS_EVENTFD2 = 290,
SYS_EXECVE = 59,
SYS_EXECVEAT = 322,
SYS_EXIT = 60,
SYS_EXIT_GROUP = 231,
SYS_FACCESSAT = 269,
SYS_FACCESSAT2 = 439,
SYS_FADVISE64 = 221,
SYS_FALLOCATE = 285,
SYS_FANOTIFY_INIT = 300,
SYS_FANOTIFY_MARK = 301,
SYS_FCHDIR = 81,
SYS_FCHMOD = 91,
SYS_FCHMODAT = 268,
SYS_FCHOWN = 93,
SYS_FCHOWNAT = 260,
SYS_FCNTL = 72,
SYS_FDATASYNC = 75,
SYS_FGETXATTR = 193,
SYS_FINIT_MODULE = 313,
SYS_FLISTXATTR = 196,
SYS_FLOCK = 73,
SYS_FORK = 57,
SYS_FREMOVEXATTR = 199,
SYS_FSCONFIG = 431,
SYS_FSETXATTR = 190,
SYS_FSMOUNT = 432,
SYS_FSOPEN = 430,
SYS_FSPICK = 433,
SYS_FSTAT = 5,
SYS_FSTATFS = 138,
SYS_FSYNC = 74,
SYS_FTRUNCATE = 77,
SYS_FUTEX = 202,
SYS_FUTIMESAT = 261,
SYS_GET_KERNEL_SYMS = 177,
SYS_GET_MEMPOLICY = 239,
SYS_GET_ROBUST_LIST = 274,
SYS_GET_THREAD_AREA = 211,
SYS_GETCPU = 309,
SYS_GETCWD = 79,
SYS_GETDENTS = 78,
SYS_GETDENTS64 = 217,
SYS_GETEGID = 108,
SYS_GETEUID = 107,
SYS_GETGID = 104,
SYS_GETGROUPS = 115,
SYS_GETITIMER = 36,
SYS_GETPEERNAME = 52,
SYS_GETPGID = 121,
SYS_GETPGRP = 111,
SYS_GETPID = 39,
SYS_GETPMSG = 181,
SYS_GETPPID = 110,
SYS_GETPRIORITY = 140,
SYS_GETRANDOM = 318,
SYS_GETRESGID = 120,
SYS_GETRESUID = 118,
SYS_GETRLIMIT = 97,
SYS_GETRUSAGE = 98,
SYS_GETSID = 124,
SYS_GETSOCKNAME = 51,
SYS_GETSOCKOPT = 55,
SYS_GETTID = 186,
SYS_GETTIMEOFDAY = 96,
SYS_GETUID = 102,
SYS_GETXATTR = 191,
SYS_INIT_MODULE = 175,
SYS_INOTIFY_ADD_WATCH = 254,
SYS_INOTIFY_INIT = 253,
SYS_INOTIFY_INIT1 = 294,
SYS_INOTIFY_RM_WATCH = 255,
SYS_IO_CANCEL = 210,
SYS_IO_DESTROY = 207,
SYS_IO_GETEVENTS = 208,
SYS_IO_PGETEVENTS = 333,
SYS_IO_SETUP = 206,
SYS_IO_SUBMIT = 209,
SYS_IO_URING_ENTER = 426,
SYS_IO_URING_REGISTER = 427,
SYS_IO_URING_SETUP = 425,
SYS_IOCTL = 16,
SYS_IOPERM = 173,
SYS_IOPL = 172,
SYS_IOPRIO_GET = 252,
SYS_IOPRIO_SET = 251,
SYS_KCMP = 312,
SYS_KEXEC_FILE_LOAD = 320,
SYS_KEXEC_LOAD = 246,
SYS_KEYCTL = 250,
SYS_KILL = 62,
SYS_LCHOWN = 94,
SYS_LGETXATTR = 192,
SYS_LINK = 86,
SYS_LINKAT = 265,
SYS_LISTEN = 50,
SYS_LISTXATTR = 194,
SYS_LLISTXATTR = 195,
SYS_LOOKUP_DCOOKIE = 212,
SYS_LREMOVEXATTR = 198,
SYS_LSEEK = 8,
SYS_LSETXATTR = 189,
SYS_LSTAT = 6,
SYS_MADVISE = 28,
SYS_MBIND = 237,
SYS_MEMBARRIER = 324,
SYS_MEMFD_CREATE = 319,
SYS_MIGRATE_PAGES = 256,
SYS_MINCORE = 27,
SYS_MKDIR = 83,
SYS_MKDIRAT = 258,
SYS_MKNOD = 133,
SYS_MKNODAT = 259,
SYS_MLOCK = 149,
SYS_MLOCK2 = 325,
SYS_MLOCKALL = 151,
SYS_MMAP = 9,
SYS_MODIFY_LDT = 154,
SYS_MOUNT = 165,
SYS_MOUNT_SETATTR = 442,
SYS_MOVE_MOUNT = 429,
SYS_MOVE_PAGES = 279,
SYS_MPROTECT = 10,
SYS_MQ_GETSETATTR = 245,
SYS_MQ_NOTIFY = 244,
SYS_MQ_OPEN = 240,
SYS_MQ_TIMEDRECEIVE = 243,
SYS_MQ_TIMEDSEND = 242,
SYS_MQ_UNLINK = 241,
SYS_MREMAP = 25,
SYS_MSGCTL = 71,
SYS_MSGGET = 68,
SYS_MSGRCV = 70,
SYS_MSGSND = 69,
SYS_MSYNC = 26,
SYS_MUNLOCK = 150,
SYS_MUNLOCKALL = 152,
SYS_MUNMAP = 11,
SYS_NAME_TO_HANDLE_AT = 303,
SYS_NANOSLEEP = 35,
SYS_NEWFSTATAT = 262,
SYS_NFSSERVCTL = 180,
SYS_OPEN = 2,
SYS_OPEN_BY_HANDLE_AT = 304,
SYS_OPEN_TREE = 428,
SYS_OPENAT = 257,
SYS_OPENAT2 = 437,
SYS_PAUSE = 34,
SYS_PERF_EVENT_OPEN = 298,
SYS_PERSONALITY = 135,
SYS_PIDFD_GETFD = 438,
SYS_PIDFD_OPEN = 434,
SYS_PIDFD_SEND_SIGNAL = 424,
SYS_PIPE = 22,
SYS_PIPE2 = 293,
SYS_PIVOT_ROOT = 155,
SYS_PKEY_ALLOC = 330,
SYS_PKEY_FREE = 331,
SYS_PKEY_MPROTECT = 329,
SYS_POLL = 7,
SYS_PPOLL = 271,
SYS_PRCTL = 157,
SYS_PREAD64 = 17,
SYS_PREADV = 295,
SYS_PREADV2 = 327,
SYS_PRLIMIT64 = 302,
SYS_PROCESS_MADVISE = 440,
SYS_PROCESS_VM_READV = 310,
SYS_PROCESS_VM_WRITEV = 311,
SYS_PSELECT6 = 270,
SYS_PTRACE = 101,
SYS_PUTPMSG = 182,
SYS_PWRITE64 = 18,
SYS_PWRITEV = 296,
SYS_PWRITEV2 = 328,
SYS_QUERY_MODULE = 178,
SYS_QUOTACTL = 179,
SYS_READ = 0,
SYS_READAHEAD = 187,
SYS_READLINK = 89,
SYS_READLINKAT = 267,
SYS_READV = 19,
SYS_REBOOT = 169,
SYS_RECVFROM = 45,
SYS_RECVMMSG = 299,
SYS_RECVMSG = 47,
SYS_REMAP_FILE_PAGES = 216,
SYS_REMOVEXATTR = 197,
SYS_RENAME = 82,
SYS_RENAMEAT = 264,
SYS_RENAMEAT2 = 316,
SYS_REQUEST_KEY = 249,
SYS_RESTART_SYSCALL = 219,
SYS_RMDIR = 84,
SYS_RSEQ = 334,
SYS_RT_SIGACTION = 13,
SYS_RT_SIGPENDING = 127,
SYS_RT_SIGPROCMASK = 14,
SYS_RT_SIGQUEUEINFO = 129,
SYS_RT_SIGRETURN = 15,
SYS_RT_SIGSUSPEND = 130,
SYS_RT_SIGTIMEDWAIT = 128,
SYS_RT_TGSIGQUEUEINFO = 297,
SYS_SCHED_GET_PRIORITY_MAX = 146,
SYS_SCHED_GET_PRIORITY_MIN = 147,
SYS_SCHED_GETAFFINITY = 204,
SYS_SCHED_GETATTR = 315,
SYS_SCHED_GETPARAM = 143,
SYS_SCHED_GETSCHEDULER = 145,
SYS_SCHED_RR_GET_INTERVAL = 148,
SYS_SCHED_SETAFFINITY = 203,
SYS_SCHED_SETATTR = 314,
SYS_SCHED_SETPARAM = 142,
SYS_SCHED_SETSCHEDULER = 144,
SYS_SCHED_YIELD = 24,
SYS_SECCOMP = 317,
SYS_SECURITY = 185,
SYS_SELECT = 23,
SYS_SEMCTL = 66,
SYS_SEMGET = 64,
SYS_SEMOP = 65,
SYS_SEMTIMEDOP = 220,
SYS_SENDFILE = 40,
SYS_SENDMMSG = 307,
SYS_SENDMSG = 46,
SYS_SENDTO = 44,
SYS_SET_MEMPOLICY = 238,
SYS_SET_ROBUST_LIST = 273,
SYS_SET_THREAD_AREA = 205,
SYS_SET_TID_ADDRESS = 218,
SYS_SETDOMAINNAME = 171,
SYS_SETFSGID = 123,
SYS_SETFSUID = 122,
SYS_SETGID = 106,
SYS_SETGROUPS = 116,
SYS_SETHOSTNAME = 170,
SYS_SETITIMER = 38,
SYS_SETNS = 308,
SYS_SETPGID = 109,
SYS_SETPRIORITY = 141,
SYS_SETREGID = 114,
SYS_SETRESGID = 119,
SYS_SETRESUID = 117,
SYS_SETREUID = 113,
SYS_SETRLIMIT = 160,
SYS_SETSID = 112,
SYS_SETSOCKOPT = 54,
SYS_SETTIMEOFDAY = 164,
SYS_SETUID = 105,
SYS_SETXATTR = 188,
SYS_SHMAT = 30,
SYS_SHMCTL = 31,
SYS_SHMDT = 67,
SYS_SHMGET = 29,
SYS_SHUTDOWN = 48,
SYS_SIGALTSTACK = 131,
SYS_SIGNALFD = 282,
SYS_SIGNALFD4 = 289,
SYS_SOCKET = 41,
SYS_SOCKETPAIR = 53,
SYS_SPLICE = 275,
SYS_STAT = 4,
SYS_STATFS = 137,
SYS_STATX = 332,
SYS_SWAPOFF = 168,
SYS_SWAPON = 167,
SYS_SYMLINK = 88,
SYS_SYMLINKAT = 266,
SYS_SYNC = 162,
SYS_SYNC_FILE_RANGE = 277,
SYS_SYNCFS = 306,
SYS_SYSFS = 139,
SYS_SYSINFO = 99,
SYS_SYSLOG = 103,
SYS_TEE = 276,
SYS_TGKILL = 234,
SYS_TIME = 201,
SYS_TIMER_CREATE = 222,
SYS_TIMER_DELETE = 226,
SYS_TIMER_GETOVERRUN = 225,
SYS_TIMER_GETTIME = 224,
SYS_TIMER_SETTIME = 223,
SYS_TIMERFD_CREATE = 283,
SYS_TIMERFD_GETTIME = 287,
SYS_TIMERFD_SETTIME = 286,
SYS_TIMES = 100,
SYS_TKILL = 200,
SYS_TRUNCATE = 76,
SYS_TUXCALL = 184,
SYS_UMASK = 95,
SYS_UMOUNT2 = 166,
SYS_UNAME = 63,
SYS_UNLINK = 87,
SYS_UNLINKAT = 263,
SYS_UNSHARE = 272,
SYS_USELIB = 134,
SYS_USERFAULTFD = 323,
SYS_USTAT = 136,
SYS_UTIME = 132,
SYS_UTIMENSAT = 280,
SYS_UTIMES = 235,
SYS_VFORK = 58,
SYS_VHANGUP = 153,
SYS_VMSPLICE = 278,
SYS_VSERVER = 236,
SYS_WAIT4 = 61,
SYS_WAITID = 247,
SYS_WRITE = 1,
SYS_WRITEV = 20,
}
impl TryFrom<usize> for SysCall {
type Error = SystemError;
fn try_from(value: usize) -> Result<Self, Self::Error> {
match <Self as FromPrimitive>::from_usize(value) {
Some(p) => Ok(p),
None => Err(SystemError::EINVAL),
}
}
}
impl From<SysCall> for usize {
fn from(value: SysCall) -> Self {
<SysCall as ToPrimitive>::to_usize(&value).unwrap()
}
}

View File

@ -22,7 +22,8 @@ use log::{debug, error, warn};
use system_error::SystemError; use system_error::SystemError;
use super::{acpi_kset, AcpiManager}; use super::{acpi_kset, AcpiManager};
use crate::driver::base::uevent::kobject_uevent::kobject_uevent;
use crate::driver::base::uevent::KobjectAction;
static mut __HOTPLUG_KSET_INSTANCE: Option<Arc<KSet>> = None; static mut __HOTPLUG_KSET_INSTANCE: Option<Arc<KSet>> = None;
static mut __ACPI_TABLES_KSET_INSTANCE: Option<Arc<KSet>> = None; static mut __ACPI_TABLES_KSET_INSTANCE: Option<Arc<KSet>> = None;
static mut __ACPI_TABLES_DATA_KSET_INSTANCE: Option<Arc<KSet>> = None; static mut __ACPI_TABLES_DATA_KSET_INSTANCE: Option<Arc<KSet>> = None;
@ -115,7 +116,27 @@ impl AcpiManager {
acpi_table_attr_list().write().push(attr); acpi_table_attr_list().write().push(attr);
self.acpi_table_data_init(&header)?; self.acpi_table_data_init(&header)?;
} }
// TODO:UEVENT
unsafe {
let _ = kobject_uevent(
acpi_tables_kset.clone() as Arc<dyn KObject>,
KobjectAction::KOBJADD,
);
let _ = kobject_uevent(
__ACPI_TABLES_DATA_KSET_INSTANCE
.as_ref()
.map(|kset| kset.clone() as Arc<dyn KObject>)
.unwrap(),
KobjectAction::KOBJADD,
);
let _ = kobject_uevent(
__ACPI_TABLES_DYNAMIC_KSET_INSTANCE
.as_ref()
.map(|kset| kset.clone() as Arc<dyn KObject>)
.unwrap(),
KobjectAction::KOBJADD,
);
}
return Ok(()); return Ok(());
} }

View File

@ -571,6 +571,7 @@ impl DriverManager {
} }
// todo: 发送kobj bind的uevent // todo: 发送kobj bind的uevent
// kobject_uevent();
} }
fn driver_is_bound(&self, device: &Arc<dyn Device>) -> bool { fn driver_is_bound(&self, device: &Arc<dyn Device>) -> bool {

View File

@ -2,6 +2,8 @@ use super::{
bus::{bus_manager, Bus}, bus::{bus_manager, Bus},
Device, DeviceMatchName, DeviceMatcher, IdTable, Device, DeviceMatchName, DeviceMatcher, IdTable,
}; };
use crate::driver::base::uevent::kobject_uevent::kobject_uevent;
use crate::driver::base::uevent::KobjectAction;
use crate::{ use crate::{
driver::base::{ driver::base::{
device::{bus::BusNotifyEvent, dd::DeviceAttrCoredump, device_manager}, device::{bus::BusNotifyEvent, dd::DeviceAttrCoredump, device_manager},
@ -17,7 +19,6 @@ use alloc::{
use core::fmt::Debug; use core::fmt::Debug;
use log::error; use log::error;
use system_error::SystemError; use system_error::SystemError;
/// @brief: Driver error /// @brief: Driver error
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -218,7 +219,9 @@ impl DriverManager {
bus_manager().remove_driver(&driver); bus_manager().remove_driver(&driver);
})?; })?;
// todo: 发送uevent // todo: 发送uevent类型问题
let _ = kobject_uevent(driver.clone() as Arc<dyn KObject>, KobjectAction::KOBJADD);
// deferred_probe_extend_timeout();
return Ok(()); return Ok(());
} }

View File

@ -506,7 +506,7 @@ impl DeviceManager {
} }
let kobject_parent = self.get_device_parent(&device, deivce_parent)?; let kobject_parent = self.get_device_parent(&device, deivce_parent)?;
if let Some(ref kobj) = kobject_parent { if let Some(ref kobj) = kobject_parent {
log::debug!("kobject parent: {:?}", kobj.name()); log::info!("kobject parent: {:?}", kobj.name());
} }
if let Some(kobject_parent) = kobject_parent { if let Some(kobject_parent) = kobject_parent {
// debug!( // debug!(
@ -547,7 +547,7 @@ impl DeviceManager {
} }
// todo: 发送uevent: KOBJ_ADD // todo: 发送uevent: KOBJ_ADD
// kobject_uevent();
// probe drivers for a new device // probe drivers for a new device
bus_probe_device(&device); bus_probe_device(&device);

View File

@ -1,6 +1,7 @@
use core::{any::Any, fmt::Debug, hash::Hash, ops::Deref}; use core::{any::Any, fmt::Debug, hash::Hash, ops::Deref};
use alloc::{ use alloc::{
boxed::Box,
string::String, string::String,
sync::{Arc, Weak}, sync::{Arc, Weak},
}; };
@ -21,7 +22,7 @@ use crate::{
use system_error::SystemError; use system_error::SystemError;
use super::kset::KSet; use super::{kset::KSet, uevent::kobject_uevent};
pub trait KObject: Any + Send + Sync + Debug + CastFromSync { pub trait KObject: Any + Send + Sync + Debug + CastFromSync {
fn as_any_ref(&self) -> &dyn core::any::Any; fn as_any_ref(&self) -> &dyn core::any::Any;
@ -103,10 +104,9 @@ bitflags! {
const ADD_UEVENT_SENT = 1 << 1; const ADD_UEVENT_SENT = 1 << 1;
const REMOVE_UEVENT_SENT = 1 << 2; const REMOVE_UEVENT_SENT = 1 << 2;
const INITIALIZED = 1 << 3; const INITIALIZED = 1 << 3;
const UEVENT_SUPPRESS = 1 << 4;
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct LockedKObjectState(RwLock<KObjectState>); pub struct LockedKObjectState(RwLock<KObjectState>);
@ -251,7 +251,7 @@ impl KObjectManager {
} }
// todo: 发送uevent: KOBJ_REMOVE // todo: 发送uevent: KOBJ_REMOVE
// kobject_uevent();
sysfs_instance().remove_dir(&kobj); sysfs_instance().remove_dir(&kobj);
kobj.update_kobj_state(None, Some(KObjectState::IN_SYSFS)); kobj.update_kobj_state(None, Some(KObjectState::IN_SYSFS));
let kset = kobj.kset(); let kset = kobj.kset();
@ -260,6 +260,105 @@ impl KObjectManager {
} }
kobj.set_parent(None); kobj.set_parent(None);
} }
fn get_kobj_path_length(kobj: &Arc<dyn KObject>) -> usize {
log::info!("get_kobj_path_length() kobj:{:?}", kobj.name());
let mut length = 1;
let mut parent = kobj.parent().unwrap().upgrade().unwrap();
/* walk up the ancestors until we hit the one pointing to the
* root.
* Add 1 to strlen for leading '/' of each level.
*/
let mut length = 0; // 确保 length 被正确初始化
let mut iteration_count = 0; // 用于记录迭代次数
const MAX_ITERATIONS: usize = 10; // 最大迭代次数
loop {
log::info!(
"Iteration {}: parent.name():{:?}",
iteration_count,
parent.name()
);
length += parent.name().len() + 1;
if let Some(weak_parent) = parent.parent() {
if let Some(upgraded_parent) = weak_parent.upgrade() {
parent = upgraded_parent;
} else {
log::error!("Failed to upgrade weak reference to parent");
break;
}
} else {
log::error!("Parent has no parent");
break;
}
iteration_count += 1;
if iteration_count >= MAX_ITERATIONS {
log::error!("Reached maximum iteration count, breaking to avoid infinite loop");
break;
}
}
return length;
}
/*
static void fill_kobj_path(struct kobject *kobj, char *path, int length)
{
struct kobject *parent;
--length;
for (parent = kobj; parent; parent = parent->parent) {
int cur = strlen(kobject_name(parent));
/* back up enough to print this name with '/' */
length -= cur;
memcpy(path + length, kobject_name(parent), cur);
*(path + --length) = '/';
}
pr_debug("kobject: '%s' (%p): %s: path = '%s'\n", kobject_name(kobj),
kobj, __func__, path);
}
*/
fn fill_kobj_path(kobj: &Arc<dyn KObject>, path: &mut [u8], length: usize) {
let mut parent = kobj.parent().unwrap().upgrade().unwrap();
let mut length = length;
length -= 1;
loop {
log::info!("fill_kobj_path parent.name():{:?}", parent.name());
let cur = parent.name().len();
if length < cur + 1 {
// 如果剩余长度不足以容纳当前名称和分隔符,则退出
break;
}
length -= cur;
let parent_name = parent.name();
let name = parent_name.as_bytes();
for i in 0..cur {
path[length + i] = name[i];
}
length -= 1;
path[length] = '/' as u8;
if let Some(weak_parent) = parent.parent() {
if let Some(upgraded_parent) = weak_parent.upgrade() {
parent = upgraded_parent;
} else {
break;
}
} else {
break;
}
}
}
// TODO: 实现kobject_get_path
// https://code.dragonos.org.cn/xref/linux-6.1.9/lib/kobject.c#139
pub fn kobject_get_path(kobj: &Arc<dyn KObject>) -> String {
log::debug!("kobject_get_path() kobj:{:?}", kobj.name());
let length = Self::get_kobj_path_length(kobj);
let path: &mut [u8] = &mut vec![0; length];
Self::fill_kobj_path(kobj, path, length);
let path_string = String::from_utf8(path.to_vec()).unwrap();
return path_string;
}
} }
/// 动态创建的kobject对象的ktype /// 动态创建的kobject对象的ktype

View File

@ -6,8 +6,11 @@ use alloc::{
use core::hash::Hash; use core::hash::Hash;
use super::kobject::{ use super::{
kobject::{
DynamicKObjKType, KObjType, KObject, KObjectManager, KObjectState, LockedKObjectState, DynamicKObjKType, KObjType, KObject, KObjectManager, KObjectState, LockedKObjectState,
},
uevent::KobjUeventEnv,
}; };
use crate::{ use crate::{
filesystem::kernfs::KernFSInode, filesystem::kernfs::KernFSInode,
@ -26,6 +29,8 @@ pub struct KSet {
/// 与父节点有关的一些信息 /// 与父节点有关的一些信息
parent_data: RwLock<KSetParentData>, parent_data: RwLock<KSetParentData>,
self_ref: Weak<KSet>, self_ref: Weak<KSet>,
/// kset用于发送uevent的操作函数集。kset能够发送它所包含的各种子kobj、孙kobj的消息即kobj或其父辈、爷爷辈都可以发送消息优先父辈然后是爷爷辈以此类推
pub uevent_ops: Option<Arc<dyn KSetUeventOps>>,
} }
impl Hash for KSet { impl Hash for KSet {
@ -51,6 +56,7 @@ impl KSet {
kobj_state: LockedKObjectState::new(None), kobj_state: LockedKObjectState::new(None),
parent_data: RwLock::new(KSetParentData::new(None, None)), parent_data: RwLock::new(KSetParentData::new(None, None)),
self_ref: Weak::default(), self_ref: Weak::default(),
uevent_ops: Some(Arc::new(KSetUeventOpsDefault)),
}; };
let r = Arc::new(r); let r = Arc::new(r);
@ -91,6 +97,7 @@ impl KSet {
pub fn register(&self, join_kset: Option<Arc<KSet>>) -> Result<(), SystemError> { pub fn register(&self, join_kset: Option<Arc<KSet>>) -> Result<(), SystemError> {
return KObjectManager::add_kobj(self.self_ref.upgrade().unwrap(), join_kset); return KObjectManager::add_kobj(self.self_ref.upgrade().unwrap(), join_kset);
// todo: 引入uevent之后发送uevent // todo: 引入uevent之后发送uevent
// kobject_uevent();
} }
/// 注销一个kset /// 注销一个kset
@ -232,3 +239,26 @@ impl InnerKSet {
} }
} }
} }
//https://code.dragonos.org.cn/xref/linux-6.1.9/include/linux/kobject.h#137
use core::fmt::Debug;
pub trait KSetUeventOps: Debug + Send + Sync {
fn filter(&self) -> Option<i32>;
fn uevent_name(&self) -> String;
fn uevent(&self, env: &KobjUeventEnv) -> i32;
}
#[derive(Debug)]
pub struct KSetUeventOpsDefault;
impl KSetUeventOps for KSetUeventOpsDefault {
fn filter(&self) -> Option<i32> {
Some(0)
}
fn uevent_name(&self) -> String {
String::new()
}
fn uevent(&self, env: &KobjUeventEnv) -> i32 {
0
}
}

View File

@ -12,3 +12,4 @@ pub mod map;
pub mod platform; pub mod platform;
pub mod subsys; pub mod subsys;
pub mod swnode; pub mod swnode;
pub mod uevent;

View File

@ -0,0 +1,504 @@
// https://code.dragonos.org.cn/xref/linux-6.1.9/lib/kobject_uevent.c
use super::KObject;
use super::KobjUeventEnv;
use super::KobjectAction;
use super::{UEVENT_BUFFER_SIZE, UEVENT_NUM_ENVP};
use crate::driver::base::kobject::{KObjectManager, KObjectState};
use crate::init::initcall::INITCALL_POSTCORE;
use crate::libs::mutex::Mutex;
use crate::libs::rwlock::RwLock;
use crate::net::socket::netlink::af_netlink::netlink_has_listeners;
use crate::net::socket::netlink::af_netlink::NetlinkSocket;
use crate::net::socket::netlink::af_netlink::{netlink_broadcast, NetlinkSock};
use crate::net::socket::netlink::netlink::{
netlink_kernel_create, NetlinkKernelCfg, NETLINK_KOBJECT_UEVENT, NL_CFG_F_NONROOT_RECV,
};
use crate::net::socket::netlink::skbuff::SkBuff;
use alloc::boxed::Box;
use alloc::collections::LinkedList;
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt::Write;
use num::Zero;
use system_error::SystemError;
use unified_init::macros::unified_init;
// 全局变量
pub static UEVENT_SEQNUM: u64 = 0;
// #ifdef CONFIG_UEVENT_HELPER
// char uevent_helper[UEVENT_HELPER_PATH_LEN] = CONFIG_UEVENT_HELPER_PATH;
// #endif
struct UeventSock {
inner: NetlinkSock,
}
impl UeventSock {
pub fn new(inner: NetlinkSock) -> Self {
UeventSock { inner }
}
}
// 用于存储所有用于发送 uevent 消息的 netlink sockets。这些 sockets 用于在内核和用户空间之间传递设备事件通知。
// 每当需要发送 uevent 消息时,内核会遍历这个链表,并通过其中的每一个 socket 发送消息。
// 使用 Mutex 保护全局链表
lazy_static::lazy_static! {
static ref UEVENT_SOCK_LIST: Mutex<LinkedList<UeventSock>> = Mutex::new(LinkedList::new());
}
// 回调函数,当接收到 uevent 消息时调用
fn uevent_net_rcv() {
// netlink_rcv_skb(skb, &uevent_net_rcv_skb);
}
/// 内核初始化的时候,在设备初始化之前执行
#[unified_init(INITCALL_POSTCORE)]
fn kobejct_uevent_init() -> Result<(), SystemError> {
// todo: net namespace
return uevent_net_init();
}
// TODO等net namespace实现后添加 net 参数和相关操作
// 内核启动的时候,即使没有进行网络命名空间的隔离也需要调用这个函数
// 支持 net namespace 之后需要在每个 net namespace 初始化的时候调用这个函数
/// 为每一个 net namespace 初始化 uevent
fn uevent_net_init() -> Result<(), SystemError> {
let cfg = NetlinkKernelCfg {
groups: 1,
flags: NL_CFG_F_NONROOT_RECV,
..Default::default()
};
// 创建一个内核 netlink socket
let ue_sk = UeventSock::new(netlink_kernel_create(NETLINK_KOBJECT_UEVENT, Some(cfg)).unwrap());
// todo: net namespace
// net.uevent_sock = ue_sk;
// 每个 net namespace 向链表中添加一个新的 uevent socket
UEVENT_SOCK_LIST.lock().push_back(ue_sk);
log::info!("uevent_net_init finish");
return Ok(());
}
// 系统关闭时清理
fn uevent_net_exit() {
// 清理链表
UEVENT_SOCK_LIST.lock().clear();
}
// /* This lock protects uevent_seqnum and uevent_sock_list */
// static DEFINE_MUTEX(uevent_sock_mutex);
// to be adjust
pub const BUFFERSIZE: usize = 666;
/*
kobject_uevent_envenvp为环境变量action的uevent
kobj本身或者其parent是否从属于某个kset2kobject没有加入ksetuevent的
kobj->uevent_suppress是否设置uevent上报并返回3Kobject的uevent_suppress标志Kobject的uevent的上报
kset有kset->filter函数43.2filter接口的说明kset可以通过filter接口过滤不希望上报的event
kset是否有合法的名称subsystemuevent
bufferenv指针中Kobject在sysfs中路径信息sysfs中访问它
add_uevent_var接口Actionsubsystem等信息env指针中
envp不空add_uevent_var接口env指针中
kset存在kset->uevent接口kset统一的环境变量到env指针
ACTION的类型kobj->state_add_uevent_sent和kobj->state_remove_uevent_sent变量
add_uevent_var接口"SEQNUM=%llu”的序列号
"CONFIG_NET”则使用netlink发送该uevent
uevent_helpersubsystem以及添加了标准环境变量HOME=/PATH=/sbin:/bin:/usr/sbin:/usr/binenv指针为参数kmod模块提供的call_usermodehelper函数uevent
uevent_helper的内容是由内核配置项CONFIG_UEVENT_HELPER_PATH(./drivers/base/Kconfig)(lib/kobject_uevent.c, line 32)uevent"/sbin/hotplug”。
call_usermodehelper的作用fork一个进程uevent为参数uevent_helper
kobject_ueventkobject_uevent_env功能一样
add_uevent_varprintfprintk等copy到env指针中
kobject_action_typeenum kobject_action类型的Action
*/
//kobject_uevent->kobject_uevent_env
pub fn kobject_uevent(kobj: Arc<dyn KObject>, action: KobjectAction) -> Result<(), SystemError> {
// kobject_uevent和kobject_uevent_env功能一样只是没有指定任何的环境变量
match kobject_uevent_env(kobj, action, None) {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
pub fn kobject_uevent_env(
kobj: Arc<dyn KObject>,
action: KobjectAction,
envp_ext: Option<Vec<String>>,
) -> Result<i32, SystemError> {
log::info!("kobject_uevent_env: kobj: {:?}, action: {:?}", kobj, action);
let mut state = KObjectState::empty();
let mut top_kobj = kobj.parent().unwrap().upgrade().unwrap();
let mut retval: i32;
let action_string = match action {
KobjectAction::KOBJADD => "add".to_string(),
KobjectAction::KOBJREMOVE => "remove".to_string(),
KobjectAction::KOBJCHANGE => "change".to_string(),
KobjectAction::KOBJMOVE => "move".to_string(),
KobjectAction::KOBJONLINE => "online".to_string(),
KobjectAction::KOBJOFFLINE => "offline".to_string(),
KobjectAction::KOBJBIND => "bind".to_string(),
KobjectAction::KOBJUNBIND => "unbind".to_string(),
};
/*
* Mark "remove" event done regardless of result, for some subsystems
* do not want to re-trigger "remove" event via automatic cleanup.
*/
if let KobjectAction::KOBJREMOVE = action {
log::info!("kobject_uevent_env: action: remove");
state.insert(KObjectState::REMOVE_UEVENT_SENT);
}
// 不断向上查找直到找到最顶层的kobject
while let Some(weak_parent) = top_kobj.parent() {
log::info!("kobject_uevent_env: top_kobj: {:?}", top_kobj);
top_kobj = weak_parent.upgrade().unwrap();
}
/* 查找当前kobject或其parent是否从属于某个kset;如果都不从属于某个kset则返回错误。(说明一个kobject若没有加入kset是不会上报uevent的) */
if kobj.kset().is_none() && top_kobj.kset().is_none() {
log::info!("attempted to send uevent without kset!\n");
return Err(SystemError::EINVAL);
}
let kset = top_kobj.kset();
// 判断该 kobject 的状态是否设置了uevent_suppress如果设置了则忽略所有的uevent上报并返回
if kobj.kobj_state().contains(KObjectState::UEVENT_SUPPRESS) {
log::info!("uevent_suppress caused the event to drop!");
return Ok(0);
}
// 如果所属的kset的kset->filter返回的是0过滤此次上报
if let Some(kset_ref) = kset.as_ref() {
if let Some(uevent_ops) = &kset_ref.uevent_ops {
if uevent_ops.filter() == Some(0) {
log::info!("filter caused the event to drop!");
return Ok(0);
}
}
}
// 判断所属的kset是否有合法的名称称作subsystem和前期的内核版本有区别否则不允许上报uevent
// originating subsystem
let subsystem: String = if let Some(kset_ref) = kset.as_ref() {
if let Some(uevent_ops) = &kset_ref.uevent_ops {
let name = uevent_ops.uevent_name();
if !name.is_empty() {
name
} else {
kobj.name()
}
} else {
kobj.name()
}
} else {
kobj.name()
};
if subsystem.is_empty() {
log::info!("unset subsystem caused the event to drop!");
}
log::info!("kobject_uevent_env: subsystem: {}", subsystem);
// 创建一个用于环境变量的缓冲区
let mut env = Box::new(KobjUeventEnv {
argv: Vec::with_capacity(UEVENT_NUM_ENVP),
envp: Vec::with_capacity(UEVENT_NUM_ENVP),
envp_idx: 0,
buf: vec![0; UEVENT_BUFFER_SIZE],
buflen: 0,
});
if env.buf.is_empty() {
log::error!("kobject_uevent_env: failed to allocate buffer");
return Err(SystemError::ENOMEM);
}
// 获取设备的完整对象路径
let devpath: String = KObjectManager::kobject_get_path(&kobj);
log::info!("kobject_uevent_env: devpath: {}", devpath);
if devpath.is_empty() {
retval = SystemError::ENOENT.to_posix_errno();
// goto exit
drop(devpath);
drop(env);
log::warn!("kobject_uevent_env: devpath is empty");
return Ok(retval);
}
retval = add_uevent_var(&mut env, "ACTION=%s", &action_string).unwrap();
log::info!("kobject_uevent_env: retval: {}", retval);
if !retval.is_zero() {
drop(devpath);
drop(env);
log::info!("add_uevent_var failed ACTION");
return Ok(retval);
};
retval = add_uevent_var(&mut env, "DEVPATH=%s", &devpath).unwrap();
if !retval.is_zero() {
drop(devpath);
drop(env);
log::info!("add_uevent_var failed DEVPATH");
return Ok(retval);
};
retval = add_uevent_var(&mut env, "SUBSYSTEM=%s", &subsystem).unwrap();
if !retval.is_zero() {
drop(devpath);
drop(env);
log::info!("add_uevent_var failed SUBSYSTEM");
return Ok(retval);
};
/* keys passed in from the caller */
if let Some(env_ext) = envp_ext {
for var in env_ext {
let retval = add_uevent_var(&mut env, "%s", &var).unwrap();
if !retval.is_zero() {
drop(devpath);
drop(env);
log::info!("add_uevent_var failed");
return Ok(retval);
}
}
}
if let Some(kset_ref) = kset.as_ref() {
if let Some(uevent_ops) = kset_ref.uevent_ops.as_ref() {
if uevent_ops.uevent(&env) != 0 {
retval = uevent_ops.uevent(&env);
if retval.is_zero() {
log::info!("kset uevent caused the event to drop!");
// goto exit
drop(devpath);
drop(env);
return Ok(retval);
}
}
}
}
match action {
KobjectAction::KOBJADD => {
state.insert(KObjectState::ADD_UEVENT_SENT);
}
KobjectAction::KOBJUNBIND => {
zap_modalias_env(&mut env);
}
_ => {}
}
//mutex_lock(&uevent_sock_mutex);
/* we will send an event, so request a new sequence number */
retval = add_uevent_var(&mut env, "SEQNUM=%llu", &(UEVENT_SEQNUM + 1).to_string()).unwrap();
if !retval.is_zero() {
drop(devpath);
drop(env);
log::info!("add_uevent_var failed");
return Ok(retval);
}
retval = kobject_uevent_net_broadcast(kobj, &env, &action_string, &devpath);
//mutex_unlock(&uevent_sock_mutex);
#[cfg(feature = "UEVENT_HELPER")]
fn handle_uevent_helper() {
// TODO
// 在特性 `UEVENT_HELPER` 开启的情况下,这里的代码会执行
// 指定处理uevent的用户空间程序通常是热插拔程序mdev、udevd等
// /* call uevent_helper, usually only enabled during early boot */
// if (uevent_helper[0] && !kobj_usermode_filter(kobj)) {
// struct subprocess_info *info;
// retval = add_uevent_var(env, "HOME=/");
// if (retval)
// goto exit;
// retval = add_uevent_var(env,
// "PATH=/sbin:/bin:/usr/sbin:/usr/bin");
// if (retval)
// goto exit;
// retval = init_uevent_argv(env, subsystem);
// if (retval)
// goto exit;
// retval = -ENOMEM;
// info = call_usermodehelper_setup(env->argv[0], env->argv,
// env->envp, GFP_KERNEL,
// NULL, cleanup_uevent_env, env);
// if (info) {
// retval = call_usermodehelper_exec(info, UMH_NO_WAIT);
// env = NULL; /* freed by cleanup_uevent_env */
// }
// }
}
#[cfg(not(feature = "UEVENT_HELPER"))]
fn handle_uevent_helper() {
// 在特性 `UEVENT_HELPER` 关闭的情况下,这里的代码会执行
}
handle_uevent_helper();
drop(devpath);
drop(env);
log::info!("kobject_uevent_env: retval: {}", retval);
return Ok(retval);
}
pub fn add_uevent_var(
env: &mut Box<KobjUeventEnv>,
format: &str,
args: &str,
) -> Result<i32, SystemError> {
log::info!("add_uevent_var: format: {}, args: {}", format, args);
if env.envp_idx >= env.envp.capacity() {
log::info!("add_uevent_var: too many keys");
return Err(SystemError::ENOMEM);
}
let mut buffer = String::new();
write!(&mut buffer, "{} {}", format, args).map_err(|_| SystemError::ENOMEM)?;
let len = buffer.len();
if len >= env.buf.capacity() - env.buflen {
log::info!("add_uevent_var: buffer size too small");
return Err(SystemError::ENOMEM);
}
// Convert the buffer to bytes and add to env.buf
env.buf.extend_from_slice(buffer.as_bytes());
env.buf.push(0); // Null-terminate the string
env.buflen += len + 1;
// Add the string to envp
env.envp.push(buffer);
env.envp_idx += 1;
Ok(0)
}
// 用于处理设备树中与模块相关的环境变量
fn zap_modalias_env(env: &mut Box<KobjUeventEnv>) {
// 定义一个静态字符串
const MODALIAS_PREFIX: &str = "MODALIAS=";
let mut len: usize;
let mut i = 0;
while i < env.envp_idx {
// 如果是以 MODALIAS= 开头的字符串
if env.envp[i].starts_with(MODALIAS_PREFIX) {
len = env.envp[i].len() + 1;
// 如果不是最后一个元素
if i != env.envp_idx - 1 {
// 将后续的环境变量向前移动,以覆盖掉 "MODALIAS=" 前缀的环境变量
for j in i..env.envp_idx - 1 {
env.envp[j] = env.envp[j + 1].clone();
}
}
// 减少环境变量数组的索引,因为一个变量已经被移除
env.envp_idx -= 1;
// 减少环境变量的总长度
env.buflen -= len;
} else {
i += 1;
}
}
}
// 用于处理网络相关的uevent通用事件广播
// https://code.dragonos.org.cn/xref/linux-6.1.9/lib/kobject_uevent.c#381
pub fn kobject_uevent_net_broadcast(
kobj: Arc<dyn KObject>,
env: &KobjUeventEnv,
action_string: &str,
devpath: &str,
) -> i32 {
let mut ret = 0;
// let net:Net = None;
// let mut ops = kobj_ns_ops(kobj);
// if (!ops && kobj.kset().is_some()) {
// let ksobj:KObject = &kobj.kset().kobj();
// if (ksobj.parent() != NULL){
// ops = kobj_ns_ops(ksobj.parent());
// }
// }
// TODO: net结构体
// https://code.dragonos.org.cn/xref/linux-6.1.9/include/net/net_namespace.h#60
/* kobjects currently only carry network namespace tags and they
* are the only tag relevant here since we want to decide which
* network namespaces to broadcast the uevent into.
*/
// if (ops && ops.netlink_ns() && kobj.ktype().namespace())
// if (ops.type() == KOBJ_NS_TYPE_NET)
// net = kobj.ktype().namespace(kobj);
// 如果有网络命名空间则广播标记的uevent如果没有则广播未标记的uevent
// if !net.is_none() {
// ret = uevent_net_broadcast_tagged(net.unwrap(), env, action_string, devpath);
// } else {
ret = uevent_net_broadcast_untagged(env, action_string, devpath);
// }
log::info!("kobject_uevent_net_broadcast finish. ret: {}", ret);
ret
}
pub fn uevent_net_broadcast_tagged(
sk: &dyn NetlinkSocket,
env: &KobjUeventEnv,
action_string: &str,
devpath: &str,
) -> i32 {
let ret = 0;
ret
}
/// 分配一个用于 uevent 消息的 skbsocket buffer
pub fn alloc_uevent_skb<'a>(
env: &'a KobjUeventEnv,
action_string: &'a str,
devpath: &'a str,
) -> Arc<RwLock<SkBuff>> {
let skb = Arc::new(RwLock::new(SkBuff::new()));
skb
}
// https://code.dragonos.org.cn/xref/linux-6.1.9/lib/kobject_uevent.c#309
/// 广播一个未标记的 uevent 消息
pub fn uevent_net_broadcast_untagged(
env: &KobjUeventEnv,
action_string: &str,
devpath: &str,
) -> i32 {
log::info!(
"uevent_net_broadcast_untagged: action_string: {}, devpath: {}",
action_string,
devpath
);
let mut retval = 0;
let mut skb = Arc::new(RwLock::new(SkBuff::new()));
// 锁定 UEVENT_SOCK_LIST 并遍历
let ue_sk_list = UEVENT_SOCK_LIST.lock();
for ue_sk in ue_sk_list.iter() {
// 如果没有监听者,则跳过
if netlink_has_listeners(&ue_sk.inner, 1) == 0 {
log::info!("uevent_net_broadcast_untagged: no listeners");
continue;
}
// 如果 skb 为空,则分配一个新的 skb
if skb.read().is_empty() {
log::info!("uevent_net_broadcast_untagged: alloc_uevent_skb failed");
retval = SystemError::ENOMEM.to_posix_errno();
skb = alloc_uevent_skb(env, action_string, devpath);
if skb.read().is_empty() {
continue;
}
}
log::info!("next is netlink_broadcast");
let netlink_socket: Arc<dyn NetlinkSocket> = Arc::new(ue_sk.inner.clone());
retval = match netlink_broadcast(&netlink_socket, Arc::clone(&skb), 0, 1, 1) {
Ok(_) => 0,
Err(err) => err.to_posix_errno(),
};
log::info!("finished netlink_broadcast");
// ENOBUFS should be handled in userspace
if retval == SystemError::ENOBUFS.to_posix_errno()
|| retval == SystemError::ESRCH.to_posix_errno()
{
retval = 0;
}
}
// consume_skb(skb);
retval
}

View File

@ -0,0 +1,102 @@
// include/linux/kobject.h
// lib/kobject_uevent.c
/*
UEVENT_HELPER_PATH_LEN
UEVENT_NUM_ENVP
_KOBJECT_H_
Variable
__randomize_layout
Enum
kobject_action
Struct
kobj_attribute
kobj_type
kobj_uevent_env
kobject
kset
kset_uevent_ops
Function
get_ktype
kobject_name
kset_get
kset_put
to_kset
*/
use crate::driver::base::kobject::KObject;
use alloc::string::String;
use alloc::vec::Vec;
pub mod kobject_uevent;
// https://code.dragonos.org.cn/xref/linux-6.1.9/lib/kobject_uevent.c?fi=kobject_uevent#457
// kobject_action
#[derive(Debug)]
pub enum KobjectAction {
KOBJADD,
KOBJREMOVE, //Kobject或上层数据结构的添加/移除事件
KOBJCHANGE, //Kobject或上层数据结构的状态或者内容发生改变; 如果设备驱动需要上报的事件不再上面事件的范围内或者是自定义的事件可以使用该event并携带相应的参数。
KOBJMOVE, //Kobject或上层数据结构更改名称或者更改Parent意味着在sysfs中更改了目录结构
KOBJONLINE,
KOBJOFFLINE, //Kobject或上层数据结构的上线/下线事件,其实是是否使能
KOBJBIND,
KOBJUNBIND,
}
/*
@parament:
envpUEVENT_NUM_ENVP
envp_idx访index
bufbufferUEVENT_BUFFER_SIZE
buflen访buf的变量
*/
//https://code.dragonos.org.cn/xref/linux-6.1.9/include/linux/kobject.h#31
pub const UEVENT_NUM_ENVP: usize = 64;
pub const UEVENT_BUFFER_SIZE: usize = 2048;
pub const UEVENT_HELPER_PATH_LEN: usize = 256;
/// Represents the environment for handling kernel object uevents.
/*
envpUEVENT_NUM_ENVP
envp_idx访index
bufbufferUEVENT_BUFFER_SIZE
buflen访buf的变量
*/
// 表示一个待发送的uevent
#[derive(Debug)]
pub struct KobjUeventEnv {
argv: Vec<String>,
envp: Vec<String>,
envp_idx: usize,
buf: Vec<u8>,
buflen: usize,
}
// kset_uevent_ops是为kset量身订做的一个数据结构里面包含filter和uevent两个回调函数用处如下
/*
filterKobject需要上报uevent时kset可以通过该接口过滤event
namekset的名称kset没有合法的名称Kobject将不允许上报uvent
ueventKobject需要上报uevent时kset可以通过该接口统一为这些event添加环境变量uevent时的环境变量都是相同的kset统一处理Kobject独自添加了
*/

View File

@ -8,7 +8,9 @@ use crate::{
device::{bus::Bus, driver::Driver, Device, DeviceCommonData, DeviceType, IdTable}, device::{bus::Bus, driver::Driver, Device, DeviceCommonData, DeviceType, IdTable},
kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState},
}, },
net::{register_netdevice, NetDeivceState, NetDevice, NetDeviceCommonData, Operstate}, net::{
register_netdevice, Iface, IfaceCommon, NetDeivceState, NetDeviceCommonData, Operstate,
},
}, },
libs::{ libs::{
rwlock::{RwLockReadGuard, RwLockWriteGuard}, rwlock::{RwLockReadGuard, RwLockWriteGuard},
@ -27,11 +29,8 @@ use core::{
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
}; };
use log::info; use log::info;
use smoltcp::{ use smoltcp::{phy, wire::HardwareAddress};
phy, // use system_error::SystemError;
wire::{self, HardwareAddress},
};
use system_error::SystemError;
use super::e1000e::{E1000EBuffer, E1000EDevice}; use super::e1000e::{E1000EBuffer, E1000EDevice};
@ -78,12 +77,12 @@ impl Debug for E1000EDriverWrapper {
} }
} }
#[cast_to([sync] NetDevice)] #[cast_to([sync] Iface)]
#[cast_to([sync] Device)] #[cast_to([sync] Device)]
#[derive(Debug)]
pub struct E1000EInterface { pub struct E1000EInterface {
driver: E1000EDriverWrapper, driver: E1000EDriverWrapper,
iface_id: usize, common: IfaceCommon,
iface: SpinLock<smoltcp::iface::Interface>,
name: String, name: String,
inner: SpinLock<InnerE1000EInterface>, inner: SpinLock<InnerE1000EInterface>,
locked_kobj_state: LockedKObjectState, locked_kobj_state: LockedKObjectState,
@ -201,11 +200,9 @@ impl E1000EInterface {
let iface = let iface =
smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into());
let driver: E1000EDriverWrapper = E1000EDriverWrapper(UnsafeCell::new(driver));
let result = Arc::new(E1000EInterface { let result = Arc::new(E1000EInterface {
driver, driver: E1000EDriverWrapper(UnsafeCell::new(driver)),
iface_id, common: IfaceCommon::new(iface_id, iface),
iface: SpinLock::new(iface),
name: format!("eth{}", iface_id), name: format!("eth{}", iface_id),
inner: SpinLock::new(InnerE1000EInterface { inner: SpinLock::new(InnerE1000EInterface {
netdevice_common: NetDeviceCommonData::default(), netdevice_common: NetDeviceCommonData::default(),
@ -223,16 +220,6 @@ impl E1000EInterface {
} }
} }
impl Debug for E1000EInterface {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("E1000EInterface")
.field("iface_id", &self.iface_id)
.field("iface", &"smoltcp::iface::Interface")
.field("name", &self.name)
.finish()
}
}
impl Device for E1000EInterface { impl Device for E1000EInterface {
fn dev_type(&self) -> DeviceType { fn dev_type(&self) -> DeviceType {
DeviceType::Net DeviceType::Net
@ -302,52 +289,23 @@ impl Device for E1000EInterface {
} }
} }
impl NetDevice for E1000EInterface { impl Iface for E1000EInterface {
fn common(&self) -> &IfaceCommon {
return &self.common;
}
fn mac(&self) -> smoltcp::wire::EthernetAddress { fn mac(&self) -> smoltcp::wire::EthernetAddress {
let mac = self.driver.inner.lock().mac_address(); let mac = self.driver.inner.lock().mac_address();
return smoltcp::wire::EthernetAddress::from_bytes(&mac); return smoltcp::wire::EthernetAddress::from_bytes(&mac);
} }
#[inline]
fn nic_id(&self) -> usize {
return self.iface_id;
}
#[inline] #[inline]
fn iface_name(&self) -> String { fn iface_name(&self) -> String {
return self.name.clone(); return self.name.clone();
} }
fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> { fn poll(&self) {
if ip_addrs.len() != 1 { self.common.poll(self.driver.force_get_mut())
return Err(SystemError::EINVAL);
}
self.iface.lock().update_ip_addrs(|addrs| {
let dest = addrs.iter_mut().next();
if let Some(dest) = dest {
*dest = ip_addrs[0];
} else {
addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
}
});
return Ok(());
}
fn poll(&self, sockets: &mut smoltcp::iface::SocketSet) -> Result<(), SystemError> {
let timestamp: smoltcp::time::Instant = Instant::now().into();
let mut guard = self.iface.lock();
let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets);
if poll_res {
return Ok(());
}
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
#[inline(always)]
fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
return &self.iface;
} }
fn addr_assign_type(&self) -> u8 { fn addr_assign_type(&self) -> u8 {

View File

@ -7,7 +7,8 @@ use crate::{
irqdesc::{IrqHandler, IrqReturn}, irqdesc::{IrqHandler, IrqReturn},
IrqNumber, IrqNumber,
}, },
net::net_core::poll_ifaces_try_lock_onetime, // net::net_core::poll_ifaces_try_lock_onetime,
net::net_core::poll_ifaces,
}; };
/// 默认的网卡中断处理函数 /// 默认的网卡中断处理函数
@ -21,7 +22,9 @@ impl IrqHandler for DefaultNetIrqHandler {
_static_data: Option<&dyn IrqHandlerData>, _static_data: Option<&dyn IrqHandlerData>,
_dynamic_data: Option<Arc<dyn IrqHandlerData>>, _dynamic_data: Option<Arc<dyn IrqHandlerData>>,
) -> Result<IrqReturn, SystemError> { ) -> Result<IrqReturn, SystemError> {
poll_ifaces_try_lock_onetime().ok(); // poll_ifaces_try_lock_onetime().ok();
log::warn!("DefaultNetIrqHandler: poll_ifaces_try_lock_onetime -> poll_ifaces");
poll_ifaces();
Ok(IrqReturn::Handled) Ok(IrqReturn::Handled)
} }
} }

View File

@ -28,7 +28,9 @@ use smoltcp::{
use system_error::SystemError; use system_error::SystemError;
use unified_init::macros::unified_init; use unified_init::macros::unified_init;
use super::{register_netdevice, NetDeivceState, NetDevice, NetDeviceCommonData, Operstate}; use super::{register_netdevice, NetDeivceState, NetDeviceCommonData, Operstate};
use super::{Iface, IfaceCommon};
const DEVICE_NAME: &str = "loopback"; const DEVICE_NAME: &str = "loopback";
@ -81,6 +83,7 @@ impl phy::TxToken for LoopbackTxToken {
let result = f(buffer.as_mut_slice()); let result = f(buffer.as_mut_slice());
let mut device = self.driver.inner.lock(); let mut device = self.driver.inner.lock();
device.loopback_transmit(buffer); device.loopback_transmit(buffer);
// debug!("lo transmit!");
result result
} }
} }
@ -112,7 +115,7 @@ impl Loopback {
let buffer = self.queue.pop_front(); let buffer = self.queue.pop_front();
match buffer { match buffer {
Some(buffer) => { Some(buffer) => {
//debug!("lo receive:{:?}", buffer); // debug!("lo receive:{:?}", buffer);
return buffer; return buffer;
} }
None => { None => {
@ -127,7 +130,7 @@ impl Loopback {
/// - &mut self自身可变引用 /// - &mut self自身可变引用
/// - buffer需要发送的数据包 /// - buffer需要发送的数据包
pub fn loopback_transmit(&mut self, buffer: Vec<u8>) { pub fn loopback_transmit(&mut self, buffer: Vec<u8>) {
//debug!("lo transmit!"); // debug!("lo transmit:{:?}", buffer);
self.queue.push_back(buffer) self.queue.push_back(buffer)
} }
} }
@ -136,6 +139,7 @@ impl Loopback {
/// 为实现获得不可变引用的Interface的内部可变性故为Driver提供UnsafeCell包裹器 /// 为实现获得不可变引用的Interface的内部可变性故为Driver提供UnsafeCell包裹器
/// ///
/// 参考virtio_net.rs /// 参考virtio_net.rs
#[derive(Debug)]
struct LoopbackDriverWapper(UnsafeCell<LoopbackDriver>); struct LoopbackDriverWapper(UnsafeCell<LoopbackDriver>);
unsafe impl Send for LoopbackDriverWapper {} unsafe impl Send for LoopbackDriverWapper {}
unsafe impl Sync for LoopbackDriverWapper {} unsafe impl Sync for LoopbackDriverWapper {}
@ -214,8 +218,10 @@ impl phy::Device for LoopbackDriver {
let buffer = self.inner.lock().loopback_receive(); let buffer = self.inner.lock().loopback_receive();
//receive队列为为空返回NONE值以通知上层没有可以receive的包 //receive队列为为空返回NONE值以通知上层没有可以receive的包
if buffer.is_empty() { if buffer.is_empty() {
// log::debug!("lo receive none!");
return Option::None; return Option::None;
} }
// log::debug!("lo receive!");
let rx = LoopbackRxToken { buffer }; let rx = LoopbackRxToken { buffer };
let tx = LoopbackTxToken { let tx = LoopbackTxToken {
driver: self.clone(), driver: self.clone(),
@ -232,6 +238,7 @@ impl phy::Device for LoopbackDriver {
/// ## 返回值 /// ## 返回值
/// - 返回一个 `Some`,其中包含一个发送令牌,该令牌包含一个对自身的克隆引用 /// - 返回一个 `Some`,其中包含一个发送令牌,该令牌包含一个对自身的克隆引用
fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> { fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
// log::debug!("lo transmit!");
Some(LoopbackTxToken { Some(LoopbackTxToken {
driver: self.clone(), driver: self.clone(),
}) })
@ -240,13 +247,12 @@ impl phy::Device for LoopbackDriver {
/// ## LoopbackInterface结构 /// ## LoopbackInterface结构
/// 封装驱动包裹器和iface设置接口名称 /// 封装驱动包裹器和iface设置接口名称
#[cast_to([sync] NetDevice)] #[cast_to([sync] Iface)]
#[cast_to([sync] Device)] #[cast_to([sync] Device)]
#[derive(Debug)]
pub struct LoopbackInterface { pub struct LoopbackInterface {
driver: LoopbackDriverWapper, driver: LoopbackDriverWapper,
iface_id: usize, common: IfaceCommon,
iface: SpinLock<smoltcp::iface::Interface>,
name: String,
inner: SpinLock<InnerLoopbackInterface>, inner: SpinLock<InnerLoopbackInterface>,
locked_kobj_state: LockedKObjectState, locked_kobj_state: LockedKObjectState,
} }
@ -280,16 +286,20 @@ impl LoopbackInterface {
smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into());
//设置网卡地址为127.0.0.1 //设置网卡地址为127.0.0.1
iface.update_ip_addrs(|ip_addrs| { iface.update_ip_addrs(|ip_addrs| {
for i in 1..=2 {
ip_addrs ip_addrs
.push(IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)) .push(IpCidr::new(IpAddress::v4(127, 0, 0, i), 8))
.unwrap(); .expect("Push ipCidr failed: full");
}
}); });
let driver = LoopbackDriverWapper(UnsafeCell::new(driver));
// iface.routes_mut().update(|routes_map| {
// routes_map[0].
// });
Arc::new(LoopbackInterface { Arc::new(LoopbackInterface {
driver, driver: LoopbackDriverWapper(UnsafeCell::new(driver)),
iface_id, common: IfaceCommon::new(iface_id, iface),
iface: SpinLock::new(iface),
name: "lo".to_string(),
inner: SpinLock::new(InnerLoopbackInterface { inner: SpinLock::new(InnerLoopbackInterface {
netdevice_common: NetDeviceCommonData::default(), netdevice_common: NetDeviceCommonData::default(),
device_common: DeviceCommonData::default(), device_common: DeviceCommonData::default(),
@ -304,16 +314,7 @@ impl LoopbackInterface {
} }
} }
impl Debug for LoopbackInterface { //TODO: 向sysfs注册lo设备
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LoopbackInterface")
.field("iface_id", &self.iface_id)
.field("iface", &"smtoltcp::iface::Interface")
.field("name", &self.name)
.finish()
}
}
impl KObject for LoopbackInterface { impl KObject for LoopbackInterface {
fn as_any_ref(&self) -> &dyn core::any::Any { fn as_any_ref(&self) -> &dyn core::any::Any {
self self
@ -348,7 +349,7 @@ impl KObject for LoopbackInterface {
} }
fn name(&self) -> String { fn name(&self) -> String {
self.name.clone() "lo".to_string()
} }
fn set_name(&self, _name: String) { fn set_name(&self, _name: String) {
@ -441,72 +442,23 @@ impl Device for LoopbackInterface {
} }
} }
impl NetDevice for LoopbackInterface { impl Iface for LoopbackInterface {
fn common(&self) -> &IfaceCommon {
&self.common
}
fn iface_name(&self) -> String {
"lo".to_string()
}
/// 由于lo网卡设备不是实际的物理设备其mac地址需要手动设置为一个默认值这里默认为00:00:00:00:00 /// 由于lo网卡设备不是实际的物理设备其mac地址需要手动设置为一个默认值这里默认为00:00:00:00:00
fn mac(&self) -> smoltcp::wire::EthernetAddress { fn mac(&self) -> smoltcp::wire::EthernetAddress {
let mac = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; let mac = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
smoltcp::wire::EthernetAddress(mac) smoltcp::wire::EthernetAddress(mac)
} }
#[inline] fn poll(&self) {
fn nic_id(&self) -> usize { self.common.poll(self.driver.force_get_mut())
self.iface_id
}
#[inline]
fn iface_name(&self) -> String {
self.name.clone()
}
/// ## `update_ip_addrs` 用于更新接口的 IP 地址。
///
/// ## 参数
/// - `&self` :自身引用
/// - `ip_addrs` :一个包含 `smoltcp::wire::IpCidr` 的切片,表示要设置的 IP 地址和子网掩码
///
/// ## 返回值
/// - 如果 `ip_addrs` 的长度不为 1返回 `Err(SystemError::EINVAL)`,表示输入参数无效
/// - 如果更新成功,返回 `Ok(())`
fn update_ip_addrs(
&self,
ip_addrs: &[smoltcp::wire::IpCidr],
) -> Result<(), system_error::SystemError> {
if ip_addrs.len() != 1 {
return Err(SystemError::EINVAL);
}
self.iface.lock().update_ip_addrs(|addrs| {
let dest = addrs.iter_mut().next();
if let Some(dest) = dest {
*dest = ip_addrs[0];
} else {
addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
}
});
return Ok(());
}
/// ## `poll` 用于轮询接口的状态。
///
/// ## 参数
/// - `&self` :自身引用
/// - `sockets` :一个可变引用到 `smoltcp::iface::SocketSet`,表示要轮询的套接字集
///
/// ## 返回值
/// - 如果轮询成功,返回 `Ok(())`
/// - 如果轮询失败,返回 `Err(SystemError::EAGAIN_OR_EWOULDBLOCK)`,表示需要再次尝试或者操作会阻塞
fn poll(&self, sockets: &mut smoltcp::iface::SocketSet) -> Result<(), SystemError> {
let timestamp: smoltcp::time::Instant = Instant::now().into();
let mut guard = self.iface.lock();
let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets);
if poll_res {
return Ok(());
}
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
#[inline(always)]
fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
return &self.iface;
} }
fn addr_assign_type(&self) -> u8 { fn addr_assign_type(&self) -> u8 {
@ -538,7 +490,7 @@ impl NetDevice for LoopbackInterface {
pub fn loopback_probe() { pub fn loopback_probe() {
loopback_driver_init(); loopback_driver_init();
} }
/// ## lo网卡设备初始化函数 /// # lo网卡设备初始化函数
/// 创建驱动和iface初始化一个lo网卡添加到全局NET_DEVICES中 /// 创建驱动和iface初始化一个lo网卡添加到全局NET_DEVICES中
pub fn loopback_driver_init() { pub fn loopback_driver_init() {
let driver = LoopbackDriver::new(); let driver = LoopbackDriver::new();
@ -548,14 +500,16 @@ pub fn loopback_driver_init() {
NET_DEVICES NET_DEVICES
.write_irqsave() .write_irqsave()
.insert(iface.iface_id, iface.clone()); .insert(iface.nic_id(), iface.clone());
register_netdevice(iface.clone()).expect("register lo device failed"); register_netdevice(iface.clone()).expect("register lo device failed");
} }
/// ## lo网卡设备的注册函数 /// ## lo网卡设备的注册函数
#[unified_init(INITCALL_DEVICE)] //TODO: 现在先不用初始化宏进行注册使virtonet排在网卡列表头待网络子系统重构后再使用初始化宏并修复该bug
// #[unified_init(INITCALL_DEVICE)]
pub fn loopback_init() -> Result<(), SystemError> { pub fn loopback_init() -> Result<(), SystemError> {
loopback_probe(); loopback_probe();
log::debug!("Successfully init loopback device");
return Ok(()); return Ok(());
} }

View File

@ -1,3 +1,4 @@
use alloc::{fmt, vec::Vec};
use alloc::{string::String, sync::Arc}; use alloc::{string::String, sync::Arc};
use smoltcp::{ use smoltcp::{
iface, iface,
@ -5,8 +6,12 @@ use smoltcp::{
}; };
use sysfs::netdev_register_kobject; use sysfs::netdev_register_kobject;
use super::base::device::Device; use crate::{
use crate::libs::spinlock::SpinLock; libs::{rwlock::RwLock, spinlock::SpinLock},
net::socket::inet::{common::PortManager, InetSocket},
process::ProcessState,
};
use smoltcp;
use system_error::SystemError; use system_error::SystemError;
pub mod class; pub mod class;
@ -52,23 +57,63 @@ pub enum Operstate {
} }
#[allow(dead_code)] #[allow(dead_code)]
pub trait NetDevice: Device { pub trait Iface: crate::driver::base::device::Device {
/// @brief 获取网卡的MAC地址 /// # `common`
fn mac(&self) -> EthernetAddress; /// 获取网卡的公共信息
fn common(&self) -> &IfaceCommon;
/// # `mac`
/// 获取网卡的MAC地址
fn mac(&self) -> smoltcp::wire::EthernetAddress;
/// # `name`
/// 获取网卡名
fn iface_name(&self) -> String; fn iface_name(&self) -> String;
/// @brief 获取网卡的id /// # `nic_id`
fn nic_id(&self) -> usize; /// 获取网卡id
fn nic_id(&self) -> usize {
self.common().iface_id
}
fn poll(&self, sockets: &mut iface::SocketSet) -> Result<(), SystemError>; /// # `poll`
/// 用于轮询接口的状态。
/// ## 参数
/// - `sockets` :一个可变引用到 `smoltcp::iface::SocketSet`,表示要轮询的套接字集
/// ## 返回值
/// - 成功返回 `Ok(())`
/// - 如果轮询失败,返回 `Err(SystemError::EAGAIN_OR_EWOULDBLOCK)`,表示需要再次尝试或者操作会阻塞
fn poll(&self);
fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError>; /// # `update_ip_addrs`
/// 用于更新接口的 IP 地址
/// ## 参数
/// - `ip_addrs` :一个包含 `smoltcp::wire::IpCidr` 的切片,表示要设置的 IP 地址和子网掩码
/// ## 返回值
/// - 如果 `ip_addrs` 的长度不为 1返回 `Err(SystemError::EINVAL)`,表示输入参数无效
fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> {
self.common().update_ip_addrs(ip_addrs)
}
/// @brief 获取smoltcp的网卡接口类型 /// @brief 获取smoltcp的网卡接口类型
fn inner_iface(&self) -> &SpinLock<smoltcp::iface::Interface>; #[inline(always)]
fn smol_iface(&self) -> &SpinLock<smoltcp::iface::Interface> {
&self.common().smol_iface
}
// fn as_any_ref(&'static self) -> &'static dyn core::any::Any; // fn as_any_ref(&'static self) -> &'static dyn core::any::Any;
/// # `sockets`
/// 获取网卡的套接字集
fn sockets(&self) -> &SpinLock<smoltcp::iface::SocketSet<'static>> {
&self.common().sockets
}
/// # `port_manager`
/// 用于管理网卡的端口
fn port_manager(&self) -> &PortManager {
&self.common().port_manager
}
fn addr_assign_type(&self) -> u8; fn addr_assign_type(&self) -> u8;
fn net_device_type(&self) -> u16; fn net_device_type(&self) -> u16;
@ -108,7 +153,7 @@ impl Default for NetDeviceCommonData {
/// 将网络设备注册到sysfs中 /// 将网络设备注册到sysfs中
/// 参考https://code.dragonos.org.cn/xref/linux-2.6.39/net/core/dev.c?fi=register_netdev#5373 /// 参考https://code.dragonos.org.cn/xref/linux-2.6.39/net/core/dev.c?fi=register_netdev#5373
fn register_netdevice(dev: Arc<dyn NetDevice>) -> Result<(), SystemError> { fn register_netdevice(dev: Arc<dyn Iface>) -> Result<(), SystemError> {
// 在sysfs中注册设备 // 在sysfs中注册设备
netdev_register_kobject(dev.clone())?; netdev_register_kobject(dev.clone())?;
@ -117,3 +162,124 @@ fn register_netdevice(dev: Arc<dyn NetDevice>) -> Result<(), SystemError> {
return Ok(()); return Ok(());
} }
pub struct IfaceCommon {
iface_id: usize,
smol_iface: SpinLock<smoltcp::iface::Interface>,
/// 存smoltcp网卡的套接字集
sockets: SpinLock<smoltcp::iface::SocketSet<'static>>,
/// 存 kernel wrap smoltcp socket 的集合
bounds: RwLock<Vec<Arc<dyn InetSocket>>>,
/// 端口管理器
port_manager: PortManager,
/// 下次轮询的时间
poll_at_ms: core::sync::atomic::AtomicU64,
}
impl fmt::Debug for IfaceCommon {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IfaceCommon")
.field("iface_id", &self.iface_id)
.field("sockets", &self.sockets)
.field("bounds", &self.bounds)
.field("port_manager", &self.port_manager)
.field("poll_at_ms", &self.poll_at_ms)
.finish()
}
}
impl IfaceCommon {
pub fn new(iface_id: usize, iface: smoltcp::iface::Interface) -> Self {
IfaceCommon {
iface_id,
smol_iface: SpinLock::new(iface),
sockets: SpinLock::new(smoltcp::iface::SocketSet::new(Vec::new())),
bounds: RwLock::new(Vec::new()),
port_manager: PortManager::new(),
poll_at_ms: core::sync::atomic::AtomicU64::new(0),
}
}
pub fn poll<D>(&self, device: &mut D)
where
D: smoltcp::phy::Device + ?Sized,
{
let timestamp = crate::time::Instant::now().into();
let mut sockets = self.sockets.lock_irqsave();
let mut interface = self.smol_iface.lock_irqsave();
let (has_events, poll_at) = {
let mut has_events = false;
let mut poll_at;
loop {
has_events |= interface.poll(timestamp, device, &mut sockets);
poll_at = interface.poll_at(timestamp, &sockets);
let Some(instant) = poll_at else {
break;
};
if instant > timestamp {
break;
}
}
(has_events, poll_at)
};
// drop sockets here to avoid deadlock
drop(interface);
drop(sockets);
use core::sync::atomic::Ordering;
if let Some(instant) = poll_at {
let _old_instant = self.poll_at_ms.load(Ordering::Relaxed);
let new_instant = instant.total_millis() as u64;
self.poll_at_ms.store(new_instant, Ordering::Relaxed);
// if old_instant == 0 || new_instant < old_instant {
// self.polling_wait_queue.wake_all();
// }
} else {
self.poll_at_ms.store(0, Ordering::Relaxed);
}
if has_events {
// log::debug!("IfaceCommon::poll: has_events");
// We never try to hold the write lock in the IRQ context, and we disable IRQ when
// holding the write lock. So we don't need to disable IRQ when holding the read lock.
self.bounds.read().iter().for_each(|bound_socket| {
bound_socket.on_iface_events();
bound_socket
.wait_queue()
.wakeup(Some(ProcessState::Blocked(true)));
});
// let closed_sockets = self
// .closing_sockets
// .lock_irq_disabled()
// .extract_if(|closing_socket| closing_socket.is_closed())
// .collect::<Vec<_>>();
// drop(closed_sockets);
}
}
pub fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> {
if ip_addrs.len() != 1 {
return Err(SystemError::EINVAL);
}
self.smol_iface.lock().update_ip_addrs(|addrs| {
let dest = addrs.iter_mut().next();
if let Some(dest) = dest {
*dest = ip_addrs[0];
} else {
addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full");
}
});
return Ok(());
}
// 需要bounds储存具体的Inet Socket信息以提供不同种类inet socket的事件分发
pub fn bind_socket(&self, socket: Arc<dyn InetSocket>) {
self.bounds.write().push(socket);
}
}

View File

@ -17,11 +17,11 @@ use intertrait::cast::CastArc;
use log::error; use log::error;
use system_error::SystemError; use system_error::SystemError;
use super::{class::sys_class_net_instance, NetDeivceState, NetDevice, Operstate}; use super::{class::sys_class_net_instance, Iface, NetDeivceState, Operstate};
/// 将设备注册到`/sys/class/net`目录下 /// 将设备注册到`/sys/class/net`目录下
/// 参考https://code.dragonos.org.cn/xref/linux-2.6.39/net/core/net-sysfs.c?fi=netdev_register_kobject#1311 /// 参考https://code.dragonos.org.cn/xref/linux-2.6.39/net/core/net-sysfs.c?fi=netdev_register_kobject#1311
pub fn netdev_register_kobject(dev: Arc<dyn NetDevice>) -> Result<(), SystemError> { pub fn netdev_register_kobject(dev: Arc<dyn Iface>) -> Result<(), SystemError> {
// 初始化设备 // 初始化设备
device_manager().device_default_initialize(&(dev.clone() as Arc<dyn Device>)); device_manager().device_default_initialize(&(dev.clone() as Arc<dyn Device>));
@ -103,8 +103,8 @@ impl Attribute for AttrAddrAssignType {
} }
fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> { fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> {
let net_device = kobj.cast::<dyn NetDevice>().map_err(|_| { let net_device = kobj.cast::<dyn Iface>().map_err(|_| {
error!("AttrAddrAssignType::show() failed: kobj is not a NetDevice"); error!("AttrAddrAssignType::show() failed: kobj is not a Iface");
SystemError::EINVAL SystemError::EINVAL
})?; })?;
let addr_assign_type = net_device.addr_assign_type(); let addr_assign_type = net_device.addr_assign_type();
@ -271,8 +271,8 @@ impl Attribute for AttrType {
} }
fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> { fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> {
let net_deive = kobj.cast::<dyn NetDevice>().map_err(|_| { let net_deive = kobj.cast::<dyn Iface>().map_err(|_| {
error!("AttrType::show() failed: kobj is not a NetDevice"); error!("AttrType::show() failed: kobj is not a Iface");
SystemError::EINVAL SystemError::EINVAL
})?; })?;
let net_type = net_deive.net_device_type(); let net_type = net_deive.net_device_type();
@ -322,8 +322,8 @@ impl Attribute for AttrAddress {
} }
fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> { fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> {
let net_device = kobj.cast::<dyn NetDevice>().map_err(|_| { let net_device = kobj.cast::<dyn Iface>().map_err(|_| {
error!("AttrAddress::show() failed: kobj is not a NetDevice"); error!("AttrAddress::show() failed: kobj is not a Iface");
SystemError::EINVAL SystemError::EINVAL
})?; })?;
let mac_addr = net_device.mac(); let mac_addr = net_device.mac();
@ -373,8 +373,8 @@ impl Attribute for AttrCarrier {
} }
fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> { fn show(&self, kobj: Arc<dyn KObject>, buf: &mut [u8]) -> Result<usize, SystemError> {
let net_device = kobj.cast::<dyn NetDevice>().map_err(|_| { let net_device = kobj.cast::<dyn Iface>().map_err(|_| {
error!("AttrCarrier::show() failed: kobj is not a NetDevice"); error!("AttrCarrier::show() failed: kobj is not a Iface");
SystemError::EINVAL SystemError::EINVAL
})?; })?;
if net_device if net_device
@ -489,8 +489,8 @@ impl Attribute for AttrOperstate {
} }
fn show(&self, _kobj: Arc<dyn KObject>, _buf: &mut [u8]) -> Result<usize, SystemError> { fn show(&self, _kobj: Arc<dyn KObject>, _buf: &mut [u8]) -> Result<usize, SystemError> {
let net_device = _kobj.cast::<dyn NetDevice>().map_err(|_| { let net_device = _kobj.cast::<dyn Iface>().map_err(|_| {
error!("AttrOperstate::show() failed: kobj is not a NetDevice"); error!("AttrOperstate::show() failed: kobj is not a Iface");
SystemError::EINVAL SystemError::EINVAL
})?; })?;
if !net_device if !net_device

View File

@ -16,7 +16,7 @@ use smoltcp::{iface, phy, wire};
use unified_init::macros::unified_init; use unified_init::macros::unified_init;
use virtio_drivers::device::net::VirtIONet; use virtio_drivers::device::net::VirtIONet;
use super::{NetDeivceState, NetDevice, NetDeviceCommonData, Operstate}; use super::{Iface, NetDeivceState, NetDeviceCommonData, Operstate};
use crate::{ use crate::{
arch::rand::rand, arch::rand::rand,
driver::{ driver::{
@ -47,7 +47,7 @@ use crate::{
rwlock::{RwLockReadGuard, RwLockWriteGuard}, rwlock::{RwLockReadGuard, RwLockWriteGuard},
spinlock::{SpinLock, SpinLockGuard}, spinlock::{SpinLock, SpinLockGuard},
}, },
net::{generate_iface_id, net_core::poll_ifaces_try_lock_onetime, NET_DEVICES}, net::{generate_iface_id, net_core::poll_ifaces, NET_DEVICES},
time::Instant, time::Instant,
}; };
use system_error::SystemError; use system_error::SystemError;
@ -253,7 +253,8 @@ impl Device for VirtIONetDevice {
impl VirtIODevice for VirtIONetDevice { impl VirtIODevice for VirtIONetDevice {
fn handle_irq(&self, _irq: IrqNumber) -> Result<IrqReturn, SystemError> { fn handle_irq(&self, _irq: IrqNumber) -> Result<IrqReturn, SystemError> {
poll_ifaces_try_lock_onetime().ok(); log::warn!("VirtioInterface: poll_ifaces_try_lock_onetime -> poll_ifaces");
poll_ifaces();
return Ok(IrqReturn::Handled); return Ok(IrqReturn::Handled);
} }
@ -362,13 +363,13 @@ impl Debug for VirtIONicDeviceInner {
} }
} }
#[cast_to([sync] NetDevice)] #[cast_to([sync] Iface)]
#[cast_to([sync] Device)] #[cast_to([sync] Device)]
#[derive(Debug)]
pub struct VirtioInterface { pub struct VirtioInterface {
device_inner: VirtIONicDeviceInnerWrapper, device_inner: VirtIONicDeviceInnerWrapper,
iface_id: usize,
iface_name: String, iface_name: String,
iface: SpinLock<iface::Interface>, iface_common: super::IfaceCommon,
inner: SpinLock<InnerVirtIOInterface>, inner: SpinLock<InnerVirtIOInterface>,
locked_kobj_state: LockedKObjectState, locked_kobj_state: LockedKObjectState,
} }
@ -380,17 +381,6 @@ struct InnerVirtIOInterface {
netdevice_common: NetDeviceCommonData, netdevice_common: NetDeviceCommonData,
} }
impl core::fmt::Debug for VirtioInterface {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("VirtioInterface")
.field("iface_id", &self.iface_id)
.field("iface_name", &self.iface_name)
.field("inner", &self.inner)
.field("locked_kobj_state", &self.locked_kobj_state)
.finish()
}
}
impl VirtioInterface { impl VirtioInterface {
pub fn new(mut device_inner: VirtIONicDeviceInner) -> Arc<Self> { pub fn new(mut device_inner: VirtIONicDeviceInner) -> Arc<Self> {
let iface_id = generate_iface_id(); let iface_id = generate_iface_id();
@ -403,10 +393,9 @@ impl VirtioInterface {
let result = Arc::new(VirtioInterface { let result = Arc::new(VirtioInterface {
device_inner: VirtIONicDeviceInnerWrapper(UnsafeCell::new(device_inner)), device_inner: VirtIONicDeviceInnerWrapper(UnsafeCell::new(device_inner)),
iface_id,
locked_kobj_state: LockedKObjectState::default(), locked_kobj_state: LockedKObjectState::default(),
iface: SpinLock::new(iface),
iface_name: format!("eth{}", iface_id), iface_name: format!("eth{}", iface_id),
iface_common: super::IfaceCommon::new(iface_id, iface),
inner: SpinLock::new(InnerVirtIOInterface { inner: SpinLock::new(InnerVirtIOInterface {
kobj_common: KObjectCommonData::default(), kobj_common: KObjectCommonData::default(),
device_common: DeviceCommonData::default(), device_common: DeviceCommonData::default(),
@ -431,7 +420,7 @@ impl VirtioInterface {
impl Drop for VirtioInterface { impl Drop for VirtioInterface {
fn drop(&mut self) { fn drop(&mut self) {
// 从全局的网卡接口信息表中删除这个网卡的接口信息 // 从全局的网卡接口信息表中删除这个网卡的接口信息
NET_DEVICES.write_irqsave().remove(&self.iface_id); NET_DEVICES.write_irqsave().remove(&self.nic_id());
} }
} }
@ -624,57 +613,25 @@ pub fn virtio_net(
} }
} }
impl NetDevice for VirtioInterface { impl Iface for VirtioInterface {
fn common(&self) -> &super::IfaceCommon {
&self.iface_common
}
fn mac(&self) -> wire::EthernetAddress { fn mac(&self) -> wire::EthernetAddress {
let mac: [u8; 6] = self.device_inner.inner.lock().mac_address(); let mac: [u8; 6] = self.device_inner.inner.lock().mac_address();
return wire::EthernetAddress::from_bytes(&mac); return wire::EthernetAddress::from_bytes(&mac);
} }
#[inline]
fn nic_id(&self) -> usize {
return self.iface_id;
}
#[inline] #[inline]
fn iface_name(&self) -> String { fn iface_name(&self) -> String {
return self.iface_name.clone(); return self.iface_name.clone();
} }
fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> { fn poll(&self) {
if ip_addrs.len() != 1 { self.iface_common.poll(self.device_inner.force_get_mut())
return Err(SystemError::EINVAL);
} }
self.iface.lock().update_ip_addrs(|addrs| {
let dest = addrs.iter_mut().next();
if let Some(dest) = dest {
*dest = ip_addrs[0];
} else {
addrs
.push(ip_addrs[0])
.expect("Push wire::IpCidr failed: full");
}
});
return Ok(());
}
fn poll(&self, sockets: &mut iface::SocketSet) -> Result<(), SystemError> {
let timestamp: smoltcp::time::Instant = Instant::now().into();
let mut guard = self.iface.lock();
let poll_res = guard.poll(timestamp, self.device_inner.force_get_mut(), sockets);
// todo: notify!!!
// debug!("Virtio Interface poll:{poll_res}");
if poll_res {
return Ok(());
}
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
#[inline(always)]
fn inner_iface(&self) -> &SpinLock<iface::Interface> {
return &self.iface;
}
// fn as_any_ref(&'static self) -> &'static dyn core::any::Any { // fn as_any_ref(&'static self) -> &'static dyn core::any::Any {
// return self; // return self;
// } // }
@ -839,7 +796,7 @@ impl VirtIODriver for VirtIONetDriver {
// 设置iface的父设备为virtio_net_device // 设置iface的父设备为virtio_net_device
iface.set_dev_parent(Some(Arc::downgrade(&virtio_net_device) as Weak<dyn Device>)); iface.set_dev_parent(Some(Arc::downgrade(&virtio_net_device) as Weak<dyn Device>));
// 在sysfs中注册iface // 在sysfs中注册iface
register_netdevice(iface.clone() as Arc<dyn NetDevice>)?; register_netdevice(iface.clone() as Arc<dyn Iface>)?;
// 将网卡的接口信息注册到全局的网卡接口信息表中 // 将网卡的接口信息注册到全局的网卡接口信息表中
NET_DEVICES NET_DEVICES

View File

@ -23,7 +23,7 @@ use crate::{
mm::{page::Page, MemoryManagementArch}, mm::{page::Page, MemoryManagementArch},
net::{ net::{
event_poll::{EPollItem, EPollPrivateData, EventPoll}, event_poll::{EPollItem, EPollPrivateData, EventPoll},
socket::SocketInode, socket::Inode as SocketInode,
}, },
process::{cred::Cred, ProcessManager}, process::{cred::Cred, ProcessManager},
}; };
@ -570,9 +570,10 @@ impl File {
match self.file_type { match self.file_type {
FileType::Socket => { FileType::Socket => {
let inode = self.inode.downcast_ref::<SocketInode>().unwrap(); let inode = self.inode.downcast_ref::<SocketInode>().unwrap();
let mut socket = inode.inner(); // let mut socket = inode.inner();
return socket.add_epoll(epitem); inode.epoll_items().add(epitem);
return Ok(());
} }
FileType::Pipe => { FileType::Pipe => {
let inode = self.inode.downcast_ref::<LockedPipeInode>().unwrap(); let inode = self.inode.downcast_ref::<LockedPipeInode>().unwrap();
@ -592,12 +593,12 @@ impl File {
/// ## 删除一个绑定的epoll /// ## 删除一个绑定的epoll
pub fn remove_epoll(&self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> { pub fn remove_epoll(&self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
match self.file_type { match self.file_type {
FileType::Socket => { FileType::Socket => self
let inode = self.inode.downcast_ref::<SocketInode>().unwrap(); .inode
let mut socket = inode.inner(); .downcast_ref::<SocketInode>()
.unwrap()
socket.remove_epoll(epoll) .epoll_items()
} .remove(epoll),
FileType::Pipe => { FileType::Pipe => {
let inode = self.inode.downcast_ref::<LockedPipeInode>().unwrap(); let inode = self.inode.downcast_ref::<LockedPipeInode>().unwrap();
inode.inner().lock().remove_epoll(epoll) inode.inner().lock().remove_epoll(epoll)

View File

@ -1,5 +1,5 @@
use alloc::sync::Arc; use alloc::sync::Arc;
use log::warn; use log::{debug, warn};
use system_error::SystemError; use system_error::SystemError;
use super::{ use super::{
@ -81,7 +81,7 @@ fn do_sys_openat2(
how: OpenHow, how: OpenHow,
follow_symlink: bool, follow_symlink: bool,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
// debug!("open path: {}, how: {:?}", path, how); //debug!("open path: {}, how: {:?}", path, how);
let path = path.trim(); let path = path.trim();
let (inode_begin, path) = user_path_at(&ProcessManager::current_pcb(), dirfd, path)?; let (inode_begin, path) = user_path_at(&ProcessManager::current_pcb(), dirfd, path)?;

View File

@ -2,7 +2,7 @@ use core::ffi::c_void;
use core::mem::size_of; use core::mem::size_of;
use alloc::{string::String, sync::Arc, vec::Vec}; use alloc::{string::String, sync::Arc, vec::Vec};
use log::warn; use log::{debug, warn};
use system_error::SystemError; use system_error::SystemError;
use crate::producefs; use crate::producefs;

View File

@ -8,7 +8,10 @@ use system_error::SystemError;
use crate::{ use crate::{
arch::{interrupt::TrapFrame, process::arch_switch_to_user}, arch::{interrupt::TrapFrame, process::arch_switch_to_user},
driver::{net::e1000e::e1000e::e1000e_init, virtio::virtio::virtio_probe}, driver::{
net::{e1000e::e1000e::e1000e_init, loopback::loopback_init},
virtio::virtio::virtio_probe,
},
filesystem::vfs::core::mount_root_fs, filesystem::vfs::core::mount_root_fs,
net::net_core::net_init, net::net_core::net_init,
process::{kthread::KernelThreadMechanism, stdio::stdio_init, ProcessFlags, ProcessManager}, process::{kthread::KernelThreadMechanism, stdio::stdio_init, ProcessFlags, ProcessManager},
@ -40,6 +43,7 @@ fn kernel_init() -> Result<(), SystemError> {
net_init().unwrap_or_else(|err| { net_init().unwrap_or_else(|err| {
error!("Failed to initialize network: {:?}", err); error!("Failed to initialize network: {:?}", err);
}); });
loopback_init()?;
debug!("initial kernel thread done."); debug!("initial kernel thread done.");

View File

@ -1,3 +1,7 @@
//! # 网络模块
//! 注意net模块下为了方便导入模块细分且共用部分模块直接使用
//! `pub use`导出,导入时也常见`use crate::net::socket::*`的写法,
//! 敬请注意。
use core::{ use core::{
fmt::{self, Debug}, fmt::{self, Debug},
sync::atomic::AtomicUsize, sync::atomic::AtomicUsize,
@ -5,20 +9,18 @@ use core::{
use alloc::{collections::BTreeMap, sync::Arc}; use alloc::{collections::BTreeMap, sync::Arc};
use crate::{driver::net::NetDevice, libs::rwlock::RwLock}; use crate::{driver::net::Iface, libs::rwlock::RwLock};
use smoltcp::wire::IpEndpoint;
use self::socket::SocketInode;
pub mod event_poll; pub mod event_poll;
pub mod net_core; pub mod net_core;
pub mod socket; pub mod socket;
pub mod syscall; pub mod syscall;
pub mod syscall_util;
lazy_static! { lazy_static! {
/// # 所有网络接口的列表 /// # 所有网络接口的列表
/// 这个列表在中断上下文会使用到因此需要irqsave /// 这个列表在中断上下文会使用到因此需要irqsave
pub static ref NET_DEVICES: RwLock<BTreeMap<usize, Arc<dyn NetDevice>>> = RwLock::new(BTreeMap::new()); pub static ref NET_DEVICES: RwLock<BTreeMap<usize, Arc<dyn Iface>>> = RwLock::new(BTreeMap::new());
} }
/// 生成网络接口的id (全局自增) /// 生成网络接口的id (全局自增)
@ -26,120 +28,3 @@ pub fn generate_iface_id() -> usize {
static IFACE_ID: AtomicUsize = AtomicUsize::new(0); static IFACE_ID: AtomicUsize = AtomicUsize::new(0);
return IFACE_ID.fetch_add(1, core::sync::atomic::Ordering::SeqCst); return IFACE_ID.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
} }
bitflags! {
/// @brief 用于指定socket的关闭类型
/// 参考https://code.dragonos.org.cn/xref/linux-6.1.9/include/net/sock.h?fi=SHUTDOWN_MASK#1573
pub struct ShutdownType: u8 {
const RCV_SHUTDOWN = 1;
const SEND_SHUTDOWN = 2;
const SHUTDOWN_MASK = 3;
}
}
#[derive(Debug, Clone)]
pub enum Endpoint {
/// 链路层端点
LinkLayer(LinkLayerEndpoint),
/// 网络层端点
Ip(Option<IpEndpoint>),
/// inode端点
Inode(Option<Arc<SocketInode>>),
// todo: 增加NetLink机制后增加NetLink端点
}
/// @brief 链路层端点
#[derive(Debug, Clone)]
pub struct LinkLayerEndpoint {
/// 网卡的接口号
pub interface: usize,
}
impl LinkLayerEndpoint {
/// @brief 创建一个链路层端点
///
/// @param interface 网卡的接口号
///
/// @return 返回创建的链路层端点
pub fn new(interface: usize) -> Self {
Self { interface }
}
}
/// IP datagram encapsulated protocol.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[repr(u8)]
pub enum Protocol {
HopByHop = 0x00,
Icmp = 0x01,
Igmp = 0x02,
Tcp = 0x06,
Udp = 0x11,
Ipv6Route = 0x2b,
Ipv6Frag = 0x2c,
Icmpv6 = 0x3a,
Ipv6NoNxt = 0x3b,
Ipv6Opts = 0x3c,
Unknown(u8),
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Protocol::HopByHop => write!(f, "Hop-by-Hop"),
Protocol::Icmp => write!(f, "ICMP"),
Protocol::Igmp => write!(f, "IGMP"),
Protocol::Tcp => write!(f, "TCP"),
Protocol::Udp => write!(f, "UDP"),
Protocol::Ipv6Route => write!(f, "IPv6-Route"),
Protocol::Ipv6Frag => write!(f, "IPv6-Frag"),
Protocol::Icmpv6 => write!(f, "ICMPv6"),
Protocol::Ipv6NoNxt => write!(f, "IPv6-NoNxt"),
Protocol::Ipv6Opts => write!(f, "IPv6-Opts"),
Protocol::Unknown(id) => write!(f, "0x{id:02x}"),
}
}
}
impl From<smoltcp::wire::IpProtocol> for Protocol {
fn from(value: smoltcp::wire::IpProtocol) -> Self {
let x: u8 = value.into();
Protocol::from(x)
}
}
impl From<u8> for Protocol {
fn from(value: u8) -> Self {
match value {
0x00 => Protocol::HopByHop,
0x01 => Protocol::Icmp,
0x02 => Protocol::Igmp,
0x06 => Protocol::Tcp,
0x11 => Protocol::Udp,
0x2b => Protocol::Ipv6Route,
0x2c => Protocol::Ipv6Frag,
0x3a => Protocol::Icmpv6,
0x3b => Protocol::Ipv6NoNxt,
0x3c => Protocol::Ipv6Opts,
_ => Protocol::Unknown(value),
}
}
}
impl From<Protocol> for u8 {
fn from(value: Protocol) -> Self {
match value {
Protocol::HopByHop => 0x00,
Protocol::Icmp => 0x01,
Protocol::Igmp => 0x02,
Protocol::Tcp => 0x06,
Protocol::Udp => 0x11,
Protocol::Ipv6Route => 0x2b,
Protocol::Ipv6Frag => 0x2c,
Protocol::Icmpv6 => 0x3a,
Protocol::Ipv6NoNxt => 0x3b,
Protocol::Ipv6Opts => 0x3c,
Protocol::Unknown(id) => id,
}
}
}

View File

@ -4,17 +4,12 @@ use smoltcp::{socket::dhcpv4, wire};
use system_error::SystemError; use system_error::SystemError;
use crate::{ use crate::{
driver::net::{NetDevice, Operstate}, driver::net::{Iface, Operstate},
libs::rwlock::RwLockReadGuard, libs::rwlock::RwLockReadGuard,
net::{socket::SocketPollMethod, NET_DEVICES}, net::NET_DEVICES,
time::timer::{next_n_ms_timer_jiffies, Timer, TimerFunction}, time::timer::{next_n_ms_timer_jiffies, Timer, TimerFunction},
}; };
use super::{
event_poll::{EPollEventType, EventPoll},
socket::{handle::GlobalSocketHandle, inet::TcpSocket, HANDLE_MAP, SOCKET_SET},
};
/// The network poll function, which will be called by timer. /// The network poll function, which will be called by timer.
/// ///
/// The main purpose of this function is to poll all network interfaces. /// The main purpose of this function is to poll all network interfaces.
@ -24,7 +19,7 @@ struct NetWorkPollFunc;
impl TimerFunction for NetWorkPollFunc { impl TimerFunction for NetWorkPollFunc {
fn run(&mut self) -> Result<(), SystemError> { fn run(&mut self) -> Result<(), SystemError> {
poll_ifaces_try_lock(10).ok(); poll_ifaces();
let next_time = next_n_ms_timer_jiffies(10); let next_time = next_n_ms_timer_jiffies(10);
let timer = Timer::new(Box::new(NetWorkPollFunc), next_time); let timer = Timer::new(Box::new(NetWorkPollFunc), next_time);
timer.activate(); timer.activate();
@ -43,10 +38,10 @@ pub fn net_init() -> Result<(), SystemError> {
fn dhcp_query() -> Result<(), SystemError> { fn dhcp_query() -> Result<(), SystemError> {
let binding = NET_DEVICES.write_irqsave(); let binding = NET_DEVICES.write_irqsave();
log::debug!("binding: {:?}", *binding);
//由于现在os未实现在用户态为网卡动态分配内存而lo网卡的id最先分配且ip固定不能被分配 //由于现在os未实现在用户态为网卡动态分配内存而lo网卡的id最先分配且ip固定不能被分配
//所以特判取用id为1的网卡也就是virto_net //所以特判取用id为0的网卡也就是virto_net
let net_face = binding.get(&1).ok_or(SystemError::ENODEV)?.clone(); let net_face = binding.get(&0).ok_or(SystemError::ENODEV)?.clone();
drop(binding); drop(binding);
@ -59,13 +54,16 @@ fn dhcp_query() -> Result<(), SystemError> {
// IMPORTANT: This should be removed in production. // IMPORTANT: This should be removed in production.
dhcp_socket.set_max_lease_duration(Some(smoltcp::time::Duration::from_secs(10))); dhcp_socket.set_max_lease_duration(Some(smoltcp::time::Duration::from_secs(10)));
let dhcp_handle = SOCKET_SET.lock_irqsave().add(dhcp_socket); let sockets = || net_face.sockets().lock_irqsave();
// let dhcp_handle = SOCKET_SET.lock_irqsave().add(dhcp_socket);
let dhcp_handle = sockets().add(dhcp_socket);
const DHCP_TRY_ROUND: u8 = 10; const DHCP_TRY_ROUND: u8 = 10;
for i in 0..DHCP_TRY_ROUND { for i in 0..DHCP_TRY_ROUND {
debug!("DHCP try round: {}", i); log::debug!("DHCP try round: {}", i);
net_face.poll(&mut SOCKET_SET.lock_irqsave()).ok(); net_face.poll();
let mut binding = SOCKET_SET.lock_irqsave(); let mut binding = sockets();
let event = binding.get_mut::<dhcpv4::Socket>(dhcp_handle).poll(); let event = binding.get_mut::<dhcpv4::Socket>(dhcp_handle).poll();
match event { match event {
@ -81,13 +79,26 @@ fn dhcp_query() -> Result<(), SystemError> {
.ok(); .ok();
if let Some(router) = config.router { if let Some(router) = config.router {
net_face let mut smol_iface = net_face.smol_iface().lock();
.inner_iface() smol_iface.routes_mut().update(|table| {
.lock() let _ = table.push(smoltcp::iface::Route {
cidr: smoltcp::wire::IpCidr::Ipv4(smoltcp::wire::Ipv4Cidr::new(
smoltcp::wire::Ipv4Address::new(127, 0, 0, 0),
8,
)),
via_router: smoltcp::wire::IpAddress::v4(127, 0, 0, 1),
preferred_until: None,
expires_at: None,
});
});
if smol_iface
.routes_mut() .routes_mut()
.add_default_ipv4_route(router) .add_default_ipv4_route(router)
.unwrap(); .is_err()
let cidr = net_face.inner_iface().lock().ip_addrs().first().cloned(); {
log::warn!("Route table full");
}
let cidr = smol_iface.ip_addrs().first().cloned();
if let Some(cidr) = cidr { if let Some(cidr) = cidr {
// 这里先在这里将网卡设置为up后面等netlink实现了再修改 // 这里先在这里将网卡设置为up后面等netlink实现了再修改
net_face.set_operstate(Operstate::IF_OPER_UP); net_face.set_operstate(Operstate::IF_OPER_UP);
@ -96,7 +107,7 @@ fn dhcp_query() -> Result<(), SystemError> {
} }
} else { } else {
net_face net_face
.inner_iface() .smol_iface()
.lock() .lock()
.routes_mut() .routes_mut()
.remove_default_ipv4_route(); .remove_default_ipv4_route();
@ -112,7 +123,7 @@ fn dhcp_query() -> Result<(), SystemError> {
))]) ))])
.ok(); .ok();
net_face net_face
.inner_iface() .smol_iface()
.lock() .lock()
.routes_mut() .routes_mut()
.remove_default_ipv4_route(); .remove_default_ipv4_route();
@ -124,123 +135,109 @@ fn dhcp_query() -> Result<(), SystemError> {
} }
pub fn poll_ifaces() { pub fn poll_ifaces() {
let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn NetDevice>>> = NET_DEVICES.read_irqsave(); // log::debug!("poll_ifaces");
let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn Iface>>> = NET_DEVICES.read_irqsave();
if guard.len() == 0 { if guard.len() == 0 {
warn!("poll_ifaces: No net driver found!"); warn!("poll_ifaces: No net driver found!");
return; return;
} }
let mut sockets = SOCKET_SET.lock_irqsave();
for (_, iface) in guard.iter() { for (_, iface) in guard.iter() {
iface.poll(&mut sockets).ok(); iface.poll();
} }
let _ = send_event(&sockets);
} }
/// 对ifaces进行轮询最多对SOCKET_SET尝试times次加锁。 // /// 对ifaces进行轮询最多对SOCKET_SET尝试times次加锁。
/// // ///
/// @return 轮询成功返回Ok(()) // /// @return 轮询成功返回Ok(())
/// @return 加锁超时返回SystemError::EAGAIN_OR_EWOULDBLOCK // /// @return 加锁超时返回SystemError::EAGAIN_OR_EWOULDBLOCK
/// @return 没有网卡返回SystemError::ENODEV // /// @return 没有网卡返回SystemError::ENODEV
pub fn poll_ifaces_try_lock(times: u16) -> Result<(), SystemError> { // pub fn poll_ifaces_try_lock(times: u16) -> Result<(), SystemError> {
let mut i = 0; // let mut i = 0;
while i < times { // while i < times {
let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn NetDevice>>> = // let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn Iface>>> =
NET_DEVICES.read_irqsave(); // NET_DEVICES.read_irqsave();
if guard.len() == 0 { // if guard.len() == 0 {
warn!("poll_ifaces: No net driver found!"); // warn!("poll_ifaces: No net driver found!");
// 没有网卡,返回错误 // // 没有网卡,返回错误
return Err(SystemError::ENODEV); // return Err(SystemError::ENODEV);
} // }
let sockets = SOCKET_SET.try_lock_irqsave(); // for (_, iface) in guard.iter() {
// 加锁失败,继续尝试 // iface.poll();
if sockets.is_err() { // }
i += 1; // return Ok(());
continue; // }
} // // 尝试次数用完,返回错误
// return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
// }
let mut sockets = sockets.unwrap(); // /// 对ifaces进行轮询最多对SOCKET_SET尝试一次加锁。
for (_, iface) in guard.iter() { // ///
iface.poll(&mut sockets).ok(); // /// @return 轮询成功返回Ok(())
} // /// @return 加锁超时返回SystemError::EAGAIN_OR_EWOULDBLOCK
send_event(&sockets)?; // /// @return 没有网卡返回SystemError::ENODEV
return Ok(()); // pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> {
} // let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn Iface>>> = NET_DEVICES.read_irqsave();
// 尝试次数用完,返回错误 // if guard.len() == 0 {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); // warn!("poll_ifaces: No net driver found!");
} // // 没有网卡,返回错误
// return Err(SystemError::ENODEV);
// }
// for (_, iface) in guard.iter() {
// let _ = iface.poll();
// }
// send_event()?;
// return Ok(());
// }
/// 对ifaces进行轮询最多对SOCKET_SET尝试一次加锁。 // /// ### 处理轮询后的事件
/// // fn send_event() -> Result<(), SystemError> {
/// @return 轮询成功返回Ok(()) // for (handle, socket_type) in .lock().iter() {
/// @return 加锁超时返回SystemError::EAGAIN_OR_EWOULDBLOCK
/// @return 没有网卡返回SystemError::ENODEV
pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> {
let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn NetDevice>>> = NET_DEVICES.read_irqsave();
if guard.len() == 0 {
warn!("poll_ifaces: No net driver found!");
// 没有网卡,返回错误
return Err(SystemError::ENODEV);
}
let mut sockets = SOCKET_SET.try_lock_irqsave()?;
for (_, iface) in guard.iter() {
iface.poll(&mut sockets).ok();
}
send_event(&sockets)?;
return Ok(());
}
/// ### 处理轮询后的事件 // let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
for (handle, socket_type) in sockets.iter() {
let handle_guard = HANDLE_MAP.read_irqsave();
let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
let item: Option<&super::socket::SocketHandleItem> = handle_guard.get(&global_handle);
if item.is_none() {
continue;
}
let handle_item = item.unwrap(); // let handle_guard = HANDLE_MAP.read_irqsave();
let posix_item = handle_item.posix_item(); // let item: Option<&super::socket::SocketHandleItem> = handle_guard.get(&global_handle);
if posix_item.is_none() { // if item.is_none() {
continue; // continue;
} // }
let posix_item = posix_item.unwrap();
// 获取socket上的事件 // let handle_item = item.unwrap();
let mut events = SocketPollMethod::poll(socket_type, handle_item).bits() as u64; // let posix_item = handle_item.posix_item();
// if posix_item.is_none() {
// continue;
// }
// let posix_item = posix_item.unwrap();
// 分发到相应类型socket处理 // // 获取socket上的事件
match socket_type { // let mut events = SocketPollMethod::poll(socket_type, handle_item).bits() as u64;
smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
posix_item.wakeup_any(events);
}
smoltcp::socket::Socket::Icmp(_) => unimplemented!("Icmp socket hasn't unimplemented"),
smoltcp::socket::Socket::Tcp(inner_socket) => {
if inner_socket.is_active() {
events |= TcpSocket::CAN_ACCPET;
}
if inner_socket.state() == smoltcp::socket::tcp::State::Established {
events |= TcpSocket::CAN_CONNECT;
}
if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait {
events |= EPollEventType::EPOLLHUP.bits() as u64;
}
posix_item.wakeup_any(events); // // 分发到相应类型socket处理
} // match socket_type {
smoltcp::socket::Socket::Dhcpv4(_) => {} // smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"), // posix_item.wakeup_any(events);
} // }
EventPoll::wakeup_epoll( // smoltcp::socket::Socket::Icmp(_) => unimplemented!("Icmp socket hasn't unimplemented"),
&posix_item.epitems, // smoltcp::socket::Socket::Tcp(inner_socket) => {
EPollEventType::from_bits_truncate(events as u32), // if inner_socket.is_active() {
)?; // events |= TcpSocket::CAN_ACCPET;
drop(handle_guard); // }
// crate::debug!( // if inner_socket.state() == smoltcp::socket::tcp::State::Established {
// "{} send_event {:?}", // events |= TcpSocket::CAN_CONNECT;
// handle, // }
// EPollEventType::from_bits_truncate(events as u32) // if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait {
// ); // events |= EPollEventType::EPOLLHUP.bits() as u64;
} // }
Ok(())
} // posix_item.wakeup_any(events);
// }
// smoltcp::socket::Socket::Dhcpv4(_) => {}
// smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"),
// }
// EventPoll::wakeup_epoll(
// &posix_item.epitems,
// EPollEventType::from_bits_truncate(events as u32),
// )?;
// drop(handle_guard);
// }
// Ok(())
// }

View File

@ -0,0 +1,143 @@
#![allow(unused_variables)]
use crate::net::socket::*;
use crate::net::syscall_util::MsgHdr;
use alloc::sync::Arc;
use core::any::Any;
use core::fmt::Debug;
use system_error::SystemError::{self, *};
/// # `Socket` methods
/// ## Reference
/// - [Posix standard](https://pubs.opengroup.org/onlinepubs/9699919799/)
pub trait Socket: Sync + Send + Debug + Any {
/// # `wait_queue`
/// 获取socket的wait queue
fn wait_queue(&self) -> &WaitQueue;
/// # `socket_poll`
/// 获取socket的事件。
fn poll(&self) -> usize;
fn send_buffer_size(&self) -> usize;
fn recv_buffer_size(&self) -> usize;
/// # `accept`
/// 接受连接仅用于listening stream socket
/// ## Block
/// 如果没有连接到来,会阻塞
fn accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
Err(ENOSYS)
}
/// # `bind`
/// 对应于POSIX的bind函数用于绑定到本机指定的端点
fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
Err(ENOSYS)
}
/// # `close`
/// 关闭socket
fn close(&self) -> Result<(), SystemError> {
Ok(())
}
/// # `connect`
/// 对应于POSIX的connect函数用于连接到指定的远程服务器端点
fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
Err(ENOSYS)
}
// fnctl
// freeaddrinfo
// getaddrinfo
// getnameinfo
/// # `get_peer_name`
/// 获取对端的地址
fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
Err(ENOSYS)
}
/// # `get_name`
/// 获取socket的地址
fn get_name(&self) -> Result<Endpoint, SystemError> {
Err(ENOSYS)
}
/// # `get_option`
/// 对应于 Posix `getsockopt` 获取socket选项
fn get_option(
&self,
level: OptionsLevel,
name: usize,
value: &mut [u8],
) -> Result<usize, SystemError> {
log::warn!("getsockopt is not implemented");
Ok(0)
}
/// # `listen`
/// 监听socket仅用于stream socket
fn listen(&self, backlog: usize) -> Result<(), SystemError> {
Err(ENOSYS)
}
// poll
// pselect
/// # `read`
fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
self.recv(buffer, MessageFlag::empty())
}
/// # `recv`
/// 接收数据,`read` = `recv` with flags = 0
fn recv(&self, buffer: &mut [u8], flags: MessageFlag) -> Result<usize, SystemError> {
Err(ENOSYS)
}
/// # `recv_from`
fn recv_from(
&self,
buffer: &mut [u8],
flags: MessageFlag,
address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> {
Err(ENOSYS)
}
/// # `recv_msg`
fn recv_msg(&self, msg: &mut MsgHdr, flags: MessageFlag) -> Result<usize, SystemError> {
Err(ENOSYS)
}
// select
/// # `send`
fn send(&self, buffer: &[u8], flags: MessageFlag) -> Result<usize, SystemError> {
Err(ENOSYS)
}
/// # `send_msg`
fn send_msg(&self, msg: &MsgHdr, flags: MessageFlag) -> Result<usize, SystemError> {
Err(ENOSYS)
}
/// # `send_to`
fn send_to(
&self,
buffer: &[u8],
flags: MessageFlag,
address: Endpoint,
) -> Result<usize, SystemError> {
Err(ENOSYS)
}
/// # `set_option`
/// Posix `setsockopt` 设置socket选项
/// ## Parameters
/// - level 选项的层次
/// - name 选项的名称
/// - value 选项的值
/// ## Reference
/// https://code.dragonos.org.cn/s?refs=sk_setsockopt&project=linux-6.6.21
fn set_option(&self, level: OptionsLevel, name: usize, val: &[u8]) -> Result<(), SystemError> {
log::warn!("setsockopt is not implemented");
Ok(())
}
/// # `shutdown`
fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
Err(ENOSYS)
}
// sockatmark
// socket
// socketpair
/// # `write`
fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
self.send(buffer, MessageFlag::empty())
}
// fn write_buffer(&self, _buf: &[u8]) -> Result<usize, SystemError> {
// todo!()
// }
}

View File

@ -0,0 +1,91 @@
use alloc::vec::Vec;
use alloc::{string::String, sync::Arc};
use log::debug;
use system_error::SystemError;
use crate::libs::spinlock::SpinLock;
#[derive(Debug)]
pub struct Buffer {
metadata: Metadata,
read_buffer: SpinLock<Vec<u8>>,
write_buffer: SpinLock<Vec<u8>>,
}
impl Buffer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
metadata: Metadata::default(),
read_buffer: SpinLock::new(Vec::new()),
write_buffer: SpinLock::new(Vec::new()),
})
}
pub fn is_read_buf_empty(&self) -> bool {
return self.read_buffer.lock().is_empty();
}
pub fn is_read_buf_full(&self) -> bool {
return self.metadata.buf_size - self.read_buffer.lock().len() == 0;
}
pub fn is_write_buf_empty(&self) -> bool {
return self.write_buffer.lock().is_empty();
}
pub fn is_write_buf_full(&self) -> bool {
return self.write_buffer.lock().len() >= self.metadata.buf_size;
}
pub fn read_read_buffer(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
let mut read_buffer = self.read_buffer.lock_irqsave();
let len = core::cmp::min(buf.len(), read_buffer.len());
buf[..len].copy_from_slice(&read_buffer[..len]);
let _ = read_buffer.split_off(len);
log::debug!("recv buf {}", String::from_utf8_lossy(buf));
return Ok(len);
}
pub fn write_read_buffer(&self, buf: &[u8]) -> Result<usize, SystemError> {
let mut buffer = self.read_buffer.lock_irqsave();
log::debug!("send buf {}", String::from_utf8_lossy(buf));
let len = buf.len();
if self.metadata.buf_size - buffer.len() < len {
return Err(SystemError::ENOBUFS);
}
buffer.extend_from_slice(buf);
Ok(len)
}
pub fn write_write_buffer(&self, buf: &[u8]) -> Result<usize, SystemError> {
let mut buffer = self.write_buffer.lock_irqsave();
let len = buf.len();
if self.metadata.buf_size - buffer.len() < len {
return Err(SystemError::ENOBUFS);
}
buffer.extend_from_slice(buf);
Ok(len)
}
}
#[derive(Debug)]
pub struct Metadata {
/// 默认的元数据缓冲区大小
metadata_buf_size: usize,
/// 默认的缓冲区大小
buf_size: usize,
}
impl Default for Metadata {
fn default() -> Self {
Self {
metadata_buf_size: 1024,
buf_size: 64 * 1024,
}
}
}

View File

@ -0,0 +1,64 @@
use alloc::{
collections::LinkedList,
sync::{Arc, Weak},
vec::Vec,
};
use system_error::SystemError;
use crate::{
libs::{spinlock::SpinLock, wait_queue::EventWaitQueue},
net::event_poll::{EPollEventType, EPollItem, EventPoll},
process::ProcessManager,
sched::{schedule, SchedMode},
};
#[derive(Debug, Clone)]
pub struct EPollItems {
items: Arc<SpinLock<LinkedList<Arc<EPollItem>>>>,
}
impl Default for EPollItems {
fn default() -> Self {
Self {
items: Arc::new(SpinLock::new(LinkedList::new())),
}
}
}
impl EPollItems {
pub fn add(&self, item: Arc<EPollItem>) {
self.items.lock_irqsave().push_back(item);
}
pub fn remove(&self, item: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
let to_remove = self
.items
.lock_irqsave()
.extract_if(|x| x.epoll().ptr_eq(item))
.collect::<Vec<_>>();
let result = if !to_remove.is_empty() {
Ok(())
} else {
Err(SystemError::ENOENT)
};
drop(to_remove);
return result;
}
pub fn clear(&self) -> Result<(), SystemError> {
let mut guard = self.items.lock_irqsave();
let mut result = Ok(());
guard.iter().for_each(|item| {
if let Some(epoll) = item.epoll().upgrade() {
let _ =
EventPoll::ep_remove(&mut epoll.lock_irqsave(), item.fd(), None).map_err(|e| {
result = Err(e);
});
}
});
guard.clear();
return result;
}
}

View File

@ -0,0 +1,20 @@
// pub mod poll_unit;
mod epoll_items;
pub mod shutdown;
pub use epoll_items::EPollItems;
#[allow(dead_code)]
pub use shutdown::Shutdown;
// /// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
// #[derive(Debug, Clone)]
// pub struct Metadata {
// /// 接收缓冲区的大小
// pub rx_buf_size: usize,
// /// 发送缓冲区的大小
// pub tx_buf_size: usize,
// /// 元数据的缓冲区的大小
// pub metadata_buf_size: usize,
// /// socket的选项
// pub options: SocketOptions,
// }

View File

@ -0,0 +1,72 @@
use alloc::{
collections::LinkedList,
sync::{Arc, Weak},
vec::Vec,
};
use system_error::SystemError;
use crate::{
libs::{spinlock::SpinLock, wait_queue::EventWaitQueue},
net::event_poll::{EPollEventType, EPollItem, EventPoll},
process::ProcessManager,
sched::{schedule, SchedMode},
};
#[derive(Debug, Clone)]
pub struct WaitQueue {
/// socket的waitqueue
wait_queue: Arc<EventWaitQueue>,
}
impl Default for WaitQueue {
fn default() -> Self {
Self {
wait_queue: Default::default(),
}
}
}
impl WaitQueue {
pub fn new(wait_queue: EventWaitQueue) -> Self {
Self {
wait_queue: Arc::new(wait_queue),
}
}
/// # `wakeup_any`
/// 唤醒该队列上等待events的进程
/// ## 参数
/// - events: 发生的事件
/// 需要注意的是只要触发了events中的任意一件事件进程都会被唤醒
pub fn wakeup_any(&self, events: EPollEventType) {
self.wait_queue.wakeup_any(events.bits() as u64);
}
/// # `wait_for`
/// 等待events事件发生
pub fn wait_for(&self, events: EPollEventType) {
unsafe {
ProcessManager::preempt_disable();
self.wait_queue.sleep_without_schedule(events.bits() as u64);
ProcessManager::preempt_enable();
}
schedule(SchedMode::SM_NONE);
}
/// # `busy_wait`
/// 轮询一个会返回EPAGAIN_OR_EWOULDBLOCK的函数
pub fn busy_wait<F, R>(&self, events: EPollEventType, mut f: F) -> Result<R, SystemError>
where
F: FnMut() -> Result<R, SystemError>,
{
loop {
match f() {
Ok(r) => return Ok(r),
Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => {
self.wait_for(events);
}
Err(e) => return Err(e),
}
}
}
}

View File

@ -0,0 +1,118 @@
use core::sync::atomic::AtomicU8;
bitflags! {
/// @brief 用于指定socket的关闭类型
/// 参考https://code.dragonos.org.cn/xref/linux-6.1.9/include/net/sock.h?fi=SHUTDOWN_MASK#1573
pub struct ShutdownBit: u8 {
const SHUT_RD = 0;
const SHUT_WR = 1;
const SHUT_RDWR = 2;
}
}
const RCV_SHUTDOWN: u8 = 0x01;
const SEND_SHUTDOWN: u8 = 0x02;
const SHUTDOWN_MASK: u8 = 0x03;
#[derive(Debug, Default)]
pub struct Shutdown {
bit: AtomicU8,
}
impl From<ShutdownBit> for Shutdown {
fn from(shutdown_bit: ShutdownBit) -> Self {
match shutdown_bit {
ShutdownBit::SHUT_RD => Shutdown {
bit: AtomicU8::new(RCV_SHUTDOWN),
},
ShutdownBit::SHUT_WR => Shutdown {
bit: AtomicU8::new(SEND_SHUTDOWN),
},
ShutdownBit::SHUT_RDWR => Shutdown {
bit: AtomicU8::new(SHUTDOWN_MASK),
},
_ => Shutdown::default(),
}
}
}
impl Shutdown {
pub fn new() -> Self {
Self {
bit: AtomicU8::new(0),
}
}
pub fn recv_shutdown(&self) {
self.bit
.fetch_or(RCV_SHUTDOWN, core::sync::atomic::Ordering::SeqCst);
}
pub fn send_shutdown(&self) {
self.bit
.fetch_or(SEND_SHUTDOWN, core::sync::atomic::Ordering::SeqCst);
}
// pub fn is_recv_shutdown(&self) -> bool {
// self.bit.load(core::sync::atomic::Ordering::SeqCst) & RCV_SHUTDOWN != 0
// }
// pub fn is_send_shutdown(&self) -> bool {
// self.bit.load(core::sync::atomic::Ordering::SeqCst) & SEND_SHUTDOWN != 0
// }
// pub fn is_both_shutdown(&self) -> bool {
// self.bit.load(core::sync::atomic::Ordering::SeqCst) & SHUTDOWN_MASK == SHUTDOWN_MASK
// }
// pub fn is_empty(&self) -> bool {
// self.bit.load(core::sync::atomic::Ordering::SeqCst) == 0
// }
pub fn from_how(how: usize) -> Self {
Self::from(ShutdownBit::from_bits_truncate(how as u8))
}
pub fn get(&self) -> ShutdownTemp {
ShutdownTemp {
bit: self.bit.load(core::sync::atomic::Ordering::SeqCst),
}
}
}
pub struct ShutdownTemp {
bit: u8,
}
impl ShutdownTemp {
pub fn is_recv_shutdown(&self) -> bool {
self.bit & RCV_SHUTDOWN != 0
}
pub fn is_send_shutdown(&self) -> bool {
self.bit & SEND_SHUTDOWN != 0
}
pub fn is_both_shutdown(&self) -> bool {
self.bit & SHUTDOWN_MASK == SHUTDOWN_MASK
}
pub fn is_empty(&self) -> bool {
self.bit == 0
}
pub fn from_how(how: usize) -> Self {
Self { bit: how as u8 + 1 }
}
}
impl From<ShutdownBit> for ShutdownTemp {
fn from(shutdown_bit: ShutdownBit) -> Self {
match shutdown_bit {
ShutdownBit::SHUT_RD => Self { bit: RCV_SHUTDOWN },
ShutdownBit::SHUT_WR => Self { bit: SEND_SHUTDOWN },
ShutdownBit::SHUT_RDWR => Self { bit: SHUTDOWN_MASK },
_ => Self { bit: 0 },
}
}
}

View File

@ -0,0 +1,76 @@
const SOL_SOCKET: u16 = 1;
#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)]
pub enum IPProtocol {
/// Dummy protocol for TCP.
IP = 0,
/// Internet Control Message Protocol.
ICMP = 1,
/// Internet Group Management Protocol.
IGMP = 2,
/// IPIP tunnels (older KA9Q tunnels use 94).
IPIP = 4,
/// Transmission Control Protocol.
TCP = 6,
/// Exterior Gateway Protocol.
EGP = 8,
/// PUP protocol.
PUP = 12,
/// User Datagram Protocol.
UDP = 17,
/// XNS IDP protocol.
IDP = 22,
/// SO Transport Protocol Class 4.
TP = 29,
/// Datagram Congestion Control Protocol.
DCCP = 33,
/// IPv6-in-IPv4 tunnelling.
IPv6 = 41,
/// RSVP Protocol.
RSVP = 46,
/// Generic Routing Encapsulation. (Cisco GRE) (rfc 1701, 1702)
GRE = 47,
/// Encapsulation Security Payload protocol
ESP = 50,
/// Authentication Header protocol
AH = 51,
/// Multicast Transport Protocol.
MTP = 92,
/// IP option pseudo header for BEET
BEETPH = 94,
/// Encapsulation Header.
ENCAP = 98,
/// Protocol Independent Multicast.
PIM = 103,
/// Compression Header Protocol.
COMP = 108,
/// Stream Control Transport Protocol
SCTP = 132,
/// UDP-Lite protocol (RFC 3828)
UDPLITE = 136,
/// MPLS in IP (RFC 4023)
MPLSINIP = 137,
/// Ethernet-within-IPv6 Encapsulation
ETHERNET = 143,
/// Raw IP packets
RAW = 255,
/// Multipath TCP connection
MPTCP = 262,
}
impl TryFrom<u16> for IPProtocol {
type Error = system_error::SystemError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
match <Self as num_traits::FromPrimitive>::from_u16(value) {
Some(p) => Ok(p),
None => Err(system_error::SystemError::EPROTONOSUPPORT),
}
}
}
impl From<IPProtocol> for u16 {
fn from(value: IPProtocol) -> Self {
<IPProtocol as num_traits::ToPrimitive>::to_u16(&value).unwrap()
}
}

View File

@ -0,0 +1,32 @@
mod option;
pub use option::Options;
mod option_level;
pub use option_level::OptionsLevel;
mod msg_flag;
pub use msg_flag::MessageFlag;
mod ipproto;
pub use ipproto::IPProtocol;
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
pub enum Type {
Stream = 1,
Datagram = 2,
Raw = 3,
RDM = 4,
SeqPacket = 5,
DCCP = 6,
Packet = 10,
}
use crate::net::syscall_util::SysArgSocketType;
impl TryFrom<SysArgSocketType> for Type {
type Error = system_error::SystemError;
fn try_from(x: SysArgSocketType) -> Result<Self, Self::Error> {
use num_traits::FromPrimitive;
return <Self as FromPrimitive>::from_u32(x.types().bits())
.ok_or(system_error::SystemError::EINVAL);
}
}

View File

@ -0,0 +1,110 @@
bitflags::bitflags! {
/// # Message Flags
/// Flags we can use with send/ and recv. \
/// Added those for 1003.1g not all are supported yet
/// ## Reference
/// - [Linux Socket Flags](https://code.dragonos.org.cn/xref/linux-6.6.21/include/linux/socket.h#299)
pub struct MessageFlag: u32 {
/// `MSG_OOB`
/// `0b0000_0001`\
/// Process out-of-band data.
const OOB = 1;
/// `MSG_PEEK`
/// `0b0000_0010`\
/// Peek at an incoming message.
const PEEK = 2;
/// `MSG_DONTROUTE`
/// `0b0000_0100`\
/// Don't use routing tables.
const DONTROUTE = 4;
/// `MSG_TRYHARD`
/// `0b0000_0100`\
/// `MSG_TRYHARD` is not defined in the standard, but it is used in Linux.
const TRYHARD = 4;
/// `MSG_CTRUNC`
/// `0b0000_1000`\
/// Control data lost before delivery.
const CTRUNC = 8;
/// `MSG_PROBE`
/// `0b0001_0000`\
const PROBE = 0x10;
/// `MSG_TRUNC`
/// `0b0010_0000`\
/// Data truncated before delivery.
const TRUNC = 0x20;
/// `MSG_DONTWAIT`
/// `0b0100_0000`\
/// This flag is used to make the socket non-blocking.
const DONTWAIT = 0x40;
/// `MSG_EOR`
/// `0b1000_0000`\
/// End of record.
const EOR = 0x80;
/// `MSG_WAITALL`
/// `0b0001_0000_0000`\
/// Wait for full request or error.
const WAITALL = 0x100;
/// `MSG_FIN`
/// `0b0010_0000_0000`\
/// Terminate the connection.
const FIN = 0x200;
/// `MSG_SYN`
/// `0b0100_0000_0000`\
/// Synchronize sequence numbers.
const SYN = 0x400;
/// `MSG_CONFIRM`
/// `0b1000_0000_0000`\
/// Confirm path validity.
const CONFIRM = 0x800;
/// `MSG_RST`
/// `0b0001_0000_0000_0000`\
/// Reset the connection.
const RST = 0x1000;
/// `MSG_ERRQUEUE`
/// `0b0010_0000_0000_0000`\
/// Fetch message from error queue.
const ERRQUEUE = 0x2000;
/// `MSG_NOSIGNAL`
/// `0b0100_0000_0000_0000`\
/// Do not generate a signal.
const NOSIGNAL = 0x4000;
/// `MSG_MORE`
/// `0b1000_0000_0000_0000`\
/// Sender will send more.
const MORE = 0x8000;
/// `MSG_WAITFORONE`
/// `0b0001_0000_0000_0000_0000`\
/// For nonblocking operation.
const WAITFORONE = 0x10000;
/// `MSG_SENDPAGE_NOPOLICY`
/// `0b0010_0000_0000_0000_0000`\
/// Sendpage: do not apply policy.
const SENDPAGE_NOPOLICY = 0x10000;
/// `MSG_BATCH`
/// `0b0100_0000_0000_0000_0000`\
/// Sendpage: next message is batch.
const BATCH = 0x40000;
/// `MSG_EOF`
const EOF = Self::FIN.bits;
/// `MSG_NO_SHARED_FRAGS`
const NO_SHARED_FRAGS = 0x80000;
/// `MSG_SENDPAGE_DECRYPTED`
const SENDPAGE_DECRYPTED = 0x10_0000;
/// `MSG_ZEROCOPY`
const ZEROCOPY = 0x400_0000;
/// `MSG_SPLICE_PAGES`
const SPLICE_PAGES = 0x800_0000;
/// `MSG_FASTOPEN`
const FASTOPEN = 0x2000_0000;
/// `MSG_CMSG_CLOEXEC`
const CMSG_CLOEXEC = 0x4000_0000;
/// `MSG_CMSG_COMPAT`
// if define CONFIG_COMPAT
// const CMSG_COMPAT = 0x8000_0000;
const CMSG_COMPAT = 0;
/// `MSG_INTERNAL_SENDMSG_FLAGS`
const INTERNAL_SENDMSG_FLAGS
= Self::SPLICE_PAGES.bits | Self::SENDPAGE_NOPOLICY.bits | Self::SENDPAGE_DECRYPTED.bits;
}
}

View File

@ -0,0 +1,92 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
#[allow(non_camel_case_types)]
pub enum Options {
DEBUG = 1,
REUSEADDR = 2,
TYPE = 3,
ERROR = 4,
DONTROUTE = 5,
BROADCAST = 6,
SNDBUF = 7,
RCVBUF = 8,
SNDBUFFORCE = 32,
RCVBUFFORCE = 33,
KEEPALIVE = 9,
OOBINLINE = 10,
NO_CHECK = 11,
PRIORITY = 12,
LINGER = 13,
BSDCOMPAT = 14,
REUSEPORT = 15,
PASSCRED = 16,
PEERCRED = 17,
RCVLOWAT = 18,
SNDLOWAT = 19,
RCVTIMEO_OLD = 20,
SNDTIMEO_OLD = 21,
SECURITY_AUTHENTICATION = 22,
SECURITY_ENCRYPTION_TRANSPORT = 23,
SECURITY_ENCRYPTION_NETWORK = 24,
BINDTODEVICE = 25,
/// 与GET_FILTER相同
ATTACH_FILTER = 26,
DETACH_FILTER = 27,
PEERNAME = 28,
ACCEPTCONN = 30,
PEERSEC = 31,
PASSSEC = 34,
MARK = 36,
PROTOCOL = 38,
DOMAIN = 39,
RXQ_OVFL = 40,
/// 与SCM_WIFI_STATUS相同
WIFI_STATUS = 41,
PEEK_OFF = 42,
/* Instruct lower device to use last 4-bytes of skb data as FCS */
NOFCS = 43,
LOCK_FILTER = 44,
SELECT_ERR_QUEUE = 45,
BUSY_POLL = 46,
MAX_PACING_RATE = 47,
BPF_EXTENSIONS = 48,
INCOMING_CPU = 49,
ATTACH_BPF = 50,
// DETACH_BPF = DETACH_FILTER,
ATTACH_REUSEPORT_CBPF = 51,
ATTACH_REUSEPORT_EBPF = 52,
CNX_ADVICE = 53,
SCM_TIMESTAMPING_OPT_STATS = 54,
MEMINFO = 55,
INCOMING_NAPI_ID = 56,
COOKIE = 57,
SCM_TIMESTAMPING_PKTINFO = 58,
PEERGROUPS = 59,
ZEROCOPY = 60,
/// 与SCM_TXTIME相同
TXTIME = 61,
BINDTOIFINDEX = 62,
TIMESTAMP_OLD = 29,
TIMESTAMPNS_OLD = 35,
TIMESTAMPING_OLD = 37,
TIMESTAMP_NEW = 63,
TIMESTAMPNS_NEW = 64,
TIMESTAMPING_NEW = 65,
RCVTIMEO_NEW = 66,
SNDTIMEO_NEW = 67,
DETACH_REUSEPORT_BPF = 68,
PREFER_BUSY_POLL = 69,
BUSY_POLL_BUDGET = 70,
NETNS_COOKIE = 71,
BUF_LOCK = 72,
RESERVE_MEM = 73,
TXREHASH = 74,
RCVMARK = 75,
}
impl TryFrom<u32> for Options {
type Error = system_error::SystemError;
fn try_from(x: u32) -> Result<Self, Self::Error> {
use num_traits::FromPrimitive;
return <Self as FromPrimitive>::from_u32(x).ok_or(system_error::SystemError::EINVAL);
}
}

View File

@ -0,0 +1,115 @@
// pub const SOL_SOCKET: u8 = 1,
// bitflags::bitflags! {
// pub struct OptionsLevel: u32 {
// const IP = 0,
// // const SOL_ICMP = 1, // No-no-no! Due to Linux :-) we cannot
// const SOCKET = 1,
// const TCP = 6,
// const UDP = 17,
// const IPV6 = 41,
// const ICMPV6 = 58,
// const SCTP = 132,
// const UDPLITE = 136, // UDP-Lite (RFC 3828)
// const RAW = 255,
// const IPX = 256,
// const AX25 = 257,
// const ATALK = 258,
// const NETROM = 259,
// const ROSE = 260,
// const DECNET = 261,
// const X25 = 262,
// const PACKET = 263,
// const ATM = 264, // ATM layer (cell level)
// const AAL = 265, // ATM Adaption Layer (packet level)
// const IRDA = 266,
// const NETBEUI = 267,
// const LLC = 268,
// const DCCP = 269,
// const NETLINK = 270,
// const TIPC = 271,
// const RXRPC = 272,
// const PPPOL2TP = 273,
// const BLUETOOTH = 274,
// const PNPIPE = 275,
// const RDS = 276,
// const IUCV = 277,
// const CAIF = 278,
// const ALG = 279,
// const NFC = 280,
// const KCM = 281,
// const TLS = 282,
// const XDP = 283,
// const MPTCP = 284,
// const MCTP = 285,
// const SMC = 286,
// const VSOCK = 287,
// }
// }
/// # SOL (Socket Option Level)
/// Setsockoptions(2) level. Thanks to BSD these must match IPPROTO_xxx
/// ## Reference
/// - [Setsockoptions(2) level](https://code.dragonos.org.cn/xref/linux-6.6.21/include/linux/socket.h#345)
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
#[allow(non_camel_case_types)]
pub enum OptionsLevel {
IP = 0,
SOCKET = 1,
// ICMP = 1, No-no-no! Due to Linux :-) we cannot
TCP = 6,
UDP = 17,
IPV6 = 41,
ICMPV6 = 58,
SCTP = 132,
UDPLITE = 136, // UDP-Lite (RFC 3828)
RAW = 255,
IPX = 256,
AX25 = 257,
ATALK = 258,
NETROM = 259,
ROSE = 260,
DECNET = 261,
X25 = 262,
PACKET = 263,
ATM = 264, // ATM layer (cell level)
AAL = 265, // ATM Adaption Layer (packet level)
IRDA = 266,
NETBEUI = 267,
LLC = 268,
DCCP = 269,
NETLINK = 270,
TIPC = 271,
RXRPC = 272,
PPPOL2TP = 273,
BLUETOOTH = 274,
PNPIPE = 275,
RDS = 276,
IUCV = 277,
CAIF = 278,
ALG = 279,
NFC = 280,
KCM = 281,
TLS = 282,
XDP = 283,
MPTCP = 284,
MCTP = 285,
SMC = 286,
VSOCK = 287,
}
impl TryFrom<u32> for OptionsLevel {
type Error = system_error::SystemError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
match <Self as num_traits::FromPrimitive>::from_u32(value) {
Some(p) => Ok(p),
None => Err(system_error::SystemError::EPROTONOSUPPORT),
}
}
}
impl From<OptionsLevel> for u32 {
fn from(value: OptionsLevel) -> Self {
<OptionsLevel as num_traits::ToPrimitive>::to_u32(&value).unwrap()
}
}

View File

@ -0,0 +1,133 @@
// bitflags! {
// // #[derive(PartialEq, Eq, Debug, Clone, Copy)]
// pub struct Options: u32 {
// const DEBUG = 1;
// const REUSEADDR = 2;
// const TYPE = 3;
// const ERROR = 4;
// const DONTROUTE = 5;
// const BROADCAST = 6;
// const SNDBUF = 7;
// const RCVBUF = 8;
// const SNDBUFFORCE = 32;
// const RCVBUFFORCE = 33;
// const KEEPALIVE = 9;
// const OOBINLINE = 10;
// const NO_CHECK = 11;
// const PRIORITY = 12;
// const LINGER = 13;
// const BSDCOMPAT = 14;
// const REUSEPORT = 15;
// const PASSCRED = 16;
// const PEERCRED = 17;
// const RCVLOWAT = 18;
// const SNDLOWAT = 19;
// const RCVTIMEO_OLD = 20;
// const SNDTIMEO_OLD = 21;
//
// const SECURITY_AUTHENTICATION = 22;
// const SECURITY_ENCRYPTION_TRANSPORT = 23;
// const SECURITY_ENCRYPTION_NETWORK = 24;
//
// const BINDTODEVICE = 25;
//
// /// 与GET_FILTER相同
// const ATTACH_FILTER = 26;
// const DETACH_FILTER = 27;
//
// const PEERNAME = 28;
//
// const ACCEPTCONN = 30;
//
// const PEERSEC = 31;
// const PASSSEC = 34;
//
// const MARK = 36;
//
// const PROTOCOL = 38;
// const DOMAIN = 39;
//
// const RXQ_OVFL = 40;
//
// /// 与SCM_WIFI_STATUS相同
// const WIFI_STATUS = 41;
// const PEEK_OFF = 42;
//
// /* Instruct lower device to use last 4-bytes of skb data as FCS */
// const NOFCS = 43;
//
// const LOCK_FILTER = 44;
// const SELECT_ERR_QUEUE = 45;
// const BUSY_POLL = 46;
// const MAX_PACING_RATE = 47;
// const BPF_EXTENSIONS = 48;
// const INCOMING_CPU = 49;
// const ATTACH_BPF = 50;
// // DETACH_BPF = DETACH_FILTER;
// const ATTACH_REUSEPORT_CBPF = 51;
// const ATTACH_REUSEPORT_EBPF = 52;
//
// const CNX_ADVICE = 53;
// const SCM_TIMESTAMPING_OPT_STATS = 54;
// const MEMINFO = 55;
// const INCOMING_NAPI_ID = 56;
// const COOKIE = 57;
// const SCM_TIMESTAMPING_PKTINFO = 58;
// const PEERGROUPS = 59;
// const ZEROCOPY = 60;
// /// 与SCM_TXTIME相同
// const TXTIME = 61;
//
// const BINDTOIFINDEX = 62;
//
// const TIMESTAMP_OLD = 29;
// const TIMESTAMPNS_OLD = 35;
// const TIMESTAMPING_OLD = 37;
// const TIMESTAMP_NEW = 63;
// const TIMESTAMPNS_NEW = 64;
// const TIMESTAMPING_NEW = 65;
//
// const RCVTIMEO_NEW = 66;
// const SNDTIMEO_NEW = 67;
//
// const DETACH_REUSEPORT_BPF = 68;
//
// const PREFER_BUSY_POLL = 69;
// const BUSY_POLL_BUDGET = 70;
//
// const NETNS_COOKIE = 71;
// const BUF_LOCK = 72;
// const RESERVE_MEM = 73;
// const TXREHASH = 74;
// const RCVMARK = 75;
// }
// }
// bitflags::bitflags! {
// pub struct Level: i32 {
// const SOL_SOCKET = 1;
// const IPPROTO_IP = super::ip::Protocol::IP.bits();
// const IPPROTO_IPV6 = super::ip::Protocol::IPv6.bits();
// const IPPROTO_TCP = super::ip::Protocol::TCP.bits();
// }
// }
// bitflags! {
// /// @brief socket的选项
// #[derive(Default)]
// pub struct Options: u32 {
// /// 是否阻塞
// const BLOCK = 1 << 0;
// /// 是否允许广播
// const BROADCAST = 1 << 1;
// /// 是否允许多播
// const MULTICAST = 1 << 2;
// /// 是否允许重用地址
// const REUSEADDR = 1 << 3;
// /// 是否允许重用端口
// const REUSEPORT = 1 << 4;
// }
// }

View File

@ -0,0 +1,43 @@
use crate::{filesystem::vfs::InodeId, net::socket};
use alloc::{string::String, sync::Arc};
pub use smoltcp::wire::IpEndpoint;
pub use socket::netlink::endpoint::NetlinkEndpoint;
#[derive(Debug, Clone)]
pub enum Endpoint {
/// 链路层端点
LinkLayer(LinkLayerEndpoint),
/// 网络层端点
Ip(IpEndpoint),
/// inode端点,Unix实际保存的端点
Inode((Arc<socket::Inode>, String)),
/// Unix传递id索引和path所用的端点
Unixpath((InodeId, String)),
/// NetLink端点
Netlink(NetlinkEndpoint),
}
/// @brief 链路层端点
#[derive(Debug, Clone)]
pub struct LinkLayerEndpoint {
/// 网卡的接口号
pub interface: usize,
}
impl LinkLayerEndpoint {
/// @brief 创建一个链路层端点
///
/// @param interface 网卡的接口号
///
/// @return 返回创建的链路层端点
pub fn new(interface: usize) -> Self {
Self { interface }
}
}
impl From<IpEndpoint> for Endpoint {
fn from(endpoint: IpEndpoint) -> Self {
Self::Ip(endpoint)
}
}

View File

@ -0,0 +1,121 @@
/// # AddressFamily
/// Socket address families.
/// ## Reference
/// https://code.dragonos.org.cn/xref/linux-5.19.10/include/linux/socket.h#180
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
pub enum AddressFamily {
/// AF_UNSPEC 表示地址族未指定
Unspecified = 0,
/// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同)
Unix = 1,
/// AF_INET 表示IPv4的socket
INet = 2,
/// AF_AX25 表示AMPR AX.25的socket
AX25 = 3,
/// AF_IPX 表示IPX的socket
IPX = 4,
/// AF_APPLETALK 表示Appletalk的socket
Appletalk = 5,
/// AF_NETROM 表示AMPR NET/ROM的socket
Netrom = 6,
/// AF_BRIDGE 表示多协议桥接的socket
Bridge = 7,
/// AF_ATMPVC 表示ATM PVCs的socket
Atmpvc = 8,
/// AF_X25 表示X.25的socket
X25 = 9,
/// AF_INET6 表示IPv6的socket
INet6 = 10,
/// AF_ROSE 表示AMPR ROSE的socket
Rose = 11,
/// AF_DECnet Reserved for DECnet project
Decnet = 12,
/// AF_NETBEUI Reserved for 802.2LLC project
Netbeui = 13,
/// AF_SECURITY 表示Security callback的伪AF
Security = 14,
/// AF_KEY 表示Key management API
Key = 15,
/// AF_NETLINK 表示Netlink的socket
Netlink = 16,
/// AF_PACKET 表示Low level packet interface
Packet = 17,
/// AF_ASH 表示Ash
Ash = 18,
/// AF_ECONET 表示Acorn Econet
Econet = 19,
/// AF_ATMSVC 表示ATM SVCs
Atmsvc = 20,
/// AF_RDS 表示Reliable Datagram Sockets
Rds = 21,
/// AF_SNA 表示Linux SNA Project
Sna = 22,
/// AF_IRDA 表示IRDA sockets
Irda = 23,
/// AF_PPPOX 表示PPPoX sockets
Pppox = 24,
/// AF_WANPIPE 表示WANPIPE API sockets
WanPipe = 25,
/// AF_LLC 表示Linux LLC
Llc = 26,
/// AF_IB 表示Native InfiniBand address
/// 介绍https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks
Ib = 27,
/// AF_MPLS 表示MPLS
Mpls = 28,
/// AF_CAN 表示Controller Area Network
Can = 29,
/// AF_TIPC 表示TIPC sockets
Tipc = 30,
/// AF_BLUETOOTH 表示Bluetooth sockets
Bluetooth = 31,
/// AF_IUCV 表示IUCV sockets
Iucv = 32,
/// AF_RXRPC 表示RxRPC sockets
Rxrpc = 33,
/// AF_ISDN 表示mISDN sockets
Isdn = 34,
/// AF_PHONET 表示Phonet sockets
Phonet = 35,
/// AF_IEEE802154 表示IEEE 802.15.4 sockets
Ieee802154 = 36,
/// AF_CAIF 表示CAIF sockets
Caif = 37,
/// AF_ALG 表示Algorithm sockets
Alg = 38,
/// AF_NFC 表示NFC sockets
Nfc = 39,
/// AF_VSOCK 表示vSockets
Vsock = 40,
/// AF_KCM 表示Kernel Connection Multiplexor
Kcm = 41,
/// AF_QIPCRTR 表示Qualcomm IPC Router
Qipcrtr = 42,
/// AF_SMC 表示SMC-R sockets.
/// reserve number for PF_SMC protocol family that reuses AF_INET address family
Smc = 43,
/// AF_XDP 表示XDP sockets
Xdp = 44,
/// AF_MCTP 表示Management Component Transport Protocol
Mctp = 45,
/// AF_MAX 表示最大的地址族
Max = 46,
}
use system_error::SystemError;
impl core::convert::TryFrom<u16> for AddressFamily {
type Error = system_error::SystemError;
fn try_from(x: u16) -> Result<Self, Self::Error> {
use num_traits::FromPrimitive;
use SystemError::*;
return <Self as FromPrimitive>::from_u16(x).ok_or(EINVAL);
}
}
use crate::net::socket;
use alloc::sync::Arc;
pub trait Family {
fn socket(stype: socket::Type, protocol: u32) -> Result<Arc<socket::Inode>, SystemError>;
}

View File

@ -1,42 +0,0 @@
use ida::IdAllocator;
use smoltcp::iface::SocketHandle;
use crate::libs::spinlock::SpinLock;
int_like!(KernelHandle, usize);
/// # socket的句柄管理组件
/// 它在smoltcp的SocketHandle上封装了一层增加更多的功能。
/// 比如在socket被关闭时自动释放socket的资源通知系统的其他组件。
#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)]
pub enum GlobalSocketHandle {
Smoltcp(SocketHandle),
Kernel(KernelHandle),
}
static KERNEL_HANDLE_IDA: SpinLock<IdAllocator> =
SpinLock::new(IdAllocator::new(0, usize::MAX).unwrap());
impl GlobalSocketHandle {
pub fn new_smoltcp_handle(handle: SocketHandle) -> Self {
return Self::Smoltcp(handle);
}
pub fn new_kernel_handle() -> Self {
return Self::Kernel(KernelHandle::new(KERNEL_HANDLE_IDA.lock().alloc().unwrap()));
}
pub fn smoltcp_handle(&self) -> Option<SocketHandle> {
if let Self::Smoltcp(sh) = *self {
return Some(sh);
}
None
}
pub fn kernel_handle(&self) -> Option<KernelHandle> {
if let Self::Kernel(kh) = *self {
return Some(kh);
}
None
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,148 @@
use crate::net::{Iface, NET_DEVICES};
use alloc::sync::Arc;
use system_error::SystemError::{self, *};
pub mod port;
pub use port::PortManager;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Types {
Raw,
Icmp,
Udp,
Tcp,
Dhcpv4,
Dns,
}
/**
* listen问题socket在绑定单网卡下的问题
*/
#[derive(Debug)]
pub struct BoundInner {
handle: smoltcp::iface::SocketHandle,
iface: Arc<dyn Iface>,
// inner: Vec<(smoltcp::iface::SocketHandle, Arc<dyn Iface>)>
// address: smoltcp::wire::IpAddress,
}
impl BoundInner {
/// # `bind`
/// 将socket绑定到指定的地址上置入指定的网络接口中
pub fn bind<T>(
socket: T,
// socket_type: Types,
address: &smoltcp::wire::IpAddress,
) -> Result<Self, SystemError>
where
T: smoltcp::socket::AnySocket<'static>,
{
if address.is_unspecified() {
// let inner = Vec::new();
// for (_, iface) in *NET_DEVICES.read_irqsave() {
// let handle = iface.sockets().lock_no_preempt().add(socket);
// iface
// }
// 强绑VirtualIO
log::debug!("Not bind to any iface, bind to virtIO");
let iface = NET_DEVICES
.read_irqsave()
.get(&0)
.expect("??bind without virtIO, serious?")
.clone();
let handle = iface.sockets().lock_no_preempt().add(socket);
return Ok(Self { handle, iface });
} else {
let iface = get_iface_to_bind(address).ok_or(ENODEV)?;
let handle = iface.sockets().lock_no_preempt().add(socket);
// log::debug!("Bind to iface: {}", iface.iface_name());
// return Ok(Self { inner: vec![(handle, iface)] });
return Ok(Self { handle, iface });
}
}
pub fn bind_ephemeral<T>(
socket: T,
// socket_type: Types,
remote: smoltcp::wire::IpAddress,
) -> Result<(Self, smoltcp::wire::IpAddress), SystemError>
where
T: smoltcp::socket::AnySocket<'static>,
{
let (iface, address) = get_ephemeral_iface(&remote);
// let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?;
let handle = iface.sockets().lock_no_preempt().add(socket);
// let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port);
Ok((Self { handle, iface }, address))
}
pub fn port_manager(&self) -> &PortManager {
self.iface.port_manager()
}
pub fn with_mut<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
&self,
mut f: F,
) -> R {
f(self.iface.sockets().lock().get_mut::<T>(self.handle))
}
pub fn with<T: smoltcp::socket::AnySocket<'static>, R, F: Fn(&T) -> R>(&self, f: F) -> R {
f(self.iface.sockets().lock().get::<T>(self.handle))
}
pub fn iface(&self) -> &Arc<dyn Iface> {
&self.iface
}
pub fn release(&self) {
self.iface.sockets().lock().remove(self.handle);
}
}
#[inline]
pub fn get_iface_to_bind(ip_addr: &smoltcp::wire::IpAddress) -> Option<Arc<dyn Iface>> {
// log::debug!("get_iface_to_bind: {:?}", ip_addr);
// if ip_addr.is_unspecified()
crate::net::NET_DEVICES
.read_irqsave()
.iter()
.find(|(_, iface)| {
let guard = iface.smol_iface().lock();
// log::debug!("iface name: {}, ip: {:?}", iface.iface_name(), guard.ip_addrs());
return guard.has_ip_addr(*ip_addr);
})
.map(|(_, iface)| iface.clone())
}
/// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface.
/// If the remote address is the same as that of some iface, we will use the iface.
/// Otherwise, we will use a default interface.
fn get_ephemeral_iface(
remote_ip_addr: &smoltcp::wire::IpAddress,
) -> (Arc<dyn Iface>, smoltcp::wire::IpAddress) {
get_iface_to_bind(remote_ip_addr)
.map(|iface| (iface, *remote_ip_addr))
.or({
let ifaces = NET_DEVICES.read_irqsave();
ifaces.iter().find_map(|(_, iface)| {
iface
.smol_iface()
.lock()
.ip_addrs()
.iter()
.find(|cidr| cidr.contains_addr(remote_ip_addr))
.map(|cidr| (iface.clone(), cidr.address()))
})
})
.or({
NET_DEVICES.read_irqsave().values().next().map(|iface| {
(
iface.clone(),
iface.smol_iface().lock().ip_addrs()[0].address(),
)
})
})
.expect("No network interface")
}

View File

@ -0,0 +1,114 @@
use hashbrown::HashMap;
use system_error::SystemError;
use crate::{
arch::rand::rand,
libs::spinlock::SpinLock,
process::{Pid, ProcessManager},
};
use super::Types::{self, *};
/// # TCP 和 UDP 的端口管理器。
/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
#[derive(Debug)]
pub struct PortManager {
// TCP 端口记录表
tcp_port_table: SpinLock<HashMap<u16, Pid>>,
// UDP 端口记录表
udp_port_table: SpinLock<HashMap<u16, Pid>>,
}
impl PortManager {
pub fn new() -> Self {
return Self {
tcp_port_table: SpinLock::new(HashMap::new()),
udp_port_table: SpinLock::new(HashMap::new()),
};
}
/// @brief 自动分配一个相对应协议中未被使用的PORT如果动态端口均已被占用返回错误码 EADDRINUSE
pub fn get_ephemeral_port(&self, socket_type: Types) -> Result<u16, SystemError> {
// TODO: selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 0;
unsafe {
if EPHEMERAL_PORT == 0 {
EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
}
}
let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数
let mut port: u16;
while remaining > 0 {
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT += 1;
}
port = EPHEMERAL_PORT;
}
// 使用 ListenTable 检查端口是否被占用
let listen_table_guard = match socket_type {
Udp => self.udp_port_table.lock(),
Tcp => self.tcp_port_table.lock(),
_ => panic!("{:?} cann't get a port", socket_type),
};
if listen_table_guard.get(&port).is_none() {
drop(listen_table_guard);
return Ok(port);
}
remaining -= 1;
}
return Err(SystemError::EADDRINUSE);
}
#[inline]
pub fn bind_ephemeral_port(&self, socket_type: Types) -> Result<u16, SystemError> {
let port = self.get_ephemeral_port(socket_type)?;
self.bind_port(socket_type, port)?;
return Ok(port);
}
/// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
///
/// TODO: 增加支持端口复用的逻辑
pub fn bind_port(&self, socket_type: Types, port: u16) -> Result<(), SystemError> {
if port > 0 {
match socket_type {
Udp => {
let mut guard = self.udp_port_table.lock();
if guard.get(&port).is_some() {
return Err(SystemError::EADDRINUSE);
}
guard.insert(port, ProcessManager::current_pid());
}
Tcp => {
let mut guard = self.tcp_port_table.lock();
if guard.get(&port).is_some() {
return Err(SystemError::EADDRINUSE);
}
guard.insert(port, ProcessManager::current_pid());
}
_ => {}
};
}
return Ok(());
}
/// @brief 在对应的端口记录表中将端口和 socket 解绑
/// should call this function when socket is closed or aborted
pub fn unbind_port(&self, socket_type: Types, port: u16) {
match socket_type {
Udp => {
self.udp_port_table.lock().remove(&port);
}
Tcp => {
self.tcp_port_table.lock().remove(&port);
}
_ => {}
};
}
}

View File

@ -0,0 +1,156 @@
use smoltcp;
use system_error::SystemError::{self, *};
use crate::{
libs::spinlock::SpinLock,
net::socket::inet::common::{BoundInner, Types as InetTypes},
};
pub type SmolUdpSocket = smoltcp::socket::udp::Socket<'static>;
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
#[derive(Debug)]
pub struct UnboundUdp {
socket: SmolUdpSocket,
}
impl UnboundUdp {
pub fn new() -> Self {
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
vec![smoltcp::socket::udp::PacketMetadata::EMPTY; DEFAULT_METADATA_BUF_SIZE],
vec![0; DEFAULT_RX_BUF_SIZE],
);
let tx_buffer = smoltcp::socket::udp::PacketBuffer::new(
vec![smoltcp::socket::udp::PacketMetadata::EMPTY; DEFAULT_METADATA_BUF_SIZE],
vec![0; DEFAULT_TX_BUF_SIZE],
);
let socket = SmolUdpSocket::new(rx_buffer, tx_buffer);
return Self { socket };
}
pub fn bind(
mut self,
local_endpoint: smoltcp::wire::IpEndpoint,
) -> Result<BoundUdp, SystemError> {
// let (addr, port) = (local_endpoint.addr, local_endpoint.port);
if self.socket.bind(local_endpoint).is_err() {
return Err(EINVAL);
}
let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?;
inner
.port_manager()
.bind_port(InetTypes::Udp, local_endpoint.port)?;
Ok(BoundUdp {
inner,
remote: SpinLock::new(None),
})
}
pub fn bind_ephemeral(self, remote: smoltcp::wire::IpAddress) -> Result<BoundUdp, SystemError> {
// let (addr, port) = (remote.addr, remote.port);
let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote)?;
let bound_port = inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?;
let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port);
Ok(BoundUdp {
inner,
remote: SpinLock::new(Some(endpoint)),
})
}
pub fn close(&mut self) {
self.socket.close();
}
}
#[derive(Debug)]
pub struct BoundUdp {
inner: BoundInner,
remote: SpinLock<Option<smoltcp::wire::IpEndpoint>>,
}
impl BoundUdp {
pub fn with_mut_socket<F, T>(&self, f: F) -> T
where
F: FnMut(&mut SmolUdpSocket) -> T,
{
self.inner.with_mut(f)
}
pub fn with_socket<F, T>(&self, f: F) -> T
where
F: Fn(&SmolUdpSocket) -> T,
{
self.inner.with(f)
}
pub fn endpoint(&self) -> smoltcp::wire::IpListenEndpoint {
self.inner
.with::<SmolUdpSocket, _, _>(|socket| socket.endpoint())
}
pub fn connect(&self, remote: smoltcp::wire::IpEndpoint) {
self.remote.lock().replace(remote);
}
#[inline]
pub fn try_recv(
&self,
buf: &mut [u8],
) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> {
self.with_mut_socket(|socket| {
if socket.can_recv() {
if let Ok((size, metadata)) = socket.recv_slice(buf) {
return Ok((size, metadata.endpoint));
}
}
return Err(EAGAIN_OR_EWOULDBLOCK);
})
}
#[inline]
pub fn can_recv(&self) -> bool {
self.with_socket(|socket| socket.can_recv())
}
pub fn try_send(
&self,
buf: &[u8],
to: Option<smoltcp::wire::IpEndpoint>,
) -> Result<usize, SystemError> {
let remote = to.or(*self.remote.lock()).ok_or(ENOTCONN)?;
let result = self.with_mut_socket(|socket| {
if socket.can_send() && socket.send_slice(buf, remote).is_ok() {
log::debug!("send {} bytes", buf.len());
return Ok(buf.len());
}
return Err(ENOBUFS);
});
return result;
}
pub fn inner(&self) -> &BoundInner {
&self.inner
}
pub fn close(&self) {
self.inner
.iface()
.port_manager()
.unbind_port(InetTypes::Udp, self.endpoint().port);
self.with_mut_socket(|socket| {
socket.close();
});
}
}
// Udp Inner 负责其内部资源管理
#[derive(Debug)]
pub enum UdpInner {
Unbound(UnboundUdp),
Bound(BoundUdp),
}

View File

@ -0,0 +1,453 @@
use inet::InetSocket;
use smoltcp;
use system_error::SystemError::{self, *};
use crate::filesystem::vfs::IndexNode;
use crate::libs::rwlock::RwLock;
use crate::libs::spinlock::SpinLock;
use crate::net::event_poll::EPollEventType;
use crate::net::net_core::poll_ifaces;
use crate::net::socket::*;
use alloc::sync::{Arc, Weak};
use core::sync::atomic::AtomicBool;
pub mod inner;
use inner::*;
type EP = EPollEventType;
// Udp Socket 负责提供状态切换接口、执行状态切换
#[derive(Debug)]
pub struct UdpSocket {
inner: RwLock<Option<UdpInner>>,
nonblock: AtomicBool,
wait_queue: WaitQueue,
self_ref: Weak<UdpSocket>,
}
impl UdpSocket {
pub fn new(nonblock: bool) -> Arc<Self> {
return Arc::new_cyclic(|me| Self {
inner: RwLock::new(Some(UdpInner::Unbound(UnboundUdp::new()))),
nonblock: AtomicBool::new(nonblock),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
});
}
pub fn is_nonblock(&self) -> bool {
self.nonblock.load(core::sync::atomic::Ordering::Relaxed)
}
pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> {
let mut inner = self.inner.write();
if let Some(UdpInner::Unbound(unbound)) = inner.take() {
let bound = unbound.bind(local_endpoint)?;
bound
.inner()
.iface()
.common()
.bind_socket(self.self_ref.upgrade().unwrap());
*inner = Some(UdpInner::Bound(bound));
return Ok(());
}
return Err(EINVAL);
}
pub fn bind_emphemeral(&self, remote: smoltcp::wire::IpAddress) -> Result<(), SystemError> {
let mut inner_guard = self.inner.write();
let bound = match inner_guard.take().expect("Udp inner is None") {
UdpInner::Bound(inner) => inner,
UdpInner::Unbound(inner) => inner.bind_ephemeral(remote)?,
};
inner_guard.replace(UdpInner::Bound(bound));
return Ok(());
}
pub fn is_bound(&self) -> bool {
let inner = self.inner.read();
if let Some(UdpInner::Bound(_)) = &*inner {
return true;
}
return false;
}
pub fn close(&self) {
let mut inner = self.inner.write();
if let Some(UdpInner::Bound(bound)) = &mut *inner {
bound.close();
inner.take();
}
}
pub fn try_recv(
&self,
buf: &mut [u8],
) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> {
poll_ifaces();
let received = match self.inner.read().as_ref().expect("Udp Inner is None") {
UdpInner::Bound(bound) => bound.try_recv(buf),
_ => Err(ENOTCONN),
};
return received;
}
#[inline]
pub fn can_recv(&self) -> bool {
self.on_events().contains(EP::EPOLLIN)
}
#[inline]
pub fn can_send(&self) -> bool {
self.on_events().contains(EP::EPOLLOUT)
}
pub fn try_send(
&self,
buf: &[u8],
to: Option<smoltcp::wire::IpEndpoint>,
) -> Result<usize, SystemError> {
{
let mut inner_guard = self.inner.write();
let inner = match inner_guard.take().expect("Udp Inner is None") {
UdpInner::Bound(bound) => bound,
UdpInner::Unbound(unbound) => {
unbound.bind_ephemeral(to.ok_or(EADDRNOTAVAIL)?.addr)?
}
};
// size = inner.try_send(buf, to)?;
inner_guard.replace(UdpInner::Bound(inner));
};
// Optimize: 拿两次锁的平均效率是否比一次长时间的读锁效率要高?
let result = match self.inner.read().as_ref().expect("Udp Inner is None") {
UdpInner::Bound(bound) => bound.try_send(buf, to),
_ => Err(ENOTCONN),
};
poll_ifaces();
return result;
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
if self.is_nonblock() {
return self.try_recv(buf).map(|(size, _)| size);
} else {
// return self
// .wait_queue
// .busy_wait(EP::EPOLLIN, || self.try_recv(buf).map(|(size, _)| size));
todo!()
}
}
pub fn on_events(&self) -> EPollEventType {
let mut event = EPollEventType::empty();
match self.inner.read().as_ref().unwrap() {
UdpInner::Unbound(_) => {
event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND);
}
UdpInner::Bound(bound) => {
let (can_recv, can_send) =
bound.with_socket(|socket| (socket.can_recv(), socket.can_send()));
if can_recv {
event.insert(EP::EPOLLIN | EP::EPOLLRDNORM);
}
if can_send {
event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND);
} else {
todo!("缓冲区空间不够,需要使用信号处理");
}
}
}
return event;
}
}
impl Socket for UdpSocket {
fn wait_queue(&self) -> &WaitQueue {
&self.wait_queue
}
fn poll(&self) -> usize {
self.on_events().bits() as usize
}
fn bind(&self, local_endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(local_endpoint) = local_endpoint {
return self.do_bind(local_endpoint);
}
Err(EAFNOSUPPORT)
}
fn send_buffer_size(&self) -> usize {
match self.inner.read().as_ref().unwrap() {
UdpInner::Bound(bound) => bound.with_socket(|socket| socket.payload_send_capacity()),
_ => inner::DEFAULT_TX_BUF_SIZE,
}
}
fn recv_buffer_size(&self) -> usize {
match self.inner.read().as_ref().unwrap() {
UdpInner::Bound(bound) => bound.with_socket(|socket| socket.payload_recv_capacity()),
_ => inner::DEFAULT_RX_BUF_SIZE,
}
}
fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(remote) = endpoint {
self.bind_emphemeral(remote.addr)?;
if let UdpInner::Bound(inner) = self.inner.read().as_ref().expect("UDP Inner disappear")
{
inner.connect(remote);
return Ok(());
} else {
panic!("");
}
}
return Err(EAFNOSUPPORT);
}
fn send(&self, buffer: &[u8], flags: MessageFlag) -> Result<usize, SystemError> {
// if flags.contains(MessageFlag::DONTWAIT) {
return self.try_send(buffer, None);
// } else {
// // return self
// // .wait_queue
// // .busy_wait(EP::EPOLLOUT, || self.try_send(buffer, None));
// todo!()
// }
}
fn send_to(
&self,
buffer: &[u8],
flags: MessageFlag,
address: Endpoint,
) -> Result<usize, SystemError> {
// if flags.contains(MessageFlag::DONTWAIT) {
if let Endpoint::Ip(remote) = address {
return self.try_send(buffer, Some(remote));
}
// } else {
// // return self
// // .wait_queue
// // .busy_wait(EP::EPOLLOUT, || {
// // if let Endpoint::Ip(remote) = address {
// // return self.try_send(buffer, Some(remote.addr));
// // }
// // return Err(EAFNOSUPPORT);
// // });
// todo!()
// }
return Err(EINVAL);
}
fn recv(&self, buffer: &mut [u8], flags: MessageFlag) -> Result<usize, SystemError> {
use crate::sched::SchedMode;
return if self.is_nonblock() || flags.contains(MessageFlag::DONTWAIT) {
self.try_recv(buffer)
} else {
loop {
match self.try_recv(buffer) {
Err(EAGAIN_OR_EWOULDBLOCK) => {
wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?;
}
result => break result,
}
}
}
.map(|(len, _)| len);
}
fn recv_from(
&self,
buffer: &mut [u8],
flags: MessageFlag,
address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> {
use crate::sched::SchedMode;
// could block io
if let Some(endpoint) = address {
self.connect(endpoint)?;
}
return if self.is_nonblock() || flags.contains(MessageFlag::DONTWAIT) {
self.try_recv(buffer)
} else {
loop {
match self.try_recv(buffer) {
Err(EAGAIN_OR_EWOULDBLOCK) => {
wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?;
log::debug!("UdpSocket::recv_from: wake up");
}
result => break result,
}
}
}
.map(|(len, remote)| (len, Endpoint::Ip(remote)));
}
}
impl InetSocket for UdpSocket {
fn on_iface_events(&self) {
return;
}
}
bitflags! {
pub struct UdpSocketOptions: u32 {
const ZERO = 0; /* No UDP options */
const UDP_CORK = 1; /* Never send partially complete segments */
const UDP_ENCAP = 100; /* Set the socket to accept encapsulated packets */
const UDP_NO_CHECK6_TX = 101; /* Disable sending checksum for UDP6X */
const UDP_NO_CHECK6_RX = 102; /* Disable accepting checksum for UDP6 */
const UDP_SEGMENT = 103; /* Set GSO segmentation size */
const UDP_GRO = 104; /* This socket can receive UDP GRO packets */
const UDPLITE_SEND_CSCOV = 10; /* sender partial coverage (as sent) */
const UDPLITE_RECV_CSCOV = 11; /* receiver partial coverage (threshold ) */
}
}
bitflags! {
pub struct UdpEncapTypes: u8 {
const ZERO = 0;
const ESPINUDP_NON_IKE = 1; // draft-ietf-ipsec-nat-t-ike-00/01
const ESPINUDP = 2; // draft-ietf-ipsec-udp-encaps-06
const L2TPINUDP = 3; // rfc2661
const GTP0 = 4; // GSM TS 09.60
const GTP1U = 5; // 3GPP TS 29.060
const RXRPC = 6;
const ESPINTCP = 7; // Yikes, this is really xfrm encap types.
}
}
// fn sock_set_option(
// &self,
// _socket: &mut udp::Socket,
// _level: SocketOptionsLevel,
// optname: PosixSocketOption,
// _optval: &[u8],
// ) -> Result<(), SystemError> {
// use PosixSocketOption::*;
// use SystemError::*;
// if optname == SO_BINDTODEVICE {
// todo!("SO_BINDTODEVICE");
// }
// match optname {
// SO_TYPE => {}
// SO_PROTOCOL => {}
// SO_DOMAIN => {}
// SO_ERROR => {
// return Err(ENOPROTOOPT);
// }
// SO_TIMESTAMP_OLD => {}
// SO_TIMESTAMP_NEW => {}
// SO_TIMESTAMPNS_OLD => {}
// SO_TIMESTAMPING_OLD => {}
// SO_RCVTIMEO_OLD => {}
// SO_SNDTIMEO_OLD => {}
// // if define CONFIG_NET_RX_BUSY_POLL
// SO_BUSY_POLL | SO_PREFER_BUSY_POLL | SO_BUSY_POLL_BUDGET => {
// debug!("Unsupported socket option: {:?}", optname);
// return Err(ENOPROTOOPT);
// }
// // end if
// optname => {
// debug!("Unsupported socket option: {:?}", optname);
// return Err(ENOPROTOOPT);
// }
// }
// return Ok(());
// }
// fn udp_set_option(
// &self,
// level: SocketOptionsLevel,
// optname: usize,
// optval: &[u8],
// ) -> Result<(), SystemError> {
// use PosixSocketOption::*;
// let so_opt_name =
// PosixSocketOption::try_from(optname as i32)
// .map_err(|_| SystemError::ENOPROTOOPT)?;
// if level == SocketOptionsLevel::SOL_SOCKET {
// self.with_mut_socket(f)
// self.sock_set_option(self., level, so_opt_name, optval)?;
// if so_opt_name == SO_RCVBUF || so_opt_name == SO_RCVBUFFORCE {
// todo!("SO_RCVBUF");
// }
// }
// match UdpSocketOptions::from_bits_truncate(optname as u32) {
// UdpSocketOptions::UDP_CORK => {
// todo!("UDP_CORK");
// }
// UdpSocketOptions::UDP_ENCAP => {
// match UdpEncapTypes::from_bits_truncate(optval[0]) {
// UdpEncapTypes::ESPINUDP_NON_IKE => {
// todo!("ESPINUDP_NON_IKE");
// }
// UdpEncapTypes::ESPINUDP => {
// todo!("ESPINUDP");
// }
// UdpEncapTypes::L2TPINUDP => {
// todo!("L2TPINUDP");
// }
// UdpEncapTypes::GTP0 => {
// todo!("GTP0");
// }
// UdpEncapTypes::GTP1U => {
// todo!("GTP1U");
// }
// UdpEncapTypes::RXRPC => {
// todo!("RXRPC");
// }
// UdpEncapTypes::ESPINTCP => {
// todo!("ESPINTCP");
// }
// UdpEncapTypes::ZERO => {}
// _ => {
// return Err(SystemError::ENOPROTOOPT);
// }
// }
// }
// UdpSocketOptions::UDP_NO_CHECK6_TX => {
// todo!("UDP_NO_CHECK6_TX");
// }
// UdpSocketOptions::UDP_NO_CHECK6_RX => {
// todo!("UDP_NO_CHECK6_RX");
// }
// UdpSocketOptions::UDP_SEGMENT => {
// todo!("UDP_SEGMENT");
// }
// UdpSocketOptions::UDP_GRO => {
// todo!("UDP_GRO");
// }
// UdpSocketOptions::UDPLITE_RECV_CSCOV => {
// todo!("UDPLITE_RECV_CSCOV");
// }
// UdpSocketOptions::UDPLITE_SEND_CSCOV => {
// todo!("UDPLITE_SEND_CSCOV");
// }
// UdpSocketOptions::ZERO => {}
// _ => {
// return Err(SystemError::ENOPROTOOPT);
// }
// }
// return Ok(());
// }

View File

@ -0,0 +1,68 @@
bitflags! {
pub struct IpOptions: u32 {
const IP_TOS = 1; // Type of service
const IP_TTL = 2; // Time to live
const IP_HDRINCL = 3; // Header compression
const IP_OPTIONS = 4; // IP options
const IP_ROUTER_ALERT = 5; // Router alert
const IP_RECVOPTS = 6; // Receive options
const IP_RETOPTS = 7; // Return options
const IP_PKTINFO = 8; // Packet information
const IP_PKTOPTIONS = 9; // Packet options
const IP_MTU_DISCOVER = 10; // MTU discovery
const IP_RECVERR = 11; // Receive errors
const IP_RECVTTL = 12; // Receive time to live
const IP_RECVTOS = 13; // Receive type of service
const IP_MTU = 14; // MTU
const IP_FREEBIND = 15; // Freebind
const IP_IPSEC_POLICY = 16; // IPsec policy
const IP_XFRM_POLICY = 17; // IPipsec transform policy
const IP_PASSSEC = 18; // Pass security
const IP_TRANSPARENT = 19; // Transparent
const IP_RECVRETOPTS = 20; // Receive return options (deprecated)
const IP_ORIGDSTADDR = 21; // Originate destination address (used by TProxy)
const IP_RECVORIGDSTADDR = 21; // Receive originate destination address
const IP_MINTTL = 22; // Minimum time to live
const IP_NODEFRAG = 23; // Don't fragment (used by TProxy)
const IP_CHECKSUM = 24; // Checksum offload (used by TProxy)
const IP_BIND_ADDRESS_NO_PORT = 25; // Bind to address without port (used by TProxy)
const IP_RECVFRAGSIZE = 26; // Receive fragment size
const IP_RECVERR_RFC4884 = 27; // Receive ICMPv6 error notifications
const IP_PMTUDISC_DONT = 28; // Don't send DF frames
const IP_PMTUDISC_DO = 29; // Always DF
const IP_PMTUDISC_PROBE = 30; // Ignore dst pmtu
const IP_PMTUDISC_INTERFACE = 31; // Always use interface mtu (ignores dst pmtu)
const IP_PMTUDISC_OMIT = 32; // Weaker version of IP_PMTUDISC_INTERFACE
const IP_MULTICAST_IF = 33; // Multicast interface
const IP_MULTICAST_TTL = 34; // Multicast time to live
const IP_MULTICAST_LOOP = 35; // Multicast loopback
const IP_ADD_MEMBERSHIP = 36; // Add multicast group membership
const IP_DROP_MEMBERSHIP = 37; // Drop multicast group membership
const IP_UNBLOCK_SOURCE = 38; // Unblock source
const IP_BLOCK_SOURCE = 39; // Block source
const IP_ADD_SOURCE_MEMBERSHIP = 40; // Add source multicast group membership
const IP_DROP_SOURCE_MEMBERSHIP = 41; // Drop source multicast group membership
const IP_MSFILTER = 42; // Multicast source filter
const MCAST_JOIN_GROUP = 43; // Join a multicast group
const MCAST_BLOCK_SOURCE = 44; // Block a multicast source
const MCAST_UNBLOCK_SOURCE = 45; // Unblock a multicast source
const MCAST_LEAVE_GROUP = 46; // Leave a multicast group
const MCAST_JOIN_SOURCE_GROUP = 47; // Join a multicast source group
const MCAST_LEAVE_SOURCE_GROUP = 48; // Leave a multicast source group
const MCAST_MSFILTER = 49; // Multicast source filter
const IP_MULTICAST_ALL = 50; // Multicast all
const IP_UNICAST_IF = 51; // Unicast interface
const IP_LOCAL_PORT_RANGE = 52; // Local port range
const IP_PROTOCOL = 53; // Protocol
// ... other flags ...
}
}

View File

@ -0,0 +1,150 @@
use alloc::sync::Arc;
use smoltcp;
use system_error::SystemError::{self, *};
// pub mod raw;
// pub mod icmp;
pub mod common;
pub mod datagram;
pub mod stream;
pub mod syscall;
pub use common::BoundInner;
pub use common::Types;
// pub use raw::RawSocket;
pub use datagram::UdpSocket;
pub use stream::TcpSocket;
pub use syscall::Inet;
use crate::filesystem::vfs::IndexNode;
use super::Socket;
use smoltcp::wire::*;
/// A local endpoint, which indicates that the local endpoint is unspecified.
///
/// According to the Linux man pages and the Linux implementation, `getsockname()` will _not_ fail
/// even if the socket is unbound. Instead, it will return an unspecified socket address. This
/// unspecified endpoint helps with that.
const UNSPECIFIED_LOCAL_ENDPOINT: IpEndpoint =
IpEndpoint::new(IpAddress::Ipv4(Ipv4Address::UNSPECIFIED), 0);
pub trait InetSocket: Socket {
/// `on_iface_events`
/// 通知socket发生的事件
fn on_iface_events(&self);
}
// #[derive(Debug)]
// pub enum InetSocket {
// // Raw(RawSocket),
// Udp(UdpSocket),
// Tcp(TcpSocket),
// }
// impl InetSocket {
// /// # `on_iface_events`
// /// 通知socket发生了事件
// pub fn on_iface_events(&self) {
// todo!()
// }
// }
// impl IndexNode for InetSocket {
// }
// impl Socket for InetSocket {
// fn epoll_items(&self) -> &super::common::poll_unit::EPollItems {
// match self {
// InetSocket::Udp(udp) => udp.epoll_items(),
// InetSocket::Tcp(tcp) => tcp.epoll_items(),
// }
// }
// fn bind(&self, endpoint: crate::net::Endpoint) -> Result<(), SystemError> {
// if let crate::net::Endpoint::Ip(ip) = endpoint {
// match self {
// InetSocket::Udp(udp) => {
// udp.do_bind(ip)?;
// },
// InetSocket::Tcp(tcp) => {
// tcp.do_bind(ip)?;
// },
// }
// return Ok(());
// }
// return Err(EINVAL);
// }
// fn wait_queue(&self) -> &super::common::poll_unit::WaitQueue {
// todo!()
// }
// fn on_iface_events(&self) {
// todo!()
// }
// }
// pub trait Socket: FileLike + Send + Sync {
// /// Assign the address specified by socket_addr to the socket
// fn bind(&self, _socket_addr: SocketAddr) -> Result<()> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "bind() is not supported");
// }
// /// Build connection for a given address
// fn connect(&self, _socket_addr: SocketAddr) -> Result<()> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "connect() is not supported");
// }
// /// Listen for connections on a socket
// fn listen(&self, _backlog: usize) -> Result<()> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "listen() is not supported");
// }
// /// Accept a connection on a socket
// fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "accept() is not supported");
// }
// /// Shut down part of a full-duplex connection
// fn shutdown(&self, _cmd: SockShutdownCmd) -> Result<()> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "shutdown() is not supported");
// }
// /// Get address of this socket.
// fn addr(&self) -> Result<SocketAddr> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "getsockname() is not supported");
// }
// /// Get address of peer socket
// fn peer_addr(&self) -> Result<SocketAddr> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "getpeername() is not supported");
// }
// /// Get options on the socket. The resulted option will put in the `option` parameter, if
// /// this method returns success.
// fn get_option(&self, _option: &mut dyn SocketOption) -> Result<()> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "getsockopt() is not supported");
// }
// /// Set options on the socket.
// fn set_option(&self, _option: &dyn SocketOption) -> Result<()> {
// return_errno_with_message!(Errno::EOPNOTSUPP, "setsockopt() is not supported");
// }
// /// Sends a message on a socket.
// fn sendmsg(
// &self,
// io_vecs: &[IoVec],
// message_header: MessageHeader,
// flags: SendRecvFlags,
// ) -> Result<usize>;
// /// Receives a message from a socket.
// ///
// /// If successful, the `io_vecs` buffer will be filled with the received content.
// /// This method returns the length of the received message,
// /// and the message header.
// fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)>;
// }

View File

@ -0,0 +1,443 @@
use core::sync::atomic::{AtomicU32, AtomicUsize};
use crate::libs::rwlock::RwLock;
use crate::net::socket::EPollEventType;
use crate::net::socket::{self, inet::Types};
use alloc::vec::Vec;
use smoltcp;
use system_error::SystemError::{self, *};
use super::inet::UNSPECIFIED_LOCAL_ENDPOINT;
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
fn new_smoltcp_socket() -> smoltcp::socket::tcp::Socket<'static> {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; DEFAULT_RX_BUF_SIZE]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; DEFAULT_TX_BUF_SIZE]);
smoltcp::socket::tcp::Socket::new(rx_buffer, tx_buffer)
}
fn new_listen_smoltcp_socket<T>(local_endpoint: T) -> smoltcp::socket::tcp::Socket<'static>
where
T: Into<smoltcp::wire::IpListenEndpoint>,
{
let mut socket = new_smoltcp_socket();
socket.listen(local_endpoint).unwrap();
socket
}
#[derive(Debug)]
pub enum Init {
Unbound(smoltcp::socket::tcp::Socket<'static>),
Bound((socket::inet::BoundInner, smoltcp::wire::IpEndpoint)),
}
impl Init {
pub(super) fn new() -> Self {
Init::Unbound(new_smoltcp_socket())
}
/// 传入一个已经绑定的socket
pub(super) fn new_bound(inner: socket::inet::BoundInner) -> Self {
let endpoint = inner.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
socket
.local_endpoint()
.expect("A Bound Socket Must Have A Local Endpoint")
});
Init::Bound((inner, endpoint))
}
pub(super) fn bind(
self,
local_endpoint: smoltcp::wire::IpEndpoint,
) -> Result<Self, SystemError> {
match self {
Init::Unbound(socket) => {
let bound = socket::inet::BoundInner::bind(socket, &local_endpoint.addr)?;
bound
.port_manager()
.bind_port(Types::Tcp, local_endpoint.port)?;
// bound.iface().common().bind_socket()
Ok(Init::Bound((bound, local_endpoint)))
}
Init::Bound(_) => Err(EINVAL),
}
}
pub(super) fn bind_to_ephemeral(
self,
remote_endpoint: smoltcp::wire::IpEndpoint,
) -> Result<(socket::inet::BoundInner, smoltcp::wire::IpEndpoint), (Self, SystemError)> {
match self {
Init::Unbound(socket) => {
let (bound, address) =
socket::inet::BoundInner::bind_ephemeral(socket, remote_endpoint.addr)
.map_err(|err| (Self::new(), err))?;
let bound_port = bound
.port_manager()
.bind_ephemeral_port(Types::Tcp)
.map_err(|err| (Self::new(), err))?;
let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port);
Ok((bound, endpoint))
}
Init::Bound(_) => Err((self, EINVAL)),
}
}
pub(super) fn connect(
self,
remote_endpoint: smoltcp::wire::IpEndpoint,
) -> Result<Connecting, (Self, SystemError)> {
let (inner, local) = match self {
Init::Unbound(_) => self.bind_to_ephemeral(remote_endpoint)?,
Init::Bound(inner) => inner,
};
if local.addr.is_unspecified() {
return Err((Init::Bound((inner, local)), EINVAL));
}
let result = inner.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
socket
.connect(
inner.iface().smol_iface().lock().context(),
remote_endpoint,
local,
)
.map_err(|_| ECONNREFUSED)
});
match result {
Ok(_) => Ok(Connecting::new(inner)),
Err(err) => Err((Init::Bound((inner, local)), err)),
}
}
/// # `listen`
pub(super) fn listen(self, backlog: usize) -> Result<Listening, (Self, SystemError)> {
let (inner, local) = match self {
Init::Unbound(_) => {
return Err((self, EINVAL));
}
Init::Bound(inner) => inner,
};
let listen_addr = if local.addr.is_unspecified() {
smoltcp::wire::IpListenEndpoint::from(local.port)
} else {
smoltcp::wire::IpListenEndpoint::from(local)
};
log::debug!("listen at {:?}", listen_addr);
let mut inners = Vec::new();
if let Err(err) = || -> Result<(), SystemError> {
for _ in 0..(backlog - 1) {
// -1 because the first one is already bound
let new_listen = socket::inet::BoundInner::bind(
new_listen_smoltcp_socket(listen_addr),
&local.addr,
)?;
inners.push(new_listen);
}
Ok(())
}() {
return Err((Init::Bound((inner, local)), err));
}
if let Err(err) = inner.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
socket.listen(listen_addr).map_err(|_| ECONNREFUSED)
}) {
return Err((Init::Bound((inner, local)), err));
}
inners.push(inner);
return Ok(Listening {
inners,
connect: AtomicUsize::new(0),
});
}
}
#[derive(Debug, Default, Clone, Copy)]
enum ConnectResult {
Connected,
#[default]
Connecting,
Refused,
}
#[derive(Debug)]
pub struct Connecting {
inner: socket::inet::BoundInner,
result: RwLock<ConnectResult>,
}
impl Connecting {
fn new(inner: socket::inet::BoundInner) -> Self {
Connecting {
inner,
result: RwLock::new(ConnectResult::Connecting),
}
}
pub fn with_mut<R, F: FnMut(&mut smoltcp::socket::tcp::Socket<'static>) -> R>(
&self,
f: F,
) -> R {
self.inner.with_mut(f)
}
pub fn into_result(self) -> (Inner, Option<SystemError>) {
use ConnectResult::*;
let result = *self.result.read_irqsave();
match result {
Connecting => (Inner::Connecting(self), Some(EAGAIN_OR_EWOULDBLOCK)),
Connected => (Inner::Established(Established { inner: self.inner }), None),
Refused => (Inner::Init(Init::new_bound(self.inner)), Some(ECONNREFUSED)),
}
}
/// Returns `true` when `conn_result` becomes ready, which indicates that the caller should
/// invoke the `into_result()` method as soon as possible.
///
/// Since `into_result()` needs to be called only once, this method will return `true`
/// _exactly_ once. The caller is responsible for not missing this event.
#[must_use]
pub(super) fn update_io_events(&self) -> bool {
if matches!(*self.result.read_irqsave(), ConnectResult::Connecting) {
return false;
}
self.inner
.with_mut(|socket: &mut smoltcp::socket::tcp::Socket| {
let mut result = self.result.write_irqsave();
if matches!(*result, ConnectResult::Refused | ConnectResult::Connected) {
return false; // Already connected or refused
}
// Connected
if socket.can_send() {
*result = ConnectResult::Connected;
return true;
}
// Connecting
if socket.is_open() {
return false;
}
// Refused
*result = ConnectResult::Refused;
return true;
})
}
pub fn get_name(&self) -> smoltcp::wire::IpEndpoint {
self.inner
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
socket
.local_endpoint()
.expect("A Connecting Tcp With No Local Endpoint")
})
}
}
#[derive(Debug)]
pub struct Listening {
inners: Vec<socket::inet::BoundInner>,
connect: AtomicUsize,
}
impl Listening {
pub fn accept(&mut self) -> Result<(Established, smoltcp::wire::IpEndpoint), SystemError> {
let connected: &mut socket::inet::BoundInner = self
.inners
.get_mut(self.connect.load(core::sync::atomic::Ordering::Relaxed))
.unwrap();
if connected.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| !socket.is_active()) {
return Err(EAGAIN_OR_EWOULDBLOCK);
}
let (local_endpoint, remote_endpoint) = connected
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
(
socket
.local_endpoint()
.expect("A Connected Tcp With No Local Endpoint"),
socket
.remote_endpoint()
.expect("A Connected Tcp With No Remote Endpoint"),
)
});
// log::debug!("local at {:?}", local_endpoint);
let mut new_listen = socket::inet::BoundInner::bind(
new_listen_smoltcp_socket(local_endpoint),
&local_endpoint.addr,
)?;
// swap the connected socket with the new_listen socket
// TODO is smoltcp socket swappable?
core::mem::swap(&mut new_listen, connected);
return Ok((Established { inner: new_listen }, remote_endpoint));
}
pub fn update_io_events(&self, pollee: &AtomicUsize) {
let position = self.inners.iter().position(|inner| {
inner.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.is_active())
});
if let Some(position) = position {
// log::debug!("Can accept!");
self.connect
.store(position, core::sync::atomic::Ordering::Relaxed);
pollee.fetch_or(
EPollEventType::EPOLLIN.bits() as usize,
core::sync::atomic::Ordering::Relaxed,
);
} else {
// log::debug!("Can't accept!");
pollee.fetch_and(
!EPollEventType::EPOLLIN.bits() as usize,
core::sync::atomic::Ordering::Relaxed,
);
}
}
pub fn get_name(&self) -> smoltcp::wire::IpEndpoint {
self.inners[0].with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
if let Some(name) = socket.local_endpoint() {
return name;
} else {
return UNSPECIFIED_LOCAL_ENDPOINT;
}
})
}
}
#[derive(Debug)]
pub struct Established {
inner: socket::inet::BoundInner,
}
impl Established {
pub fn with_mut<R, F: FnMut(&mut smoltcp::socket::tcp::Socket<'static>) -> R>(
&self,
f: F,
) -> R {
self.inner.with_mut(f)
}
pub fn with<R, F: Fn(&smoltcp::socket::tcp::Socket<'static>) -> R>(&self, f: F) -> R {
self.inner.with(f)
}
pub fn close(&self) {
self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.close());
self.inner.iface().poll();
}
pub fn release(&self) {
self.inner.release();
}
pub fn local_endpoint(&self) -> smoltcp::wire::IpEndpoint {
self.inner
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.local_endpoint())
.unwrap()
}
pub fn remote_endpoint(&self) -> smoltcp::wire::IpEndpoint {
self.inner
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.remote_endpoint().unwrap())
}
pub fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
use smoltcp::socket::tcp::RecvError::*;
if socket.can_send() {
match socket.recv_slice(buf) {
Ok(size) => Ok(size),
Err(InvalidState) => {
log::error!("TcpSocket::try_recv: InvalidState");
Err(ENOTCONN)
}
Err(Finished) => Ok(0),
}
} else {
Err(ENOBUFS)
}
})
}
pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> {
self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
if socket.can_send() {
socket.send_slice(buf).map_err(|_| ECONNABORTED)
} else {
Err(ENOBUFS)
}
})
}
pub fn update_io_events(&self, pollee: &AtomicUsize) {
self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
if socket.can_send() {
pollee.fetch_or(
EPollEventType::EPOLLOUT.bits() as usize,
core::sync::atomic::Ordering::Relaxed,
);
} else {
pollee.fetch_and(
!EPollEventType::EPOLLOUT.bits() as usize,
core::sync::atomic::Ordering::Relaxed,
);
}
if socket.can_recv() {
pollee.fetch_or(
EPollEventType::EPOLLIN.bits() as usize,
core::sync::atomic::Ordering::Relaxed,
);
} else {
pollee.fetch_and(
!EPollEventType::EPOLLIN.bits() as usize,
core::sync::atomic::Ordering::Relaxed,
);
}
})
}
}
#[derive(Debug)]
pub enum Inner {
Init(Init),
Connecting(Connecting),
Listening(Listening),
Established(Established),
}
impl Inner {
pub fn send_buffer_size(&self) -> usize {
match self {
Inner::Init(_) => DEFAULT_TX_BUF_SIZE,
Inner::Connecting(conn) => conn.with_mut(|socket| socket.send_capacity()),
// only the first socket in the list is used for sending
Inner::Listening(listen) => listen.inners[0]
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.send_capacity()),
Inner::Established(est) => est.with_mut(|socket| socket.send_capacity()),
}
}
pub fn recv_buffer_size(&self) -> usize {
match self {
Inner::Init(_) => DEFAULT_RX_BUF_SIZE,
Inner::Connecting(conn) => conn.with_mut(|socket| socket.recv_capacity()),
// only the first socket in the list is used for receiving
Inner::Listening(listen) => listen.inners[0]
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.recv_capacity()),
Inner::Established(est) => est.with_mut(|socket| socket.recv_capacity()),
}
}
}

View File

@ -0,0 +1,485 @@
use alloc::sync::{Arc, Weak};
use core::sync::atomic::{AtomicBool, AtomicUsize};
use system_error::SystemError::{self, *};
use crate::libs::rwlock::RwLock;
use crate::net::event_poll::EPollEventType;
use crate::net::net_core::poll_ifaces;
use crate::net::socket::*;
use crate::sched::SchedMode;
use inet::{InetSocket, UNSPECIFIED_LOCAL_ENDPOINT};
use smoltcp;
pub mod inner;
use inner::*;
type EP = EPollEventType;
#[derive(Debug)]
pub struct TcpSocket {
inner: RwLock<Option<Inner>>,
shutdown: Shutdown,
nonblock: AtomicBool,
epitems: EPollItems,
wait_queue: WaitQueue,
self_ref: Weak<Self>,
pollee: AtomicUsize,
}
impl TcpSocket {
pub fn new(nonblock: bool) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Some(Inner::Init(Init::new()))),
shutdown: Shutdown::new(),
nonblock: AtomicBool::new(nonblock),
epitems: EPollItems::default(),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
pollee: AtomicUsize::new((EP::EPOLLIN.bits() | EP::EPOLLOUT.bits()) as usize),
})
}
pub fn new_established(inner: Established, nonblock: bool) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Some(Inner::Established(inner))),
shutdown: Shutdown::new(),
nonblock: AtomicBool::new(nonblock),
epitems: EPollItems::default(),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
pollee: AtomicUsize::new((EP::EPOLLIN.bits() | EP::EPOLLOUT.bits()) as usize),
})
}
pub fn is_nonblock(&self) -> bool {
self.nonblock.load(core::sync::atomic::Ordering::Relaxed)
}
#[inline]
fn write_state<F>(&self, mut f: F) -> Result<(), SystemError>
where
F: FnMut(Inner) -> Result<Inner, SystemError>,
{
let mut inner_guard = self.inner.write();
let inner = inner_guard.take().expect("Tcp Inner is None");
let update = f(inner)?;
inner_guard.replace(update);
Ok(())
}
pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> {
let mut writer = self.inner.write();
match writer.take().expect("Tcp Inner is None") {
Inner::Init(inner) => {
let bound = inner.bind(local_endpoint)?;
if let Init::Bound((ref bound, _)) = bound {
bound
.iface()
.common()
.bind_socket(self.self_ref.upgrade().unwrap());
}
writer.replace(Inner::Init(bound));
Ok(())
}
_ => Err(EINVAL),
}
}
pub fn do_listen(&self, backlog: usize) -> Result<(), SystemError> {
let mut writer = self.inner.write();
let inner = writer.take().expect("Tcp Inner is None");
let (listening, err) = match inner {
Inner::Init(init) => {
let listen_result = init.listen(backlog);
match listen_result {
Ok(listening) => (Inner::Listening(listening), None),
Err((init, err)) => (Inner::Init(init), Some(err)),
}
}
_ => (inner, Some(EINVAL)),
};
writer.replace(listening);
drop(writer);
if let Some(err) = err {
return Err(err);
}
return Ok(());
}
pub fn try_accept(&self) -> Result<(Arc<TcpSocket>, smoltcp::wire::IpEndpoint), SystemError> {
poll_ifaces();
match self.inner.write().as_mut().expect("Tcp Inner is None") {
Inner::Listening(listening) => listening.accept().map(|(stream, remote)| {
(
TcpSocket::new_established(stream, self.is_nonblock()),
remote,
)
}),
_ => Err(EINVAL),
}
}
pub fn start_connect(
&self,
remote_endpoint: smoltcp::wire::IpEndpoint,
) -> Result<(), SystemError> {
let mut writer = self.inner.write();
let inner = writer.take().expect("Tcp Inner is None");
let (init, err) = match inner {
Inner::Init(init) => {
let conn_result = init.connect(remote_endpoint);
match conn_result {
Ok(connecting) => (
Inner::Connecting(connecting),
if self.is_nonblock() {
None
} else {
Some(EINPROGRESS)
},
),
Err((init, err)) => (Inner::Init(init), Some(err)),
}
}
Inner::Connecting(connecting) if self.is_nonblock() => {
(Inner::Connecting(connecting), Some(EALREADY))
}
Inner::Connecting(connecting) => (Inner::Connecting(connecting), None),
Inner::Listening(inner) => (Inner::Listening(inner), Some(EISCONN)),
Inner::Established(inner) => (Inner::Established(inner), Some(EISCONN)),
};
writer.replace(init);
drop(writer);
poll_ifaces();
if let Some(err) = err {
return Err(err);
}
return Ok(());
}
pub fn finish_connect(&self) -> Result<(), SystemError> {
let mut writer = self.inner.write();
let Inner::Connecting(conn) = writer.take().expect("Tcp Inner is None") else {
log::error!("TcpSocket::finish_connect: not Connecting");
return Err(EINVAL);
};
let (inner, err) = conn.into_result();
writer.replace(inner);
drop(writer);
if let Some(err) = err {
return Err(err);
}
return Ok(());
}
pub fn check_connect(&self) -> Result<(), SystemError> {
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Connecting(_) => Err(EAGAIN_OR_EWOULDBLOCK),
Inner::Established(_) => Ok(()), // TODO check established
_ => Err(EINVAL), // TODO socket error options
}
}
pub fn try_recv(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
poll_ifaces();
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Established(inner) => inner.recv_slice(buf),
_ => Err(EINVAL),
}
}
pub fn try_send(&self, buf: &[u8]) -> Result<usize, SystemError> {
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Established(inner) => {
let sent = inner.send_slice(buf);
poll_ifaces();
sent
}
_ => Err(EINVAL),
}
}
fn update_events(&self) -> bool {
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Init(_) => false,
Inner::Connecting(connecting) => connecting.update_io_events(),
Inner::Established(established) => {
established.update_io_events(&self.pollee);
false
}
Inner::Listening(listening) => {
listening.update_io_events(&self.pollee);
false
}
}
}
// should only call on accept
fn is_acceptable(&self) -> bool {
// (self.poll() & EP::EPOLLIN.bits() as usize) != 0
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN)
}
}
impl Socket for TcpSocket {
fn wait_queue(&self) -> &WaitQueue {
&self.wait_queue
}
fn get_name(&self) -> Result<Endpoint, SystemError> {
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Init(Init::Unbound(_)) => Ok(Endpoint::Ip(UNSPECIFIED_LOCAL_ENDPOINT)),
Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(local.clone())),
Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())),
Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())),
}
}
fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(addr) = endpoint {
return self.do_bind(addr);
}
return Err(EINVAL);
}
fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(addr) = endpoint {
return self.start_connect(addr);
}
return Err(EINVAL);
}
fn poll(&self) -> usize {
self.pollee.load(core::sync::atomic::Ordering::Relaxed)
}
fn listen(&self, backlog: usize) -> Result<(), SystemError> {
self.do_listen(backlog)
}
fn accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
// could block io
if self.is_nonblock() {
self.try_accept()
} else {
loop {
// log::debug!("TcpSocket::accept: wake up");
match self.try_accept() {
Err(EAGAIN_OR_EWOULDBLOCK) => {
wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
}
result => break result,
}
}
}
.map(|(inner, endpoint)| (Inode::new(inner), Endpoint::Ip(endpoint)))
}
fn recv(&self, buffer: &mut [u8], _flags: MessageFlag) -> Result<usize, SystemError> {
self.try_recv(buffer)
}
fn send(&self, buffer: &[u8], _flags: MessageFlag) -> Result<usize, SystemError> {
self.try_send(buffer)
}
fn send_buffer_size(&self) -> usize {
self.inner
.read()
.as_ref()
.expect("Tcp Inner is None")
.send_buffer_size()
}
fn recv_buffer_size(&self) -> usize {
self.inner
.read()
.as_ref()
.expect("Tcp Inner is None")
.recv_buffer_size()
}
fn close(&self) -> Result<(), SystemError> {
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Init(_) => {}
Inner::Connecting(_) => {
return Err(EINPROGRESS);
}
Inner::Established(es) => {
es.close();
es.release();
}
Inner::Listening(_) => {}
}
Ok(())
}
}
impl InetSocket for TcpSocket {
fn on_iface_events(&self) {
if self.update_events() {
let result = self.finish_connect();
// set error
}
}
}
// #[derive(Debug)]
// // #[cast_to([sync] IndexNode)]
// struct TcpStream {
// inner: Established,
// shutdown: Shutdown,
// nonblock: AtomicBool,
// epitems: EPollItems,
// wait_queue: WaitQueue,
// self_ref: Weak<Self>,
// }
// impl TcpStream {
// pub fn is_nonblock(&self) -> bool {
// self.nonblock.load(core::sync::atomic::Ordering::Relaxed)
// }
// pub fn read(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
// if self.nonblock.load(core::sync::atomic::Ordering::Relaxed) {
// return self.recv_slice(buf);
// } else {
// return self.wait_queue().busy_wait(
// EP::EPOLLIN,
// || self.recv_slice(buf)
// )
// }
// }
// pub fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
// let received = self.inner.recv_slice(buf);
// poll_ifaces();
// received
// }
// pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> {
// let sent = self.inner.send_slice(buf);
// poll_ifaces();
// sent
// }
// }
// use crate::net::socket::{Inode, Socket};
// use crate::filesystem::vfs::IndexNode;
// impl IndexNode for TcpStream {
// fn read_at(
// &self,
// _offset: usize,
// _len: usize,
// buf: &mut [u8],
// data: crate::libs::spinlock::SpinLockGuard<crate::filesystem::vfs::FilePrivateData>,
// ) -> Result<usize, SystemError> {
// drop(data);
// self.read(buf)
// }
// fn write_at(
// &self,
// _offset: usize,
// _len: usize,
// buf: &[u8],
// data: crate::libs::spinlock::SpinLockGuard<crate::filesystem::vfs::FilePrivateData>,
// ) -> Result<usize, SystemError> {
// drop(data);
// self.send_slice(buf)
// }
// fn fs(&self) -> alloc::sync::Arc<dyn crate::filesystem::vfs::FileSystem> {
// todo!("TcpSocket::fs")
// }
// fn as_any_ref(&self) -> &dyn core::any::Any {
// self
// }
// fn list(&self) -> Result<alloc::vec::Vec<alloc::string::String>, SystemError> {
// todo!("TcpSocket::list")
// }
// }
// impl Socket for TcpStream {
// fn wait_queue(&self) -> WaitQueue {
// self.wait_queue.clone()
// }
// fn poll(&self) -> usize {
// // self.inner.with(|socket| {
// // let mut mask = EPollEventType::empty();
// // let shutdown = self.shutdown.get();
// // let state = socket.state();
// // use smoltcp::socket::tcp::State::*;
// // type EP = crate::net::event_poll::EPollEventType;
// // if shutdown.is_both_shutdown() || state == Closed {
// // mask |= EP::EPOLLHUP;
// // }
// // if shutdown.is_recv_shutdown() {
// // mask |= EP::EPOLLIN | EP::EPOLLRDNORM | EP::EPOLLRDHUP;
// // }
// // if state != SynSent && state != SynReceived {
// // if socket.can_recv() {
// // mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
// // }
// // if !shutdown.is_send_shutdown() {
// // // __sk_stream_is_writeable这是一个内联函数用于判断一个TCP套接字是否可写。
// // //
// // // 以下是函数的逐行解释:
// // // static inline bool __sk_stream_is_writeable(const struct sock *sk, int wake)
// // // - 这行定义了函数__sk_stream_is_writeable它是一个内联函数static inline
// // // 这意味着在调用点直接展开代码,而不是调用函数体。函数接收两个参数:
// // // 一个指向struct sock对象的指针sk代表套接字和一个整型变量wake。
// // //
// // // return sk_stream_wspace(sk) >= sk_stream_min_wspace(sk) &&
// // // - 这行代码调用了sk_stream_wspace函数获取套接字sk的可写空间write space大小。
// // // 随后与sk_stream_min_wspace调用结果进行比较该函数返回套接字为了保持稳定写入速度所需的
// // // 最小可写空间。如果当前可写空间大于或等于最小可写空间,则表达式为真。
// // // __sk_stream_memory_free(sk, wake);
// // // - 这行代码调用了__sk_stream_memory_free函数它可能用于检查套接字的内存缓冲区是否
// // // 有足够的空间可供写入数据。参数wake可能用于通知网络协议栈有数据需要发送如果设置了相应的标志。
// // // 综上所述__sk_stream_is_writeable函数的目的是判断一个TCP套接字是否可以安全地进行写操作
// // // 它基于套接字的当前可写空间和所需的最小空间以及内存缓冲区的可用性。只有当这两个条件都满足时,
// // // 函数才会返回true表示套接字是可写的。
// // if socket.can_send() {
// // mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
// // } else {
// // todo!("TcpStream::poll: buffer space not enough");
// // }
// // } else {
// // mask |= EP::EPOLLOUT | EP::EPOLLWRNORM;
// // }
// // // TODO tcp urg data => EPOLLPRI
// // } else if state == SynSent /* inet_test_bit */ {
// // log::warn!("Active TCP fastopen socket with defer_connect");
// // mask |= EP::EPOLLOUT | EP::EPOLLWRNORM;
// // }
// // // TODO socket error
// // return Ok(mask);
// // })
// self.pollee.load(core::sync::atomic::Ordering::Relaxed)
// }
// fn send_buffer_size(&self) -> usize {
// self.inner.with(|socket| socket.send_capacity())
// }
// fn recv_buffer_size(&self) -> usize {
// self.inner.with(|socket| socket.recv_capacity())
// }
// }

View File

@ -0,0 +1,55 @@
use alloc::sync::Arc;
use smoltcp;
use system_error::SystemError::{self, *};
use inet::{TcpSocket, UdpSocket};
// use crate::net::syscall_util::SysArgSocketType;
use crate::net::socket::*;
fn create_inet_socket(
socket_type: Type,
protocol: smoltcp::wire::IpProtocol,
) -> Result<Arc<dyn Socket>, SystemError> {
log::debug!("type: {:?}, protocol: {:?}", socket_type, protocol);
use smoltcp::wire::IpProtocol::*;
use Type::*;
match socket_type {
Datagram => {
match protocol {
HopByHop | Udp => {
return Ok(UdpSocket::new(false));
}
_ => {
return Err(EPROTONOSUPPORT);
}
}
// if !matches!(protocol, Udp) {
// return Err(EPROTONOSUPPORT);
// }
// return Ok(UdpSocket::new(false));
}
Stream => match protocol {
HopByHop | Tcp => {
return Ok(TcpSocket::new(false));
}
_ => {
return Err(EPROTONOSUPPORT);
}
},
Raw => {
todo!("raw")
}
_ => {
return Err(EPROTONOSUPPORT);
}
}
}
pub struct Inet;
impl family::Family for Inet {
fn socket(stype: Type, protocol: u32) -> Result<Arc<Inode>, SystemError> {
let socket = create_inet_socket(stype, smoltcp::wire::IpProtocol::from(protocol as u8))?;
Ok(Inode::new(socket))
}
}

View File

@ -0,0 +1,195 @@
use crate::filesystem::vfs::IndexNode;
use alloc::sync::Arc;
use system_error::SystemError;
use crate::net::socket::*;
#[derive(Debug)]
pub struct Inode {
inner: Arc<dyn Socket>,
epoll_items: EPollItems,
}
impl IndexNode for Inode {
fn read_at(
&self,
_offset: usize,
_len: usize,
buf: &mut [u8],
data: crate::libs::spinlock::SpinLockGuard<crate::filesystem::vfs::FilePrivateData>,
) -> Result<usize, SystemError> {
drop(data);
self.inner.read(buf)
}
fn write_at(
&self,
_offset: usize,
_len: usize,
buf: &[u8],
data: crate::libs::spinlock::SpinLockGuard<crate::filesystem::vfs::FilePrivateData>,
) -> Result<usize, SystemError> {
drop(data);
self.inner.write(buf)
}
/* Following are not yet available in socket */
fn as_any_ref(&self) -> &dyn core::any::Any {
self
}
/* filesystem associate interfaces are about unix and netlink socket */
fn fs(&self) -> Arc<dyn crate::filesystem::vfs::FileSystem> {
unimplemented!()
}
fn list(&self) -> Result<alloc::vec::Vec<alloc::string::String>, SystemError> {
unimplemented!()
}
fn poll(
&self,
private_data: &crate::filesystem::vfs::FilePrivateData,
) -> Result<usize, SystemError> {
// let _ = private_data;
Ok(self.inner.poll())
}
fn open(
&self,
_data: crate::libs::spinlock::SpinLockGuard<crate::filesystem::vfs::FilePrivateData>,
_mode: &crate::filesystem::vfs::file::FileMode,
) -> Result<(), SystemError> {
Ok(())
}
fn metadata(&self) -> Result<crate::filesystem::vfs::Metadata, SystemError> {
let meta = crate::filesystem::vfs::Metadata {
mode: crate::filesystem::vfs::syscall::ModeType::from_bits_truncate(0o755),
file_type: crate::filesystem::vfs::FileType::Socket,
size: self.send_buffer_size() as i64,
..Default::default()
};
return Ok(meta);
}
fn close(
&self,
_data: crate::libs::spinlock::SpinLockGuard<crate::filesystem::vfs::FilePrivateData>,
) -> Result<(), SystemError> {
self.inner.close()
}
}
impl Inode {
// pub fn wait_queue(&self) -> WaitQueue {
// self.inner.wait_queue()
// }
pub fn send_buffer_size(&self) -> usize {
self.inner.send_buffer_size()
}
pub fn recv_buffer_size(&self) -> usize {
self.inner.recv_buffer_size()
}
pub fn accept(&self) -> Result<(Arc<Self>, Endpoint), SystemError> {
self.inner.accept()
}
pub fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
self.inner.bind(endpoint)
}
pub fn set_option(
&self,
level: OptionsLevel,
name: usize,
value: &[u8],
) -> Result<(), SystemError> {
self.inner.set_option(level, name, value)
}
pub fn get_option(
&self,
level: OptionsLevel,
name: usize,
value: &mut [u8],
) -> Result<usize, SystemError> {
self.inner.get_option(level, name, value)
}
pub fn listen(&self, backlog: usize) -> Result<(), SystemError> {
self.inner.listen(backlog)
}
pub fn send_to(
&self,
buffer: &[u8],
address: Endpoint,
flags: MessageFlag,
) -> Result<usize, SystemError> {
self.inner.send_to(buffer, flags, address)
}
pub fn send(&self, buffer: &[u8], flags: MessageFlag) -> Result<usize, SystemError> {
self.inner.send(buffer, flags)
}
pub fn recv(&self, buffer: &mut [u8], flags: MessageFlag) -> Result<usize, SystemError> {
self.inner.recv(buffer, flags)
}
// TODO receive from split with endpoint or not
pub fn recv_from(
&self,
buffer: &mut [u8],
flags: MessageFlag,
address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> {
self.inner.recv_from(buffer, flags, address)
}
pub fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
self.inner.shutdown(how)
}
pub fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
self.inner.connect(endpoint)
}
pub fn get_name(&self) -> Result<Endpoint, SystemError> {
self.inner.get_name()
}
pub fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
self.inner.get_peer_name()
}
pub fn new(inner: Arc<dyn Socket>) -> Arc<Self> {
Arc::new(Self {
inner,
epoll_items: EPollItems::default(),
})
}
/// # `epoll_items`
/// socket的epoll事件集
pub fn epoll_items(&self) -> EPollItems {
self.epoll_items.clone()
}
pub fn set_nonblock(&self, nonblock: bool) {
log::warn!("nonblock is not support yet");
}
pub fn set_close_on_exec(&self, close_on_exec: bool) {
log::warn!("close_on_exec is not support yet");
}
pub fn inner(&self) -> Arc<dyn Socket> {
return self.inner.clone();
}
}

View File

@ -1,919 +1,28 @@
use core::{any::Any, fmt::Debug, sync::atomic::AtomicUsize}; mod base;
mod buffer;
use alloc::{ mod common;
boxed::Box, mod define;
collections::LinkedList, mod endpoint;
string::String, mod family;
sync::{Arc, Weak},
vec::Vec,
};
use hashbrown::HashMap;
use log::warn;
use smoltcp::{
iface::SocketSet,
socket::{self, raw, tcp, udp},
};
use system_error::SystemError;
use crate::{
arch::rand::rand,
filesystem::vfs::{
file::FileMode, syscall::ModeType, FilePrivateData, FileSystem, FileType, IndexNode,
Metadata,
},
libs::{
rwlock::{RwLock, RwLockWriteGuard},
spinlock::{SpinLock, SpinLockGuard},
wait_queue::EventWaitQueue,
},
process::{Pid, ProcessManager},
sched::{schedule, SchedMode},
};
use self::{
handle::GlobalSocketHandle,
inet::{RawSocket, TcpSocket, UdpSocket},
unix::{SeqpacketSocket, StreamSocket},
};
use super::{
event_poll::{EPollEventType, EPollItem, EventPoll},
Endpoint, Protocol, ShutdownType,
};
pub mod handle;
pub mod inet; pub mod inet;
mod inode;
pub mod netlink;
pub mod unix; pub mod unix;
mod utils;
lazy_static! {
/// 所有socket的集合 use crate::libs::wait_queue::WaitQueue;
/// TODO: 优化这里自己实现SocketSet现在这样的话不管全局有多少个网卡每个时间点都只会有1个进程能够访问socket pub use base::Socket;
pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![])); use buffer::Buffer;
/// SocketHandle表每个SocketHandle对应一个SocketHandleItem pub use common::{
/// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁 shutdown::*,
pub static ref HANDLE_MAP: RwLock<HashMap<GlobalSocketHandle, SocketHandleItem>> = RwLock::new(HashMap::new()); // poll_unit::{EPollItems, WaitQueue},
/// 端口管理器 EPollItems,
pub static ref PORT_MANAGER: PortManager = PortManager::new(); };
} pub use define::*;
pub use endpoint::*;
/* For setsockopt(2) */ pub use family::{AddressFamily, Family};
// See: linux-5.19.10/include/uapi/asm-generic/socket.h#9 pub use inode::Inode;
pub const SOL_SOCKET: u8 = 1; pub use utils::create_socket;
/// 根据地址族、socket类型和协议创建socket pub use crate::net::event_poll::EPollEventType;
pub(super) fn new_socket( // pub use crate::net::sys
address_family: AddressFamily,
socket_type: PosixSocketType,
protocol: Protocol,
) -> Result<Box<dyn Socket>, SystemError> {
let socket: Box<dyn Socket> = match address_family {
AddressFamily::Unix => match socket_type {
PosixSocketType::Stream => Box::new(StreamSocket::new(SocketOptions::default())),
PosixSocketType::SeqPacket => Box::new(SeqpacketSocket::new(SocketOptions::default())),
_ => {
return Err(SystemError::EINVAL);
}
},
AddressFamily::INet => match socket_type {
PosixSocketType::Stream => Box::new(TcpSocket::new(SocketOptions::default())),
PosixSocketType::Datagram => Box::new(UdpSocket::new(SocketOptions::default())),
PosixSocketType::Raw => Box::new(RawSocket::new(protocol, SocketOptions::default())),
_ => {
return Err(SystemError::EINVAL);
}
},
_ => {
return Err(SystemError::EAFNOSUPPORT);
}
};
let handle_item = SocketHandleItem::new(Arc::downgrade(&socket.posix_item()));
HANDLE_MAP
.write_irqsave()
.insert(socket.socket_handle(), handle_item);
Ok(socket)
}
pub trait Socket: Sync + Send + Debug + Any {
/// @brief 从socket中读取数据如果socket是阻塞的那么直到读取到数据才返回
///
/// @param buf 读取到的数据存放的缓冲区
///
/// @return - 成功:(返回读取的数据的长度,读取数据的端点).
/// - 失败:错误码
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint);
/// @brief 向socket中写入数据。如果socket是阻塞的那么直到写入的数据全部写入socket中才返回
///
/// @param buf 要写入的数据
/// @param to 要写入的目的端点如果是None那么写入的数据将会被丢弃
///
/// @return 返回写入的数据的长度
fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError>;
/// @brief 对应于POSIX的connect函数用于连接到指定的远程服务器端点
///
/// It is used to establish a connection to a remote server.
/// When a socket is connected to a remote server,
/// the operating system will establish a network connection with the server
/// and allow data to be sent and received between the local socket and the remote server.
///
/// @param endpoint 要连接的端点
///
/// @return 返回连接是否成功
fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError>;
/// @brief 对应于POSIX的bind函数用于绑定到本机指定的端点
///
/// The bind() function is used to associate a socket with a particular IP address and port number on the local machine.
///
/// @param endpoint 要绑定的端点
///
/// @return 返回绑定是否成功
fn bind(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
Err(SystemError::ENOSYS)
}
/// @brief 对应于 POSIX 的 shutdown 函数用于关闭socket。
///
/// shutdown() 函数用于启动网络连接的正常关闭。
/// 当在两个端点之间建立网络连接时,任一端点都可以通过调用其端点对象上的 shutdown() 函数来启动关闭序列。
/// 此函数向远程端点发送关闭消息以指示本地端点不再接受新数据。
///
/// @return 返回是否成功关闭
fn shutdown(&mut self, _type: ShutdownType) -> Result<(), SystemError> {
Err(SystemError::ENOSYS)
}
/// @brief 对应于POSIX的listen函数用于监听端点
///
/// @param backlog 最大的等待连接数
///
/// @return 返回监听是否成功
fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> {
Err(SystemError::ENOSYS)
}
/// @brief 对应于POSIX的accept函数用于接受连接
///
/// @param endpoint 对端的端点
///
/// @return 返回接受连接是否成功
fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
Err(SystemError::ENOSYS)
}
/// @brief 获取socket的端点
///
/// @return 返回socket的端点
fn endpoint(&self) -> Option<Endpoint> {
None
}
/// @brief 获取socket的对端端点
///
/// @return 返回socket的对端端点
fn peer_endpoint(&self) -> Option<Endpoint> {
None
}
/// @brief
/// The purpose of the poll function is to provide
/// a non-blocking way to check if a socket is ready for reading or writing,
/// so that you can efficiently handle multiple sockets in a single thread or event loop.
///
/// @return (in, out, err)
///
/// The first boolean value indicates whether the socket is ready for reading. If it is true, then there is data available to be read from the socket without blocking.
/// The second boolean value indicates whether the socket is ready for writing. If it is true, then data can be written to the socket without blocking.
/// The third boolean value indicates whether the socket has encountered an error condition. If it is true, then the socket is in an error state and should be closed or reset
///
fn poll(&self) -> EPollEventType {
EPollEventType::empty()
}
/// @brief socket的ioctl函数
///
/// @param cmd ioctl命令
/// @param arg0 ioctl命令的第一个参数
/// @param arg1 ioctl命令的第二个参数
/// @param arg2 ioctl命令的第三个参数
///
/// @return 返回ioctl命令的返回值
fn ioctl(
&self,
_cmd: usize,
_arg0: usize,
_arg1: usize,
_arg2: usize,
) -> Result<usize, SystemError> {
Ok(0)
}
/// @brief 获取socket的元数据
fn metadata(&self) -> SocketMetadata;
fn box_clone(&self) -> Box<dyn Socket>;
/// @brief 设置socket的选项
///
/// @param level 选项的层次
/// @param optname 选项的名称
/// @param optval 选项的值
///
/// @return 返回设置是否成功, 如果不支持该选项返回ENOSYS
fn setsockopt(
&self,
_level: usize,
_optname: usize,
_optval: &[u8],
) -> Result<(), SystemError> {
warn!("setsockopt is not implemented");
Ok(())
}
fn socket_handle(&self) -> GlobalSocketHandle;
fn write_buffer(&self, _buf: &[u8]) -> Result<usize, SystemError> {
todo!()
}
fn as_any_ref(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn add_epoll(&mut self, epitem: Arc<EPollItem>) -> Result<(), SystemError> {
let posix_item = self.posix_item();
posix_item.add_epoll(epitem);
Ok(())
}
fn remove_epoll(&mut self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
let posix_item = self.posix_item();
posix_item.remove_epoll(epoll)?;
Ok(())
}
fn clear_epoll(&mut self) -> Result<(), SystemError> {
let posix_item = self.posix_item();
for epitem in posix_item.epitems.lock_irqsave().iter() {
let epoll = epitem.epoll();
if let Some(epoll) = epoll.upgrade() {
EventPoll::ep_remove(&mut epoll.lock_irqsave(), epitem.fd(), None)?;
}
}
Ok(())
}
fn close(&mut self);
fn posix_item(&self) -> Arc<PosixSocketHandleItem>;
}
impl Clone for Box<dyn Socket> {
fn clone(&self) -> Box<dyn Socket> {
self.box_clone()
}
}
/// # Socket在文件系统中的inode封装
#[derive(Debug)]
pub struct SocketInode(SpinLock<Box<dyn Socket>>, AtomicUsize);
impl SocketInode {
pub fn new(socket: Box<dyn Socket>) -> Arc<Self> {
Arc::new(Self(SpinLock::new(socket), AtomicUsize::new(0)))
}
#[inline]
pub fn inner(&self) -> SpinLockGuard<Box<dyn Socket>> {
self.0.lock()
}
pub unsafe fn inner_no_preempt(&self) -> SpinLockGuard<Box<dyn Socket>> {
self.0.lock_no_preempt()
}
fn do_close(&self) -> Result<(), SystemError> {
let prev_ref_count = self.1.fetch_sub(1, core::sync::atomic::Ordering::SeqCst);
if prev_ref_count == 1 {
// 最后一次关闭,需要释放
let mut socket = self.0.lock_irqsave();
if socket.metadata().socket_type == SocketType::Unix {
return Ok(());
}
if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() {
PORT_MANAGER.unbind_port(socket.metadata().socket_type, ip.port);
}
socket.clear_epoll()?;
HANDLE_MAP
.write_irqsave()
.remove(&socket.socket_handle())
.unwrap();
socket.close();
}
Ok(())
}
}
impl Drop for SocketInode {
fn drop(&mut self) {
for _ in 0..self.1.load(core::sync::atomic::Ordering::SeqCst) {
let _ = self.do_close();
}
}
}
impl IndexNode for SocketInode {
fn open(
&self,
_data: SpinLockGuard<FilePrivateData>,
_mode: &FileMode,
) -> Result<(), SystemError> {
self.1.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
Ok(())
}
fn close(&self, _data: SpinLockGuard<FilePrivateData>) -> Result<(), SystemError> {
self.do_close()
}
fn read_at(
&self,
_offset: usize,
len: usize,
buf: &mut [u8],
data: SpinLockGuard<FilePrivateData>,
) -> Result<usize, SystemError> {
drop(data);
self.0.lock_no_preempt().read(&mut buf[0..len]).0
}
fn write_at(
&self,
_offset: usize,
len: usize,
buf: &[u8],
data: SpinLockGuard<FilePrivateData>,
) -> Result<usize, SystemError> {
drop(data);
self.0.lock_no_preempt().write(&buf[0..len], None)
}
fn poll(&self, _private_data: &FilePrivateData) -> Result<usize, SystemError> {
let events = self.0.lock_irqsave().poll();
return Ok(events.bits() as usize);
}
fn fs(&self) -> Arc<dyn FileSystem> {
todo!()
}
fn as_any_ref(&self) -> &dyn Any {
self
}
fn list(&self) -> Result<Vec<String>, SystemError> {
return Err(SystemError::ENOTDIR);
}
fn metadata(&self) -> Result<Metadata, SystemError> {
let meta = Metadata {
mode: ModeType::from_bits_truncate(0o755),
file_type: FileType::Socket,
..Default::default()
};
return Ok(meta);
}
fn resize(&self, _len: usize) -> Result<(), SystemError> {
return Ok(());
}
}
#[derive(Debug)]
pub struct PosixSocketHandleItem {
/// socket的waitqueue
wait_queue: Arc<EventWaitQueue>,
pub epitems: SpinLock<LinkedList<Arc<EPollItem>>>,
}
impl PosixSocketHandleItem {
pub fn new(wait_queue: Option<Arc<EventWaitQueue>>) -> Self {
Self {
wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())),
epitems: SpinLock::new(LinkedList::new()),
}
}
/// ## 在socket的等待队列上睡眠
pub fn sleep(&self, events: u64) {
unsafe {
ProcessManager::preempt_disable();
self.wait_queue.sleep_without_schedule(events);
ProcessManager::preempt_enable();
}
schedule(SchedMode::SM_NONE);
}
pub fn add_epoll(&self, epitem: Arc<EPollItem>) {
self.epitems.lock_irqsave().push_back(epitem)
}
pub fn remove_epoll(&self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
let is_remove = !self
.epitems
.lock_irqsave()
.extract_if(|x| x.epoll().ptr_eq(epoll))
.collect::<Vec<_>>()
.is_empty();
if is_remove {
return Ok(());
}
Err(SystemError::ENOENT)
}
/// ### 唤醒该队列上等待events的进程
///
/// ### 参数
/// - events: 发生的事件
///
/// 需要注意的是只要触发了events中的任意一件事件进程都会被唤醒
pub fn wakeup_any(&self, events: u64) {
self.wait_queue.wakeup_any(events);
}
}
#[derive(Debug)]
pub struct SocketHandleItem {
/// 对应的posix socket是否为listen的
pub is_posix_listen: bool,
/// shutdown状态
pub shutdown_type: RwLock<ShutdownType>,
pub posix_item: Weak<PosixSocketHandleItem>,
}
impl SocketHandleItem {
pub fn new(posix_item: Weak<PosixSocketHandleItem>) -> Self {
Self {
is_posix_listen: false,
shutdown_type: RwLock::new(ShutdownType::empty()),
posix_item,
}
}
pub fn shutdown_type(&self) -> ShutdownType {
*self.shutdown_type.read()
}
pub fn shutdown_type_writer(&mut self) -> RwLockWriteGuard<ShutdownType> {
self.shutdown_type.write_irqsave()
}
pub fn reset_shutdown_type(&self) {
*self.shutdown_type.write() = ShutdownType::empty();
}
pub fn posix_item(&self) -> Option<Arc<PosixSocketHandleItem>> {
self.posix_item.upgrade()
}
}
/// # TCP 和 UDP 的端口管理器。
/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。
pub struct PortManager {
// TCP 端口记录表
tcp_port_table: SpinLock<HashMap<u16, Pid>>,
// UDP 端口记录表
udp_port_table: SpinLock<HashMap<u16, Pid>>,
}
impl PortManager {
pub fn new() -> Self {
return Self {
tcp_port_table: SpinLock::new(HashMap::new()),
udp_port_table: SpinLock::new(HashMap::new()),
};
}
/// @brief 自动分配一个相对应协议中未被使用的PORT如果动态端口均已被占用返回错误码 EADDRINUSE
pub fn get_ephemeral_port(&self, socket_type: SocketType) -> Result<u16, SystemError> {
// TODO: selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 0;
unsafe {
if EPHEMERAL_PORT == 0 {
EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
}
}
let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数
let mut port: u16;
while remaining > 0 {
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT += 1;
}
port = EPHEMERAL_PORT;
}
// 使用 ListenTable 检查端口是否被占用
let listen_table_guard = match socket_type {
SocketType::Udp => self.udp_port_table.lock(),
SocketType::Tcp => self.tcp_port_table.lock(),
_ => panic!("{:?} cann't get a port", socket_type),
};
if listen_table_guard.get(&port).is_none() {
drop(listen_table_guard);
return Ok(port);
}
remaining -= 1;
}
return Err(SystemError::EADDRINUSE);
}
/// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录
///
/// TODO: 增加支持端口复用的逻辑
pub fn bind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> {
if port > 0 {
let mut listen_table_guard = match socket_type {
SocketType::Udp => self.udp_port_table.lock(),
SocketType::Tcp => self.tcp_port_table.lock(),
_ => panic!("{:?} cann't bind a port", socket_type),
};
match listen_table_guard.get(&port) {
Some(_) => return Err(SystemError::EADDRINUSE),
None => listen_table_guard.insert(port, ProcessManager::current_pid()),
};
drop(listen_table_guard);
}
return Ok(());
}
/// @brief 在对应的端口记录表中将端口和 socket 解绑
/// should call this function when socket is closed or aborted
pub fn unbind_port(&self, socket_type: SocketType, port: u16) {
let mut listen_table_guard = match socket_type {
SocketType::Udp => self.udp_port_table.lock(),
SocketType::Tcp => self.tcp_port_table.lock(),
_ => {
return;
}
};
listen_table_guard.remove(&port);
drop(listen_table_guard);
}
}
/// @brief socket的类型
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SocketType {
/// 原始的socket
Raw,
/// 用于Tcp通信的 Socket
Tcp,
/// 用于Udp通信的 Socket
Udp,
/// unix域的 Socket
Unix,
}
bitflags! {
/// @brief socket的选项
#[derive(Default)]
pub struct SocketOptions: u32 {
/// 是否阻塞
const BLOCK = 1 << 0;
/// 是否允许广播
const BROADCAST = 1 << 1;
/// 是否允许多播
const MULTICAST = 1 << 2;
/// 是否允许重用地址
const REUSEADDR = 1 << 3;
/// 是否允许重用端口
const REUSEPORT = 1 << 4;
}
}
#[derive(Debug, Clone)]
/// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
pub struct SocketMetadata {
/// socket的类型
pub socket_type: SocketType,
/// 接收缓冲区的大小
pub rx_buf_size: usize,
/// 发送缓冲区的大小
pub tx_buf_size: usize,
/// 元数据的缓冲区的大小
pub metadata_buf_size: usize,
/// socket的选项
pub options: SocketOptions,
}
impl SocketMetadata {
fn new(
socket_type: SocketType,
rx_buf_size: usize,
tx_buf_size: usize,
metadata_buf_size: usize,
options: SocketOptions,
) -> Self {
Self {
socket_type,
rx_buf_size,
tx_buf_size,
metadata_buf_size,
options,
}
}
}
/// @brief 地址族的枚举
///
/// 参考https://code.dragonos.org.cn/xref/linux-5.19.10/include/linux/socket.h#180
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
pub enum AddressFamily {
/// AF_UNSPEC 表示地址族未指定
Unspecified = 0,
/// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同)
Unix = 1,
/// AF_INET 表示IPv4的socket
INet = 2,
/// AF_AX25 表示AMPR AX.25的socket
AX25 = 3,
/// AF_IPX 表示IPX的socket
IPX = 4,
/// AF_APPLETALK 表示Appletalk的socket
Appletalk = 5,
/// AF_NETROM 表示AMPR NET/ROM的socket
Netrom = 6,
/// AF_BRIDGE 表示多协议桥接的socket
Bridge = 7,
/// AF_ATMPVC 表示ATM PVCs的socket
Atmpvc = 8,
/// AF_X25 表示X.25的socket
X25 = 9,
/// AF_INET6 表示IPv6的socket
INet6 = 10,
/// AF_ROSE 表示AMPR ROSE的socket
Rose = 11,
/// AF_DECnet Reserved for DECnet project
Decnet = 12,
/// AF_NETBEUI Reserved for 802.2LLC project
Netbeui = 13,
/// AF_SECURITY 表示Security callback的伪AF
Security = 14,
/// AF_KEY 表示Key management API
Key = 15,
/// AF_NETLINK 表示Netlink的socket
Netlink = 16,
/// AF_PACKET 表示Low level packet interface
Packet = 17,
/// AF_ASH 表示Ash
Ash = 18,
/// AF_ECONET 表示Acorn Econet
Econet = 19,
/// AF_ATMSVC 表示ATM SVCs
Atmsvc = 20,
/// AF_RDS 表示Reliable Datagram Sockets
Rds = 21,
/// AF_SNA 表示Linux SNA Project
Sna = 22,
/// AF_IRDA 表示IRDA sockets
Irda = 23,
/// AF_PPPOX 表示PPPoX sockets
Pppox = 24,
/// AF_WANPIPE 表示WANPIPE API sockets
WanPipe = 25,
/// AF_LLC 表示Linux LLC
Llc = 26,
/// AF_IB 表示Native InfiniBand address
/// 介绍https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks
Ib = 27,
/// AF_MPLS 表示MPLS
Mpls = 28,
/// AF_CAN 表示Controller Area Network
Can = 29,
/// AF_TIPC 表示TIPC sockets
Tipc = 30,
/// AF_BLUETOOTH 表示Bluetooth sockets
Bluetooth = 31,
/// AF_IUCV 表示IUCV sockets
Iucv = 32,
/// AF_RXRPC 表示RxRPC sockets
Rxrpc = 33,
/// AF_ISDN 表示mISDN sockets
Isdn = 34,
/// AF_PHONET 表示Phonet sockets
Phonet = 35,
/// AF_IEEE802154 表示IEEE 802.15.4 sockets
Ieee802154 = 36,
/// AF_CAIF 表示CAIF sockets
Caif = 37,
/// AF_ALG 表示Algorithm sockets
Alg = 38,
/// AF_NFC 表示NFC sockets
Nfc = 39,
/// AF_VSOCK 表示vSockets
Vsock = 40,
/// AF_KCM 表示Kernel Connection Multiplexor
Kcm = 41,
/// AF_QIPCRTR 表示Qualcomm IPC Router
Qipcrtr = 42,
/// AF_SMC 表示SMC-R sockets.
/// reserve number for PF_SMC protocol family that reuses AF_INET address family
Smc = 43,
/// AF_XDP 表示XDP sockets
Xdp = 44,
/// AF_MCTP 表示Management Component Transport Protocol
Mctp = 45,
/// AF_MAX 表示最大的地址族
Max = 46,
}
impl TryFrom<u16> for AddressFamily {
type Error = SystemError;
fn try_from(x: u16) -> Result<Self, Self::Error> {
use num_traits::FromPrimitive;
return <Self as FromPrimitive>::from_u16(x).ok_or(SystemError::EINVAL);
}
}
/// @brief posix套接字类型的枚举(这些值与linux内核中的值一致)
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
pub enum PosixSocketType {
Stream = 1,
Datagram = 2,
Raw = 3,
Rdm = 4,
SeqPacket = 5,
Dccp = 6,
Packet = 10,
}
impl TryFrom<u8> for PosixSocketType {
type Error = SystemError;
fn try_from(x: u8) -> Result<Self, Self::Error> {
use num_traits::FromPrimitive;
return <Self as FromPrimitive>::from_u8(x).ok_or(SystemError::EINVAL);
}
}
/// ### 为socket提供无锁的poll方法
///
/// 因为在网卡中断中需要轮询socket的状态如果使用socket文件或者其inode来poll
/// 在当前的设计会必然死锁所以引用这一个设计来解决提供无🔓的poll
pub struct SocketPollMethod;
impl SocketPollMethod {
pub fn poll(socket: &socket::Socket, handle_item: &SocketHandleItem) -> EPollEventType {
let shutdown = handle_item.shutdown_type();
match socket {
socket::Socket::Udp(udp) => Self::udp_poll(udp, shutdown),
socket::Socket::Tcp(tcp) => Self::tcp_poll(tcp, shutdown, handle_item.is_posix_listen),
socket::Socket::Raw(raw) => Self::raw_poll(raw, shutdown),
_ => todo!(),
}
}
pub fn tcp_poll(
socket: &tcp::Socket,
shutdown: ShutdownType,
is_posix_listen: bool,
) -> EPollEventType {
let mut events = EPollEventType::empty();
// debug!("enter tcp_poll! is_posix_listen:{}", is_posix_listen);
// 处理listen的socket
if is_posix_listen {
// 如果是listen的socket那么只有EPOLLIN和EPOLLRDNORM
if socket.is_active() {
events.insert(EPollEventType::EPOLL_LISTEN_CAN_ACCEPT);
}
// debug!("tcp_poll listen socket! events:{:?}", events);
return events;
}
let state = socket.state();
if shutdown == ShutdownType::SHUTDOWN_MASK || state == tcp::State::Closed {
events.insert(EPollEventType::EPOLLHUP);
}
if shutdown.contains(ShutdownType::RCV_SHUTDOWN) {
events.insert(
EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM | EPollEventType::EPOLLRDHUP,
);
}
// Connected or passive Fast Open socket?
if state != tcp::State::SynSent && state != tcp::State::SynReceived {
// socket有可读数据
if socket.can_recv() {
events.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM);
}
if !(shutdown.contains(ShutdownType::SEND_SHUTDOWN)) {
// 缓冲区可写这里判断可写的逻辑好像跟linux不太一样
if socket.send_queue() < socket.send_capacity() {
events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM);
} else {
// TODO触发缓冲区已满的信号SIGIO
todo!("A signal SIGIO that the buffer is full needs to be sent");
}
} else {
// 如果我们的socket关闭了SEND_SHUTDOWNepoll事件就是EPOLLOUT
events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM);
}
} else if state == tcp::State::SynSent {
events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM);
}
// socket发生错误
// TODO: 这里的逻辑可能有问题需要进一步验证是否is_active()==false就代表socket发生错误
if !socket.is_active() {
events.insert(EPollEventType::EPOLLERR);
}
events
}
pub fn udp_poll(socket: &udp::Socket, shutdown: ShutdownType) -> EPollEventType {
let mut event = EPollEventType::empty();
if shutdown.contains(ShutdownType::RCV_SHUTDOWN) {
event.insert(
EPollEventType::EPOLLRDHUP | EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM,
);
}
if shutdown.contains(ShutdownType::SHUTDOWN_MASK) {
event.insert(EPollEventType::EPOLLHUP);
}
if socket.can_recv() {
event.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM);
}
if socket.can_send() {
event.insert(
EPollEventType::EPOLLOUT
| EPollEventType::EPOLLWRNORM
| EPollEventType::EPOLLWRBAND,
);
} else {
// TODO: 缓冲区空间不够,需要使用信号处理
todo!()
}
return event;
}
pub fn raw_poll(socket: &raw::Socket, shutdown: ShutdownType) -> EPollEventType {
//debug!("enter raw_poll!");
let mut event = EPollEventType::empty();
if shutdown.contains(ShutdownType::RCV_SHUTDOWN) {
event.insert(
EPollEventType::EPOLLRDHUP | EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM,
);
}
if shutdown.contains(ShutdownType::SHUTDOWN_MASK) {
event.insert(EPollEventType::EPOLLHUP);
}
if socket.can_recv() {
//debug!("poll can recv!");
event.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM);
} else {
//debug!("poll can not recv!");
}
if socket.can_send() {
//debug!("poll can send!");
event.insert(
EPollEventType::EPOLLOUT
| EPollEventType::EPOLLWRNORM
| EPollEventType::EPOLLWRBAND,
);
} else {
//debug!("poll can not send!");
// TODO: 缓冲区空间不够,需要使用信号处理
todo!()
}
return event;
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,9 @@
use super::skbuff::SkBuff;
use crate::libs::rwlock::RwLock;
use alloc::sync::Arc;
use core::fmt::Debug;
pub trait NetlinkCallback: Send + Sync + Debug {
/// 接收到netlink数据包时的回调函数
fn netlink_rcv(&self, skb: Arc<RwLock<SkBuff>>) -> i32;
}
struct NetlinkCallbackData {}

View File

@ -0,0 +1,10 @@
use crate::net::syscall::SockAddrNl;
#[derive(Debug, Clone)]
pub struct NetlinkEndpoint {
pub addr: SockAddrNl,
}
impl NetlinkEndpoint {
pub fn new(addr: SockAddrNl) -> Self {
NetlinkEndpoint { addr }
}
}

View File

@ -0,0 +1,44 @@
use alloc::sync::Arc;
use netlink::NETLINK_KOBJECT_UEVENT;
use system_error::SystemError;
use crate::driver::base::uevent::KobjUeventEnv;
use super::{family, inet::datagram, Inode, Socket, Type};
//https://code.dragonos.org.cn/xref/linux-6.1.9/net/netlink/
/*
.. - -
Kconfig
Makefile
af_netlink.c
af_netlink.h
diag.c Netlink Netlink
genetlink.c
policy.c
*/
// Top-level module defining the public API for Netlink
pub mod af_netlink;
pub mod callback;
pub mod endpoint;
pub mod netlink;
pub mod netlink_proto;
pub mod skbuff;
pub mod sock;
pub struct Netlink;
impl family::Family for Netlink {
/// 用户空间创建一个新的套接字的入口
fn socket(stype: Type, _protocol: u32) -> Result<Arc<Inode>, SystemError> {
let socket = create_netlink_socket(_protocol)?;
Ok(Inode::new(socket))
}
}
/// 用户空间创建一个新的Netlink套接字
fn create_netlink_socket(_protocol: u32) -> Result<Arc<dyn Socket>, SystemError> {
match _protocol as usize {
NETLINK_KOBJECT_UEVENT => Ok(Arc::new(af_netlink::NetlinkSock::new())),
_ => Err(SystemError::EPROTONOSUPPORT),
}
}

View File

@ -0,0 +1,319 @@
use alloc::{
boxed::Box,
slice,
sync::{Arc, Weak},
vec::Vec,
};
use system_error::SystemError;
//定义Netlink消息的结构体如NLmsghdr和geNLmsghdr(拓展的netlink消息头),以及用于封包和解包消息的函数。
//参考 https://code.dragonos.org.cn/xref/linux-6.1.9/include/linux/netlink.h
// SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note
// Ensure the header is only included once
use crate::libs::mutex::Mutex;
use core::mem;
use super::af_netlink::{
netlink_insert, Listeners, NetlinkFlags, NetlinkSock, NetlinkSocket, NL_TABLE,
};
// Netlink protocol family
pub const NETLINK_ROUTE: usize = 0;
pub const NETLINK_UNUSED: usize = 1;
pub const NETLINK_USERSOCK: usize = 2;
pub const NETLINK_FIREWALL: usize = 3;
pub const NETLINK_SOCK_DIAG: usize = 4;
pub const NETLINK_NFLOG: usize = 5;
pub const NETLINK_XFRM: usize = 6;
pub const NETLINK_SELINUX: usize = 7;
pub const NETLINK_ISCSI: usize = 8;
pub const NETLINK_AUDIT: usize = 9;
pub const NETLINK_FIB_LOOKUP: usize = 10;
pub const NETLINK_CONNECTOR: usize = 11;
pub const NETLINK_NETFILTER: usize = 12;
pub const NETLINK_IP6_FW: usize = 13;
pub const NETLINK_DNRTMSG: usize = 14;
// implemente uevent needed
pub const NETLINK_KOBJECT_UEVENT: usize = 15;
pub const NETLINK_GENERIC: usize = 16;
// pub const NETLINK_DM : usize = 17; // Assuming DM Events is unused, not defined
pub const NETLINK_SCSITRANSPORT: usize = 18;
pub const NETLINK_ECRYPTFS: usize = 19;
pub const NETLINK_RDMA: usize = 20;
pub const NETLINK_CRYPTO: usize = 21;
pub const NETLINK_SMC: usize = 22;
//pub const NETLINK_INET_DIAG = NETLINK_SOCK_DIAG;
pub const NETLINK_INET_DIAG: usize = 4;
pub const MAX_LINKS: usize = 32;
pub const NL_CFG_F_NONROOT_RECV: u32 = 1 << 0;
pub const NL_CFG_F_NONROOT_SEND: u32 = 1 << 1;
bitflags! {
/// 四种通用的消息类型 nlmsg_type
pub struct NLmsgType: u8 {
/* Nothing. */
const NLMSG_NOOP = 0x1;
/* Error */
const NLMSG_ERROR = 0x2;
/* End of a dump */
const NLMSG_DONE = 0x3;
/* Data lost */
const NLMSG_OVERRUN = 0x4;
}
//消息标记 nlmsg_flags
// const NLM_F_REQUEST = 1; /* It is request message. */
// const NLM_F_MULTI = 2; /* Multipart message, terminated by NLMSG_DONE */
// const NLM_F_ACK = 4; /* Reply with ack, with zero or error code */
// const NLM_F_ECHO = 8; /* Echo this request */
// const NLM_F_DUMP_INTR = 16; /* Dump was inconsistent due to sequence change */
pub struct NLmsgFlags: u16 {
/* Flags values */
const NLM_F_REQUEST = 0x01;
const NLM_F_MULTI = 0x02;
const NLM_F_ACK = 0x04;
const NLM_F_ECHO = 0x08;
const NLM_F_DUMP_INTR = 0x10;
const NLM_F_DUMP_FILTERED = 0x20;
/* Modifiers to GET request */
const NLM_F_ROOT = 0x100; /* specify tree root */
const NLM_F_MATCH = 0x200; /* return all matching */
const NLM_F_ATOMIC = 0x400; /* atomic GET */
//const NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH;
const NLM_F_DUMP = 0x100 | 0x200;
/* Modifiers to NEW request */
const NLM_F_REPLACE = 0x100; /* Override existing */
const NLM_F_EXCL = 0x200; /* Do not touch, if it exists */
const NLM_F_CREATE = 0x400; /* Create, if it does not exist */
const NLM_F_APPEND = 0x800; /* Add to end of list */
/* Modifiers to DELETE request */
const NLM_F_NONREC = 0x100; /* Do not delete recursively */
/* Flags for ACK message */
const NLM_F_CAPPED = 0x100; /* request was capped */
const NLM_F_ACK_TLVS = 0x200; /* extended ACK TVLs were included */
}
}
/// netlink消息报头
/**
* struct NLmsghdr - fixed format metadata header of Netlink messages
* @nlmsg_len: Length of message including header
* @nlmsg_type: Message content type
* @nlmsg_flags: Additional flags
* @nlmsg_seq: Sequence number
* @nlmsg_pid: Sending process port ID
*/
pub struct NLmsghdr {
pub nlmsg_len: usize,
pub nlmsg_type: NLmsgType,
pub nlmsg_flags: NLmsgFlags,
pub nlmsg_seq: u32,
pub nlmsg_pid: u32,
}
const NLMSG_ALIGNTO: usize = 4;
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum NetlinkState {
NetlinkUnconnected = 0,
NetlinkConnected,
NETLINK_S_CONGESTED = 2,
}
fn nlmsg_align(len: usize) -> usize {
(len + NLMSG_ALIGNTO - 1) & !(NLMSG_ALIGNTO - 1)
}
fn nlmsg_hdrlen() -> usize {
nlmsg_align(mem::size_of::<NLmsghdr>())
}
fn nlmsg_length(len: usize) -> usize {
len + nlmsg_hdrlen()
}
fn nlmsg_space(len: usize) -> usize {
nlmsg_align(nlmsg_length(len))
}
unsafe fn nlmsg_data(nlh: &NLmsghdr) -> *mut u8 {
((nlh as *const NLmsghdr) as *mut u8).add(nlmsg_length(0))
}
unsafe fn nlmsg_next(nlh: *mut NLmsghdr, len: usize) -> *mut NLmsghdr {
let nlmsg_len = (*nlh).nlmsg_len;
let new_len = len - nlmsg_align(nlmsg_len);
nlh.add(nlmsg_align(nlmsg_len))
}
fn nlmsg_ok(nlh: &NLmsghdr, len: usize) -> bool {
len >= nlmsg_hdrlen() && nlh.nlmsg_len >= nlmsg_hdrlen() && nlh.nlmsg_len <= len
}
fn nlmsg_payload(nlh: &NLmsghdr, len: usize) -> usize {
nlh.nlmsg_len - nlmsg_space(len)
}
// 定义类型别名来简化闭包类型的定义
type InputCallback = Arc<dyn FnMut() + Send + Sync>;
type BindCallback = Arc<dyn Fn(i32) -> i32 + Send + Sync>;
type UnbindCallback = Arc<dyn Fn(i32) -> i32 + Send + Sync>;
type CompareCallback = Arc<dyn Fn(&NetlinkSock) -> bool + Send + Sync>;
/// 该结构包含了内核netlink的可选参数:
#[derive(Default)]
pub struct NetlinkKernelCfg {
pub groups: u32,
pub flags: u32,
pub input: Option<InputCallback>,
pub bind: Option<BindCallback>,
pub unbind: Option<UnbindCallback>,
pub compare: Option<CompareCallback>,
}
impl NetlinkKernelCfg {
pub fn new() -> Self {
NetlinkKernelCfg {
groups: 32,
flags: 0,
input: None,
bind: None,
unbind: None,
compare: None,
}
}
pub fn set_input<F>(&mut self, callback: F)
where
F: FnMut() + Send + Sync + 'static,
{
self.input = Some(Arc::new(callback));
}
pub fn set_bind<F>(&mut self, callback: F)
where
F: Fn(i32) -> i32 + Send + Sync + 'static,
{
self.bind = Some(Arc::new(callback));
}
pub fn set_unbind<F>(&mut self, callback: F)
where
F: Fn(i32) -> i32 + Send + Sync + 'static,
{
self.unbind = Some(Arc::new(callback));
}
pub fn set_compare<F>(&mut self, callback: F)
where
F: Fn(&NetlinkSock) -> bool + Send + Sync + 'static,
{
self.compare = Some(Arc::new(callback));
}
}
//https://code.dragonos.org.cn/xref/linux-6.1.9/include/linux/netlink.h#229
//netlink属性头
struct NLattr {
nla_len: u16,
nla_type: u16,
}
pub trait VecExt {
fn align4(&mut self);
fn push_ext<T: Sized>(&mut self, data: T);
fn set_ext<T: Sized>(&mut self, offset: usize, data: T);
}
impl VecExt for Vec<u8> {
fn align4(&mut self) {
let len = (self.len() + 3) & !3;
if len > self.len() {
self.resize(len, 0);
}
}
fn push_ext<T: Sized>(&mut self, data: T) {
#[allow(unsafe_code)]
let bytes =
unsafe { slice::from_raw_parts(&data as *const T as *const u8, size_of::<T>()) };
for byte in bytes {
self.push(*byte);
}
}
fn set_ext<T: Sized>(&mut self, offset: usize, data: T) {
if self.len() < offset + size_of::<T>() {
self.resize(offset + size_of::<T>(), 0);
}
#[allow(unsafe_code)]
let bytes =
unsafe { slice::from_raw_parts(&data as *const T as *const u8, size_of::<T>()) };
self[offset..(bytes.len() + offset)].copy_from_slice(bytes);
}
}
// todo net namespace
pub fn netlink_kernel_create(
unit: usize,
cfg: Option<NetlinkKernelCfg>,
) -> Result<NetlinkSock, SystemError> {
// THIS_MODULE
let mut nlk: NetlinkSock = NetlinkSock::new();
let sk: Arc<Mutex<Box<dyn NetlinkSocket>>> = Arc::new(Mutex::new(Box::new(nlk.clone())));
let groups: u32;
if unit >= MAX_LINKS {
return Err(SystemError::EINVAL);
}
__netlink_create(&mut nlk, unit, 1).expect("__netlink_create failed");
if let Some(cfg) = cfg.as_ref() {
if cfg.groups < 32 {
groups = 32;
} else {
groups = cfg.groups;
}
} else {
groups = 32;
}
let listeners = Listeners::new();
// todo设计和实现回调函数
// sk.sk_data_read = netlink_data_ready;
// if cfg.is_some() && cfg.unwrap().input.is_some(){
// nlk.netlink_rcv = cfg.unwrap().input;
// }
netlink_insert(sk, 0).expect("netlink_insert failed");
nlk.flags |= NetlinkFlags::NETLINK_F_KERNEL_SOCKET.bits();
let mut nl_table = NL_TABLE.write();
if nl_table[unit].get_registered() == 0 {
nl_table[unit].set_groups(groups);
if let Some(cfg) = cfg.as_ref() {
nl_table[unit].bind = cfg.bind.clone();
nl_table[unit].unbind = cfg.unbind.clone();
nl_table[unit].set_flags(cfg.flags);
if cfg.compare.is_some() {
nl_table[unit].compare = cfg.compare.clone();
}
nl_table[unit].set_registered(1);
} else {
drop(listeners);
let registered = nl_table[unit].get_registered();
nl_table[unit].set_registered(registered + 1);
}
}
return Ok(nlk);
}
fn __netlink_create(nlk: &mut NetlinkSock, unit: usize, kern: usize) -> Result<i32, SystemError> {
// 其他的初始化配置参数
nlk.flags = kern as u32;
nlk.protocol = unit;
return Ok(0);
}
pub fn sk_data_ready(nlk: Arc<NetlinkSock>) -> Result<(), SystemError> {
// 唤醒
return Ok(());
}

View File

@ -0,0 +1,56 @@
use bitmap::{traits::BitMapOps, AllocBitmap};
use core::intrinsics::unlikely;
use system_error::SystemError;
use crate::libs::lazy_init::Lazy;
pub const PROTO_INUSE_NR: usize = 64;
// pub static mut PROTO_INUSE_IDX: Lazy<AllocBitmap> = Lazy::new();
// pub static PROTO_INUSE_IDX: Lazy<AllocBitmap> = Lazy::new(<AllocBitmap::new(PROTO_INUSE_NR));
/// 协议操作集的trait
pub trait Protocol {
fn close(&self);
// fn first_false_index(&self, proto_inuse_idx:usize, proto_inuse_nr:usize)->usize;
}
/// 协议操作集的结构体
pub struct Proto<'a> {
name: &'a str,
// owner: THIS_MODULE,
obj_size: usize,
inuse_idx: Option<usize>,
}
impl Protocol for Proto<'_> {
fn close(&self) {}
}
/// 静态变量用于注册netlink协议是一个操作集结构体的实例
// https://code.dragonos.org.cn/xref/linux-6.1.9/net/netlink/af_netlink.c#634
pub static mut NETLINK_PROTO: Proto = Proto {
name: "NETLINK",
// owner: THIS_MODULE,
obj_size: core::mem::size_of::<Proto>(),
// 运行时分配的索引
inuse_idx: None,
};
// https://code.dragonos.org.cn/xref/linux-6.1.9/net/core/sock.c?fi=proto_register#3853
/// 注册协议
pub fn proto_register(proto: &mut Proto, alloc_slab: i32) -> Result<i32, SystemError> {
let mut ret = Err(SystemError::ENOBUFS);
if alloc_slab != 0 {
log::info!("TODO: netlink_proto: slab allocation not supported\n");
return ret;
}
ret = assign_proto_idx(proto);
ret
}
// https://code.dragonos.org.cn/xref/linux-6.1.9/net/core/sock.c?fi=proto_register#3752
/// 为协议分配一个索引
pub fn assign_proto_idx(prot: &mut Proto) -> Result<i32, SystemError> {
// prot.inuse_idx = unsafe { PROTO_INUSE_IDX.first_false_index() };
// 如果没有找到空闲的索引
if unlikely(prot.inuse_idx == Some(PROTO_INUSE_NR - 1)) {
log::info!("PROTO_INUSE_NR exhausted\n");
return Err(SystemError::ENOSPC);
}
// 为协议分配一个索引
// unsafe { PROTO_INUSE_IDX.set((prot.inuse_idx).unwrap(), true) };
return Ok(0);
}

View File

@ -0,0 +1,109 @@
use super::af_netlink::{NetlinkSock, NetlinkSocket};
use crate::libs::{mutex::Mutex, rwlock::RwLock};
use alloc::{boxed::Box, sync::Arc};
// 曾用方案:在 smoltcp::PacketBuffer 的基础上封装了一层,用于处理 netlink 协议中网络数据包(skb)的相关操作
// 暂时弃用,目前尝试使用更简单的方式处理 skb
#[derive(Debug, Clone)]
pub struct SkBuff {
pub sk: Arc<Mutex<Box<dyn NetlinkSocket>>>,
pub len: u32,
pub pkt_type: u32,
pub mark: u32,
pub queue_mapping: u32,
pub protocol: u32,
pub vlan_present: u32,
pub vlan_tci: u32,
pub vlan_proto: u32,
pub priority: u32,
pub ingress_ifindex: u32,
pub ifindex: u32,
pub tc_index: u32,
pub cb: [u32; 5],
pub hash: u32,
pub tc_classid: u32,
pub data: u32,
pub data_end: u32,
pub napi_id: u32,
pub family: u32,
pub remote_ip4: u32,
pub local_ip4: u32,
pub remote_ip6: [u32; 4],
pub local_ip6: [u32; 4],
pub remote_port: u32,
pub local_port: u32,
pub data_meta: u32,
pub tstamp: u64,
pub wire_len: u32,
pub gso_segs: u32,
pub gso_size: u32,
pub tstamp_type: u8,
pub _bitfield_align_1: [u8; 0],
pub hwtstamp: u64,
}
impl SkBuff {
pub fn new() -> Self {
SkBuff {
sk: Arc::new(Mutex::new(Box::new(NetlinkSock::new()))),
len: 0,
pkt_type: 0,
mark: 0,
queue_mapping: 0,
protocol: 0,
vlan_present: 0,
vlan_tci: 0,
vlan_proto: 0,
priority: 0,
ingress_ifindex: 0,
ifindex: 0,
tc_index: 0,
cb: [0; 5],
hash: 0,
tc_classid: 0,
data: 0,
data_end: 0,
napi_id: 0,
family: 0,
remote_ip4: 0,
local_ip4: 0,
remote_ip6: [0; 4],
local_ip6: [0; 4],
remote_port: 0,
local_port: 0,
data_meta: 0,
tstamp: 0,
wire_len: 0,
gso_segs: 0,
gso_size: 0,
tstamp_type: 0,
_bitfield_align_1: [0; 0],
hwtstamp: 0,
}
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
// 处理网络套接字的过度运行情况
pub fn netlink_overrun(sk: &Arc<Mutex<Box<dyn NetlinkSocket>>>) {
// Implementation of the function
}
// 用于检查网络数据包(skb)是否被共享
pub fn skb_shared(skb: &RwLock<SkBuff>) -> bool {
// Implementation of the function
false
}
/// 处理被孤儿化的网络数据包(skb)
/// 孤儿化网络数据包意味着数据包不再与任何套接字关联,
/// 通常是因为发送数据包时指定了 MSG_DONTWAIT 标志,这告诉内核不要等待必要的资源(如内存),而是尽可能快地发送数据包。
pub fn skb_orphan(skb: &Arc<RwLock<SkBuff>>) {
// TODO: Implementation of the function
}
fn skb_recv_datagram() {}
fn skb_try_recv_datagram() {}
fn skb_try_recv_from_queue() {}

View File

@ -0,0 +1,34 @@
// Sock flags in Rust
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SockFlags {
SockDead,
SockDone,
SockUrginline,
SockKeepopen,
SockLinger,
SockDestroy,
SockBroadcast,
SockTimestamp,
SockZapped,
SockUseWriteQueue, // whether to call sk->sk_write_space in sock_wfree
SockDbg, // %SO_DEBUG setting
SockRcvtstamp, // %SO_TIMESTAMP setting
SockRcvtstampns, // %SO_TIMESTAMPNS setting
SockLocalroute, // route locally only, %SO_DONTROUTE setting
SockMemalloc, // VM depends on this socket for swapping
SockTimestampingRxSoftware, // %SOF_TIMESTAMPING_RX_SOFTWARE
SockFasync, // fasync() active
SockRxqOvfl,
SockZerocopy, // buffers from userspace
SockWifiStatus, // push wifi status to userspace
SockNofcs, // Tell NIC not to do the Ethernet FCS.
// Will use last 4 bytes of packet sent from
// user-space instead.
SockFilterLocked, // Filter cannot be changed anymore
SockSelectErrQueue, // Wake select on error queue
SockRcuFree, // wait rcu grace period in sk_destruct()
SockTxtime,
SockXdp, // XDP is attached
SockTstampNew, // Indicates 64 bit timestamps always
SockRcvmark, // Receive SO_MARK ancillary data with packet
}

View File

@ -1,239 +0,0 @@
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use system_error::SystemError;
use crate::{libs::spinlock::SpinLock, net::Endpoint};
use super::{
handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketInode, SocketMetadata,
SocketOptions, SocketType,
};
#[derive(Debug, Clone)]
pub struct StreamSocket {
metadata: SocketMetadata,
buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle,
posix_item: Arc<PosixSocketHandleItem>,
}
impl StreamSocket {
/// 默认的元数据缓冲区大小
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
/// 默认的缓冲区大小
pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
/// # 创建一个 Stream Socket
///
/// ## 参数
/// - `options`: socket选项
pub fn new(options: SocketOptions) -> Self {
let buffer = Arc::new(SpinLock::new(Vec::with_capacity(Self::DEFAULT_BUF_SIZE)));
let metadata = SocketMetadata::new(
SocketType::Unix,
Self::DEFAULT_BUF_SIZE,
Self::DEFAULT_BUF_SIZE,
Self::DEFAULT_METADATA_BUF_SIZE,
options,
);
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
Self {
metadata,
buffer,
peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(),
posix_item,
}
}
}
impl Socket for StreamSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
let mut buffer = self.buffer.lock_irqsave();
let len = core::cmp::min(buf.len(), buffer.len());
buf[..len].copy_from_slice(&buffer[..len]);
let _ = buffer.split_off(len);
(Ok(len), Endpoint::Inode(self.peer_inode.clone()))
}
fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
if self.peer_inode.is_none() {
return Err(SystemError::ENOTCONN);
}
let peer_inode = self.peer_inode.clone().unwrap();
let len = peer_inode.inner().write_buffer(buf)?;
Ok(len)
}
fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
if self.peer_inode.is_some() {
return Err(SystemError::EISCONN);
}
if let Endpoint::Inode(inode) = endpoint {
self.peer_inode = inode;
Ok(())
} else {
Err(SystemError::EINVAL)
}
}
fn write_buffer(&self, buf: &[u8]) -> Result<usize, SystemError> {
let mut buffer = self.buffer.lock_irqsave();
let len = buf.len();
if buffer.capacity() - buffer.len() < len {
return Err(SystemError::ENOBUFS);
}
buffer.extend_from_slice(buf);
Ok(len)
}
fn metadata(&self) -> SocketMetadata {
self.metadata.clone()
}
fn box_clone(&self) -> Box<dyn Socket> {
Box::new(self.clone())
}
fn as_any_ref(&self) -> &dyn core::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
self
}
}
#[derive(Debug, Clone)]
pub struct SeqpacketSocket {
metadata: SocketMetadata,
buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle,
posix_item: Arc<PosixSocketHandleItem>,
}
impl SeqpacketSocket {
/// 默认的元数据缓冲区大小
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
/// 默认的缓冲区大小
pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
/// # 创建一个 Seqpacket Socket
///
/// ## 参数
/// - `options`: socket选项
pub fn new(options: SocketOptions) -> Self {
let buffer = Arc::new(SpinLock::new(Vec::with_capacity(Self::DEFAULT_BUF_SIZE)));
let metadata = SocketMetadata::new(
SocketType::Unix,
Self::DEFAULT_BUF_SIZE,
Self::DEFAULT_BUF_SIZE,
Self::DEFAULT_METADATA_BUF_SIZE,
options,
);
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
Self {
metadata,
buffer,
peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(),
posix_item,
}
}
}
impl Socket for SeqpacketSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
let mut buffer = self.buffer.lock_irqsave();
let len = core::cmp::min(buf.len(), buffer.len());
buf[..len].copy_from_slice(&buffer[..len]);
let _ = buffer.split_off(len);
(Ok(len), Endpoint::Inode(self.peer_inode.clone()))
}
fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
if self.peer_inode.is_none() {
return Err(SystemError::ENOTCONN);
}
let peer_inode = self.peer_inode.clone().unwrap();
let len = peer_inode.inner().write_buffer(buf)?;
Ok(len)
}
fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
if self.peer_inode.is_some() {
return Err(SystemError::EISCONN);
}
if let Endpoint::Inode(inode) = endpoint {
self.peer_inode = inode;
Ok(())
} else {
Err(SystemError::EINVAL)
}
}
fn write_buffer(&self, buf: &[u8]) -> Result<usize, SystemError> {
let mut buffer = self.buffer.lock_irqsave();
let len = buf.len();
if buffer.capacity() - buffer.len() < len {
return Err(SystemError::ENOBUFS);
}
buffer.extend_from_slice(buf);
Ok(len)
}
fn socket_handle(&self) -> GlobalSocketHandle {
self.handle
}
fn metadata(&self) -> SocketMetadata {
self.metadata.clone()
}
fn box_clone(&self) -> Box<dyn Socket> {
Box::new(self.clone())
}
fn as_any_ref(&self) -> &dyn core::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
self
}
}

View File

@ -0,0 +1,37 @@
pub(crate) mod seqpacket;
mod stream;
use crate::{filesystem::vfs::InodeId, libs::rwlock::RwLock, net::socket::*};
use alloc::sync::Arc;
use hashbrown::HashMap;
use system_error::SystemError::{self, *};
pub struct Unix;
lazy_static! {
pub static ref INODE_MAP: RwLock<HashMap<InodeId, Endpoint>> = RwLock::new(HashMap::new());
}
fn create_unix_socket(sock_type: Type) -> Result<Arc<Inode>, SystemError> {
match sock_type {
Type::Stream | Type::Datagram => stream::StreamSocket::new_inode(),
Type::SeqPacket => seqpacket::SeqpacketSocket::new_inode(false),
_ => Err(EPROTONOSUPPORT),
}
}
impl family::Family for Unix {
fn socket(stype: Type, _protocol: u32) -> Result<Arc<Inode>, SystemError> {
let socket = create_unix_socket(stype)?;
Ok(socket)
}
}
impl Unix {
pub fn new_pairs(socket_type: Type) -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
log::debug!("socket_type {:?}", socket_type);
match socket_type {
Type::SeqPacket => seqpacket::SeqpacketSocket::new_pairs(),
Type::Stream | Type::Datagram => stream::StreamSocket::new_pairs(),
_ => todo!(),
}
}
}

View File

@ -0,0 +1,260 @@
use alloc::string::String;
use alloc::{collections::VecDeque, sync::Arc};
use core::sync::atomic::{AtomicUsize, Ordering};
use super::SeqpacketSocket;
use crate::{
libs::mutex::Mutex,
net::socket::{buffer::Buffer, endpoint::Endpoint, Inode, ShutdownTemp},
};
use system_error::SystemError::{self, *};
#[derive(Debug)]
pub(super) struct Init {
inode: Option<Endpoint>,
}
impl Init {
pub(super) fn new() -> Self {
Self { inode: None }
}
pub(super) fn bind(&mut self, epoint_to_bind: Endpoint) -> Result<(), SystemError> {
if self.inode.is_some() {
log::error!("the socket is already bound");
return Err(EINVAL);
}
match epoint_to_bind {
Endpoint::Inode(_) => self.inode = Some(epoint_to_bind),
_ => return Err(EINVAL),
}
return Ok(());
}
pub fn bind_path(&mut self, sun_path: String) -> Result<Endpoint, SystemError> {
if self.inode.is_none() {
log::error!("the socket is not bound");
return Err(EINVAL);
}
if let Some(Endpoint::Inode((inode, mut path))) = self.inode.take() {
path = sun_path;
let epoint = Endpoint::Inode((inode, path));
self.inode.replace(epoint.clone());
return Ok(epoint);
};
return Err(SystemError::EINVAL);
}
pub fn endpoint(&self) -> Option<&Endpoint> {
return self.inode.as_ref();
}
}
#[derive(Debug)]
pub(super) struct Listener {
inode: Endpoint,
backlog: AtomicUsize,
incoming_conns: Mutex<VecDeque<Arc<Inode>>>,
}
impl Listener {
pub(super) fn new(inode: Endpoint, backlog: usize) -> Self {
log::debug!("backlog {}", backlog);
let back = if backlog > 1024 {
1024 as usize
} else {
backlog
};
return Self {
inode,
backlog: AtomicUsize::new(back),
incoming_conns: Mutex::new(VecDeque::with_capacity(back)),
};
}
pub(super) fn endpoint(&self) -> &Endpoint {
return &self.inode;
}
pub(super) fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
let mut incoming_conns = self.incoming_conns.lock();
log::debug!(" incom len {}", incoming_conns.len());
let conn = incoming_conns
.pop_front()
.ok_or_else(|| SystemError::EAGAIN_OR_EWOULDBLOCK)?;
let socket =
Arc::downcast::<SeqpacketSocket>(conn.inner()).map_err(|_| SystemError::EINVAL)?;
let peer = match &*socket.inner.read() {
Inner::Connected(connected) => connected.peer_endpoint().unwrap().clone(),
_ => return Err(SystemError::ENOTCONN),
};
return Ok((Inode::new(socket), peer));
}
pub(super) fn listen(&self, backlog: usize) -> Result<(), SystemError> {
self.backlog.store(backlog, Ordering::Relaxed);
Ok(())
}
pub(super) fn push_incoming(
&self,
client_epoint: Option<Endpoint>,
) -> Result<Connected, SystemError> {
let mut incoming_conns = self.incoming_conns.lock();
if incoming_conns.len() >= self.backlog.load(Ordering::Relaxed) {
log::error!("the pending connection queue on the listening socket is full");
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
let new_server = SeqpacketSocket::new(false);
let new_inode = Inode::new(new_server.clone());
// log::debug!("new inode {:?},client_epoint {:?}",new_inode,client_epoint);
let path = match &self.inode {
Endpoint::Inode((_, path)) => path.clone(),
_ => return Err(SystemError::EINVAL),
};
let (server_conn, client_conn) = Connected::new_pair(
Some(Endpoint::Inode((new_inode.clone(), path))),
client_epoint,
);
*new_server.inner.write() = Inner::Connected(server_conn);
incoming_conns.push_back(new_inode);
// TODO: epollin
Ok(client_conn)
}
pub(super) fn is_acceptable(&self) -> bool {
return self.incoming_conns.lock().len() != 0;
}
}
#[derive(Debug)]
pub struct Connected {
inode: Option<Endpoint>,
peer_inode: Option<Endpoint>,
buffer: Arc<Buffer>,
}
impl Connected {
/// 默认的缓冲区大小
pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
pub fn new_pair(
inode: Option<Endpoint>,
peer_inode: Option<Endpoint>,
) -> (Connected, Connected) {
let this = Connected {
inode: inode.clone(),
peer_inode: peer_inode.clone(),
buffer: Buffer::new(),
};
let peer = Connected {
inode: peer_inode,
peer_inode: inode,
buffer: Buffer::new(),
};
(this, peer)
}
pub fn set_peer_inode(&mut self, peer_epoint: Option<Endpoint>) {
self.peer_inode = peer_epoint;
}
pub fn set_inode(&mut self, epoint: Option<Endpoint>) {
self.inode = epoint;
}
pub fn endpoint(&self) -> Option<&Endpoint> {
self.inode.as_ref()
}
pub fn peer_endpoint(&self) -> Option<&Endpoint> {
self.peer_inode.as_ref()
}
pub fn try_read(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
if self.can_recv() {
return self.recv_slice(buf);
} else {
return Err(SystemError::EINVAL);
}
}
pub fn try_write(&self, buf: &[u8]) -> Result<usize, SystemError> {
if self.can_send()? {
return self.send_slice(buf);
} else {
log::debug!("can not send {:?}", String::from_utf8_lossy(&buf[..]));
return Err(SystemError::ENOBUFS);
}
}
pub fn can_recv(&self) -> bool {
return !self.buffer.is_read_buf_empty();
}
// 检查发送缓冲区是否满了
pub fn can_send(&self) -> Result<bool, SystemError> {
// let sebuffer = self.sebuffer.lock(); // 获取锁
// sebuffer.capacity()-sebuffer.len() ==0;
let peer_inode = match self.peer_inode.as_ref().unwrap() {
Endpoint::Inode((inode, _)) => inode,
_ => return Err(SystemError::EINVAL),
};
let peer_socket = Arc::downcast::<SeqpacketSocket>(peer_inode.inner())
.map_err(|_| SystemError::EINVAL)?;
let is_full = match &*peer_socket.inner.read() {
Inner::Connected(connected) => connected.buffer.is_read_buf_full(),
_ => return Err(SystemError::EINVAL),
};
Ok(!is_full)
}
pub fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
return self.buffer.read_read_buffer(buf);
}
pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> {
//找到peer_inode并将write_buffer的内容写入对端的read_buffer
let peer_inode = match self.peer_inode.as_ref().unwrap() {
Endpoint::Inode((inode, _)) => inode,
_ => return Err(SystemError::EINVAL),
};
let peer_socket = Arc::downcast::<SeqpacketSocket>(peer_inode.inner())
.map_err(|_| SystemError::EINVAL)?;
let usize = match &*peer_socket.inner.write() {
Inner::Connected(connected) => {
let usize = connected.buffer.write_read_buffer(buf)?;
usize
}
_ => return Err(SystemError::EINVAL),
};
peer_socket.wait_queue.wakeup(None);
Ok(usize)
}
pub fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
if how.is_empty() {
return Err(SystemError::EINVAL);
} else if how.is_send_shutdown() {
unimplemented!("unimplemented!");
} else if how.is_recv_shutdown() {
unimplemented!("unimplemented!");
}
Ok(())
}
}
#[derive(Debug)]
pub(super) enum Inner {
Init(Init),
Listen(Listener),
Connected(Connected),
}

View File

@ -0,0 +1,483 @@
pub mod inner;
use alloc::{
string::String,
sync::{Arc, Weak},
};
use core::sync::atomic::{AtomicBool, Ordering};
use crate::sched::SchedMode;
use crate::{libs::rwlock::RwLock, net::socket::*};
use inner::*;
use system_error::SystemError;
use super::INODE_MAP;
type EP = EPollEventType;
#[derive(Debug)]
pub struct SeqpacketSocket {
inner: RwLock<Inner>,
shutdown: Shutdown,
is_nonblocking: AtomicBool,
wait_queue: WaitQueue,
self_ref: Weak<Self>,
}
impl SeqpacketSocket {
/// 默认的元数据缓冲区大小
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
/// 默认的缓冲区大小
pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
pub fn new(is_nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Inner::Init(Init::new())),
shutdown: Shutdown::new(),
is_nonblocking: AtomicBool::new(is_nonblocking),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
})
}
pub fn new_inode(is_nonblocking: bool) -> Result<Arc<Inode>, SystemError> {
let socket = SeqpacketSocket::new(is_nonblocking);
let inode = Inode::new(socket.clone());
// 建立时绑定自身为后续能正常获取本端地址
let _ = match &mut *socket.inner.write() {
Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))),
_ => return Err(SystemError::EINVAL),
};
return Ok(inode);
}
pub fn new_connected(connected: Connected, is_nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Inner::Connected(connected)),
shutdown: Shutdown::new(),
is_nonblocking: AtomicBool::new(is_nonblocking),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
})
}
pub fn new_pairs() -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
let socket0 = SeqpacketSocket::new(false);
let socket1 = SeqpacketSocket::new(false);
let inode0 = Inode::new(socket0.clone());
let inode1 = Inode::new(socket1.clone());
let (conn_0, conn_1) = Connected::new_pair(
Some(Endpoint::Inode((inode0.clone(), String::from("")))),
Some(Endpoint::Inode((inode1.clone(), String::from("")))),
);
*socket0.inner.write() = Inner::Connected(conn_0);
*socket1.inner.write() = Inner::Connected(conn_1);
return Ok((inode0, inode1));
}
fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
match &*self.inner.read() {
Inner::Listen(listen) => listen.try_accept() as _,
_ => {
log::error!("the socket is not listening");
return Err(SystemError::EINVAL);
}
}
}
fn is_acceptable(&self) -> bool {
match &*self.inner.read() {
Inner::Listen(listen) => listen.is_acceptable(),
_ => {
panic!("the socket is not listening");
}
}
}
fn is_peer_shutdown(&self) -> Result<bool, SystemError> {
let peer_shutdown = match self.get_peer_name()? {
Endpoint::Inode((inode, _)) => Arc::downcast::<SeqpacketSocket>(inode.inner())
.map_err(|_| SystemError::EINVAL)?
.shutdown
.get()
.is_both_shutdown(),
_ => return Err(SystemError::EINVAL),
};
Ok(peer_shutdown)
}
fn can_recv(&self) -> Result<bool, SystemError> {
let can = match &*self.inner.read() {
Inner::Connected(connected) => connected.can_recv(),
_ => return Err(SystemError::ENOTCONN),
};
Ok(can)
}
fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Socket for SeqpacketSocket {
fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
let peer_inode = match endpoint {
Endpoint::Inode((inode, _)) => inode,
Endpoint::Unixpath((inode_id, _)) => {
let inode_guard = INODE_MAP.read_irqsave();
let inode = inode_guard.get(&inode_id).unwrap();
match inode {
Endpoint::Inode((inode, _)) => inode.clone(),
_ => return Err(SystemError::EINVAL),
}
}
_ => return Err(SystemError::EINVAL),
};
// 远端为服务端
let remote_socket = Arc::downcast::<SeqpacketSocket>(peer_inode.inner())
.map_err(|_| SystemError::EINVAL)?;
let client_epoint = match &mut *self.inner.write() {
Inner::Init(init) => match init.endpoint().cloned() {
Some(end) => {
log::debug!("bind when connect");
Some(end)
}
None => {
log::debug!("not bind when connect");
let inode = Inode::new(self.self_ref.upgrade().unwrap().clone());
let epoint = Endpoint::Inode((inode.clone(), String::from("")));
let _ = init.bind(epoint.clone());
Some(epoint)
}
},
Inner::Listen(_) => return Err(SystemError::EINVAL),
Inner::Connected(_) => return Err(SystemError::EISCONN),
};
// ***阻塞与非阻塞处理还未实现
// 客户端与服务端建立连接将服务端inode推入到自身的listen_incom队列中
// accept时从中获取推出对应的socket
match &*remote_socket.inner.read() {
Inner::Listen(listener) => match listener.push_incoming(client_epoint) {
Ok(connected) => {
*self.inner.write() = Inner::Connected(connected);
log::debug!("try to wake up");
remote_socket.wait_queue.wakeup(None);
return Ok(());
}
// ***错误处理
Err(_) => todo!(),
},
Inner::Init(_) => {
log::debug!("init einval");
return Err(SystemError::EINVAL);
}
Inner::Connected(_) => return Err(SystemError::EISCONN),
};
}
fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
// 将自身socket的inode与用户端提供路径的文件indoe_id进行绑定
match endpoint {
Endpoint::Unixpath((inodeid, path)) => {
let inode = match &mut *self.inner.write() {
Inner::Init(init) => init.bind_path(path)?,
_ => {
log::error!("socket has listen or connected");
return Err(SystemError::EINVAL);
}
};
INODE_MAP.write_irqsave().insert(inodeid, inode);
Ok(())
}
_ => return Err(SystemError::EINVAL),
}
}
fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
log::debug!("seqpacket shutdown");
match &*self.inner.write() {
Inner::Connected(connected) => connected.shutdown(how),
_ => Err(SystemError::EINVAL),
}
}
fn listen(&self, backlog: usize) -> Result<(), SystemError> {
let mut state = self.inner.write();
log::debug!("listen into socket");
let epoint = match &*state {
Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(),
Inner::Listen(listener) => return listener.listen(backlog),
Inner::Connected(_) => {
log::error!("the socket is connected");
return Err(SystemError::EINVAL);
}
};
let listener = Listener::new(epoint, backlog);
*state = Inner::Listen(listener);
Ok(())
}
fn accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
if !self.is_nonblocking() {
loop {
wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
match self
.try_accept()
.map(|(seqpacket_socket, remote_endpoint)| {
(seqpacket_socket, Endpoint::from(remote_endpoint))
}) {
Ok((socket, epoint)) => return Ok((socket, epoint)),
Err(_) => continue,
}
}
} else {
// ***非阻塞状态
todo!()
}
}
fn set_option(
&self,
_level: crate::net::socket::OptionsLevel,
_optname: usize,
_optval: &[u8],
) -> Result<(), SystemError> {
log::warn!("setsockopt is not implemented");
Ok(())
}
fn wait_queue(&self) -> &WaitQueue {
return &self.wait_queue;
}
fn close(&self) -> Result<(), SystemError> {
log::debug!("seqpacket close");
self.shutdown.recv_shutdown();
self.shutdown.send_shutdown();
Ok(())
}
fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
// 获取对端地址
let endpoint = match &*self.inner.read() {
Inner::Connected(connected) => connected.peer_endpoint().cloned(),
_ => return Err(SystemError::ENOTCONN),
};
if let Some(endpoint) = endpoint {
return Ok(Endpoint::from(endpoint));
} else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
}
fn get_name(&self) -> Result<Endpoint, SystemError> {
// 获取本端地址
let endpoint = match &*self.inner.read() {
Inner::Init(init) => init.endpoint().cloned(),
Inner::Listen(listener) => Some(listener.endpoint().clone()),
Inner::Connected(connected) => connected.endpoint().cloned(),
};
if let Some(endpoint) = endpoint {
return Ok(Endpoint::from(endpoint));
} else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
}
fn get_option(
&self,
_level: crate::net::socket::OptionsLevel,
_name: usize,
_value: &mut [u8],
) -> Result<usize, SystemError> {
log::warn!("getsockopt is not implemented");
Ok(0)
}
fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
self.recv(buffer, crate::net::socket::MessageFlag::empty())
}
fn recv(
&self,
buffer: &mut [u8],
flags: crate::net::socket::MessageFlag,
) -> Result<usize, SystemError> {
if flags.contains(MessageFlag::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if !flags.contains(MessageFlag::DONTWAIT) {
loop {
wq_wait_event_interruptible!(
self.wait_queue,
self.can_recv()? || self.is_peer_shutdown()?,
{}
)?;
// connect锁和flag判断顺序不正确应该先判断在
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_read(buffer) {
Ok(usize) => {
log::debug!("recv from successfully");
return Ok(usize);
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn recv_msg(
&self,
_msg: &mut crate::net::syscall::MsgHdr,
_flags: crate::net::socket::MessageFlag,
) -> Result<usize, SystemError> {
Err(SystemError::ENOSYS)
}
fn send(
&self,
buffer: &[u8],
flags: crate::net::socket::MessageFlag,
) -> Result<usize, SystemError> {
if flags.contains(MessageFlag::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if self.is_peer_shutdown()? {
return Err(SystemError::EPIPE);
}
if !flags.contains(MessageFlag::DONTWAIT) {
loop {
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_write(buffer) {
Ok(usize) => {
log::debug!("send successfully");
return Ok(usize);
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn send_msg(
&self,
_msg: &crate::net::syscall::MsgHdr,
_flags: crate::net::socket::MessageFlag,
) -> Result<usize, SystemError> {
Err(SystemError::ENOSYS)
}
fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
self.send(buffer, crate::net::socket::MessageFlag::empty())
}
fn recv_from(
&self,
buffer: &mut [u8],
flags: MessageFlag,
_address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> {
log::debug!("recvfrom flags {:?}", flags);
if flags.contains(MessageFlag::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if !flags.contains(MessageFlag::DONTWAIT) {
loop {
wq_wait_event_interruptible!(
self.wait_queue,
self.can_recv()? || self.is_peer_shutdown()?,
{}
)?;
// connect锁和flag判断顺序不正确应该先判断在
match &*self.inner.write() {
Inner::Connected(connected) => match connected.recv_slice(buffer) {
Ok(usize) => {
log::debug!("recvs from successfully");
return Ok((usize, connected.peer_endpoint().unwrap().clone()));
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
//Err(SystemError::ENOSYS)
}
fn send_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
SeqpacketSocket::DEFAULT_BUF_SIZE
}
fn recv_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
SeqpacketSocket::DEFAULT_BUF_SIZE
}
fn poll(&self) -> usize {
let mut mask = EP::empty();
let shutdown = self.shutdown.get();
// 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152
// 用关闭读写端表示连接断开
if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() {
mask |= EP::EPOLLHUP;
}
if shutdown.is_recv_shutdown() {
mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM;
}
match &*self.inner.read() {
Inner::Connected(connected) => {
if connected.can_recv() {
mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
}
// if (sk_is_readable(sk))
// mask |= EPOLLIN | EPOLLRDNORM;
// TODO:处理紧急情况 EPOLLPRI
// TODO:处理连接是否关闭 EPOLLHUP
if !shutdown.is_send_shutdown() {
if connected.can_send().unwrap() {
mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
} else {
todo!("poll: buffer space not enough");
}
}
}
Inner::Listen(_) => mask |= EP::EPOLLIN,
Inner::Init(_) => mask |= EP::EPOLLOUT,
}
mask.bits() as usize
}
}

View File

@ -0,0 +1,243 @@
use core::sync::atomic::{AtomicUsize, Ordering};
use log::debug;
use system_error::SystemError;
use crate::libs::mutex::Mutex;
use crate::net::socket::buffer::Buffer;
use crate::net::socket::unix::stream::StreamSocket;
use crate::net::socket::{Endpoint, Inode, ShutdownTemp};
use alloc::collections::VecDeque;
use alloc::{string::String, sync::Arc};
#[derive(Debug)]
pub enum Inner {
Init(Init),
Connected(Connected),
Listener(Listener),
}
#[derive(Debug)]
pub struct Init {
addr: Option<Endpoint>,
}
impl Init {
pub(super) fn new() -> Self {
Self { addr: None }
}
pub(super) fn bind(&mut self, endpoint_to_bind: Endpoint) -> Result<(), SystemError> {
if self.addr.is_some() {
log::error!("the socket is already bound");
return Err(SystemError::EINVAL);
}
match endpoint_to_bind {
Endpoint::Inode(_) => self.addr = Some(endpoint_to_bind),
_ => return Err(SystemError::EINVAL),
}
return Ok(());
}
pub fn bind_path(&mut self, sun_path: String) -> Result<Endpoint, SystemError> {
if self.addr.is_none() {
log::error!("the socket is not bound");
return Err(SystemError::EINVAL);
}
if let Some(Endpoint::Inode((inode, mut path))) = self.addr.take() {
path = sun_path;
let epoint = Endpoint::Inode((inode, path));
self.addr.replace(epoint.clone());
return Ok(epoint);
};
return Err(SystemError::EINVAL);
}
pub(super) fn endpoint(&self) -> Option<&Endpoint> {
self.addr.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct Connected {
addr: Option<Endpoint>,
peer_addr: Option<Endpoint>,
buffer: Arc<Buffer>,
}
impl Connected {
pub fn new_pair(addr: Option<Endpoint>, peer_addr: Option<Endpoint>) -> (Self, Self) {
let this = Connected {
addr: addr.clone(),
peer_addr: peer_addr.clone(),
buffer: Buffer::new(),
};
let peer = Connected {
addr: peer_addr,
peer_addr: addr,
buffer: Buffer::new(),
};
return (this, peer);
}
pub fn endpoint(&self) -> Option<&Endpoint> {
self.addr.as_ref()
}
pub fn set_addr(&mut self, addr: Option<Endpoint>) {
self.addr = addr;
}
pub fn peer_endpoint(&self) -> Option<&Endpoint> {
self.peer_addr.as_ref()
}
pub fn set_peer_addr(&mut self, peer: Option<Endpoint>) {
self.peer_addr = peer;
}
pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> {
//写入对端buffer
let peer_inode = match self.peer_addr.as_ref().unwrap() {
Endpoint::Inode((inode, _)) => inode,
_ => return Err(SystemError::EINVAL),
};
let peer_socket =
Arc::downcast::<StreamSocket>(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?;
let usize = match &*peer_socket.inner.read() {
Inner::Connected(conntected) => {
let usize = conntected.buffer.write_read_buffer(buf)?;
usize
}
_ => {
debug!("no! is not connested!");
return Err(SystemError::EINVAL);
}
};
peer_socket.wait_queue.wakeup(None);
Ok(usize)
}
pub fn can_send(&self) -> Result<bool, SystemError> {
//查看连接体里的buf是否非满
let peer_inode = match self.peer_addr.as_ref().unwrap() {
Endpoint::Inode((inode, _)) => inode,
_ => return Err(SystemError::EINVAL),
};
let peer_socket =
Arc::downcast::<StreamSocket>(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?;
let is_full = match &*peer_socket.inner.read() {
Inner::Connected(connected) => connected.buffer.is_read_buf_full(),
_ => return Err(SystemError::EINVAL),
};
debug!("can send? :{}", !is_full);
Ok(!is_full)
}
pub fn can_recv(&self) -> bool {
//查看连接体里的buf是否非空
return !self.buffer.is_read_buf_empty();
}
pub fn try_send(&self, buf: &[u8]) -> Result<usize, SystemError> {
if self.can_send()? {
return self.send_slice(buf);
} else {
return Err(SystemError::ENOBUFS);
}
}
fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
return self.buffer.read_read_buffer(buf);
}
pub fn try_recv(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
if self.can_recv() {
return self.recv_slice(buf);
} else {
return Err(SystemError::EINVAL);
}
}
pub fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
if how.is_empty() {
return Err(SystemError::EINVAL);
} else if how.is_send_shutdown() {
unimplemented!("unimplemented!");
} else if how.is_recv_shutdown() {
unimplemented!("unimplemented!");
}
Ok(())
}
}
#[derive(Debug)]
pub struct Listener {
addr: Option<Endpoint>,
incoming_connects: Mutex<VecDeque<Arc<Inode>>>,
backlog: AtomicUsize,
}
impl Listener {
pub fn new(addr: Option<Endpoint>, backlog: usize) -> Self {
Self {
addr,
incoming_connects: Mutex::new(VecDeque::new()),
backlog: AtomicUsize::new(backlog),
}
}
pub fn listen(&self, backlog: usize) -> Result<(), SystemError> {
self.backlog.store(backlog, Ordering::Relaxed);
return Ok(());
}
pub fn push_incoming(&self, server_inode: Arc<Inode>) -> Result<(), SystemError> {
let mut incoming_connects = self.incoming_connects.lock();
if incoming_connects.len() >= self.backlog.load(Ordering::Relaxed) {
debug!("unix stream listen socket connected queue is full!");
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
incoming_connects.push_back(server_inode);
return Ok(());
}
pub fn pop_incoming(&self) -> Option<Arc<Inode>> {
let mut incoming_connects = self.incoming_connects.lock();
return incoming_connects.pop_front();
}
pub(super) fn endpoint(&self) -> Option<&Endpoint> {
self.addr.as_ref()
}
pub(super) fn is_acceptable(&self) -> bool {
return self.incoming_connects.lock().len() != 0;
}
pub(super) fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
let mut incoming_connecteds = self.incoming_connects.lock();
debug!("incom len {}", incoming_connecteds.len());
let connected = incoming_connecteds
.pop_front()
.ok_or(SystemError::EAGAIN_OR_EWOULDBLOCK)?;
let socket =
Arc::downcast::<StreamSocket>(connected.inner()).map_err(|_| SystemError::EINVAL)?;
let peer = match &*socket.inner.read() {
Inner::Connected(connected) => connected.peer_endpoint().unwrap().clone(),
_ => return Err(SystemError::ENOTCONN),
};
debug!("server accept!");
return Ok((Inode::new(socket), peer));
}
}

View File

@ -0,0 +1,478 @@
use crate::sched::SchedMode;
use alloc::{
string::String,
sync::{Arc, Weak},
};
use inner::{Connected, Init, Inner, Listener};
use log::debug;
use system_error::SystemError;
use unix::INODE_MAP;
use crate::{
libs::rwlock::RwLock,
net::socket::{self, *},
};
type EP = EPollEventType;
pub mod inner;
#[derive(Debug)]
pub struct StreamSocket {
inner: RwLock<Inner>,
shutdown: Shutdown,
_epitems: EPollItems,
wait_queue: WaitQueue,
self_ref: Weak<Self>,
}
impl StreamSocket {
/// 默认的元数据缓冲区大小
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
/// 默认的缓冲区大小
pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
pub fn new() -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Inner::Init(Init::new())),
shutdown: Shutdown::new(),
_epitems: EPollItems::default(),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
})
}
pub fn new_pairs() -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
let socket0 = StreamSocket::new();
let socket1 = StreamSocket::new();
let inode0 = Inode::new(socket0.clone());
let inode1 = Inode::new(socket1.clone());
let (conn_0, conn_1) = Connected::new_pair(
Some(Endpoint::Inode((inode0.clone(), String::from("")))),
Some(Endpoint::Inode((inode1.clone(), String::from("")))),
);
*socket0.inner.write() = Inner::Connected(conn_0);
*socket1.inner.write() = Inner::Connected(conn_1);
return Ok((inode0, inode1));
}
pub fn new_connected(connected: Connected) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Inner::Connected(connected)),
shutdown: Shutdown::new(),
_epitems: EPollItems::default(),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
})
}
pub fn new_inode() -> Result<Arc<Inode>, SystemError> {
let socket = StreamSocket::new();
let inode = Inode::new(socket.clone());
let _ = match &mut *socket.inner.write() {
Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))),
_ => return Err(SystemError::EINVAL),
};
return Ok(inode);
}
fn is_acceptable(&self) -> bool {
match &*self.inner.read() {
Inner::Listener(listener) => listener.is_acceptable(),
_ => {
panic!("the socket is not listening");
}
}
}
pub fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
match &*self.inner.read() {
Inner::Listener(listener) => listener.try_accept() as _,
_ => {
log::error!("the socket is not listening");
return Err(SystemError::EINVAL);
}
}
}
fn is_peer_shutdown(&self) -> Result<bool, SystemError> {
let peer_shutdown = match self.get_peer_name()? {
Endpoint::Inode((inode, _)) => Arc::downcast::<StreamSocket>(inode.inner())
.map_err(|_| SystemError::EINVAL)?
.shutdown
.get()
.is_both_shutdown(),
_ => return Err(SystemError::EINVAL),
};
Ok(peer_shutdown)
}
fn can_recv(&self) -> Result<bool, SystemError> {
let can = match &*self.inner.read() {
Inner::Connected(connected) => connected.can_recv(),
_ => return Err(SystemError::ENOTCONN),
};
Ok(can)
}
}
impl Socket for StreamSocket {
fn connect(&self, server_endpoint: Endpoint) -> Result<(), SystemError> {
//获取客户端地址
let client_endpoint = match &mut *self.inner.write() {
Inner::Init(init) => match init.endpoint().cloned() {
Some(endpoint) => {
debug!("bind when connected");
Some(endpoint)
}
None => {
debug!("not bind when connected");
let inode = Inode::new(self.self_ref.upgrade().unwrap().clone());
let epoint = Endpoint::Inode((inode.clone(), String::from("")));
let _ = init.bind(epoint.clone());
Some(epoint)
}
},
Inner::Connected(_) => return Err(SystemError::EISCONN),
Inner::Listener(_) => return Err(SystemError::EINVAL),
};
//获取服务端地址
// let peer_inode = match server_endpoint.clone() {
// Endpoint::Inode(socket) => socket,
// _ => return Err(SystemError::EINVAL),
// };
//找到对端socket
let (peer_inode, sun_path) = match server_endpoint {
Endpoint::Inode((inode, path)) => (inode, path),
Endpoint::Unixpath((inode_id, path)) => {
let inode_guard = INODE_MAP.read_irqsave();
let inode = inode_guard.get(&inode_id).unwrap();
match inode {
Endpoint::Inode((inode, _)) => (inode.clone(), path),
_ => return Err(SystemError::EINVAL),
}
}
_ => return Err(SystemError::EINVAL),
};
let remote_socket: Arc<StreamSocket> =
Arc::downcast::<StreamSocket>(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?;
//创建新的对端socket
let new_server_socket = StreamSocket::new();
let new_server_inode = Inode::new(new_server_socket.clone());
let new_server_endpoint = Some(Endpoint::Inode((new_server_inode.clone(), sun_path)));
//获取connect pair
let (client_conn, server_conn) =
Connected::new_pair(client_endpoint, new_server_endpoint.clone());
*new_server_socket.inner.write() = Inner::Connected(server_conn);
//查看remote_socket是否处于监听状态
let remote_listener = remote_socket.inner.write();
match &*remote_listener {
Inner::Listener(listener) => {
//往服务端socket的连接队列中添加connected
listener.push_incoming(new_server_inode)?;
*self.inner.write() = Inner::Connected(client_conn);
remote_socket.wait_queue.wakeup(None);
}
_ => return Err(SystemError::EINVAL),
}
return Ok(());
}
fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
match endpoint {
Endpoint::Unixpath((inodeid, path)) => {
let inode = match &mut *self.inner.write() {
Inner::Init(init) => init.bind_path(path)?,
_ => {
log::error!("socket has listen or connected");
return Err(SystemError::EINVAL);
}
};
INODE_MAP.write_irqsave().insert(inodeid, inode);
Ok(())
}
_ => return Err(SystemError::EINVAL),
}
}
fn shutdown(&self, _stype: ShutdownTemp) -> Result<(), SystemError> {
todo!();
}
fn listen(&self, backlog: usize) -> Result<(), SystemError> {
let mut inner = self.inner.write();
let epoint = match &*inner {
Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(),
Inner::Connected(_) => {
return Err(SystemError::EINVAL);
}
Inner::Listener(listener) => {
return listener.listen(backlog);
}
};
let listener = Listener::new(Some(epoint), backlog);
*inner = Inner::Listener(listener);
return Ok(());
}
fn accept(&self) -> Result<(Arc<socket::Inode>, Endpoint), SystemError> {
debug!("stream server begin accept");
//目前只实现了阻塞式实现
loop {
wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
match self.try_accept().map(|(stream_socket, remote_endpoint)| {
(stream_socket, remote_endpoint)
}) {
Ok((socket, endpoint)) => {
debug!("server accept!:{:?}", endpoint);
return Ok((socket, endpoint));
}
Err(_) => continue,
}
}
}
fn set_option(
&self,
_level: OptionsLevel,
_optname: usize,
_optval: &[u8],
) -> Result<(), SystemError> {
log::warn!("setsockopt is not implemented");
Ok(())
}
fn wait_queue(&self) -> &WaitQueue {
return &self.wait_queue;
}
fn poll(&self) -> usize {
let mut mask = EP::empty();
let shutdown = self.shutdown.get();
// 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152
// 用关闭读写端表示连接断开
if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() {
mask |= EP::EPOLLHUP;
}
if shutdown.is_recv_shutdown() {
mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM;
}
match &*self.inner.read() {
Inner::Connected(connected) => {
if connected.can_recv() {
mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
}
// if (sk_is_readable(sk))
// mask |= EPOLLIN | EPOLLRDNORM;
// TODO:处理紧急情况 EPOLLPRI
// TODO:处理连接是否关闭 EPOLLHUP
if !shutdown.is_send_shutdown() {
if connected.can_send().unwrap() {
mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
} else {
todo!("poll: buffer space not enough");
}
}
}
Inner::Listener(_) => mask |= EP::EPOLLIN,
Inner::Init(_) => mask |= EP::EPOLLOUT,
}
mask.bits() as usize
}
fn close(&self) -> Result<(), SystemError> {
self.shutdown.recv_shutdown();
self.shutdown.send_shutdown();
Ok(())
}
fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
//获取对端地址
let endpoint = match &*self.inner.read() {
Inner::Connected(connected) => connected.peer_endpoint().cloned(),
_ => return Err(SystemError::ENOTCONN),
};
if let Some(endpoint) = endpoint {
return Ok(endpoint);
} else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
}
fn get_name(&self) -> Result<Endpoint, SystemError> {
//获取本端地址
let endpoint = match &*self.inner.read() {
Inner::Init(init) => init.endpoint().cloned(),
Inner::Connected(connected) => connected.endpoint().cloned(),
Inner::Listener(listener) => listener.endpoint().cloned(),
};
if let Some(endpoint) = endpoint {
return Ok(endpoint);
} else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
}
fn get_option(
&self,
_level: OptionsLevel,
_name: usize,
_value: &mut [u8],
) -> Result<usize, SystemError> {
log::warn!("getsockopt is not implemented");
Ok(0)
}
fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
self.recv(buffer, socket::MessageFlag::empty())
}
fn recv(&self, buffer: &mut [u8], flags: socket::MessageFlag) -> Result<usize, SystemError> {
if !flags.contains(MessageFlag::DONTWAIT) {
loop {
log::debug!("socket try recv");
wq_wait_event_interruptible!(
self.wait_queue,
self.can_recv()? || self.is_peer_shutdown()?,
{}
)?;
// connect锁和flag判断顺序不正确应该先判断在
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_recv(buffer) {
Ok(usize) => {
log::debug!("recv successfully");
return Ok(usize);
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn recv_from(
&self,
buffer: &mut [u8],
flags: socket::MessageFlag,
_address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> {
if flags.contains(MessageFlag::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if !flags.contains(MessageFlag::DONTWAIT) {
loop {
log::debug!("socket try recv from");
wq_wait_event_interruptible!(
self.wait_queue,
self.can_recv()? || self.is_peer_shutdown()?,
{}
)?;
// connect锁和flag判断顺序不正确应该先判断在
log::debug!("try recv");
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_recv(buffer) {
Ok(usize) => {
log::debug!("recvs from successfully");
return Ok((usize, connected.peer_endpoint().unwrap().clone()));
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn recv_msg(
&self,
_msg: &mut crate::net::syscall::MsgHdr,
_flags: socket::MessageFlag,
) -> Result<usize, SystemError> {
Err(SystemError::ENOSYS)
}
fn send(&self, buffer: &[u8], flags: socket::MessageFlag) -> Result<usize, SystemError> {
if self.is_peer_shutdown()? {
return Err(SystemError::EPIPE);
}
if !flags.contains(MessageFlag::DONTWAIT) {
loop {
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_send(buffer) {
Ok(usize) => {
log::debug!("send successfully");
return Ok(usize);
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn send_msg(
&self,
_msg: &crate::net::syscall::MsgHdr,
_flags: socket::MessageFlag,
) -> Result<usize, SystemError> {
todo!()
}
fn send_to(
&self,
_buffer: &[u8],
_flags: socket::MessageFlag,
_address: Endpoint,
) -> Result<usize, SystemError> {
Err(SystemError::ENOSYS)
}
fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
self.send(buffer, socket::MessageFlag::empty())
}
fn send_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
StreamSocket::DEFAULT_BUF_SIZE
}
fn recv_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
StreamSocket::DEFAULT_BUF_SIZE
}
}

View File

@ -0,0 +1,28 @@
use crate::net::socket;
use alloc::sync::Arc;
use socket::Family;
use system_error::SystemError;
pub fn create_socket(
family: socket::AddressFamily,
socket_type: socket::Type,
protocol: u32,
is_nonblock: bool,
is_close_on_exec: bool,
) -> Result<Arc<socket::Inode>, SystemError> {
type AF = socket::AddressFamily;
let inode = match family {
AF::INet => socket::inet::Inet::socket(socket_type, protocol)?,
AF::INet6 => {
todo!("AF_INET6 unimplemented");
}
AF::Unix => socket::unix::Unix::socket(socket_type, protocol)?,
AF::Netlink => socket::netlink::Netlink::socket(socket_type, protocol)?,
_ => {
todo!("unsupport address family");
}
};
// inode.set_nonblock(is_nonblock);
// inode.set_close_on_exec(is_close_on_exec);
return Ok(inode);
}

View File

@ -1,9 +1,11 @@
use core::{cmp::min, ffi::CStr}; use core::{cmp::min, ffi::CStr};
use acpi::address;
use alloc::{boxed::Box, sync::Arc}; use alloc::{boxed::Box, sync::Arc};
use log::debug;
use num_traits::{FromPrimitive, ToPrimitive}; use num_traits::{FromPrimitive, ToPrimitive};
use smoltcp::wire; use smoltcp::wire;
use system_error::SystemError; use system_error::SystemError::{self, *};
use crate::{ use crate::{
filesystem::vfs::{ filesystem::vfs::{
@ -13,15 +15,15 @@ use crate::{
}, },
libs::spinlock::SpinLockGuard, libs::spinlock::SpinLockGuard,
mm::{verify_area, VirtAddr}, mm::{verify_area, VirtAddr},
net::socket::{AddressFamily, SOL_SOCKET}, // net::socket::{netlink::af_netlink::NetlinkSock, AddressFamily},
process::ProcessManager, process::ProcessManager,
syscall::Syscall, syscall::Syscall,
}; };
use super::{ use super::socket::{self, Endpoint, Socket};
socket::{new_socket, PosixSocketType, Socket, SocketInode}, use super::socket::{netlink::endpoint, unix::Unix, AddressFamily as AF};
Endpoint, Protocol, ShutdownType,
}; pub use super::syscall_util::*;
/// Flags for socket, socketpair, accept4 /// Flags for socket, socketpair, accept4
const SOCK_CLOEXEC: FileMode = FileMode::O_CLOEXEC; const SOCK_CLOEXEC: FileMode = FileMode::O_CLOEXEC;
@ -38,18 +40,34 @@ impl Syscall {
socket_type: usize, socket_type: usize,
protocol: usize, protocol: usize,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
let address_family = AddressFamily::try_from(address_family as u16)?; // 打印收到的参数
let socket_type = PosixSocketType::try_from((socket_type & 0xf) as u8)?; log::debug!(
let protocol = Protocol::from(protocol as u8); "socket: address_family={:?}, socket_type={:?}, protocol={:?}",
address_family,
socket_type,
protocol
);
let address_family = socket::AddressFamily::try_from(address_family as u16)?;
let type_arg = SysArgSocketType::from_bits_truncate(socket_type as u32);
let is_nonblock = type_arg.is_nonblock();
let is_close_on_exec = type_arg.is_cloexec();
let stype = socket::Type::try_from(type_arg)?;
log::debug!("type_arg {:?} stype {:?}", type_arg, stype);
let socket = new_socket(address_family, socket_type, protocol)?; let inode = socket::create_socket(
address_family,
stype,
protocol as u32,
is_nonblock,
is_close_on_exec,
)?;
let socketinode: Arc<SocketInode> = SocketInode::new(socket); let file = File::new(inode, FileMode::O_RDWR)?;
let f = File::new(socketinode, FileMode::O_RDWR)?;
// 把socket添加到当前进程的文件描述符表中 // 把socket添加到当前进程的文件描述符表中
let binding = ProcessManager::current_pcb().fd_table(); let binding = ProcessManager::current_pcb().fd_table();
let mut fd_table_guard = binding.write(); let mut fd_table_guard = binding.write();
let fd = fd_table_guard.alloc_fd(f, None).map(|x| x as usize); let fd: Result<usize, SystemError> =
fd_table_guard.alloc_fd(file, None).map(|x| x as usize);
drop(fd_table_guard); drop(fd_table_guard);
return fd; return fd;
} }
@ -67,27 +85,43 @@ impl Syscall {
protocol: usize, protocol: usize,
fds: &mut [i32], fds: &mut [i32],
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
let address_family = AddressFamily::try_from(address_family as u16)?; let address_family = AF::try_from(address_family as u16)?;
let socket_type = PosixSocketType::try_from((socket_type & 0xf) as u8)?; let socket_type = SysArgSocketType::from_bits_truncate(socket_type as u32);
let protocol = Protocol::from(protocol as u8); let stype = socket::Type::try_from(socket_type)?;
let binding = ProcessManager::current_pcb().fd_table(); let binding = ProcessManager::current_pcb().fd_table();
let mut fd_table_guard = binding.write(); let mut fd_table_guard = binding.write();
// 创建一对socket // check address family, only support AF_UNIX
let inode0 = SocketInode::new(new_socket(address_family, socket_type, protocol)?); if address_family != AF::Unix {
let inode1 = SocketInode::new(new_socket(address_family, socket_type, protocol)?); return Err(SystemError::EAFNOSUPPORT);
// 进行pair
unsafe {
inode0
.inner_no_preempt()
.connect(Endpoint::Inode(Some(inode1.clone())))?;
inode1
.inner_no_preempt()
.connect(Endpoint::Inode(Some(inode0.clone())))?;
} }
// 创建一对socket
// let inode0 = socket::create_socket(
// address_family,
// stype,
// protocol as u32,
// socket_type.is_nonblock(),
// socket_type.is_cloexec(),
// )?;
// let inode1 = socket::create_socket(
// address_family,
// stype,
// protocol as u32,
// socket_type.is_nonblock(),
// socket_type.is_cloexec(),
// )?;
// // 进行pair
// unsafe {
// inode0.connect(socket::Endpoint::Inode(inode1.clone()))?;
// inode1.connect(socket::Endpoint::Inode(inode0.clone()))?;
// }
// 创建一对新的unix socket pair
let (inode0, inode1) = Unix::new_pairs(stype)?;
fds[0] = fd_table_guard.alloc_fd(File::new(inode0, FileMode::O_RDWR)?, None)?; fds[0] = fd_table_guard.alloc_fd(File::new(inode0, FileMode::O_RDWR)?, None)?;
fds[1] = fd_table_guard.alloc_fd(File::new(inode1, FileMode::O_RDWR)?, None)?; fds[1] = fd_table_guard.alloc_fd(File::new(inode1, FileMode::O_RDWR)?, None)?;
@ -108,12 +142,12 @@ impl Syscall {
optname: usize, optname: usize,
optval: &[u8], optval: &[u8],
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
let socket_inode: Arc<SocketInode> = ProcessManager::current_pcb() let sol = socket::OptionsLevel::try_from(level as u32)?;
let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
// 获取内层的socket真正的数据 debug!("setsockopt: level={:?}", level);
let socket: SpinLockGuard<Box<dyn Socket>> = socket_inode.inner(); return socket.set_option(sol, optname, optval).map(|_| 0);
return socket.setsockopt(level, optname, optval).map(|_| 0);
} }
/// @brief sys_getsockopt系统调用的实际执行函数 /// @brief sys_getsockopt系统调用的实际执行函数
@ -134,33 +168,35 @@ impl Syscall {
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
// 获取socket // 获取socket
let optval = optval as *mut u32; let optval = optval as *mut u32;
let binding: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(EBADF)?;
let socket = binding.inner();
if level as u8 == SOL_SOCKET { let level = socket::OptionsLevel::try_from(level as u32)?;
let optname = PosixSocketOption::try_from(optname as i32)
.map_err(|_| SystemError::ENOPROTOOPT)?; use socket::Options as SO;
use socket::OptionsLevel as SOL;
if matches!(level, SOL::SOCKET) {
let optname = SO::try_from(optname as u32).map_err(|_| ENOPROTOOPT)?;
match optname { match optname {
PosixSocketOption::SO_SNDBUF => { SO::SNDBUF => {
// 返回发送缓冲区大小 // 返回发送缓冲区大小
unsafe { unsafe {
*optval = socket.metadata().tx_buf_size as u32; *optval = socket.send_buffer_size() as u32;
*optlen = core::mem::size_of::<u32>() as u32; *optlen = core::mem::size_of::<u32>() as u32;
} }
return Ok(0); return Ok(0);
} }
PosixSocketOption::SO_RCVBUF => { SO::RCVBUF => {
// 返回默认的接收缓冲区大小 // 返回默认的接收缓冲区大小
unsafe { unsafe {
*optval = socket.metadata().rx_buf_size as u32; *optval = socket.recv_buffer_size() as u32;
*optlen = core::mem::size_of::<u32>() as u32; *optlen = core::mem::size_of::<u32>() as u32;
} }
return Ok(0); return Ok(0);
} }
_ => { _ => {
return Err(SystemError::ENOPROTOOPT); return Err(ENOPROTOOPT);
} }
} }
} }
@ -172,19 +208,17 @@ impl Syscall {
// to be interpreted by the TCP protocol, level should be set to the // to be interpreted by the TCP protocol, level should be set to the
// protocol number of TCP. // protocol number of TCP.
let posix_protocol = if matches!(level, SOL::TCP) {
PosixIpProtocol::try_from(level as u16).map_err(|_| SystemError::ENOPROTOOPT)?; let optname =
if posix_protocol == PosixIpProtocol::TCP { PosixTcpSocketOptions::try_from(optname as i32).map_err(|_| ENOPROTOOPT)?;
let optname = PosixTcpSocketOptions::try_from(optname as i32)
.map_err(|_| SystemError::ENOPROTOOPT)?;
match optname { match optname {
PosixTcpSocketOptions::Congestion => return Ok(0), PosixTcpSocketOptions::Congestion => return Ok(0),
_ => { _ => {
return Err(SystemError::ENOPROTOOPT); return Err(ENOPROTOOPT);
} }
} }
} }
return Err(SystemError::ENOPROTOOPT); return Err(ENOPROTOOPT);
} }
/// @brief sys_connect系统调用的实际执行函数 /// @brief sys_connect系统调用的实际执行函数
@ -194,12 +228,11 @@ impl Syscall {
/// @param addrlen 地址长度 /// @param addrlen 地址长度
/// ///
/// @return 成功返回0失败返回错误码 /// @return 成功返回0失败返回错误码
pub fn connect(fd: usize, addr: *const SockAddr, addrlen: usize) -> Result<usize, SystemError> { pub fn connect(fd: usize, addr: *const SockAddr, addrlen: u32) -> Result<usize, SystemError> {
let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?;
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let mut socket = unsafe { socket.inner_no_preempt() };
socket.connect(endpoint)?; socket.connect(endpoint)?;
Ok(0) Ok(0)
} }
@ -211,12 +244,19 @@ impl Syscall {
/// @param addrlen 地址长度 /// @param addrlen 地址长度
/// ///
/// @return 成功返回0失败返回错误码 /// @return 成功返回0失败返回错误码
pub fn bind(fd: usize, addr: *const SockAddr, addrlen: usize) -> Result<usize, SystemError> { pub fn bind(fd: usize, addr: *const SockAddr, addrlen: u32) -> Result<usize, SystemError> {
// 打印收到的参数
// log::debug!(
// "bind: fd={:?}, family={:?}, addrlen={:?}",
// fd,
// (unsafe { addr.as_ref().unwrap().family }),
// addrlen
// );
let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?;
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let mut socket = unsafe { socket.inner_no_preempt() }; log::debug!("bind: socket={:?}", socket);
socket.bind(endpoint)?; socket.bind(endpoint)?;
Ok(0) Ok(0)
} }
@ -233,9 +273,9 @@ impl Syscall {
pub fn sendto( pub fn sendto(
fd: usize, fd: usize,
buf: &[u8], buf: &[u8],
_flags: u32, flags: u32,
addr: *const SockAddr, addr: *const SockAddr,
addrlen: usize, addrlen: u32,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
let endpoint = if addr.is_null() { let endpoint = if addr.is_null() {
None None
@ -243,11 +283,17 @@ impl Syscall {
Some(SockAddr::to_endpoint(addr, addrlen)?) Some(SockAddr::to_endpoint(addr, addrlen)?)
}; };
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let flags = socket::MessageFlag::from_bits_truncate(flags);
let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let socket = unsafe { socket.inner_no_preempt() };
return socket.write(buf, endpoint); if let Some(endpoint) = endpoint {
return socket.send_to(buf, endpoint, flags);
} else {
return socket.send(buf, flags);
}
} }
/// @brief sys_recvfrom系统调用的实际执行函数 /// @brief sys_recvfrom系统调用的实际执行函数
@ -262,28 +308,37 @@ impl Syscall {
pub fn recvfrom( pub fn recvfrom(
fd: usize, fd: usize,
buf: &mut [u8], buf: &mut [u8],
_flags: u32, flags: u32,
addr: *mut SockAddr, addr: *mut SockAddr,
addrlen: *mut u32, addr_len: *mut u32,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let socket = unsafe { socket.inner_no_preempt() }; let flags = socket::MessageFlag::from_bits_truncate(flags);
let (n, endpoint) = socket.read(buf); if addr.is_null() {
drop(socket); let (n, _) = socket.recv_from(buf, flags, None)?;
return Ok(n);
}
let n: usize = n?; // address is not null
let address = unsafe { addr.as_ref() }.ok_or(EINVAL)?;
// 如果有地址信息,将地址信息写入用户空间 if unsafe { address.is_empty() } {
if !addr.is_null() { let (recv_len, endpoint) = socket.recv_from(buf, flags, None)?;
let sockaddr_in = SockAddr::from(endpoint); let sockaddr_in = SockAddr::from(endpoint);
unsafe { unsafe {
sockaddr_in.write_to_user(addr, addrlen)?; sockaddr_in.write_to_user(addr, addr_len)?;
} }
} return Ok(recv_len);
return Ok(n); } else {
// 从socket中读取数据
let addr_len = *unsafe { addr_len.as_ref() }.ok_or(EINVAL)?;
let address = SockAddr::to_endpoint(addr, addr_len)?;
let (recv_len, _) = socket.recv_from(buf, flags, Some(address))?;
return Ok(recv_len);
};
} }
/// @brief sys_recvmsg系统调用的实际执行函数 /// @brief sys_recvmsg系统调用的实际执行函数
@ -293,30 +348,30 @@ impl Syscall {
/// @param flags 标志,暂时未使用 /// @param flags 标志,暂时未使用
/// ///
/// @return 成功返回接收的字节数,失败返回错误码 /// @return 成功返回接收的字节数,失败返回错误码
pub fn recvmsg(fd: usize, msg: &mut MsgHdr, _flags: u32) -> Result<usize, SystemError> { pub fn recvmsg(fd: usize, msg: &mut MsgHdr, flags: u32) -> Result<usize, SystemError> {
// 检查每个缓冲区地址是否合法生成iovecs todo!("recvmsg, fd={}, msg={:?}, flags={}", fd, msg, flags);
let mut iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; // // 检查每个缓冲区地址是否合法生成iovecs
// let mut iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? };
let socket: Arc<SocketInode> = ProcessManager::current_pcb() // let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) // .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; // .ok_or(SystemError::EBADF)?;
let socket = unsafe { socket.inner_no_preempt() };
let mut buf = iovs.new_buf(true); // let flags = socket::MessageFlag::from_bits_truncate(flags as u32);
// 从socket中读取数据
let (n, endpoint) = socket.read(&mut buf);
drop(socket);
let n: usize = n?; // let mut buf = iovs.new_buf(true);
// // 从socket中读取数据
// let recv_size = socket.recv_msg(&mut buf, flags)?;
// drop(socket);
// 将数据写入用户空间的iovecs // // 将数据写入用户空间的iovecs
iovs.scatter(&buf[..n]); // iovs.scatter(&buf[..recv_size]);
let sockaddr_in = SockAddr::from(endpoint); // // let sockaddr_in = SockAddr::from(endpoint);
unsafe { // // unsafe {
sockaddr_in.write_to_user(msg.msg_name, &mut msg.msg_namelen)?; // // sockaddr_in.write_to_user(msg.msg_name, &mut msg.msg_namelen)?;
} // // }
return Ok(n); // return Ok(recv_size);
} }
/// @brief sys_listen系统调用的实际执行函数 /// @brief sys_listen系统调用的实际执行函数
@ -326,12 +381,10 @@ impl Syscall {
/// ///
/// @return 成功返回0失败返回错误码 /// @return 成功返回0失败返回错误码
pub fn listen(fd: usize, backlog: usize) -> Result<usize, SystemError> { pub fn listen(fd: usize, backlog: usize) -> Result<usize, SystemError> {
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let mut socket = unsafe { socket.inner_no_preempt() }; socket.listen(backlog).map(|_| 0)
socket.listen(backlog)?;
return Ok(0);
} }
/// @brief sys_shutdown系统调用的实际执行函数 /// @brief sys_shutdown系统调用的实际执行函数
@ -341,11 +394,10 @@ impl Syscall {
/// ///
/// @return 成功返回0失败返回错误码 /// @return 成功返回0失败返回错误码
pub fn shutdown(fd: usize, how: usize) -> Result<usize, SystemError> { pub fn shutdown(fd: usize, how: usize) -> Result<usize, SystemError> {
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let mut socket = unsafe { socket.inner_no_preempt() }; socket.shutdown(socket::ShutdownTemp::from_how(how))?;
socket.shutdown(ShutdownType::from_bits_truncate((how + 1) as u8))?;
return Ok(0); return Ok(0);
} }
@ -401,18 +453,16 @@ impl Syscall {
addrlen: *mut u32, addrlen: *mut u32,
flags: u32, flags: u32,
) -> Result<usize, SystemError> { ) -> Result<usize, SystemError> {
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
// debug!("accept: socket={:?}", socket);
let mut socket = unsafe { socket.inner_no_preempt() };
// 从socket中接收连接 // 从socket中接收连接
let (new_socket, remote_endpoint) = socket.accept()?; let (new_socket, remote_endpoint) = socket.accept()?;
drop(socket); drop(socket);
// debug!("accept: new_socket={:?}", new_socket); // debug!("accept: new_socket={:?}", new_socket);
// Insert the new socket into the file descriptor vector // Insert the new socket into the file descriptor vector
let new_socket: Arc<SocketInode> = SocketInode::new(new_socket);
let mut file_mode = FileMode::O_RDWR; let mut file_mode = FileMode::O_RDWR;
if flags & SOCK_NONBLOCK.bits() != 0 { if flags & SOCK_NONBLOCK.bits() != 0 {
@ -456,12 +506,10 @@ impl Syscall {
if addr.is_null() { if addr.is_null() {
return Err(SystemError::EINVAL); return Err(SystemError::EINVAL);
} }
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let endpoint = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?
let socket = socket.inner(); .get_name()?;
let endpoint: Endpoint = socket.endpoint().ok_or(SystemError::EINVAL)?;
drop(socket);
let sockaddr_in = SockAddr::from(endpoint); let sockaddr_in = SockAddr::from(endpoint);
unsafe { unsafe {
@ -486,11 +534,11 @@ impl Syscall {
return Err(SystemError::EINVAL); return Err(SystemError::EINVAL);
} }
let socket: Arc<SocketInode> = ProcessManager::current_pcb() let socket: Arc<socket::Inode> = ProcessManager::current_pcb()
.get_socket(fd as i32) .get_socket(fd as i32)
.ok_or(SystemError::EBADF)?; .ok_or(SystemError::EBADF)?;
let socket = socket.inner();
let endpoint: Endpoint = socket.peer_endpoint().ok_or(SystemError::EINVAL)?; let endpoint: Endpoint = socket.get_peer_name()?;
drop(socket); drop(socket);
let sockaddr_in = SockAddr::from(endpoint); let sockaddr_in = SockAddr::from(endpoint);
@ -501,449 +549,6 @@ impl Syscall {
} }
} }
// 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrIn {
pub sin_family: u16,
pub sin_port: u16,
pub sin_addr: u32,
pub sin_zero: [u8; 8],
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrUn {
pub sun_family: u16,
pub sun_path: [u8; 108],
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrLl {
pub sll_family: u16,
pub sll_protocol: u16,
pub sll_ifindex: u32,
pub sll_hatype: u16,
pub sll_pkttype: u8,
pub sll_halen: u8,
pub sll_addr: [u8; 8],
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrNl {
nl_family: u16,
nl_pad: u16,
nl_pid: u32,
nl_groups: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrPlaceholder {
pub family: u16,
pub data: [u8; 14],
}
#[repr(C)]
#[derive(Clone, Copy)]
pub union SockAddr {
pub family: u16,
pub addr_in: SockAddrIn,
pub addr_un: SockAddrUn,
pub addr_ll: SockAddrLl,
pub addr_nl: SockAddrNl,
pub addr_ph: SockAddrPlaceholder,
}
impl SockAddr {
/// @brief 把用户传入的SockAddr转换为Endpoint结构体
pub fn to_endpoint(addr: *const SockAddr, len: usize) -> Result<Endpoint, SystemError> {
verify_area(
VirtAddr::new(addr as usize),
core::mem::size_of::<SockAddr>(),
)
.map_err(|_| SystemError::EFAULT)?;
let addr = unsafe { addr.as_ref() }.ok_or(SystemError::EFAULT)?;
unsafe {
match AddressFamily::try_from(addr.family)? {
AddressFamily::INet => {
if len < addr.len()? {
return Err(SystemError::EINVAL);
}
let addr_in: SockAddrIn = addr.addr_in;
let ip: wire::IpAddress = wire::IpAddress::from(wire::Ipv4Address::from_bytes(
&u32::from_be(addr_in.sin_addr).to_be_bytes()[..],
));
let port = u16::from_be(addr_in.sin_port);
return Ok(Endpoint::Ip(Some(wire::IpEndpoint::new(ip, port))));
}
AddressFamily::Unix => {
let addr_un: SockAddrUn = addr.addr_un;
let path = CStr::from_bytes_until_nul(&addr_un.sun_path)
.map_err(|_| SystemError::EINVAL)?
.to_str()
.map_err(|_| SystemError::EINVAL)?;
let fd = Syscall::open(path.as_ptr(), FileMode::O_RDWR.bits(), 0o755, true)?;
let binding = ProcessManager::current_pcb().fd_table();
let fd_table_guard = binding.read();
let file = fd_table_guard.get_file_by_fd(fd as i32).unwrap();
if file.file_type() != FileType::Socket {
return Err(SystemError::ENOTSOCK);
}
let inode = file.inode();
let socketinode = inode.as_any_ref().downcast_ref::<Arc<SocketInode>>();
return Ok(Endpoint::Inode(socketinode.cloned()));
}
AddressFamily::Packet => {
// TODO: support packet socket
return Err(SystemError::EINVAL);
}
AddressFamily::Netlink => {
// TODO: support netlink socket
return Err(SystemError::EINVAL);
}
_ => {
return Err(SystemError::EINVAL);
}
}
}
}
/// @brief 获取地址长度
pub fn len(&self) -> Result<usize, SystemError> {
let ret = match AddressFamily::try_from(unsafe { self.family })? {
AddressFamily::INet => Ok(core::mem::size_of::<SockAddrIn>()),
AddressFamily::Packet => Ok(core::mem::size_of::<SockAddrLl>()),
AddressFamily::Netlink => Ok(core::mem::size_of::<SockAddrNl>()),
AddressFamily::Unix => Err(SystemError::EINVAL),
_ => Err(SystemError::EINVAL),
};
return ret;
}
/// @brief 把SockAddr的数据写入用户空间
///
/// @param addr 用户空间的SockAddr的地址
/// @param len 要写入的长度
///
/// @return 成功返回写入的长度,失败返回错误码
pub unsafe fn write_to_user(
&self,
addr: *mut SockAddr,
addr_len: *mut u32,
) -> Result<usize, SystemError> {
// 当用户传入的地址或者长度为空时直接返回0
if addr.is_null() || addr_len.is_null() {
return Ok(0);
}
// 检查用户传入的地址是否合法
verify_area(
VirtAddr::new(addr as usize),
core::mem::size_of::<SockAddr>(),
)
.map_err(|_| SystemError::EFAULT)?;
verify_area(
VirtAddr::new(addr_len as usize),
core::mem::size_of::<u32>(),
)
.map_err(|_| SystemError::EFAULT)?;
let to_write = min(self.len()?, *addr_len as usize);
if to_write > 0 {
let buf = core::slice::from_raw_parts_mut(addr as *mut u8, to_write);
buf.copy_from_slice(core::slice::from_raw_parts(
self as *const SockAddr as *const u8,
to_write,
));
}
*addr_len = self.len()? as u32;
return Ok(to_write);
}
}
impl From<Endpoint> for SockAddr {
fn from(value: Endpoint) -> Self {
match value {
Endpoint::Ip(ip_endpoint) => {
// 未指定地址
if ip_endpoint.is_none() {
return SockAddr {
addr_ph: SockAddrPlaceholder {
family: AddressFamily::Unspecified as u16,
data: [0; 14],
},
};
}
// 指定了地址
let ip_endpoint = ip_endpoint.unwrap();
match ip_endpoint.addr {
wire::IpAddress::Ipv4(ipv4_addr) => {
let addr_in = SockAddrIn {
sin_family: AddressFamily::INet as u16,
sin_port: ip_endpoint.port.to_be(),
sin_addr: u32::from_be_bytes(ipv4_addr.0).to_be(),
sin_zero: [0; 8],
};
return SockAddr { addr_in };
}
_ => {
unimplemented!("not support ipv6");
}
}
}
Endpoint::LinkLayer(link_endpoint) => {
let addr_ll = SockAddrLl {
sll_family: AddressFamily::Packet as u16,
sll_protocol: 0,
sll_ifindex: link_endpoint.interface as u32,
sll_hatype: 0,
sll_pkttype: 0,
sll_halen: 0,
sll_addr: [0; 8],
};
return SockAddr { addr_ll };
}
_ => {
// todo: support other endpoint, like Netlink...
unimplemented!("not support {value:?}");
}
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct MsgHdr {
/// 指向一个SockAddr结构体的指针
pub msg_name: *mut SockAddr,
/// SockAddr结构体的大小
pub msg_namelen: u32,
/// scatter/gather array
pub msg_iov: *mut IoVec,
/// elements in msg_iov
pub msg_iovlen: usize,
/// 辅助数据
pub msg_control: *mut u8,
/// 辅助数据长度
pub msg_controllen: usize,
/// 接收到的消息的标志
pub msg_flags: u32,
}
#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)]
pub enum PosixIpProtocol {
/// Dummy protocol for TCP.
IP = 0,
/// Internet Control Message Protocol.
ICMP = 1,
/// Internet Group Management Protocol.
IGMP = 2,
/// IPIP tunnels (older KA9Q tunnels use 94).
IPIP = 4,
/// Transmission Control Protocol.
TCP = 6,
/// Exterior Gateway Protocol.
EGP = 8,
/// PUP protocol.
PUP = 12,
/// User Datagram Protocol.
UDP = 17,
/// XNS IDP protocol.
IDP = 22,
/// SO Transport Protocol Class 4.
TP = 29,
/// Datagram Congestion Control Protocol.
DCCP = 33,
/// IPv6-in-IPv4 tunnelling.
IPv6 = 41,
/// RSVP Protocol.
RSVP = 46,
/// Generic Routing Encapsulation. (Cisco GRE) (rfc 1701, 1702)
GRE = 47,
/// Encapsulation Security Payload protocol
ESP = 50,
/// Authentication Header protocol
AH = 51,
/// Multicast Transport Protocol.
MTP = 92,
/// IP option pseudo header for BEET
BEETPH = 94,
/// Encapsulation Header.
ENCAP = 98,
/// Protocol Independent Multicast.
PIM = 103,
/// Compression Header Protocol.
COMP = 108,
/// Stream Control Transport Protocol
SCTP = 132,
/// UDP-Lite protocol (RFC 3828)
UDPLITE = 136,
/// MPLS in IP (RFC 4023)
MPLSINIP = 137,
/// Ethernet-within-IPv6 Encapsulation
ETHERNET = 143,
/// Raw IP packets
RAW = 255,
/// Multipath TCP connection
MPTCP = 262,
}
impl TryFrom<u16> for PosixIpProtocol {
type Error = SystemError;
fn try_from(value: u16) -> Result<Self, Self::Error> {
match <Self as FromPrimitive>::from_u16(value) {
Some(p) => Ok(p),
None => Err(SystemError::EPROTONOSUPPORT),
}
}
}
impl From<PosixIpProtocol> for u16 {
fn from(value: PosixIpProtocol) -> Self {
<PosixIpProtocol as ToPrimitive>::to_u16(&value).unwrap()
}
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)]
pub enum PosixSocketOption {
SO_DEBUG = 1,
SO_REUSEADDR = 2,
SO_TYPE = 3,
SO_ERROR = 4,
SO_DONTROUTE = 5,
SO_BROADCAST = 6,
SO_SNDBUF = 7,
SO_RCVBUF = 8,
SO_SNDBUFFORCE = 32,
SO_RCVBUFFORCE = 33,
SO_KEEPALIVE = 9,
SO_OOBINLINE = 10,
SO_NO_CHECK = 11,
SO_PRIORITY = 12,
SO_LINGER = 13,
SO_BSDCOMPAT = 14,
SO_REUSEPORT = 15,
SO_PASSCRED = 16,
SO_PEERCRED = 17,
SO_RCVLOWAT = 18,
SO_SNDLOWAT = 19,
SO_RCVTIMEO_OLD = 20,
SO_SNDTIMEO_OLD = 21,
SO_SECURITY_AUTHENTICATION = 22,
SO_SECURITY_ENCRYPTION_TRANSPORT = 23,
SO_SECURITY_ENCRYPTION_NETWORK = 24,
SO_BINDTODEVICE = 25,
/// 与SO_GET_FILTER相同
SO_ATTACH_FILTER = 26,
SO_DETACH_FILTER = 27,
SO_PEERNAME = 28,
SO_ACCEPTCONN = 30,
SO_PEERSEC = 31,
SO_PASSSEC = 34,
SO_MARK = 36,
SO_PROTOCOL = 38,
SO_DOMAIN = 39,
SO_RXQ_OVFL = 40,
/// 与SCM_WIFI_STATUS相同
SO_WIFI_STATUS = 41,
SO_PEEK_OFF = 42,
/* Instruct lower device to use last 4-bytes of skb data as FCS */
SO_NOFCS = 43,
SO_LOCK_FILTER = 44,
SO_SELECT_ERR_QUEUE = 45,
SO_BUSY_POLL = 46,
SO_MAX_PACING_RATE = 47,
SO_BPF_EXTENSIONS = 48,
SO_INCOMING_CPU = 49,
SO_ATTACH_BPF = 50,
// SO_DETACH_BPF = SO_DETACH_FILTER,
SO_ATTACH_REUSEPORT_CBPF = 51,
SO_ATTACH_REUSEPORT_EBPF = 52,
SO_CNX_ADVICE = 53,
SCM_TIMESTAMPING_OPT_STATS = 54,
SO_MEMINFO = 55,
SO_INCOMING_NAPI_ID = 56,
SO_COOKIE = 57,
SCM_TIMESTAMPING_PKTINFO = 58,
SO_PEERGROUPS = 59,
SO_ZEROCOPY = 60,
/// 与SCM_TXTIME相同
SO_TXTIME = 61,
SO_BINDTOIFINDEX = 62,
SO_TIMESTAMP_OLD = 29,
SO_TIMESTAMPNS_OLD = 35,
SO_TIMESTAMPING_OLD = 37,
SO_TIMESTAMP_NEW = 63,
SO_TIMESTAMPNS_NEW = 64,
SO_TIMESTAMPING_NEW = 65,
SO_RCVTIMEO_NEW = 66,
SO_SNDTIMEO_NEW = 67,
SO_DETACH_REUSEPORT_BPF = 68,
SO_PREFER_BUSY_POLL = 69,
SO_BUSY_POLL_BUDGET = 70,
SO_NETNS_COOKIE = 71,
SO_BUF_LOCK = 72,
SO_RESERVE_MEM = 73,
SO_TXREHASH = 74,
SO_RCVMARK = 75,
}
impl TryFrom<i32> for PosixSocketOption {
type Error = SystemError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match <Self as FromPrimitive>::from_i32(value) {
Some(p) => Ok(p),
None => Err(SystemError::EINVAL),
}
}
}
impl From<PosixSocketOption> for i32 {
fn from(value: PosixSocketOption) -> Self {
<PosixSocketOption as ToPrimitive>::to_i32(&value).unwrap()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] #[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
pub enum PosixTcpSocketOptions { pub enum PosixTcpSocketOptions {
/// Turn off Nagle's algorithm. /// Turn off Nagle's algorithm.

View File

@ -0,0 +1,347 @@
bitflags::bitflags! {
// #[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub struct SysArgSocketType: u32 {
const DGRAM = 1; // 0b0000_0001
const STREAM = 2; // 0b0000_0010
const RAW = 3; // 0b0000_0011
const RDM = 4; // 0b0000_0100
const SEQPACKET = 5; // 0b0000_0101
const DCCP = 6; // 0b0000_0110
const PACKET = 10; // 0b0000_1010
const NONBLOCK = crate::filesystem::vfs::file::FileMode::O_NONBLOCK.bits();
const CLOEXEC = crate::filesystem::vfs::file::FileMode::O_CLOEXEC.bits();
}
}
impl SysArgSocketType {
#[inline(always)]
pub fn types(&self) -> SysArgSocketType {
SysArgSocketType::from_bits(self.bits() & 0b_1111).unwrap()
}
#[inline(always)]
pub fn is_nonblock(&self) -> bool {
self.contains(SysArgSocketType::NONBLOCK)
}
#[inline(always)]
pub fn is_cloexec(&self) -> bool {
self.contains(SysArgSocketType::CLOEXEC)
}
}
use alloc::sync::Arc;
use core::ffi::CStr;
use unix::INODE_MAP;
use crate::{
filesystem::vfs::{
file::FileMode, FileType, IndexNode, MAX_PATHLEN, ROOT_INODE, VFS_MAX_FOLLOW_SYMLINK_TIMES,
},
libs::casting::DowncastArc,
mm::{verify_area, VirtAddr},
net::socket::{self, *},
process::ProcessManager,
syscall::Syscall,
};
use smoltcp;
use system_error::SystemError::{self, *};
// 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrIn {
pub sin_family: u16,
pub sin_port: u16,
pub sin_addr: u32,
pub sin_zero: [u8; 8],
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrUn {
pub sun_family: u16,
pub sun_path: [u8; 108],
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrLl {
pub sll_family: u16,
pub sll_protocol: u16,
pub sll_ifindex: u32,
pub sll_hatype: u16,
pub sll_pkttype: u8,
pub sll_halen: u8,
pub sll_addr: [u8; 8],
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrNl {
pub nl_family: AddressFamily,
pub nl_pad: u16,
pub nl_pid: u32,
pub nl_groups: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SockAddrPlaceholder {
pub family: u16,
pub data: [u8; 14],
}
#[repr(C)]
#[derive(Clone, Copy)]
pub union SockAddr {
pub family: u16,
pub addr_in: SockAddrIn,
pub addr_un: SockAddrUn,
pub addr_ll: SockAddrLl,
pub addr_nl: SockAddrNl,
pub addr_ph: SockAddrPlaceholder,
}
impl SockAddr {
/// @brief 把用户传入的SockAddr转换为Endpoint结构体
pub fn to_endpoint(addr: *const SockAddr, len: u32) -> Result<Endpoint, SystemError> {
use crate::net::socket::AddressFamily;
let addr = unsafe { addr.as_ref() }.ok_or(SystemError::EFAULT)?;
unsafe {
match AddressFamily::try_from(addr.family)? {
AddressFamily::INet => {
if len < addr.len()? {
log::error!("len < addr.len()");
return Err(SystemError::EINVAL);
}
let addr_in: SockAddrIn = addr.addr_in;
use smoltcp::wire;
let ip: wire::IpAddress = wire::IpAddress::from(wire::Ipv4Address::from_bytes(
&u32::from_be(addr_in.sin_addr).to_be_bytes()[..],
));
let port = u16::from_be(addr_in.sin_port);
return Ok(Endpoint::Ip(wire::IpEndpoint::new(ip, port)));
}
AddressFamily::Unix => {
let addr_un: SockAddrUn = addr.addr_un;
let path = CStr::from_bytes_until_nul(&addr_un.sun_path)
.map_err(|_| {
log::error!("CStr::from_bytes_until_nul fail");
SystemError::EINVAL
})?
.to_str()
.map_err(|_| {
log::error!("CStr::to_str fail");
SystemError::EINVAL
})?;
let (inode_begin, path) = crate::filesystem::vfs::utils::user_path_at(
&ProcessManager::current_pcb(),
crate::filesystem::vfs::fcntl::AtFlags::AT_FDCWD.bits(),
path.trim(),
)?;
let inode0: Result<Arc<dyn IndexNode>, SystemError> =
inode_begin.lookup_follow_symlink(&path, VFS_MAX_FOLLOW_SYMLINK_TIMES);
let inode = match inode0 {
Ok(inode) => inode,
Err(_) => {
let (filename, parent_path) =
crate::filesystem::vfs::utils::rsplit_path(&path);
// 查找父目录
log::debug!("filename {:?} parent_path {:?}", filename, parent_path);
let parent_inode: Arc<dyn IndexNode> =
ROOT_INODE().lookup(parent_path.unwrap_or("/"))?;
// 创建文件
let inode: Arc<dyn IndexNode> = match parent_inode.create(
filename,
FileType::File,
crate::filesystem::vfs::syscall::ModeType::from_bits_truncate(
0o755,
),
) {
Ok(inode) => inode,
Err(e) => {
log::debug!("inode create fail {:?}", e);
return Err(e);
}
};
inode
}
};
return Ok(Endpoint::Unixpath((inode.metadata()?.inode_id, path)));
}
AddressFamily::Packet => {
// TODO: support packet socket
log::warn!("not support address family {:?}", addr.family);
return Err(SystemError::EINVAL);
}
AddressFamily::Netlink => {
// TODO: support netlink socket
let addr: SockAddrNl = addr.addr_nl;
return Ok(Endpoint::Netlink(NetlinkEndpoint::new(addr)));
}
_ => {
log::warn!("not support address family {:?}", addr.family);
return Err(SystemError::EINVAL);
}
}
}
}
/// @brief 获取地址长度
pub fn len(&self) -> Result<u32, SystemError> {
match AddressFamily::try_from(unsafe { self.family })? {
AddressFamily::INet => Ok(core::mem::size_of::<SockAddrIn>()),
AddressFamily::Packet => Ok(core::mem::size_of::<SockAddrLl>()),
AddressFamily::Netlink => Ok(core::mem::size_of::<SockAddrNl>()),
AddressFamily::Unix => Ok(core::mem::size_of::<SockAddrUn>()),
_ => Err(SystemError::EINVAL),
}
.map(|x| x as u32)
}
/// @brief 把SockAddr的数据写入用户空间
///
/// @param addr 用户空间的SockAddr的地址
/// @param len 要写入的长度
///
/// @return 成功返回写入的长度,失败返回错误码
pub unsafe fn write_to_user(
&self,
addr: *mut SockAddr,
addr_len: *mut u32,
) -> Result<u32, SystemError> {
// 当用户传入的地址或者长度为空时直接返回0
if addr.is_null() || addr_len.is_null() {
return Ok(0);
}
// 检查用户传入的地址是否合法
verify_area(
VirtAddr::new(addr as usize),
core::mem::size_of::<SockAddr>(),
)
.map_err(|_| SystemError::EFAULT)?;
verify_area(
VirtAddr::new(addr_len as usize),
core::mem::size_of::<u32>(),
)
.map_err(|_| SystemError::EFAULT)?;
let to_write = core::cmp::min(self.len()?, *addr_len);
if to_write > 0 {
let buf = core::slice::from_raw_parts_mut(addr as *mut u8, to_write as usize);
buf.copy_from_slice(core::slice::from_raw_parts(
self as *const SockAddr as *const u8,
to_write as usize,
));
}
*addr_len = self.len()?;
return Ok(to_write);
}
pub unsafe fn is_empty(&self) -> bool {
unsafe { self.family == 0 && self.addr_ph.data == [0; 14] }
}
}
impl From<Endpoint> for SockAddr {
fn from(value: Endpoint) -> Self {
match value {
Endpoint::Ip(ip_endpoint) => match ip_endpoint.addr {
smoltcp::wire::IpAddress::Ipv4(ipv4_addr) => {
let addr_in = SockAddrIn {
sin_family: AddressFamily::INet as u16,
sin_port: ip_endpoint.port.to_be(),
sin_addr: u32::from_be_bytes(ipv4_addr.0).to_be(),
sin_zero: [0; 8],
};
return SockAddr { addr_in };
}
_ => {
unimplemented!("not support ipv6");
}
},
Endpoint::LinkLayer(link_endpoint) => {
let addr_ll = SockAddrLl {
sll_family: AddressFamily::Packet as u16,
sll_protocol: 0,
sll_ifindex: link_endpoint.interface as u32,
sll_hatype: 0,
sll_pkttype: 0,
sll_halen: 0,
sll_addr: [0; 8],
};
return SockAddr { addr_ll };
}
Endpoint::Netlink(netlink_endpoint) => {
let addr_nl = SockAddrNl {
nl_family: AddressFamily::Netlink,
nl_pad: 0,
nl_pid: netlink_endpoint.addr.nl_pid,
nl_groups: netlink_endpoint.addr.nl_groups,
};
return SockAddr { addr_nl };
}
Endpoint::Inode((_, path)) => {
log::debug!("from unix path {:?}", path);
let bytes = path.as_bytes();
let mut sun_path = [0u8; 108];
if bytes.len() <= 108 {
sun_path[..bytes.len()].copy_from_slice(bytes);
} else {
panic!("unix address path too long!");
}
let addr_un = SockAddrUn {
sun_family: AddressFamily::Unix as u16,
sun_path: sun_path,
};
return SockAddr { addr_un };
}
_ => {
// todo: support other endpoint, like Netlink...
unimplemented!("not support {value:?}");
}
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct MsgHdr {
/// 指向一个SockAddr结构体的指针
pub msg_name: *mut SockAddr,
/// SockAddr结构体的大小
pub msg_namelen: u32,
/// scatter/gather array
pub msg_iov: *mut crate::filesystem::vfs::syscall::IoVec,
/// elements in msg_iov
pub msg_iovlen: usize,
/// 辅助数据
pub msg_control: *mut u8,
/// 辅助数据长度
pub msg_controllen: u32,
/// 接收到的消息的标志
pub msg_flags: u32,
}
// TODO: 从用户态读取MsgHdr以及写入MsgHdr

View File

@ -50,7 +50,8 @@ use crate::{
ucontext::AddressSpace, ucontext::AddressSpace,
VirtAddr, VirtAddr,
}, },
net::socket::SocketInode, net::socket::Inode as SocketInode,
// net::socket::SocketInode,
sched::completion::Completion, sched::completion::Completion,
sched::{ sched::{
cpu_rq, fair::FairSchedEntity, prio::MAX_PRIO, DequeueFlag, EnqueueFlag, OnRq, SchedMode, cpu_rq, fair::FairSchedEntity, prio::MAX_PRIO, DequeueFlag, EnqueueFlag, OnRq, SchedMode,

View File

@ -456,9 +456,10 @@ impl Syscall {
// 地址空间超出了用户空间的范围,不合法 // 地址空间超出了用户空间的范围,不合法
Err(SystemError::EFAULT) Err(SystemError::EFAULT)
} else { } else {
Self::connect(args[0], addr, addrlen) Self::connect(args[0], addr, addrlen as u32)
} }
} }
SYS_BIND => { SYS_BIND => {
let addr = args[1] as *const SockAddr; let addr = args[1] as *const SockAddr;
let addrlen = args[2]; let addrlen = args[2];
@ -468,7 +469,7 @@ impl Syscall {
// 地址空间超出了用户空间的范围,不合法 // 地址空间超出了用户空间的范围,不合法
Err(SystemError::EFAULT) Err(SystemError::EFAULT)
} else { } else {
Self::bind(args[0], addr, addrlen) Self::bind(args[0], addr, addrlen as u32)
} }
} }
@ -486,7 +487,7 @@ impl Syscall {
Err(SystemError::EFAULT) Err(SystemError::EFAULT)
} else { } else {
let data: &[u8] = unsafe { core::slice::from_raw_parts(buf, len) }; let data: &[u8] = unsafe { core::slice::from_raw_parts(buf, len) };
Self::sendto(args[0], data, flags, addr, addrlen) Self::sendto(args[0], data, flags, addr, addrlen as u32)
} }
} }
@ -495,7 +496,7 @@ impl Syscall {
let len = args[2]; let len = args[2];
let flags = args[3] as u32; let flags = args[3] as u32;
let addr = args[4] as *mut SockAddr; let addr = args[4] as *mut SockAddr;
let addrlen = args[5] as *mut usize; let addrlen = args[5] as *mut u32;
let virt_buf = VirtAddr::new(buf as usize); let virt_buf = VirtAddr::new(buf as usize);
let virt_addrlen = VirtAddr::new(addrlen as usize); let virt_addrlen = VirtAddr::new(addrlen as usize);
let virt_addr = VirtAddr::new(addr as usize); let virt_addr = VirtAddr::new(addr as usize);
@ -507,7 +508,7 @@ impl Syscall {
} }
// 验证addrlen的地址是否合法 // 验证addrlen的地址是否合法
if verify_area(virt_addrlen, core::mem::size_of::<u32>()).is_err() { if verify_area(virt_addrlen, core::mem::size_of::<usize>()).is_err() {
// 地址空间超出了用户空间的范围,不合法 // 地址空间超出了用户空间的范围,不合法
return Err(SystemError::EFAULT); return Err(SystemError::EFAULT);
} }
@ -518,12 +519,11 @@ impl Syscall {
} }
return Ok(()); return Ok(());
}; };
let r = security_check(); if let Err(e) = security_check() {
if let Err(e) = r {
Err(e) Err(e)
} else { } else {
let buf = unsafe { core::slice::from_raw_parts_mut(buf, len) }; let buf = unsafe { core::slice::from_raw_parts_mut(buf, len) };
Self::recvfrom(args[0], buf, flags, addr, addrlen as *mut u32) Self::recvfrom(args[0], buf, flags, addr, addrlen)
} }
} }

View File

@ -143,7 +143,7 @@ while true;do
# ps: 下面这条使用tap的方式无法dhcp获取到ip暂时不知道为什么 # ps: 下面这条使用tap的方式无法dhcp获取到ip暂时不知道为什么
# QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -net nic,netdev=nic0 -netdev tap,id=nic0,model=virtio-net-pci,script=qemu/ifup-nat,downscript=qemu/ifdown-nat -usb -device qemu-xhci,id=xhci,p2=8,p3=4 " # QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -net nic,netdev=nic0 -netdev tap,id=nic0,model=virtio-net-pci,script=qemu/ifup-nat,downscript=qemu/ifdown-nat -usb -device qemu-xhci,id=xhci,p2=8,p3=4 "
QEMU_DEVICES+="${QEMU_DEVICES_DISK} " QEMU_DEVICES+="${QEMU_DEVICES_DISK} "
QEMU_DEVICES+=" -netdev user,id=hostnet0,hostfwd=tcp::12580-:12580 -device virtio-net-pci,vectors=5,netdev=hostnet0,id=net0 -usb -device qemu-xhci,id=xhci,p2=8,p3=4 " QEMU_DEVICES+=" -netdev user,id=hostnet0,hostfwd=tcp::12580-:12580,hostfwd=udp::12549-:12549 -device virtio-net-pci,vectors=5,netdev=hostnet0,id=net0 -usb -device qemu-xhci,id=xhci,p2=8,p3=4 "
# E1000E # E1000E
# QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -netdev user,id=hostnet0,hostfwd=tcp::12580-:12580 -net nic,model=e1000e,netdev=hostnet0,id=net0 -netdev user,id=hostnet1,hostfwd=tcp::12581-:12581 -device virtio-net-pci,vectors=5,netdev=hostnet1,id=net1 -usb -device qemu-xhci,id=xhci,p2=8,p3=4 " # QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -netdev user,id=hostnet0,hostfwd=tcp::12580-:12580 -net nic,model=e1000e,netdev=hostnet0,id=net0 -netdev user,id=hostnet1,hostfwd=tcp::12581-:12581 -device virtio-net-pci,vectors=5,netdev=hostnet1,id=net1 -usb -device qemu-xhci,id=xhci,p2=8,p3=4 "
QEMU_ARGUMENT+="-d ${QEMU_DISK_IMAGE} -m ${QEMU_MEMORY} -smp ${QEMU_SMP} -boot order=d ${QEMU_MONITOR} -d ${qemu_trace_std} " QEMU_ARGUMENT+="-d ${QEMU_DISK_IMAGE} -m ${QEMU_MEMORY} -smp ${QEMU_SMP} -boot order=d ${QEMU_MONITOR} -d ${qemu_trace_std} "

3
user/apps/ping/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target
Cargo.lock
/install/

18
user/apps/ping/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "ping"
version = "0.1.0"
edition = "2021"
description = "ping for dragonOS"
authors = [ "smallc <2628035541@qq.com>" ]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.86"
clap = { version = "4.5.11", features = ["derive"] }
crossbeam-channel = "0.5.13"
pnet = "0.35.0"
rand = "0.8.5"
signal-hook = "0.3.17"
socket2 = "0.5.7"
thiserror = "1.0.63"

56
user/apps/ping/Makefile Normal file
View File

@ -0,0 +1,56 @@
TOOLCHAIN=
RUSTFLAGS=
ifdef DADK_CURRENT_BUILD_DIR
# 如果是在dadk中编译那么安装到dadk的安装目录中
INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR)
else
# 如果是在本地编译那么安装到当前目录下的install目录中
INSTALL_DIR = ./install
endif
ifeq ($(ARCH), x86_64)
export RUST_TARGET=x86_64-unknown-linux-musl
else ifeq ($(ARCH), riscv64)
export RUST_TARGET=riscv64gc-unknown-linux-gnu
else
# 默认为x86_86用于本地编译
export RUST_TARGET=x86_64-unknown-linux-musl
endif
run:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET)
build:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET)
clean:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET)
test:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET)
doc:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET)
fmt:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt
fmt-check:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check
run-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release
build-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release
clean-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release
test-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release
.PHONY: install
install:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force

23
user/apps/ping/README.md Normal file
View File

@ -0,0 +1,23 @@
# PING
为DragonOS实现ping
## NAME
ping - 向网络主机发送ICMP ECHO_REQUEST
## SYNOPSIS
[-c count] 指定 ping 的次数。例如,`-c 4` 会向目标主机发送 4 个 ping 请求。
[-i interval]:指定两次 ping 请求之间的时间间隔,单位是秒。例如,`-i 2` 会每 2 秒发送一次 ping 请求。
[-w timeout] 指定等待 ping 响应的超时时间,单位是秒。例如,`-w 5` 会在 5 秒后超时。
[-s packetsize]:指定发送的 ICMP Packet 的大小,单位是字节。例如,`-s 64` 会发送大小为 64 字节的 ICMP Packet。
[-t ttl]:指定 ping 的 TTL (Time to Live)。例如,`-t 64` 会设置 TTL 为 64。
{destination}:指定要 ping 的目标主机。可以是 IP 地址或者主机名。例如,`192.168.1.1``www.example.com`
## DESCRIPTION
ping 使用 ICMP 协议的必需的 ECHO_REQUEST 数据报来引发主机或网关的 ICMP ECHO_RESPONSE。ECHO_REQUEST 数据报“ping”具有 IP 和 ICMP 头,后面跟着一个 struct timeval然后是用于填充数据包的任意数量的“填充”字节。
ping 支持 IPv4 和 IPv6。可以通过指定 -4 或 -6 来强制只使用其中一个。
ping 还可以发送 IPv6 节点信息查询RFC4620。可能不允许中间跳跃因为 IPv6 源路由已被弃用RFC5095

View File

@ -0,0 +1,50 @@
use clap::{arg, command, Parser};
use rand::random;
use crate::config::{Config, IpAddress};
/// # Args结构体
/// 使用clap库对命令行输入进行pasing产生参数配置
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Args {
// Count of ping times
#[arg(short, default_value_t = 4)]
count: u16,
// Ping packet size
#[arg(short = 's', default_value_t = 64)]
packet_size: usize,
// Ping ttl
#[arg(short = 't', default_value_t = 64)]
ttl: u32,
// Ping timeout seconds
#[arg(short = 'w', default_value_t = 1)]
timeout: u64,
// Ping interval duration milliseconds
#[arg(short = 'i', default_value_t = 1000)]
interval: u64,
// Ping destination, ip or domain
#[arg(value_parser=IpAddress::parse)]
destination: IpAddress,
}
impl Args {
/// # 将Args结构体转换为config结构体
pub fn as_config(&self) -> Config {
Config {
count: self.count,
packet_size: self.packet_size,
ttl: self.ttl,
timeout: self.timeout,
interval: self.interval,
id: random::<u16>(),
sequence: 1,
address: self.destination.clone(),
}
}
}

View File

@ -0,0 +1,45 @@
use anyhow::bail;
use std::{
ffi::CString,
net::{self},
};
use crate::error;
///# Config结构体
/// 记录ping指令的一些参数值
#[derive(Debug, Clone)]
pub struct Config {
pub count: u16,
pub packet_size: usize,
pub ttl: u32,
pub timeout: u64,
pub interval: u64,
pub id: u16,
pub sequence: u16,
pub address: IpAddress,
}
///# 目标地址ip结构体
/// ip负责提供给socket使用
/// raw负责打印输出
#[derive(Debug, Clone)]
pub struct IpAddress {
pub ip: net::IpAddr,
pub raw: String,
}
impl IpAddress {
pub fn parse(host: &str) -> anyhow::Result<Self> {
let raw = String::from(host);
let opt = host.parse::<net::IpAddr>().ok();
match opt {
Some(ip) => Ok(Self { ip, raw }),
None => {
bail!(error::PingError::InvalidConfig(
"Invalid Address".to_string()
));
}
}
}
}

View File

@ -0,0 +1,10 @@
#![allow(dead_code)]
#[derive(Debug, Clone, thiserror::Error)]
pub enum PingError {
#[error("invaild config")]
InvalidConfig(String),
#[error("invaild packet")]
InvalidPacket,
}

View File

@ -0,0 +1,23 @@
use args::Args;
use clap::Parser;
use std::format;
mod args;
mod config;
mod error;
mod ping;
///# ping入口主函数
fn main() {
let args = Args::parse();
match ping::Ping::new(args.as_config()) {
Ok(pinger) => pinger.run().unwrap_or_else(|e| {
exit(format!("Error on run ping: {}", e));
}),
Err(e) => exit(format!("Error on init: {}", e)),
}
}
fn exit(msg: String) {
eprintln!("{}", msg);
std::process::exit(1);
}

151
user/apps/ping/src/ping.rs Normal file
View File

@ -0,0 +1,151 @@
use crossbeam_channel::{bounded, select, Receiver};
use pnet::packet::{
icmp::{
echo_reply::{EchoReplyPacket, IcmpCodes},
echo_request::MutableEchoRequestPacket,
IcmpTypes,
},
util, Packet,
};
use signal_hook::consts::{SIGINT, SIGTERM};
use socket2::{Domain, Protocol, Socket, Type};
use std::{
io,
net::{self, Ipv4Addr, SocketAddr},
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
thread::{self},
time::{Duration, Instant},
};
use crate::{config::Config, error::PingError};
#[derive(Clone)]
pub struct Ping {
config: Config,
socket: Arc<Socket>,
dest: SocketAddr,
}
impl Ping {
///# ping创建函数
/// 使用config进行ping的配置
pub fn new(config: Config) -> std::io::Result<Self> {
let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?;
let src = SocketAddr::new(net::IpAddr::V4(Ipv4Addr::UNSPECIFIED), 12549);
let dest = SocketAddr::new(config.address.ip, 12549);
socket.bind(&src.into())?;
// socket.set_ttl(64)?;
// socket.set_read_timeout(Some(Duration::from_secs(config.timeout)))?;
// socket.set_write_timeout(Some(Duration::from_secs(config.timeout)))?;
Ok(Self {
config,
dest,
socket: Arc::new(socket),
})
}
///# ping主要执行逻辑
/// 创建icmpPacket发送给socket
pub fn ping(&self, seq_offset: u16) -> anyhow::Result<()> {
//创建 icmp request packet
let mut buf = vec![0; self.config.packet_size];
let mut icmp = MutableEchoRequestPacket::new(&mut buf[..]).expect("InvalidBuffferSize");
icmp.set_icmp_type(IcmpTypes::EchoRequest);
icmp.set_icmp_code(IcmpCodes::NoCode);
icmp.set_identifier(self.config.id);
icmp.set_sequence_number(self.config.sequence + seq_offset);
icmp.set_checksum(util::checksum(icmp.packet(), 1));
let start = Instant::now();
//发送 request
self.socket.send_to(icmp.packet(), &self.dest.into())?;
//处理 recv
let mut mem_buf =
unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [std::mem::MaybeUninit<u8>]) };
let (size, _) = self.socket.recv_from(&mut mem_buf)?;
let duration = start.elapsed().as_micros() as f64 / 1000.0;
let reply = EchoReplyPacket::new(&buf).ok_or(PingError::InvalidPacket)?;
println!(
"{} bytes from {}: icmp_seq={} ttl={} time={:.2}ms",
size,
self.config.address.ip,
reply.get_sequence_number(),
self.config.ttl,
duration
);
Ok(())
}
///# ping指令多线程运行
/// 创建多个线程负责不同的ping函数的执行
pub fn run(&self) -> io::Result<()> {
println!(
"PING {}({})",
self.config.address.raw, self.config.address.ip
);
let _now = Instant::now();
let send = Arc::new(AtomicU64::new(0));
let _send = send.clone();
let this = Arc::new(self.clone());
let success = Arc::new(AtomicU64::new(0));
let _success = success.clone();
let mut handles = vec![];
for i in 0..this.config.count {
let _this = this.clone();
let handle = thread::spawn(move||{
_this.ping(i).unwrap();
});
_send.fetch_add(1, Ordering::SeqCst);
handles.push(handle);
if i < this.config.count - 1 {
thread::sleep(Duration::from_millis(this.config.interval));
}
}
for handle in handles {
if handle.join().is_ok() {
_success.fetch_add(1, Ordering::SeqCst);
}
}
let total = _now.elapsed().as_micros() as f64 / 1000.0;
let send = send.load(Ordering::SeqCst);
let success = success.load(Ordering::SeqCst);
let loss_rate = if send > 0 {
(send - success) * 100 / send
} else {
0
};
println!("\n--- {} ping statistics ---", self.config.address.raw);
println!(
"{} packets transmitted, {} received, {}% packet loss, time {}ms",
send, success, loss_rate, total,
);
Ok(())
}
}
//TODO: 等待添加ctrl+c发送信号后添加该特性
// /// # 创建一个进程用于监听用户是否提前退出程序
// fn signal_notify() -> std::io::Result<Receiver<i32>> {
// let (s, r) = bounded(1);
// let mut signals = signal_hook::iterator::Signals::new(&[SIGINT, SIGTERM])?;
// thread::spawn(move || {
// for signal in signals.forever() {
// s.send(signal).unwrap();
// break;
// }
// });
// Ok(r)
// }

View File

@ -0,0 +1,2 @@
[build]
target = "x86_64-unknown-linux-musl"

3
user/apps/test-uevent/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target
Cargo.lock
/install/

View File

@ -0,0 +1,12 @@
[package]
name = "test-uevent"
version = "0.1.0"
edition = "2021"
description = "test for uevent"
authors = [ "val213 <val213666@gmail.com>" ]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
netlink-sys = "0.5"
nix = "0.24"

View File

@ -0,0 +1,56 @@
TOOLCHAIN="+nightly-2023-08-15-x86_64-unknown-linux-gnu"
RUSTFLAGS+=""
ifdef DADK_CURRENT_BUILD_DIR
# 如果是在dadk中编译那么安装到dadk的安装目录中
INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR)
else
# 如果是在本地编译那么安装到当前目录下的install目录中
INSTALL_DIR = ./install
endif
ifeq ($(ARCH), x86_64)
export RUST_TARGET=x86_64-unknown-linux-musl
else ifeq ($(ARCH), riscv64)
export RUST_TARGET=riscv64gc-unknown-linux-gnu
else
# 默认为x86_86用于本地编译
export RUST_TARGET=x86_64-unknown-linux-musl
endif
run:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET)
build:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET)
clean:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET)
test:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET)
doc:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET)
fmt:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt
fmt-check:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check
run-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release
build-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release
clean-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release
test-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release
.PHONY: install
install:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force

View File

@ -0,0 +1,14 @@
# DragonOS Rust-Application Template
您可以使用此模板来创建DragonOS应用程序。
## 使用方法
1. 使用DragonOS的tools目录下的`bootstrap.sh`脚本初始化环境
2. 在终端输入`cargo install cargo-generate`
3. 在终端输入`cargo generate --git https://github.com/DragonOS-Community/Rust-App-Template`即可创建项目
如果您的网络较慢,请使用镜像站`cargo generate --git https://git.mirrors.dragonos.org/DragonOS-Community/Rust-App-Template`
4. 使用`cargo run`来运行项目
5. 在DragonOS的`user/dadk/config`目录下,使用`dadk new`命令,创建编译配置,安装到DragonOS的`/`目录下。
(在dadk的编译命令选项处请使用Makefile里面的`make install`配置进行编译、安装)
6. 编译DragonOS即可安装

View File

@ -0,0 +1,150 @@
use libc::{sockaddr, sockaddr_storage, recvfrom, bind, sendto, socket, AF_NETLINK, SOCK_DGRAM, SOCK_CLOEXEC, getpid, c_void};
use nix::libc;
use std::os::unix::io::RawFd;
use std::{ mem, io};
#[repr(C)]
struct Nlmsghdr {
nlmsg_len: u32,
nlmsg_type: u16,
nlmsg_flags: u16,
nlmsg_seq: u32,
nlmsg_pid: u32,
}
fn create_netlink_socket() -> io::Result<RawFd> {
let sockfd = unsafe {
socket(AF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, libc::NETLINK_KOBJECT_UEVENT)
};
if sockfd < 0 {
println!("Error: {}", io::Error::last_os_error());
return Err(io::Error::last_os_error());
}
Ok(sockfd)
}
fn bind_netlink_socket(sock: RawFd) -> io::Result<()> {
let pid = unsafe { getpid() };
let mut addr: libc::sockaddr_nl = unsafe { mem::zeroed() };
addr.nl_family = AF_NETLINK as u16;
addr.nl_pid = pid as u32;
addr.nl_groups = 0;
let ret = unsafe {
bind(sock, &addr as *const _ as *const sockaddr, mem::size_of::<libc::sockaddr_nl>() as u32)
};
if ret < 0 {
println!("Error: {}", io::Error::last_os_error());
return Err(io::Error::last_os_error());
}
Ok(())
}
fn send_uevent(sock: RawFd, message: &str) -> io::Result<()> {
let mut addr: libc::sockaddr_nl = unsafe { mem::zeroed() };
addr.nl_family = AF_NETLINK as u16;
addr.nl_pid = 0;
addr.nl_groups = 0;
let nlmsghdr = Nlmsghdr {
nlmsg_len: (mem::size_of::<Nlmsghdr>() + message.len()) as u32,
nlmsg_type: 0,
nlmsg_flags: 0,
nlmsg_seq: 0,
nlmsg_pid: 0,
};
let mut buffer = Vec::with_capacity(nlmsghdr.nlmsg_len as usize);
buffer.extend_from_slice(unsafe {
std::slice::from_raw_parts(
&nlmsghdr as *const Nlmsghdr as *const u8,
mem::size_of::<Nlmsghdr>(),
)
});
buffer.extend_from_slice(message.as_bytes());
let ret = unsafe {
sendto(
sock,
buffer.as_ptr() as *const c_void,
buffer.len(),
0,
&addr as *const _ as *const sockaddr,
mem::size_of::<libc::sockaddr_nl>() as u32,
)
};
if ret < 0 {
println!("Error: {}", io::Error::last_os_error());
return Err(io::Error::last_os_error());
}
Ok(())
}
fn receive_uevent(sock: RawFd) -> io::Result<String> {
// 检查套接字文件描述符是否有效
if sock < 0 {
println!("Invalid socket file descriptor: {}", sock);
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid socket file descriptor"));
}
let mut buf = [0u8; 1024];
// let mut addr: sockaddr_storage = unsafe { mem::zeroed() };
// let mut addr_len = mem::size_of::<sockaddr_storage>() as u32;
// 检查缓冲区指针和长度是否有效
if buf.is_empty() {
println!("Buffer is empty");
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Buffer is empty"));
}
let len = unsafe {
recvfrom(
sock,
buf.as_mut_ptr() as *mut c_void,
buf.len(),
0,
core::ptr::null_mut(), // 不接收发送方地址
core::ptr::null_mut(), // 不接收发送方地址长度
)
};
println!("Received {} bytes", len);
println!("Received message: {:?}", &buf[..len as usize]);
if len < 0 {
println!("Error: {}", io::Error::last_os_error());
return Err(io::Error::last_os_error());
}
let nlmsghdr_size = mem::size_of::<Nlmsghdr>();
if (len as usize) < nlmsghdr_size {
println!("Received message is too short");
return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is too short"));
}
let nlmsghdr = unsafe { &*(buf.as_ptr() as *const Nlmsghdr) };
if nlmsghdr.nlmsg_len as isize > len {
println!("Received message is incomplete");
return Err(io::Error::new(io::ErrorKind::InvalidData, "Received message is incomplete"));
}
let message_data = &buf[nlmsghdr_size..nlmsghdr.nlmsg_len as usize];
Ok(String::from_utf8_lossy(message_data).to_string())
}
fn main() {
let socket = create_netlink_socket().expect("Failed to create Netlink socket");
println!("Netlink socket created successfully");
bind_netlink_socket(socket).expect("Failed to bind Netlink socket");
println!("Netlink socket created and bound successfully");
send_uevent(socket, "add@/devices/virtual/block/loop0").expect("Failed to send uevent message");
println!("Custom uevent message sent successfully");
let message = receive_uevent(socket).expect("Failed to receive uevent message");
println!("Received uevent message: {}", message);
}

3
user/apps/test_seqpacket/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target
Cargo.lock
/install/

View File

@ -0,0 +1,12 @@
[package]
name = "test_seqpacket"
version = "0.1.0"
edition = "2021"
description = "测试seqpacket的socket"
authors = [ "Saga <1750226968@qq.com>" ]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
nix = "0.26"
libc = "0.2"

View File

@ -0,0 +1,56 @@
TOOLCHAIN=
RUSTFLAGS=
ifdef DADK_CURRENT_BUILD_DIR
# 如果是在dadk中编译那么安装到dadk的安装目录中
INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR)
else
# 如果是在本地编译那么安装到当前目录下的install目录中
INSTALL_DIR = ./install
endif
ifeq ($(ARCH), x86_64)
export RUST_TARGET=x86_64-unknown-linux-musl
else ifeq ($(ARCH), riscv64)
export RUST_TARGET=riscv64gc-unknown-linux-gnu
else
# 默认为x86_86用于本地编译
export RUST_TARGET=x86_64-unknown-linux-musl
endif
run:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET)
build:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET)
clean:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET)
test:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET)
doc:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET)
fmt:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt
fmt-check:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check
run-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release
build-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release
clean-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release
test-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release
.PHONY: install
install:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force

View File

@ -0,0 +1,14 @@
# DragonOS Rust-Application Template
您可以使用此模板来创建DragonOS应用程序。
## 使用方法
1. 使用DragonOS的tools目录下的`bootstrap.sh`脚本初始化环境
2. 在终端输入`cargo install cargo-generate`
3. 在终端输入`cargo generate --git https://github.com/DragonOS-Community/Rust-App-Template`即可创建项目
如果您的网络较慢,请使用镜像站`cargo generate --git https://git.mirrors.dragonos.org/DragonOS-Community/Rust-App-Template`
4. 使用`cargo run`来运行项目
5. 在DragonOS的`user/dadk/config`目录下,使用`dadk new`命令,创建编译配置,安装到DragonOS的`/`目录下。
(在dadk的编译命令选项处请使用Makefile里面的`make install`配置进行编译、安装)
6. 编译DragonOS即可安装

View File

@ -0,0 +1,190 @@
mod seq_socket;
mod seq_pair;
use seq_socket::test_seq_socket;
use seq_pair::test_seq_pair;
fn main() -> Result<(), std::io::Error> {
if let Err(e) = test_seq_socket() {
println!("[ fault ] test_seq_socket, err: {}", e);
} else {
println!("[success] test_seq_socket");
}
if let Err(e) = test_seq_pair() {
println!("[ fault ] test_seq_pair, err: {}", e);
} else {
println!("[success] test_seq_pair");
}
Ok(())
}
// use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType};
// use std::fs::File;
// use std::io::{Read, Write};
// use std::os::fd::FromRawFd;
// use std::{fs, str};
// use libc::*;
// use std::ffi::CString;
// use std::io::Error;
// use std::mem;
// use std::os::unix::io::RawFd;
// use std::ptr;
// const SOCKET_PATH: &str = "/test.seqpacket";
// const MSG: &str = "Hello, Unix SEQPACKET socket!";
// fn create_seqpacket_socket() -> Result<RawFd, Error> {
// unsafe {
// let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
// if fd == -1 {
// return Err(Error::last_os_error());
// }
// Ok(fd)
// }
// }
// fn bind_socket(fd: RawFd) -> Result<(), Error> {
// unsafe {
// let mut addr = sockaddr_un {
// sun_family: AF_UNIX as u16,
// sun_path: [0; 108],
// };
// let path_cstr = CString::new(SOCKET_PATH).unwrap();
// let path_bytes = path_cstr.as_bytes();
// for (i, &byte) in path_bytes.iter().enumerate() {
// addr.sun_path[i] = byte as i8;
// }
// if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
// return Err(Error::last_os_error());
// }
// }
// Ok(())
// }
// fn listen_socket(fd: RawFd) -> Result<(), Error> {
// unsafe {
// if listen(fd, 5) == -1 {
// return Err(Error::last_os_error());
// }
// }
// Ok(())
// }
// fn accept_connection(fd: RawFd) -> Result<RawFd, Error> {
// unsafe {
// // let mut addr = sockaddr_un {
// // sun_family: AF_UNIX as u16,
// // sun_path: [0; 108],
// // };
// // let mut len = mem::size_of_val(&addr) as socklen_t;
// let client_fd = accept(fd, std::ptr::null_mut(), std::ptr::null_mut());
// if client_fd == -1 {
// return Err(Error::last_os_error());
// }
// Ok(client_fd)
// }
// }
// fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
// unsafe {
// let msg_bytes = msg.as_bytes();
// if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0) == -1 {
// return Err(Error::last_os_error());
// }
// }
// Ok(())
// }
// fn receive_message(fd: RawFd) -> Result<String, Error> {
// let mut buffer = [0; 1024];
// unsafe {
// let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(), 0);
// if len == -1 {
// return Err(Error::last_os_error());
// }
// Ok(String::from_utf8_lossy(&buffer[..len as usize]).into_owned())
// }
// }
// fn main() -> Result<(), Error> {
// // Create and bind the server socket
// fs::remove_file(&SOCKET_PATH).ok();
// let server_fd = create_seqpacket_socket()?;
// bind_socket(server_fd)?;
// listen_socket(server_fd)?;
// // Accept connection in a separate thread
// let server_thread = std::thread::spawn(move || {
// let client_fd = accept_connection(server_fd).expect("Failed to accept connection");
// // Receive and print message
// let received_msg = receive_message(client_fd).expect("Failed to receive message");
// println!("Server: Received message: {}", received_msg);
// // Close client connection
// unsafe { close(client_fd) };
// });
// // Create and connect the client socket
// let client_fd = create_seqpacket_socket()?;
// unsafe {
// let mut addr = sockaddr_un {
// sun_family: AF_UNIX as u16,
// sun_path: [0; 108],
// };
// let path_cstr = CString::new(SOCKET_PATH).unwrap();
// let path_bytes = path_cstr.as_bytes();
// // Convert u8 to i8
// for (i, &byte) in path_bytes.iter().enumerate() {
// addr.sun_path[i] = byte as i8;
// }
// if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
// return Err(Error::last_os_error());
// }
// }
// send_message(client_fd, MSG)?;
// // Close client connection
// unsafe { close(client_fd) };
// // Wait for server thread to complete
// server_thread.join().expect("Server thread panicked");
// fs::remove_file(&SOCKET_PATH).ok();
// // 创建 socket pair
// let (sock1, sock2) = socketpair(
// AddressFamily::Unix,
// SockType::SeqPacket, // 使用 SeqPacket 类型
// None, // 协议默认
// SockFlag::empty(),
// ).expect("Failed to create socket pair");
// let mut socket1 = unsafe { File::from_raw_fd(sock1) };
// let mut socket2 = unsafe { File::from_raw_fd(sock2) };
// // sock1 写入数据
// let msg = b"hello from sock1";
// socket1.write_all(msg)?;
// println!("sock1 send: {:?}", String::from_utf8_lossy(&msg[..]));
// // 因os read和write时会调整file的offset,write会对offset和meta size(目前返回的都是0)进行比较,
// // 而read不会故双socket都先send,后recv
// // sock2 回复数据
// let reply = b"hello from sock2";
// socket2.write_all(reply)?;
// println!("sock2 send: {:?}", String::from_utf8_lossy(reply));
// // sock2 读取数据
// let mut buf = [0u8; 128];
// let len = socket2.read(&mut buf)?;
// println!("sock2 receive: {:?}", String::from_utf8_lossy(&buf[..len]));
// // sock1 读取回复
// let len = socket1.read(&mut buf)?;
// println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len]));
// Ok(())
// }

View File

@ -0,0 +1,39 @@
use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType};
use std::fs::File;
use std::io::{Read, Write,Error};
use std::os::fd::FromRawFd;
pub fn test_seq_pair()->Result<(),Error>{
// 创建 socket pair
let (sock1, sock2) = socketpair(
AddressFamily::Unix,
SockType::SeqPacket, // 使用 SeqPacket 类型
None, // 协议默认
SockFlag::empty(),
).expect("Failed to create socket pair");
let mut socket1 = unsafe { File::from_raw_fd(sock1) };
let mut socket2 = unsafe { File::from_raw_fd(sock2) };
// sock1 写入数据
let msg = b"hello from sock1";
socket1.write_all(msg)?;
println!("sock1 send: {:?}", String::from_utf8_lossy(&msg[..]));
// 因os read和write时会调整file的offset,write会对offset和meta size(目前返回的都是0)进行比较,
// 而read不会故双socket都先send,后recv
// sock2 回复数据
let reply = b"hello from sock2";
socket2.write_all(reply)?;
println!("sock2 send: {:?}", String::from_utf8_lossy(reply));
// sock2 读取数据
let mut buf = [0u8; 128];
let len = socket2.read(&mut buf)?;
println!("sock2 receive: {:?}", String::from_utf8_lossy(&buf[..len]));
// sock1 读取回复
let len = socket1.read(&mut buf)?;
println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len]));
Ok(())
}

View File

@ -0,0 +1,155 @@
use libc::*;
use std::{fs, str};
use std::ffi::CString;
use std::io::Error;
use std::mem;
use std::os::unix::io::RawFd;
const SOCKET_PATH: &str = "/test.seqpacket";
const MSG1: &str = "Hello, Unix SEQPACKET socket from Client!";
const MSG2: &str = "Hello, Unix SEQPACKET socket from Server!";
fn create_seqpacket_socket() -> Result<RawFd, Error> {
unsafe {
let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
if fd == -1 {
return Err(Error::last_os_error());
}
Ok(fd)
}
}
fn bind_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let path_cstr = CString::new(SOCKET_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i] = byte as i8;
}
if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
return Err(Error::last_os_error());
}
}
Ok(())
}
fn listen_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
if listen(fd, 5) == -1 {
return Err(Error::last_os_error());
}
}
Ok(())
}
fn accept_connection(fd: RawFd) -> Result<RawFd, Error> {
unsafe {
// let mut addr = sockaddr_un {
// sun_family: AF_UNIX as u16,
// sun_path: [0; 108],
// };
// let mut len = mem::size_of_val(&addr) as socklen_t;
// let client_fd = accept(fd, &mut addr as *mut _ as *mut sockaddr, &mut len);
let client_fd = accept(fd, std::ptr::null_mut(), std::ptr::null_mut());
if client_fd == -1 {
return Err(Error::last_os_error());
}
Ok(client_fd)
}
}
fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
unsafe {
let msg_bytes = msg.as_bytes();
if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0) == -1 {
return Err(Error::last_os_error());
}
}
Ok(())
}
fn receive_message(fd: RawFd) -> Result<String, Error> {
let mut buffer = [0; 1024];
unsafe {
let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(), 0);
if len == -1 {
return Err(Error::last_os_error());
}
Ok(String::from_utf8_lossy(&buffer[..len as usize]).into_owned())
}
}
pub fn test_seq_socket() ->Result<(), Error>{
// Create and bind the server socket
fs::remove_file(&SOCKET_PATH).ok();
let server_fd = create_seqpacket_socket()?;
bind_socket(server_fd)?;
listen_socket(server_fd)?;
// Accept connection in a separate thread
let server_thread = std::thread::spawn(move || {
let client_fd = accept_connection(server_fd).expect("Failed to accept connection");
// Receive and print message
let received_msg = receive_message(client_fd).expect("Failed to receive message");
println!("Server: Received message: {}", received_msg);
send_message(client_fd, MSG2).expect("Failed to send message");
// Close client connection
unsafe { close(client_fd) };
});
// Create and connect the client socket
let client_fd = create_seqpacket_socket()?;
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let path_cstr = CString::new(SOCKET_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
// Convert u8 to i8
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i] = byte as i8;
}
if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
return Err(Error::last_os_error());
}
}
send_message(client_fd, MSG1)?;
let received_msg = receive_message(client_fd).expect("Failed to receive message");
println!("Client: Received message: {}", received_msg);
// get peer_name
unsafe {
let mut addrss = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let mut len = mem::size_of_val(&addrss) as socklen_t;
let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
if res == -1 {
return Err(Error::last_os_error());
}
let sun_path = addrss.sun_path.clone();
let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::<Vec<u8>>().try_into().unwrap();
println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path));
}
server_thread.join().expect("Server thread panicked");
let received_msg = receive_message(client_fd).expect("Failed to receive message");
println!("Client: Received message: {}", received_msg);
// Close client connection
unsafe { close(client_fd) };
fs::remove_file(&SOCKET_PATH).ok();
Ok(())
}

View File

@ -0,0 +1,3 @@
/target
Cargo.lock
/install/

View File

@ -0,0 +1,11 @@
[package]
name = "test_unix_stream_socket"
version = "0.1.0"
edition = "2021"
description = "test for unix stream socket"
authors = [ "smallcjy <2628035541@qq.com>" ]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
libc = "0.2.158"

View File

@ -0,0 +1,56 @@
TOOLCHAIN=
RUSTFLAGS=
ifdef DADK_CURRENT_BUILD_DIR
# 如果是在dadk中编译那么安装到dadk的安装目录中
INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR)
else
# 如果是在本地编译那么安装到当前目录下的install目录中
INSTALL_DIR = ./install
endif
ifeq ($(ARCH), x86_64)
export RUST_TARGET=x86_64-unknown-linux-musl
else ifeq ($(ARCH), riscv64)
export RUST_TARGET=riscv64gc-unknown-linux-gnu
else
# 默认为x86_86用于本地编译
export RUST_TARGET=x86_64-unknown-linux-musl
endif
run:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET)
build:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET)
clean:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET)
test:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET)
doc:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET)
fmt:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt
fmt-check:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check
run-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release
build-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release
clean-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release
test-release:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release
.PHONY: install
install:
RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force

View File

@ -0,0 +1,5 @@
# unix stream socket 测试程序
## 测试思路
跨线程通信,一个线程作为服务端监听一个测试文件,另一个线程作为客户端连接监听的文件。若连接成功,测试能够正常通信。

View File

@ -0,0 +1,153 @@
use std::io::Error;
use std::os::fd::RawFd;
use std::fs;
use libc::*;
use std::ffi::CString;
use std::mem;
const SOCKET_PATH: &str = "/test.stream";
const MSG1: &str = "Hello, unix stream socket from Client!";
const MSG2: &str = "Hello, unix stream socket from Server!";
fn create_stream_socket() -> Result<RawFd, Error>{
unsafe {
let fd = socket(AF_UNIX, SOCK_STREAM, 0);
if fd == -1 {
return Err(Error::last_os_error())
}
Ok(fd)
}
}
fn bind_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let path_cstr = CString::new(SOCKET_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i] = byte as i8;
}
if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
return Err(Error::last_os_error());
}
}
Ok(())
}
fn listen_socket(fd: RawFd) -> Result<(), Error> {
unsafe {
if listen(fd, 5) == -1 {
return Err(Error::last_os_error());
}
}
Ok(())
}
fn accept_conn(fd: RawFd) -> Result<RawFd, Error> {
unsafe {
let client_fd = accept(fd, std::ptr::null_mut(), std::ptr::null_mut());
if client_fd == -1 {
return Err(Error::last_os_error());
}
Ok(client_fd)
}
}
fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> {
unsafe {
let msg_bytes = msg.as_bytes();
if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0)== -1 {
return Err(Error::last_os_error());
}
}
Ok(())
}
fn recv_message(fd: RawFd) -> Result<String, Error> {
let mut buffer = [0; 1024];
unsafe {
let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(),0);
if len == -1 {
return Err(Error::last_os_error());
}
Ok(String::from_utf8_lossy(&buffer[..len as usize]).into_owned())
}
}
fn test_stream() -> Result<(), Error> {
fs::remove_file(&SOCKET_PATH).ok();
let server_fd = create_stream_socket()?;
bind_socket(server_fd)?;
listen_socket(server_fd)?;
let server_thread = std::thread::spawn(move || {
let client_fd = accept_conn(server_fd).expect("Failed to accept connection");
println!("accept success!");
let recv_msg = recv_message(client_fd).expect("Failed to receive message");
println!("Server: Received message: {}", recv_msg);
send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish");
unsafe {close(client_fd)};
});
let client_fd = create_stream_socket()?;
unsafe {
let mut addr = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let path_cstr = CString::new(SOCKET_PATH).unwrap();
let path_bytes = path_cstr.as_bytes();
for (i, &byte) in path_bytes.iter().enumerate() {
addr.sun_path[i] = byte as i8;
}
if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 {
return Err(Error::last_os_error());
}
}
send_message(client_fd, MSG1)?;
// get peer_name
unsafe {
let mut addrss = sockaddr_un {
sun_family: AF_UNIX as u16,
sun_path: [0; 108],
};
let mut len = mem::size_of_val(&addrss) as socklen_t;
let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len);
if res == -1 {
return Err(Error::last_os_error());
}
let sun_path = addrss.sun_path.clone();
let peer_path:[u8;108] = sun_path.iter().map(|&x| x as u8).collect::<Vec<u8>>().try_into().unwrap();
println!("Client: Connected to server at path: {}", String::from_utf8_lossy(&peer_path));
}
server_thread.join().expect("Server thread panicked");
println!("Client try recv!");
let recv_msg = recv_message(client_fd).expect("Failed to receive message from server");
println!("Client Received message: {}", recv_msg);
unsafe {close(client_fd)};
fs::remove_file(&SOCKET_PATH).ok();
Ok(())
}
fn main() {
match test_stream() {
Ok(_) => println!("test for unix stream success"),
Err(_) => println!("test for unix stream failed")
}
}

View File

@ -0,0 +1,24 @@
{
"name": "ping",
"version": "0.1.0",
"description": "ping用户程序",
"task_type": {
"BuildFromSource": {
"Local": {
"path": "apps/ping"
}
}
},
"depends": [],
"build": {
"build_command": "make install"
},
"install": {
"in_dragonos_path": "/usr"
},
"clean": {
"clean_command": "make clean"
},
"envs": [],
"target_arch": ["x86_64"]
}

View File

@ -0,0 +1,29 @@
{
"name": "test_seqpacket",
"version": "0.1.0",
"description": "对seqpacket_pair的简单测试",
"rust_target": null,
"task_type": {
"BuildFromSource": {
"Local": {
"path": "apps/test_seqpacket"
}
}
},
"depends": [],
"build": {
"build_command": "make install"
},
"install": {
"in_dragonos_path": "/"
},
"clean": {
"clean_command": "make clean"
},
"envs": [],
"build_once": false,
"install_once": false,
"target_arch": [
"x86_64"
]
}

View File

@ -0,0 +1,29 @@
{
"name": "test_stream_socket",
"version": "0.1.0",
"description": "test for unix stream socket",
"rust_target": null,
"task_type": {
"BuildFromSource": {
"Local": {
"path": "apps/test_unix_stream_socket"
}
}
},
"depends": [],
"build": {
"build_command": "make install"
},
"install": {
"in_dragonos_path": "/"
},
"clean": {
"clean_command": "make clean"
},
"envs": [],
"build_once": false,
"install_once": false,
"target_arch": [
"x86_64"
]
}

Some files were not shown because too many files have changed in this diff Show More