From cde5492f725681ed89abe1e6eb088e05d943d793 Mon Sep 17 00:00:00 2001 From: login Date: Wed, 19 Apr 2023 18:05:02 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=BD=91=E7=BB=9Csocket?= =?UTF-8?q?=E7=9A=84=E7=B3=BB=E7=BB=9F=E8=B0=83=E7=94=A8=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=20(#247)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1.修复spinlock忘记恢复rflags的问题 2.WaitQueue增加wakeup_all的功能 3.完善tcp,udp,raw socket 4.把PollStatus结构体改为使用bitflags 5.新增iovec结构体 6.完成网络的系统调用 7.在bootstrap里面添加dnsmasq bridge-utils iptables --------- Co-authored-by: guanjinquan <1666320330@qq.com> --- .gitignore | 1 + kernel/src/driver/disk/ahci/ahci_inode.rs | 4 +- kernel/src/driver/keyboard/ps2_keyboard.rs | 4 +- kernel/src/driver/net/virtio_net.rs | 2 +- kernel/src/exception/trap.c | 5 +- kernel/src/filesystem/devfs/mod.rs | 4 +- kernel/src/filesystem/devfs/null_dev.rs | 4 +- kernel/src/filesystem/devfs/zero_dev.rs | 4 +- kernel/src/filesystem/fat/fs.rs | 4 +- kernel/src/filesystem/procfs/mod.rs | 4 +- kernel/src/filesystem/ramfs/mod.rs | 4 +- kernel/src/filesystem/vfs/file.rs | 6 + kernel/src/filesystem/vfs/mod.rs | 51 +- kernel/src/filesystem/vfs/syscall.rs | 105 +- kernel/src/lib.rs | 6 +- kernel/src/libs/casting.rs | 2 +- kernel/src/libs/spinlock.rs | 6 +- kernel/src/libs/wait_queue.rs | 29 +- kernel/src/mm/allocator.rs | 4 +- kernel/src/net/mod.rs | 53 +- kernel/src/net/net_core.rs | 77 +- kernel/src/net/socket.rs | 729 ++++++++++--- kernel/src/net/syscall.rs | 1122 ++++++++++++++++++++ kernel/src/process/process.rs | 24 +- kernel/src/syscall/syscall.c | 37 +- kernel/src/syscall/syscall_num.h | 14 + kernel/src/time/timer.rs | 50 +- tools/bootstrap.sh | 2 +- tools/qemu/ifdown-nat | 24 + tools/qemu/ifup-nat | 85 ++ tools/run-qemu.sh | 39 +- user/apps/test_relibc/main.c | 256 ++++- user/libs/libsystem/syscall.h | 14 + 33 files changed, 2535 insertions(+), 240 deletions(-) create mode 100644 kernel/src/net/syscall.rs create mode 100755 tools/qemu/ifdown-nat create mode 100755 tools/qemu/ifup-nat diff --git a/.gitignore b/.gitignore index 58fdc1b7..289ac261 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ !.gitkeep DragonOS.iso .idea/ +/tmp/ kernel/kernel .DS_Store diff --git a/kernel/src/driver/disk/ahci/ahci_inode.rs b/kernel/src/driver/disk/ahci/ahci_inode.rs index 7d07286f..1dca48e3 100644 --- a/kernel/src/driver/disk/ahci/ahci_inode.rs +++ b/kernel/src/driver/disk/ahci/ahci_inode.rs @@ -108,9 +108,7 @@ impl IndexNode for LockedAhciInode { } fn poll(&self) -> Result { - return Ok(PollStatus { - flags: PollStatus::READ_MASK | PollStatus::WRITE_MASK, - }); + return Ok(PollStatus::READ | PollStatus::WRITE); } /// 读设备 - 应该调用设备的函数读写,而不是通过文件系统读写 diff --git a/kernel/src/driver/keyboard/ps2_keyboard.rs b/kernel/src/driver/keyboard/ps2_keyboard.rs index 082f34a0..c669adca 100644 --- a/kernel/src/driver/keyboard/ps2_keyboard.rs +++ b/kernel/src/driver/keyboard/ps2_keyboard.rs @@ -150,9 +150,7 @@ impl IndexNode for LockedPS2KeyBoardInode { } fn poll(&self) -> Result { - return Ok(PollStatus { - flags: PollStatus::READ_MASK, - }); + return Ok(PollStatus::READ); } fn metadata(&self) -> Result { diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 4456e38c..417e60da 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -10,7 +10,7 @@ use virtio_drivers::{device::net::VirtIONet, transport::Transport}; use crate::{ driver::{virtio::virtio_impl::HalImpl, Driver}, - kdebug, kerror, kinfo, + kerror, kinfo, libs::spinlock::SpinLock, net::{generate_iface_id, NET_DRIVERS}, syscall::SystemError, diff --git a/kernel/src/exception/trap.c b/kernel/src/exception/trap.c index 68acd3cb..f42dd178 100644 --- a/kernel/src/exception/trap.c +++ b/kernel/src/exception/trap.c @@ -245,8 +245,9 @@ void do_page_fault(struct pt_regs *regs, unsigned long error_code) printk_color(RED, BLACK, "CR2:%#018lx\n", cr2); traceback(regs); - current_pcb->state = PROC_STOPPED; - sched(); + process_do_exit(-1); + // current_pcb->state = PROC_STOPPED; + // sched(); } // 15 Intel保留,请勿使用 diff --git a/kernel/src/filesystem/devfs/mod.rs b/kernel/src/filesystem/devfs/mod.rs index 2e4b8000..cb68b590 100644 --- a/kernel/src/filesystem/devfs/mod.rs +++ b/kernel/src/filesystem/devfs/mod.rs @@ -465,9 +465,7 @@ impl IndexNode for LockedDevFSInode { return Err(SystemError::EISDIR); } - return Ok(PollStatus { - flags: PollStatus::READ_MASK | PollStatus::WRITE_MASK, - }); + return Ok(PollStatus::READ | PollStatus::WRITE); } /// 读设备 - 应该调用设备的函数读写,而不是通过文件系统读写 diff --git a/kernel/src/filesystem/devfs/null_dev.rs b/kernel/src/filesystem/devfs/null_dev.rs index 3331fbf1..2a1ed6e7 100644 --- a/kernel/src/filesystem/devfs/null_dev.rs +++ b/kernel/src/filesystem/devfs/null_dev.rs @@ -102,9 +102,7 @@ impl IndexNode for LockedNullInode { } fn poll(&self) -> Result { - return Ok(PollStatus { - flags: PollStatus::READ_MASK | PollStatus::WRITE_MASK, - }); + return Ok(PollStatus::READ | PollStatus::WRITE); } /// 读设备 - 应该调用设备的函数读写,而不是通过文件系统读写 diff --git a/kernel/src/filesystem/devfs/zero_dev.rs b/kernel/src/filesystem/devfs/zero_dev.rs index 5d98616b..fd37bda5 100644 --- a/kernel/src/filesystem/devfs/zero_dev.rs +++ b/kernel/src/filesystem/devfs/zero_dev.rs @@ -102,9 +102,7 @@ impl IndexNode for LockedZeroInode { } fn poll(&self) -> Result { - return Ok(PollStatus { - flags: PollStatus::READ_MASK | PollStatus::WRITE_MASK, - }); + return Ok(PollStatus::READ | PollStatus::WRITE); } /// 读设备 - 应该调用设备的函数读写,而不是通过文件系统读写 diff --git a/kernel/src/filesystem/fat/fs.rs b/kernel/src/filesystem/fat/fs.rs index 7503cf7a..bf3dc90e 100644 --- a/kernel/src/filesystem/fat/fs.rs +++ b/kernel/src/filesystem/fat/fs.rs @@ -1411,9 +1411,7 @@ impl IndexNode for LockedFATInode { return Err(SystemError::EISDIR); } - return Ok(PollStatus { - flags: PollStatus::READ_MASK | PollStatus::WRITE_MASK, - }); + return Ok(PollStatus::READ | PollStatus::WRITE); } fn create( diff --git a/kernel/src/filesystem/procfs/mod.rs b/kernel/src/filesystem/procfs/mod.rs index a26bacb0..b46f954c 100644 --- a/kernel/src/filesystem/procfs/mod.rs +++ b/kernel/src/filesystem/procfs/mod.rs @@ -425,9 +425,7 @@ impl IndexNode for LockedProcFSInode { return Err(SystemError::EISDIR); } - return Ok(PollStatus { - flags: PollStatus::READ_MASK, - }); + return Ok(PollStatus::READ); } fn fs(&self) -> Arc { diff --git a/kernel/src/filesystem/ramfs/mod.rs b/kernel/src/filesystem/ramfs/mod.rs index e087095a..3dd893b9 100644 --- a/kernel/src/filesystem/ramfs/mod.rs +++ b/kernel/src/filesystem/ramfs/mod.rs @@ -186,9 +186,7 @@ impl IndexNode for LockedRamFSInode { return Err(SystemError::EISDIR); } - return Ok(PollStatus { - flags: PollStatus::READ_MASK | PollStatus::WRITE_MASK, - }); + return Ok(PollStatus::READ | PollStatus::WRITE); } fn fs(&self) -> Arc { diff --git a/kernel/src/filesystem/vfs/file.rs b/kernel/src/filesystem/vfs/file.rs index e8b93bd5..6c2394c0 100644 --- a/kernel/src/filesystem/vfs/file.rs +++ b/kernel/src/filesystem/vfs/file.rs @@ -290,6 +290,12 @@ impl File { return Some(res); } + + /// @brief 获取文件的类型 + #[inline] + pub fn file_type(&self) -> FileType { + return self.file_type; + } } impl Drop for File { diff --git a/kernel/src/filesystem/vfs/mod.rs b/kernel/src/filesystem/vfs/mod.rs index e0b4d484..aabe202d 100644 --- a/kernel/src/filesystem/vfs/mod.rs +++ b/kernel/src/filesystem/vfs/mod.rs @@ -3,14 +3,14 @@ pub mod core; pub mod file; pub mod mount; -mod syscall; +pub mod syscall; mod utils; use ::core::{any::Any, fmt::Debug}; use alloc::{string::String, sync::Arc, vec::Vec}; -use crate::{syscall::SystemError, time::TimeSpec}; +use crate::{libs::casting::DowncastArc, syscall::SystemError, time::TimeSpec}; use self::{core::generate_inode_id, file::FileMode}; pub use self::{core::ROOT_INODE, file::FilePrivateData, mount::MountFS}; @@ -36,6 +36,8 @@ pub enum FileType { Pipe, /// 符号链接 SymLink, + /// 套接字 + Socket, } /* these are defined by POSIX and also present in glibc's dirent.h */ @@ -68,20 +70,18 @@ impl FileType { FileType::CharDevice => DT_CHR, FileType::Pipe => DT_FIFO, FileType::SymLink => DT_LNK, + FileType::Socket => DT_SOCK, }; } } -/// @brief inode的状态(由poll方法返回) -#[derive(Debug, Default, PartialEq)] -pub struct PollStatus { - pub flags: u8, -} - -impl PollStatus { - pub const WRITE_MASK: u8 = (1u8 << 0); - pub const READ_MASK: u8 = (1u8 << 1); - pub const ERR_MASK: u8 = (1u8 << 2); +bitflags! { + /// @brief inode的状态(由poll方法返回) + pub struct PollStatus: u8 { + const WRITE = 1u8 << 0; + const READ = 1u8 << 1; + const ERROR = 1u8 << 2; + } } pub trait IndexNode: Any + Sync + Send + Debug { @@ -336,6 +336,12 @@ pub trait IndexNode: Any + Sync + Send + Debug { } } +impl DowncastArc for dyn IndexNode { + fn as_any_arc(self: Arc) -> Arc { + self + } +} + impl dyn IndexNode { /// @brief 将当前Inode转换为一个具体的结构体(类型由T指定) /// 如果类型正确,则返回Some,否则返回None @@ -482,6 +488,27 @@ pub struct Metadata { pub raw_dev: usize, } +impl Default for Metadata { + fn default() -> Self { + return Self { + dev_id: 0, + inode_id: 0, + size: 0, + blk_size: 0, + blocks: 0, + atime: TimeSpec::default(), + mtime: TimeSpec::default(), + ctime: TimeSpec::default(), + file_type: FileType::File, + mode: 0, + nlinks: 1, + uid: 0, + gid: 0, + raw_dev: 0, + }; + } +} + /// @brief 所有文件系统都应该实现的trait pub trait FileSystem: Any + Sync + Send + Debug { /// @brief 获取当前文件系统的root inode的指针 diff --git a/kernel/src/filesystem/vfs/syscall.rs b/kernel/src/filesystem/vfs/syscall.rs index a9ef927e..b26a7d29 100644 --- a/kernel/src/filesystem/vfs/syscall.rs +++ b/kernel/src/filesystem/vfs/syscall.rs @@ -1,6 +1,6 @@ use core::ffi::{c_char, CStr}; -use alloc::{boxed::Box, string::ToString}; +use alloc::{boxed::Box, string::ToString, vec::Vec}; use crate::{ arch::asm::{current::current_pcb, ptrace::user_mode}, @@ -435,3 +435,106 @@ pub extern "C" fn sys_dup2(regs: &pt_regs) -> u64 { return r.unwrap_err().to_posix_errno() as u64; } } + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct IoVec { + /// 缓冲区的起始地址 + pub iov_base: *mut u8, + /// 缓冲区的长度 + pub iov_len: usize, +} + +/// 用于存储多个来自用户空间的IoVec +/// +/// 由于目前内核中的文件系统还不支持分散读写,所以暂时只支持将用户空间的IoVec聚合成一个缓冲区,然后进行操作。 +/// TODO:支持分散读写 +#[derive(Debug)] +pub struct IoVecs(Vec<&'static mut [u8]>); + +impl IoVecs { + /// 从用户空间的IoVec中构造IoVecs + /// + /// @param iov 用户空间的IoVec + /// @param iovcnt 用户空间的IoVec的数量 + /// @param readv 是否为readv系统调用 + /// + /// @return 构造成功返回IoVecs,否则返回错误码 + pub unsafe fn from_user( + iov: *const IoVec, + iovcnt: usize, + _readv: bool, + ) -> Result { + // 检查iov指针所在空间是否合法 + if !verify_area( + iov as usize as u64, + (iovcnt * core::mem::size_of::()) as u64, + ) { + return Err(SystemError::EFAULT); + } + + // 将用户空间的IoVec转换为引用(注意:这里的引用是静态的,因为用户空间的IoVec不会被释放) + let iovs: &[IoVec] = core::slice::from_raw_parts(iov, iovcnt); + + let mut slices: Vec<&mut [u8]> = vec![]; + slices.reserve(iovs.len()); + + for iov in iovs.iter() { + if iov.iov_len == 0 { + continue; + } + + if !verify_area(iov.iov_base as usize as u64, iov.iov_len as u64) { + return Err(SystemError::EFAULT); + } + + slices.push(core::slice::from_raw_parts_mut(iov.iov_base, iov.iov_len)); + } + + return Ok(Self(slices)); + } + + /// @brief 将IoVecs中的数据聚合到一个缓冲区中 + /// + /// @return 返回聚合后的缓冲区 + pub fn gather(&self) -> Vec { + let mut buf = Vec::new(); + for slice in self.0.iter() { + buf.extend_from_slice(slice); + } + return buf; + } + + /// @brief 将给定的数据分散写入到IoVecs中 + pub fn scatter(&mut self, data: &[u8]) { + let mut data: &[u8] = data; + for slice in self.0.iter_mut() { + let len = core::cmp::min(slice.len(), data.len()); + if len == 0 { + continue; + } + + slice[..len].copy_from_slice(&data[..len]); + data = &data[len..]; + } + } + + /// @brief 创建与IoVecs等长的缓冲区 + /// + /// @param set_len 是否设置返回的Vec的len。 + /// 如果为true,则返回的Vec的len为所有IoVec的长度之和; + /// 否则返回的Vec的len为0,capacity为所有IoVec的长度之和. + /// + /// @return 返回创建的缓冲区 + pub fn new_buf(&self, set_len: bool) -> Vec { + let total_len: usize = self.0.iter().map(|slice| slice.len()).sum(); + let mut buf: Vec = Vec::with_capacity(total_len); + + if set_len { + unsafe { + buf.set_len(total_len); + } + } + return buf; + } +} diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 10b2745e..09098b4f 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -6,6 +6,7 @@ #![feature(panic_info_message)] #![feature(drain_filter)] // 允许Vec的drain_filter特性 #![feature(c_void_variant)] +#![feature(trait_upcasting)] #[allow(non_upper_case_globals)] #[allow(non_camel_case_types)] #[allow(non_snake_case)] @@ -99,6 +100,9 @@ pub fn panic(info: &PanicInfo) -> ! { #[no_mangle] pub extern "C" fn __rust_demo_func() -> i32 { printk_color!(GREEN, BLACK, "__rust_demo_func()\n"); - net_init().expect("Failed to init network"); + let r = net_init(); + if r.is_err() { + kwarn!("net_init() failed: {:?}", r.err().unwrap()); + } return 0; } diff --git a/kernel/src/libs/casting.rs b/kernel/src/libs/casting.rs index 20c0fc41..a80dc261 100644 --- a/kernel/src/libs/casting.rs +++ b/kernel/src/libs/casting.rs @@ -53,7 +53,7 @@ use alloc::sync::Arc; /// assert!(a_arc2.is_some()); /// } /// ``` -trait DowncastArc: Any + Send + Sync { +pub trait DowncastArc: Any + Send + Sync { /// 请在具体类型中实现这个函数,返回self fn as_any_arc(self: Arc) -> Arc; diff --git a/kernel/src/libs/spinlock.rs b/kernel/src/libs/spinlock.rs index 3893d9b6..a0e1a7a3 100644 --- a/kernel/src/libs/spinlock.rs +++ b/kernel/src/libs/spinlock.rs @@ -244,6 +244,10 @@ impl DerefMut for SpinLockGuard<'_, T> { /// @brief 为SpinLockGuard实现Drop方法,那么,一旦守卫的生命周期结束,就会自动释放自旋锁,避免了忘记放锁的情况 impl Drop for SpinLockGuard<'_, T> { fn drop(&mut self) { - self.lock.lock.unlock(); + if self.flag != 0 { + self.lock.lock.unlock_irqrestore(&self.flag); + } else { + self.lock.lock.unlock(); + } } } diff --git a/kernel/src/libs/wait_queue.rs b/kernel/src/libs/wait_queue.rs index 610732a7..0e1fd588 100644 --- a/kernel/src/libs/wait_queue.rs +++ b/kernel/src/libs/wait_queue.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use alloc::collections::LinkedList; +use alloc::{collections::LinkedList, vec::Vec}; use crate::{ arch::{asm::current::current_pcb, sched::sched}, @@ -127,6 +127,33 @@ impl WaitQueue { } } + /// @brief 唤醒在队列中,符合条件的所有进程。 + /// + /// @param state 用于判断的state,如果队列中第一个进程的state与它进行and操作之后,结果不为0,则唤醒这个进程。 + pub fn wakeup_all(&self, state: u64) { + let mut guard: SpinLockGuard = self.0.lock_irqsave(); + // 如果队列为空,则返回 + if guard.wait_list.is_empty() { + return; + } + + let mut to_push_back: Vec<&mut process_control_block> = Vec::new(); + // 如果队列头部的pcb的state与给定的state相与,结果不为0,则唤醒 + while let Some(to_wakeup) = guard.wait_list.pop_front() { + if (to_wakeup.state & state) != 0 { + unsafe { + process_wakeup(to_wakeup); + } + } else { + to_push_back.push(to_wakeup); + } + } + + for to_wakeup in to_push_back { + guard.wait_list.push_back(to_wakeup); + } + } + /// @brief 获得当前等待队列的大小 pub fn len(&self) -> usize { return self.0.lock().wait_list.len(); diff --git a/kernel/src/mm/allocator.rs b/kernel/src/mm/allocator.rs index 4bee32b7..c1020d51 100644 --- a/kernel/src/mm/allocator.rs +++ b/kernel/src/mm/allocator.rs @@ -50,6 +50,6 @@ unsafe impl GlobalAlloc for KernelAllocator { /// 内存分配错误处理函数 #[alloc_error_handler] -pub fn global_alloc_err_handler(_layout: Layout) -> ! { - panic!("global_alloc_error"); +pub fn global_alloc_err_handler(layout: Layout) -> ! { + panic!("global_alloc_error, layout: {:?}", layout); } diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 1440f764..b8e71107 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -5,7 +5,7 @@ use core::{ use alloc::{boxed::Box, collections::BTreeMap, sync::Arc}; -use crate::{driver::net::NetDriver, libs::rwlock::RwLock, syscall::SystemError}; +use crate::{driver::net::NetDriver, kwarn, libs::rwlock::RwLock, syscall::SystemError}; use smoltcp::wire::IpEndpoint; use self::socket::SocketMetadata; @@ -13,6 +13,7 @@ use self::socket::SocketMetadata; pub mod endpoints; pub mod net_core; pub mod socket; +pub mod syscall; lazy_static! { /// @brief 所有网络接口的列表 @@ -28,11 +29,28 @@ pub fn generate_iface_id() -> usize { } /// @brief 用于指定socket的关闭类型 -#[derive(Debug, Clone)] +/// 参考:https://pubs.opengroup.org/onlinepubs/9699919799/functions/shutdown.html +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] pub enum ShutdownType { - ShutRd, // Disables further receive operations. - ShutWr, // Disables further send operations. - ShutRdwr, // Disables further send and receive operations. + ShutRd = 0, // Disables further receive operations. + ShutWr = 1, // Disables further send operations. + ShutRdwr = 2, // Disables further send and receive operations. +} + +impl TryFrom for ShutdownType { + type Error = SystemError; + + fn try_from(value: i32) -> Result { + use num_traits::FromPrimitive; + ::from_i32(value).ok_or(SystemError::EINVAL) + } +} + +impl Into for ShutdownType { + fn into(self) -> i32 { + use num_traits::ToPrimitive; + ::to_i32(&self).unwrap() + } } #[derive(Debug, Clone)] @@ -40,7 +58,7 @@ pub enum Endpoint { /// 链路层端点 LinkLayer(endpoints::LinkLayerEndpoint), /// 网络层端点 - Ip(IpEndpoint), + Ip(Option), // todo: 增加NetLink机制后,增加NetLink端点 } @@ -51,7 +69,7 @@ pub trait Socket: Sync + Send + Debug { /// /// @return - 成功:(返回读取的数据的长度,读取数据的端点). /// - 失败:错误码 - fn read(&self, buf: &mut [u8]) -> Result<(usize, Endpoint), SystemError>; + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint); /// @brief 向socket中写入数据。如果socket是阻塞的,那么直到写入的数据全部写入socket中才返回 /// @@ -80,7 +98,7 @@ pub trait Socket: Sync + Send + Debug { /// @param endpoint 要绑定的端点 /// /// @return 返回绑定是否成功 - fn bind(&self, _endpoint: Endpoint) -> Result<(), SystemError> { + fn bind(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { return Err(SystemError::ENOSYS); } @@ -106,7 +124,7 @@ pub trait Socket: Sync + Send + Debug { /// @brief 对应于POSIX的accept函数,用于接受连接 /// - /// @param endpoint 用于返回连接的端点 + /// @param endpoint 对端的端点 /// /// @return 返回接受连接是否成功 fn accept(&mut self) -> Result<(Box, Endpoint), SystemError> { @@ -164,6 +182,23 @@ pub trait Socket: Sync + Send + Debug { fn metadata(&self) -> Result; fn box_clone(&self) -> Box; + + /// @brief 设置socket的选项 + /// + /// @param level 选项的层次 + /// @param optname 选项的名称 + /// @param optval 选项的值 + /// + /// @return 返回设置是否成功, 如果不支持该选项,返回ENOSYS + fn setsockopt( + &self, + _level: usize, + _optname: usize, + _optval: &[u8], + ) -> Result<(), SystemError> { + kwarn!("setsockopt is not implemented"); + return Err(SystemError::ENOSYS); + } } impl Clone for Box { diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 7b203f93..7d2105f7 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -1,15 +1,42 @@ +use alloc::{boxed::Box, collections::BTreeMap, sync::Arc}; use smoltcp::{socket::dhcpv4, wire}; -use crate::{kdebug, kinfo, net::NET_DRIVERS, syscall::SystemError}; +use crate::{ + driver::net::NetDriver, + kdebug, kinfo, kwarn, + libs::rwlock::RwLockReadGuard, + net::NET_DRIVERS, + syscall::SystemError, + time::timer::{next_n_ms_timer_jiffies, Timer, TimerFunction}, +}; + +use super::socket::{SOCKET_SET, SOCKET_WAITQUEUE}; + +/// The network poll function, which will be called by timer. +/// +/// The main purpose of this function is to poll all network interfaces. +struct NetWorkPollFunc(); +impl TimerFunction for NetWorkPollFunc { + fn run(&mut self) { + poll_ifaces_try_lock(10).ok(); + let next_time = next_n_ms_timer_jiffies(10); + let timer = Timer::new(Box::new(NetWorkPollFunc()), next_time); + timer.activate(); + } +} pub fn net_init() -> Result<(), SystemError> { dhcp_query()?; + // Init poll timer function + let next_time = next_n_ms_timer_jiffies(5); + let timer = Timer::new(Box::new(NetWorkPollFunc()), next_time); + timer.activate(); return Ok(()); } fn dhcp_query() -> Result<(), SystemError> { let binding = NET_DRIVERS.write(); - let net_face = binding.get(&0).unwrap().clone(); + let net_face = binding.get(&0).ok_or(SystemError::ENODEV)?.clone(); drop(binding); @@ -85,3 +112,49 @@ fn dhcp_query() -> Result<(), SystemError> { return Err(SystemError::ETIMEDOUT); } + +pub fn poll_ifaces() { + let guard: RwLockReadGuard>> = NET_DRIVERS.read(); + if guard.len() == 0 { + kwarn!("poll_ifaces: No net driver found!"); + return; + } + let mut sockets = SOCKET_SET.lock(); + for (_, iface) in guard.iter() { + iface.poll(&mut sockets).ok(); + } + SOCKET_WAITQUEUE.wakeup_all((-1i64) as u64); +} + +/// 对ifaces进行轮询,最多对SOCKET_SET尝试times次加锁。 +/// +/// @return 轮询成功,返回Ok(()) +/// @return 加锁超时,返回SystemError::EAGAIN_OR_EWOULDBLOCK +/// @return 没有网卡,返回SystemError::ENODEV +pub fn poll_ifaces_try_lock(times: u16) -> Result<(), SystemError> { + let mut i = 0; + while i < times { + let guard: RwLockReadGuard>> = NET_DRIVERS.read(); + if guard.len() == 0 { + kwarn!("poll_ifaces: No net driver found!"); + // 没有网卡,返回错误 + return Err(SystemError::ENODEV); + } + let sockets = SOCKET_SET.try_lock(); + // 加锁失败,继续尝试 + if sockets.is_err() { + i += 1; + continue; + } + + let mut sockets = sockets.unwrap(); + for (_, iface) in guard.iter() { + iface.poll(&mut sockets).ok(); + } + SOCKET_WAITQUEUE.wakeup_all((-1i64) as u64); + return Ok(()); + } + + // 尝试次数用完,返回错误 + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); +} diff --git a/kernel/src/net/socket.rs b/kernel/src/net/socket.rs index b5b5fb8d..daee14da 100644 --- a/kernel/src/net/socket.rs +++ b/kernel/src/net/socket.rs @@ -1,23 +1,36 @@ #![allow(dead_code)] -use alloc::{boxed::Box, vec::Vec}; +use alloc::{boxed::Box, sync::Arc, vec::Vec}; use smoltcp::{ iface::{SocketHandle, SocketSet}, socket::{raw, tcp, udp}, - wire::{IpAddress, IpEndpoint, IpProtocol, Ipv4Address, Ipv4Packet, Ipv6Address}, + wire, }; use crate::{ - arch::rand::rand, kdebug, kerror, kwarn, libs::spinlock::SpinLock, syscall::SystemError, + arch::rand::rand, + driver::net::NetDriver, + filesystem::vfs::{FileType, IndexNode, Metadata, PollStatus}, + kerror, kwarn, + libs::{ + spinlock::{SpinLock, SpinLockGuard}, + wait_queue::WaitQueue, + }, + syscall::SystemError, }; -use super::{Endpoint, Protocol, Socket, NET_DRIVERS}; +use super::{net_core::poll_ifaces, Endpoint, Protocol, Socket, NET_DRIVERS}; lazy_static! { /// 所有socket的集合 /// TODO: 优化这里,自己实现SocketSet!!!现在这样的话,不管全局有多少个网卡,每个时间点都只会有1个进程能够访问socket pub static ref SOCKET_SET: SpinLock> = SpinLock::new(SocketSet::new(vec![])); + pub static ref SOCKET_WAITQUEUE: WaitQueue = WaitQueue::INIT; } +/* For setsockopt(2) */ +// See: linux-5.19.10/include/uapi/asm-generic/socket.h#9 +pub const SOL_SOCKET: u8 = 1; + /// @brief socket的句柄管理组件。 /// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。 /// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。 @@ -41,6 +54,7 @@ impl Drop for GlobalSocketHandle { let mut socket_set_guard = SOCKET_SET.lock(); socket_set_guard.remove(self.0); // 删除的时候,会发送一条FINISH的信息? drop(socket_set_guard); + poll_ifaces(); } } @@ -76,15 +90,15 @@ bitflags! { /// @brief 在trait Socket的metadata函数中返回该结构体供外部使用 pub struct SocketMetadata { /// socket的类型 - socket_type: SocketType, + pub socket_type: SocketType, /// 发送缓冲区的大小 - send_buf_size: usize, + pub send_buf_size: usize, /// 接收缓冲区的大小 - recv_buf_size: usize, + pub recv_buf_size: usize, /// 元数据的缓冲区的大小 - metadata_buf_size: usize, + pub metadata_buf_size: usize, /// socket的选项 - options: SocketOptions, + pub options: SocketOptions, } /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 @@ -127,7 +141,7 @@ impl RawSocket { let protocol: u8 = protocol.into(); let socket = raw::Socket::new( smoltcp::wire::IpVersion::Ipv4, - IpProtocol::from(protocol), + wire::IpProtocol::from(protocol), tx_buffer, rx_buffer, ); @@ -144,7 +158,8 @@ impl RawSocket { } impl Socket for RawSocket { - fn read(&self, buf: &mut [u8]) -> Result<(usize, Endpoint), SystemError> { + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { + poll_ifaces(); loop { // 如何优化这里? let mut socket_set_guard = SOCKET_SET.lock(); @@ -152,24 +167,25 @@ impl Socket for RawSocket { match socket.recv_slice(buf) { Ok(len) => { - let packet = Ipv4Packet::new_unchecked(buf); - return Ok(( - len, - Endpoint::Ip(smoltcp::wire::IpEndpoint { - addr: IpAddress::Ipv4(packet.src_addr()), + let packet = wire::Ipv4Packet::new_unchecked(buf); + return ( + Ok(len), + Endpoint::Ip(Some(smoltcp::wire::IpEndpoint { + addr: wire::IpAddress::Ipv4(packet.src_addr()), port: 0, - }), - )); + })), + ); } Err(smoltcp::socket::raw::RecvError::Exhausted) => { if !self.options.contains(SocketOptions::BLOCK) { // 如果是非阻塞的socket,就返回错误 - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); } } } drop(socket); drop(socket_set_guard); + SOCKET_WAITQUEUE.sleep(); } } @@ -189,7 +205,7 @@ impl Socket for RawSocket { } else { // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 - if let Some(Endpoint::Ip(endpoint)) = to { + if let Some(Endpoint::Ip(Some(endpoint))) = to { let mut socket_set_guard = SOCKET_SET.lock(); let socket: &mut raw::Socket = socket_set_guard.get_mut::(self.handle.0); @@ -205,13 +221,13 @@ impl Socket for RawSocket { } let ipv4_src_addr = ipv4_src_addr.unwrap(); - if let IpAddress::Ipv4(ipv4_dst) = endpoint.addr { + if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { let len = buf.len(); // 创建20字节的IPv4头部 let mut buffer: Vec = vec![0u8; len + 20]; - let mut packet: Ipv4Packet<&mut Vec> = - Ipv4Packet::new_unchecked(&mut buffer); + let mut packet: wire::Ipv4Packet<&mut Vec> = + wire::Ipv4Packet::new_unchecked(&mut buffer); // 封装ipv4 header packet.set_version(4); @@ -234,9 +250,10 @@ impl Socket for RawSocket { socket.send_slice(&buffer).unwrap(); drop(socket); - drop(socket_set_guard); - // poll? + iface.poll(&mut socket_set_guard).ok(); + + drop(socket_set_guard); return Ok(len); } else { kwarn!("Unsupport Ip protocol type!"); @@ -306,113 +323,143 @@ impl UdpSocket { options, }; } + + fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { + if let Endpoint::Ip(Some(ip)) = endpoint { + let bind_res = if ip.addr.is_unspecified() { + socket.bind(ip.port) + } else { + socket.bind(ip) + }; + + match bind_res { + Ok(()) => return Ok(()), + Err(_) => return Err(SystemError::EINVAL), + } + } else { + return Err(SystemError::EINVAL); + }; + } } impl Socket for UdpSocket { /// @brief 在read函数执行之前,请先bind到本地的指定端口 - fn read(&self, buf: &mut [u8]) -> Result<(usize, Endpoint), SystemError> { + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { loop { - kdebug!("Wait22 to Read"); - + // kdebug!("Wait22 to Read"); + poll_ifaces(); let mut socket_set_guard = SOCKET_SET.lock(); let socket = socket_set_guard.get_mut::(self.handle.0); - kdebug!("Wait to Read"); + // kdebug!("Wait to Read"); if socket.can_recv() { - kdebug!("Can Receive"); - if let Ok((size, endpoint)) = socket.recv_slice(buf) { + if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) { drop(socket); drop(socket_set_guard); - - return Ok((size, Endpoint::Ip(endpoint))); + poll_ifaces(); + return (Ok(size), Endpoint::Ip(Some(remote_endpoint))); } } else { - // 没有数据可以读取. 如果没有bind到指定端口,也会导致rx_buf为空 - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + // 如果socket没有连接,则忙等 + // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } + drop(socket); + drop(socket_set_guard); + SOCKET_WAITQUEUE.sleep(); } } fn write(&self, buf: &[u8], to: Option) -> Result { - let endpoint: &IpEndpoint = { - if let Some(Endpoint::Ip(ref endpoint)) = to { + // kdebug!("udp to send: {:?}, len={}", to, buf.len()); + let remote_endpoint: &wire::IpEndpoint = { + if let Some(Endpoint::Ip(Some(ref endpoint))) = to { endpoint - } else if let Some(Endpoint::Ip(ref endpoint)) = self.remote_endpoint { + } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { endpoint } else { return Err(SystemError::ENOTCONN); } }; + // kdebug!("udp write: remote = {:?}", remote_endpoint); let mut socket_set_guard = SOCKET_SET.lock(); let socket = socket_set_guard.get_mut::(self.handle.0); - + // kdebug!("is open()={}", socket.is_open()); + // kdebug!("socket endpoint={:?}", socket.endpoint()); if socket.endpoint().port == 0 { let temp_port = get_ephemeral_port(); - match endpoint.addr { + let local_ep = match remote_endpoint.addr { // 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧 // 否则就用 self.endpoint().addr.unwrap() - IpAddress::Ipv4(_) => { - socket - .bind(IpEndpoint::new( - smoltcp::wire::IpAddress::Ipv4(Ipv4Address::UNSPECIFIED), - temp_port, - )) - .unwrap(); - } - IpAddress::Ipv6(_) => { - socket - .bind(IpEndpoint::new( - smoltcp::wire::IpAddress::Ipv6(Ipv6Address::UNSPECIFIED), - temp_port, - )) - .unwrap(); - } - } + wire::IpAddress::Ipv4(_) => Endpoint::Ip(Some(wire::IpEndpoint::new( + smoltcp::wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED), + temp_port, + ))), + wire::IpAddress::Ipv6(_) => Endpoint::Ip(Some(wire::IpEndpoint::new( + smoltcp::wire::IpAddress::Ipv6(wire::Ipv6Address::UNSPECIFIED), + temp_port, + ))), + }; + // kdebug!("udp write: local_ep = {:?}", local_ep); + self.do_bind(socket, local_ep)?; } - - return if socket.can_send() { - match socket.send_slice(&buf, *endpoint) { + // kdebug!("is open()={}", socket.is_open()); + if socket.can_send() { + // kdebug!("udp write: can send"); + match socket.send_slice(&buf, *remote_endpoint) { Ok(()) => { - // avoid deadlock + // kdebug!("udp write: send ok"); drop(socket); drop(socket_set_guard); - - Ok(buf.len()) + poll_ifaces(); + return Ok(buf.len()); + } + Err(_) => { + // kdebug!("udp write: send err"); + return Err(SystemError::ENOBUFS); } - Err(_) => Err(SystemError::ENOBUFS), } } else { - Err(SystemError::ENOBUFS) + // kdebug!("udp write: can not send"); + return Err(SystemError::ENOBUFS); }; } - fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { + fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { let mut sockets = SOCKET_SET.lock(); let socket = sockets.get_mut::(self.handle.0); + // kdebug!("UDP Bind to {:?}", endpoint); + return self.do_bind(socket, endpoint); + } - return if let Endpoint::Ip(ip) = endpoint { - match socket.bind(ip) { - Ok(()) => Ok(()), - Err(_) => Err(SystemError::EINVAL), - } - } else { - Err(SystemError::EINVAL) - }; + fn poll(&self) -> (bool, bool, bool) { + let sockets = SOCKET_SET.lock(); + let socket = sockets.get::(self.handle.0); + + return (socket.can_send(), socket.can_recv(), false); } /// @brief fn connect(&mut self, endpoint: super::Endpoint) -> Result<(), SystemError> { - return if let Endpoint::Ip(_) = endpoint { + if let Endpoint::Ip(_) = endpoint { self.remote_endpoint = Some(endpoint); - Ok(()) + return Ok(()); } else { - Err(SystemError::EINVAL) + return Err(SystemError::EINVAL); }; } + fn ioctl( + &self, + _cmd: usize, + _arg0: usize, + _arg1: usize, + _arg2: usize, + ) -> Result { + todo!() + } fn metadata(&self) -> Result { todo!() } @@ -420,6 +467,31 @@ impl Socket for UdpSocket { fn box_clone(&self) -> alloc::boxed::Box { return Box::new(self.clone()); } + + fn endpoint(&self) -> Option { + let sockets = SOCKET_SET.lock(); + let socket = sockets.get::(self.handle.0); + let listen_endpoint = socket.endpoint(); + + if listen_endpoint.port == 0 { + return None; + } else { + // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 + // 这里假设所有的地址都是ipv4 + // TODO: 支持ipv6 + let result = wire::IpEndpoint::new( + listen_endpoint + .addr + .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), + listen_endpoint.port, + ); + return Some(Endpoint::Ip(Some(result))); + } + } + + fn peer_endpoint(&self) -> Option { + return self.remote_endpoint.clone(); + } } /// @brief 表示 tcp socket @@ -428,7 +500,7 @@ impl Socket for UdpSocket { #[derive(Debug, Clone)] pub struct TcpSocket { handle: GlobalSocketHandle, - local_endpoint: Option, // save local endpoint for bind() + local_endpoint: Option, // save local endpoint for bind() is_listening: bool, options: SocketOptions, } @@ -437,9 +509,9 @@ impl TcpSocket { /// 元数据的缓冲区的大小 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; /// 默认的发送缓冲区的大小 transmiss - pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; + pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; /// 默认的接收缓冲区的大小 receive - pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; + pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; /// @brief 创建一个原始的socket /// @@ -462,33 +534,83 @@ impl TcpSocket { options, }; } + fn do_listen( + &mut self, + socket: &mut smoltcp::socket::tcp::Socket, + local_endpoint: smoltcp::wire::IpEndpoint, + ) -> Result<(), SystemError> { + let listen_result = if local_endpoint.addr.is_unspecified() { + // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port); + socket.listen(local_endpoint.port) + } else { + // kdebug!("Tcp Socket Listen on {local_endpoint}"); + socket.listen(local_endpoint) + }; + // todo: 增加端口占用检查 + return match listen_result { + Ok(()) => { + // kdebug!( + // "Tcp Socket Listen on {local_endpoint}, open?:{}", + // socket.is_open() + // ); + self.is_listening = true; + + Ok(()) + } + Err(_) => Err(SystemError::EINVAL), + }; + } } impl Socket for TcpSocket { - /// @breif - fn read(&self, buf: &mut [u8]) -> Result<(usize, Endpoint), SystemError> { + fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { + // kdebug!("tcp socket: read, buf len={}", buf.len()); + loop { + poll_ifaces(); let mut socket_set_guard = SOCKET_SET.lock(); let socket = socket_set_guard.get_mut::(self.handle.0); + // 如果socket已经关闭,返回错误 + if !socket.is_active() { + // kdebug!("Tcp Socket Read Error, socket is closed"); + return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); + } + if socket.may_recv() { - if let Ok(size) = socket.recv_slice(buf) { + let recv_res = socket.recv_slice(buf); + + if let Ok(size) = recv_res { if size > 0 { let endpoint = if let Some(p) = socket.remote_endpoint() { p } else { - return Err(SystemError::ENOTCONN); + return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); }; drop(socket); drop(socket_set_guard); - - return Ok((size, Endpoint::Ip(endpoint))); + poll_ifaces(); + return (Ok(size), Endpoint::Ip(Some(endpoint))); + } + } else { + let err = recv_res.unwrap_err(); + match err { + tcp::RecvError::InvalidState => { + kwarn!("Tcp Socket Read Error, InvalidState"); + return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); + } + tcp::RecvError::Finished => { + return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); + } } } } else { - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } + drop(socket); + drop(socket_set_guard); + SOCKET_WAITQUEUE.sleep(); } } @@ -502,7 +624,7 @@ impl Socket for TcpSocket { Ok(size) => { drop(socket); drop(socket_set_guard); - + poll_ifaces(); return Ok(size); } Err(e) => { @@ -518,54 +640,83 @@ impl Socket for TcpSocket { return Err(SystemError::ENOTCONN); } - fn connect(&mut self, _endpoint: super::Endpoint) -> Result<(), SystemError> { - // let mut sockets = SOCKET_SET.lock(); - // let mut socket = sockets.get::(self.handle.0); + fn poll(&self) -> (bool, bool, bool) { + let mut socket_set_guard = SOCKET_SET.lock(); + let socket = socket_set_guard.get_mut::(self.handle.0); - // if let Endpoint::Ip(ip) = endpoint { - // let temp_port = if ip.port == 0 { - // get_ephemeral_port() - // } else { - // ip.port - // }; + let mut input = false; + let mut output = false; + let mut error = false; + if self.is_listening && socket.is_active() { + input = true; + } else if !socket.is_open() { + error = true; + } else { + if socket.may_recv() { + input = true; + } + if socket.can_send() { + output = true; + } + } - // return match socket.connect(iface.context(), temp_port) { - // Ok(()) => { - // // avoid deadlock - // drop(socket); - // drop(sockets); + return (input, output, error); + } - // // wait for connection result - // loop { - // // poll_ifaces(); + fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { + let mut sockets = SOCKET_SET.lock(); + let socket = sockets.get_mut::(self.handle.0); - // let mut sockets = SOCKET_SET.lock(); - // let socket = sockets.get::(self.handle.0); - // match socket.state() { - // State::SynSent => { - // // still connecting - // drop(socket); - // kdebug!("poll for connection wait"); - // // SOCKET_ACTIVITY.wait(sockets); - // } - // State::Established => { - // break Ok(()); - // } - // _ => { - // break Err(SystemError::ECONNREFUSED); - // } - // } - // } - // } - // Err(_) => Err(SystemError::ENOBUFS), - // }; - // } else { - // return Err(SystemError::EINVAL); - // } - return Err(SystemError::EINVAL); + if let Endpoint::Ip(Some(ip)) = endpoint { + let temp_port = get_ephemeral_port(); + // kdebug!("temp_port: {}", temp_port); + let iface: Arc = NET_DRIVERS.write().get(&0).unwrap().clone(); + let mut inner_iface = iface.inner_iface().lock(); + // kdebug!("to connect: {ip:?}"); + + match socket.connect(&mut inner_iface.context(), ip, temp_port) { + Ok(()) => { + // avoid deadlock + drop(inner_iface); + drop(iface); + drop(socket); + drop(sockets); + loop { + poll_ifaces(); + let mut sockets = SOCKET_SET.lock(); + let socket = sockets.get_mut::(self.handle.0); + + match socket.state() { + tcp::State::Established => { + return Ok(()); + } + tcp::State::SynSent => { + drop(socket); + drop(sockets); + SOCKET_WAITQUEUE.sleep(); + } + _ => { + return Err(SystemError::ECONNREFUSED); + } + } + } + } + Err(e) => { + // kerror!("Tcp Socket Connect Error {e:?}"); + match e { + tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), + tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), + } + } + } + } else { + return Err(SystemError::EINVAL); + } } /// @brief tcp socket 监听 local_endpoint 端口 + /// + /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效 fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> { if self.is_listening { return Ok(()); @@ -576,16 +727,101 @@ impl Socket for TcpSocket { let socket = sockets.get_mut::(self.handle.0); if socket.is_listening() { + // kdebug!("Tcp Socket is already listening on {local_endpoint}"); return Ok(()); } + // kdebug!("Tcp Socket before listen, open={}", socket.is_open()); + return self.do_listen(socket, local_endpoint); + } - return match socket.listen(local_endpoint) { - Ok(()) => { - self.is_listening = true; - Ok(()) + fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { + if let Endpoint::Ip(Some(mut ip)) = endpoint { + if ip.port == 0 { + ip.port = get_ephemeral_port(); } - Err(_) => Err(SystemError::EINVAL), - }; + + self.local_endpoint = Some(ip); + self.is_listening = false; + return Ok(()); + } + return Err(SystemError::EINVAL); + } + + fn shutdown(&self, _type: super::ShutdownType) -> Result<(), SystemError> { + let mut sockets = SOCKET_SET.lock(); + let socket = sockets.get_mut::(self.handle.0); + socket.close(); + return Ok(()); + } + + fn accept(&mut self) -> Result<(Box, Endpoint), SystemError> { + let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; + loop { + // kdebug!("tcp accept: poll_ifaces()"); + poll_ifaces(); + + let mut sockets = SOCKET_SET.lock(); + + let socket = sockets.get_mut::(self.handle.0); + + if socket.is_active() { + // kdebug!("tcp accept: socket.is_active()"); + let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; + drop(socket); + + let new_socket = { + // Initialize the TCP socket's buffers. + let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); + let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); + // The new TCP socket used for sending and receiving data. + let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer); + self.do_listen(&mut tcp_socket, endpoint) + .expect("do_listen failed"); + + // tcp_socket.listen(endpoint).unwrap(); + + // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 + // 因此需要再为当前的socket分配一个新的handle + let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket)); + let old_handle = ::core::mem::replace(&mut self.handle, new_handle); + + Box::new(TcpSocket { + handle: old_handle, + local_endpoint: self.local_endpoint, + is_listening: false, + options: self.options, + }) + }; + // kdebug!("tcp accept: new socket: {:?}", new_socket); + drop(sockets); + poll_ifaces(); + + return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); + } + drop(socket); + drop(sockets); + SOCKET_WAITQUEUE.sleep(); + } + } + + fn endpoint(&self) -> Option { + let mut result: Option = + self.local_endpoint.clone().map(|x| Endpoint::Ip(Some(x))); + + if result.is_none() { + let sockets = SOCKET_SET.lock(); + let socket = sockets.get::(self.handle.0); + if let Some(ep) = socket.local_endpoint() { + result = Some(Endpoint::Ip(Some(ep))); + } + } + return result; + } + + fn peer_endpoint(&self) -> Option { + let sockets = SOCKET_SET.lock(); + let socket = sockets.get::(self.handle.0); + return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); } fn metadata(&self) -> Result { @@ -598,8 +834,11 @@ impl Socket for TcpSocket { } /// @breif 自动分配一个未被使用的PORT +/// +/// TODO: 增加ListenTable, 用于检查端口是否被占用 pub fn get_ephemeral_port() -> u16 { // TODO selects non-conflict high port + static mut EPHEMERAL_PORT: u16 = 0; unsafe { if EPHEMERAL_PORT == 0 { @@ -613,3 +852,223 @@ pub fn get_ephemeral_port() -> u16 { EPHEMERAL_PORT } } + +/// @brief 地址族的枚举 +/// +/// 参考:https://opengrok.ringotek.cn/xref/linux-5.19.10/include/linux/socket.h#180 +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum AddressFamily { + /// AF_UNSPEC 表示地址族未指定 + Unspecified = 0, + /// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同) + Unix = 1, + /// AF_INET 表示IPv4的socket + INet = 2, + /// AF_AX25 表示AMPR AX.25的socket + AX25 = 3, + /// AF_IPX 表示IPX的socket + IPX = 4, + /// AF_APPLETALK 表示Appletalk的socket + Appletalk = 5, + /// AF_NETROM 表示AMPR NET/ROM的socket + Netrom = 6, + /// AF_BRIDGE 表示多协议桥接的socket + Bridge = 7, + /// AF_ATMPVC 表示ATM PVCs的socket + Atmpvc = 8, + /// AF_X25 表示X.25的socket + X25 = 9, + /// AF_INET6 表示IPv6的socket + INet6 = 10, + /// AF_ROSE 表示AMPR ROSE的socket + Rose = 11, + /// AF_DECnet Reserved for DECnet project + Decnet = 12, + /// AF_NETBEUI Reserved for 802.2LLC project + Netbeui = 13, + /// AF_SECURITY 表示Security callback的伪AF + Security = 14, + /// AF_KEY 表示Key management API + Key = 15, + /// AF_NETLINK 表示Netlink的socket + Netlink = 16, + /// AF_PACKET 表示Low level packet interface + Packet = 17, + /// AF_ASH 表示Ash + Ash = 18, + /// AF_ECONET 表示Acorn Econet + Econet = 19, + /// AF_ATMSVC 表示ATM SVCs + Atmsvc = 20, + /// AF_RDS 表示Reliable Datagram Sockets + Rds = 21, + /// AF_SNA 表示Linux SNA Project + Sna = 22, + /// AF_IRDA 表示IRDA sockets + Irda = 23, + /// AF_PPPOX 表示PPPoX sockets + Pppox = 24, + /// AF_WANPIPE 表示WANPIPE API sockets + WanPipe = 25, + /// AF_LLC 表示Linux LLC + Llc = 26, + /// AF_IB 表示Native InfiniBand address + /// 介绍:https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks + Ib = 27, + /// AF_MPLS 表示MPLS + Mpls = 28, + /// AF_CAN 表示Controller Area Network + Can = 29, + /// AF_TIPC 表示TIPC sockets + Tipc = 30, + /// AF_BLUETOOTH 表示Bluetooth sockets + Bluetooth = 31, + /// AF_IUCV 表示IUCV sockets + Iucv = 32, + /// AF_RXRPC 表示RxRPC sockets + Rxrpc = 33, + /// AF_ISDN 表示mISDN sockets + Isdn = 34, + /// AF_PHONET 表示Phonet sockets + Phonet = 35, + /// AF_IEEE802154 表示IEEE 802.15.4 sockets + Ieee802154 = 36, + /// AF_CAIF 表示CAIF sockets + Caif = 37, + /// AF_ALG 表示Algorithm sockets + Alg = 38, + /// AF_NFC 表示NFC sockets + Nfc = 39, + /// AF_VSOCK 表示vSockets + Vsock = 40, + /// AF_KCM 表示Kernel Connection Multiplexor + Kcm = 41, + /// AF_QIPCRTR 表示Qualcomm IPC Router + Qipcrtr = 42, + /// AF_SMC 表示SMC-R sockets. + /// reserve number for PF_SMC protocol family that reuses AF_INET address family + Smc = 43, + /// AF_XDP 表示XDP sockets + Xdp = 44, + /// AF_MCTP 表示Management Component Transport Protocol + Mctp = 45, + /// AF_MAX 表示最大的地址族 + Max = 46, +} + +impl TryFrom for AddressFamily { + type Error = SystemError; + fn try_from(x: u16) -> Result { + use num_traits::FromPrimitive; + return ::from_u16(x).ok_or_else(|| SystemError::EINVAL); + } +} + +/// @brief posix套接字类型的枚举(这些值与linux内核中的值一致) +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum PosixSocketType { + Stream = 1, + Datagram = 2, + Raw = 3, + Rdm = 4, + SeqPacket = 5, + Dccp = 6, + Packet = 10, +} + +impl TryFrom for PosixSocketType { + type Error = SystemError; + fn try_from(x: u8) -> Result { + use num_traits::FromPrimitive; + return ::from_u8(x).ok_or_else(|| SystemError::EINVAL); + } +} + +/// @brief Socket在文件系统中的inode封装 +#[derive(Debug)] +pub struct SocketInode(SpinLock>); + +impl SocketInode { + pub fn new(socket: Box) -> Arc { + return Arc::new(Self(SpinLock::new(socket))); + } + + #[inline] + pub fn inner(&self) -> SpinLockGuard> { + return self.0.lock(); + } +} + +impl IndexNode for SocketInode { + fn open( + &self, + _data: &mut crate::filesystem::vfs::FilePrivateData, + _mode: &crate::filesystem::vfs::file::FileMode, + ) -> Result<(), SystemError> { + return Ok(()); + } + + fn close( + &self, + _data: &mut crate::filesystem::vfs::FilePrivateData, + ) -> Result<(), SystemError> { + return Ok(()); + } + + fn read_at( + &self, + _offset: usize, + len: usize, + buf: &mut [u8], + _data: &mut crate::filesystem::vfs::FilePrivateData, + ) -> Result { + return self.0.lock().read(&mut buf[0..len]).0; + } + + fn write_at( + &self, + _offset: usize, + len: usize, + buf: &[u8], + _data: &mut crate::filesystem::vfs::FilePrivateData, + ) -> Result { + return self.0.lock().write(&buf[0..len], None); + } + + fn poll(&self) -> Result { + let (read, write, error) = self.0.lock().poll(); + let mut result = PollStatus::empty(); + if read { + result.insert(PollStatus::READ); + } + if write { + result.insert(PollStatus::WRITE); + } + if error { + result.insert(PollStatus::ERROR); + } + return Ok(result); + } + + fn fs(&self) -> alloc::sync::Arc { + todo!() + } + + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } + + fn list(&self) -> Result, SystemError> { + return Err(SystemError::ENOTDIR); + } + + fn metadata(&self) -> Result { + let meta = Metadata { + mode: 0o777, + file_type: FileType::Socket, + ..Default::default() + }; + + return Ok(meta); + } +} diff --git a/kernel/src/net/syscall.rs b/kernel/src/net/syscall.rs new file mode 100644 index 00000000..b3f3e6b3 --- /dev/null +++ b/kernel/src/net/syscall.rs @@ -0,0 +1,1122 @@ +use core::cmp::min; + +use alloc::{boxed::Box, sync::Arc}; +use num_traits::{FromPrimitive, ToPrimitive}; +use smoltcp::wire; + +use crate::{ + arch::asm::current::current_pcb, + filesystem::vfs::{ + file::{File, FileMode}, + syscall::{IoVec, IoVecs}, + }, + include::bindings::bindings::{pt_regs, verify_area}, + net::socket::{AddressFamily, SOL_SOCKET}, + syscall::SystemError, +}; + +use super::{ + socket::{PosixSocketType, RawSocket, SocketInode, SocketOptions, TcpSocket, UdpSocket}, + Endpoint, Protocol, ShutdownType, Socket, +}; + +#[no_mangle] +pub extern "C" fn sys_socket(regs: &pt_regs) -> u64 { + let address_family = regs.r8 as usize; + let socket_type = regs.r9 as usize; + let protocol: usize = regs.r10 as usize; + // kdebug!("sys_socket: address_family: {address_family}, socket_type: {socket_type}, protocol: {protocol}"); + return do_socket(address_family, socket_type, protocol) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_socket系统调用的实际执行函数 +/// +/// @param address_family 地址族 +/// @param socket_type socket类型 +/// @param protocol 传输协议 +pub fn do_socket( + address_family: usize, + socket_type: usize, + protocol: usize, +) -> Result { + let address_family = AddressFamily::try_from(address_family as u16)?; + let socket_type = PosixSocketType::try_from((socket_type & 0xf) as u8)?; + // kdebug!("do_socket: address_family: {address_family:?}, socket_type: {socket_type:?}, protocol: {protocol}"); + // 根据地址族和socket类型创建socket + let socket: Box = match address_family { + AddressFamily::Unix | AddressFamily::INet => match socket_type { + PosixSocketType::Stream => Box::new(TcpSocket::new(SocketOptions::default())), + PosixSocketType::Datagram => Box::new(UdpSocket::new(SocketOptions::default())), + PosixSocketType::Raw => Box::new(RawSocket::new( + Protocol::from(protocol as u8), + SocketOptions::default(), + )), + _ => { + // kdebug!("do_socket: EINVAL"); + return Err(SystemError::EINVAL); + } + }, + _ => { + // kdebug!("do_socket: EAFNOSUPPORT"); + return Err(SystemError::EAFNOSUPPORT); + } + }; + // kdebug!("do_socket: socket: {socket:?}"); + let socketinode: Arc = SocketInode::new(socket); + let f = File::new(socketinode, FileMode::O_RDWR)?; + // kdebug!("do_socket: f: {f:?}"); + // 把socket添加到当前进程的文件描述符表中 + let fd = current_pcb().alloc_fd(f, None).map(|x| x as i64); + // kdebug!("do_socket: fd: {fd:?}"); + return fd; +} + +#[no_mangle] +pub extern "C" fn sys_setsockopt(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let level = regs.r9 as usize; + let optname = regs.r10 as usize; + let optval = regs.r11 as usize; + let optlen = regs.r12 as usize; + return do_setsockopt(fd, level, optname, optval as *const u8, optlen) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_setsockopt系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param level 选项级别 +/// @param optname 选项名称 +/// @param optval 选项值 +/// @param optlen optval缓冲区长度 +pub fn do_setsockopt( + fd: usize, + level: usize, + optname: usize, + optval: *const u8, + optlen: usize, +) -> Result { + // 验证optval的地址是否合法 + if unsafe { verify_area(optval as u64, optlen as u64) } == false { + // 地址空间超出了用户空间的范围,不合法 + return Err(SystemError::EFAULT); + } + + let socket_inode: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let data: &[u8] = unsafe { core::slice::from_raw_parts(optval, optlen) }; + // 获取内层的socket(真正的数据) + let socket = socket_inode.inner(); + return socket.setsockopt(level, optname, data).map(|_| 0); +} + +#[no_mangle] +pub extern "C" fn sys_getsockopt(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let level = regs.r9 as usize; + let optname = regs.r10 as usize; + let optval = regs.r11 as usize; + let optlen = regs.r12 as usize; + return do_getsockopt(fd, level, optname, optval as *mut u8, optlen as *mut u32) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_getsockopt系统调用的实际执行函数 +/// +/// 参考:https://man7.org/linux/man-pages/man2/setsockopt.2.html +/// +/// @param fd 文件描述符 +/// @param level 选项级别 +/// @param optname 选项名称 +/// @param optval 返回的选项值 +/// @param optlen 返回的optval缓冲区长度 +pub fn do_getsockopt( + fd: usize, + level: usize, + optname: usize, + optval: *mut u8, + optlen: *mut u32, +) -> Result { + // 验证optval的地址是否合法 + if unsafe { verify_area(optval as u64, core::mem::size_of::() as u64) } == false { + // 地址空间超出了用户空间的范围,不合法 + return Err(SystemError::EFAULT); + } + + // 验证optlen的地址是否合法 + if unsafe { verify_area(optlen as u64, core::mem::size_of::() as u64) } == false { + // 地址空间超出了用户空间的范围,不合法 + return Err(SystemError::EFAULT); + } + + // 获取socket + let optval = optval as *mut u32; + let binding: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = binding.inner(); + + if level as u8 == SOL_SOCKET { + let optname = + PosixSocketOption::try_from(optname as i32).map_err(|_| SystemError::ENOPROTOOPT)?; + match optname { + PosixSocketOption::SO_SNDBUF => { + // 返回发送缓冲区大小 + unsafe { + *optval = socket.metadata()?.send_buf_size as u32; + *optlen = core::mem::size_of::() as u32; + } + return Ok(0); + } + PosixSocketOption::SO_RCVBUF => { + let optval = optval as *mut u32; + // 返回默认的接收缓冲区大小 + unsafe { + *optval = socket.metadata()?.recv_buf_size as u32; + *optlen = core::mem::size_of::() as u32; + } + return Ok(0); + } + _ => { + return Err(SystemError::ENOPROTOOPT); + } + } + } + drop(socket); + + // To manipulate options at any other level the + // protocol number of the appropriate protocol controlling the + // option is supplied. For example, to indicate that an option is + // to be interpreted by the TCP protocol, level should be set to the + // protocol number of TCP. + + let posix_protocol = + PosixIpProtocol::try_from(level as u16).map_err(|_| SystemError::ENOPROTOOPT)?; + if posix_protocol == PosixIpProtocol::TCP { + let optname = PosixTcpSocketOptions::try_from(optname as i32) + .map_err(|_| SystemError::ENOPROTOOPT)?; + match optname { + PosixTcpSocketOptions::Congestion => return Ok(0), + _ => { + return Err(SystemError::ENOPROTOOPT); + } + } + } + return Err(SystemError::ENOPROTOOPT); +} + +#[no_mangle] +pub extern "C" fn sys_connect(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let addr = regs.r9 as usize; + let addrlen = regs.r10 as usize; + return do_connect(fd, addr as *const SockAddr, addrlen) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_connect系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回0,失败返回错误码 +pub fn do_connect(fd: usize, addr: *const SockAddr, addrlen: usize) -> Result { + let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let mut socket = socket.inner(); + // kdebug!("connect to {:?}...", endpoint); + socket.connect(endpoint)?; + return Ok(0); +} + +#[no_mangle] +pub extern "C" fn sys_bind(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let addr = regs.r9 as usize; + let addrlen = regs.r10 as usize; + return do_bind(fd, addr as *const SockAddr, addrlen) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_bind系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回0,失败返回错误码 +pub fn do_bind(fd: usize, addr: *const SockAddr, addrlen: usize) -> Result { + let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let mut socket = socket.inner(); + socket.bind(endpoint)?; + return Ok(0); +} + +#[no_mangle] +pub extern "C" fn sys_sendto(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let buf = regs.r9 as usize; + let len = regs.r10 as usize; + let flags = regs.r11 as usize; + let addr = regs.r12 as usize; + let addrlen = regs.r13 as usize; + return do_sendto( + fd, + buf as *const u8, + len, + flags, + addr as *const SockAddr, + addrlen, + ) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_sendto系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param buf 发送缓冲区 +/// @param len 发送缓冲区长度 +/// @param flags 标志 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回发送的字节数,失败返回错误码 +pub fn do_sendto( + fd: usize, + buf: *const u8, + len: usize, + _flags: usize, + addr: *const SockAddr, + addrlen: usize, +) -> Result { + if unsafe { verify_area(buf as usize as u64, len as u64) } == false { + return Err(SystemError::EFAULT); + } + let buf = unsafe { core::slice::from_raw_parts(buf, len) }; + let endpoint = if addr.is_null() { + None + } else { + Some(SockAddr::to_endpoint(addr, addrlen)?) + }; + + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = socket.inner(); + return socket.write(buf, endpoint).map(|n| n as i64); +} + +#[no_mangle] +pub extern "C" fn sys_recvfrom(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let buf = regs.r9 as usize; + let len = regs.r10 as usize; + let flags = regs.r11 as usize; + let addr = regs.r12 as usize; + let addrlen = regs.r13 as usize; + return do_recvfrom( + fd, + buf as *mut u8, + len, + flags, + addr as *mut SockAddr, + addrlen as *mut u32, + ) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_recvfrom系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param buf 接收缓冲区 +/// @param len 接收缓冲区长度 +/// @param flags 标志 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回接收的字节数,失败返回错误码 +pub fn do_recvfrom( + fd: usize, + buf: *mut u8, + len: usize, + _flags: usize, + addr: *mut SockAddr, + addrlen: *mut u32, +) -> Result { + if unsafe { verify_area(buf as usize as u64, len as u64) } == false { + return Err(SystemError::EFAULT); + } + // kdebug!( + // "do_recvfrom: fd: {}, buf: {:x}, len: {}, addr: {:x}, addrlen: {:x}", + // fd, + // buf as usize, + // len, + // addr as usize, + // addrlen as usize + // ); + + let buf = unsafe { core::slice::from_raw_parts_mut(buf, len) }; + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = socket.inner(); + + let (n, endpoint) = socket.read(buf); + drop(socket); + + let n: usize = n?; + + // 如果有地址信息,将地址信息写入用户空间 + if !addr.is_null() { + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to_user(addr, addrlen)?; + } + } + return Ok(n as i64); +} + +#[no_mangle] +pub extern "C" fn sys_recvmsg(regs: &pt_regs) -> i64 { + let fd = regs.r8 as usize; + let msg = regs.r9 as usize; + let flags = regs.r10 as usize; + return do_recvmsg(fd, msg as *mut MsgHdr, flags) + .map(|x| x as i64) + .unwrap_or_else(|e| e.to_posix_errno() as i64); +} + +/// @brief sys_recvmsg系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param msg MsgHdr +/// @param flags 标志 +/// +/// @return 成功返回接收的字节数,失败返回错误码 +pub fn do_recvmsg(fd: usize, msg: *mut MsgHdr, _flags: usize) -> Result { + // 检查指针是否合法 + if unsafe { verify_area(msg as usize as u64, core::mem::size_of::() as u64) } == false { + return Err(SystemError::EFAULT); + } + let msg: &mut MsgHdr = unsafe { msg.as_mut() }.ok_or(SystemError::EFAULT)?; + // 检查每个缓冲区地址是否合法,生成iovecs + let mut iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; + + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = socket.inner(); + + let mut buf = iovs.new_buf(true); + // 从socket中读取数据 + let (n, endpoint) = socket.read(&mut buf); + drop(socket); + + let n: usize = n?; + + // 将数据写入用户空间的iovecs + iovs.scatter(&buf[..n]); + + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to_user(msg.msg_name, &mut msg.msg_namelen)?; + } + return Ok(n as i64); +} + +#[no_mangle] +pub extern "C" fn sys_listen(regs: &pt_regs) -> i64 { + let fd = regs.r8 as usize; + let backlog = regs.r9 as usize; + return do_listen(fd, backlog) + .map(|x| x as i64) + .unwrap_or_else(|e| e.to_posix_errno() as i64); +} + +/// @brief sys_listen系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param backlog 最大连接数 +/// +/// @return 成功返回0,失败返回错误码 +pub fn do_listen(fd: usize, backlog: usize) -> Result { + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let mut socket = socket.inner(); + socket.listen(backlog)?; + return Ok(0); +} + +#[no_mangle] +pub extern "C" fn sys_shutdown(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let how = regs.r9 as usize; + return do_shutdown(fd, how) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_shutdown系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param how 关闭方式 +/// +/// @return 成功返回0,失败返回错误码 +pub fn do_shutdown(fd: usize, how: usize) -> Result { + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = socket.inner(); + socket.shutdown(ShutdownType::try_from(how as i32)?)?; + return Ok(0); +} + +#[no_mangle] +pub extern "C" fn sys_accept(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let addr = regs.r9 as usize; + let addrlen = regs.r10 as usize; + return do_accept(fd, addr as *mut SockAddr, addrlen as *mut u32) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_accept系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回新的文件描述符,失败返回错误码 +pub fn do_accept(fd: usize, addr: *mut SockAddr, addrlen: *mut u32) -> Result { + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + // kdebug!("accept: socket={:?}", socket); + let mut socket = socket.inner(); + // 从socket中接收连接 + let (new_socket, remote_endpoint) = socket.accept()?; + drop(socket); + + // kdebug!("accept: new_socket={:?}", new_socket); + // Insert the new socket into the file descriptor vector + let new_socket: Arc = SocketInode::new(new_socket); + let new_fd = current_pcb().alloc_fd(File::new(new_socket, FileMode::O_RDWR)?, None)?; + // kdebug!("accept: new_fd={}", new_fd); + if !addr.is_null() { + // kdebug!("accept: write remote_endpoint to user"); + // 将对端地址写入用户空间 + let sockaddr_in = SockAddr::from(remote_endpoint); + unsafe { + sockaddr_in.write_to_user(addr, addrlen)?; + } + } + return Ok(new_fd as i64); +} + +#[no_mangle] +pub extern "C" fn sys_getsockname(regs: &pt_regs) -> i64 { + let fd = regs.r8 as usize; + let addr = regs.r9 as usize; + let addrlen = regs.r10 as usize; + return do_getsockname(fd, addr as *mut SockAddr, addrlen as *mut u32) + .map(|x| x as i64) + .unwrap_or_else(|e| e.to_posix_errno() as i64); +} + +/// @brief sys_getsockname系统调用的实际执行函数 +/// +/// Returns the current address to which the socket +/// sockfd is bound, in the buffer pointed to by addr. +/// +/// @param fd 文件描述符 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回0,失败返回错误码 +pub fn do_getsockname( + fd: usize, + addr: *mut SockAddr, + addrlen: *mut u32, +) -> Result { + if addr.is_null() { + return Err(SystemError::EINVAL); + } + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = socket.inner(); + let endpoint: Endpoint = socket.endpoint().ok_or(SystemError::EINVAL)?; + drop(socket); + + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to_user(addr, addrlen)?; + } + return Ok(0); +} + +#[no_mangle] +pub extern "C" fn sys_getpeername(regs: &pt_regs) -> u64 { + let fd = regs.r8 as usize; + let addr = regs.r9 as usize; + let addrlen = regs.r10 as usize; + return do_getpeername(fd, addr as *mut SockAddr, addrlen as *mut u32) + .map(|x| x as u64) + .unwrap_or_else(|e| e.to_posix_errno() as u64); +} + +/// @brief sys_getpeername系统调用的实际执行函数 +/// +/// @param fd 文件描述符 +/// @param addr SockAddr +/// @param addrlen 地址长度 +/// +/// @return 成功返回0,失败返回错误码 +pub fn do_getpeername( + fd: usize, + addr: *mut SockAddr, + addrlen: *mut u32, +) -> Result { + if addr.is_null() { + return Err(SystemError::EINVAL); + } + + let socket: Arc = current_pcb() + .get_socket(fd as i32) + .ok_or(SystemError::EBADF)?; + let socket = socket.inner(); + let endpoint: Endpoint = socket.peer_endpoint().ok_or(SystemError::EINVAL)?; + drop(socket); + + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to_user(addr, addrlen)?; + } + return Ok(0); +} + +// 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32 +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrIn { + pub sin_family: u16, + pub sin_port: u16, + pub sin_addr: u32, + pub sin_zero: [u8; 8], +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrUn { + pub sun_family: u16, + pub sun_path: [u8; 108], +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrLl { + pub sll_family: u16, + pub sll_protocol: u16, + pub sll_ifindex: u32, + pub sll_hatype: u16, + pub sll_pkttype: u8, + pub sll_halen: u8, + pub sll_addr: [u8; 8], +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrNl { + nl_family: u16, + nl_pad: u16, + nl_pid: u32, + nl_groups: u32, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrPlaceholder { + pub family: u16, + pub data: [u8; 14], +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub union SockAddr { + pub family: u16, + pub addr_in: SockAddrIn, + pub addr_un: SockAddrUn, + pub addr_ll: SockAddrLl, + pub addr_nl: SockAddrNl, + pub addr_ph: SockAddrPlaceholder, +} + +impl SockAddr { + /// @brief 把用户传入的SockAddr转换为Endpoint结构体 + pub fn to_endpoint(addr: *const SockAddr, len: usize) -> Result { + if unsafe { + verify_area( + addr as usize as u64, + core::mem::size_of::() as u64, + ) + } == false + { + return Err(SystemError::EFAULT); + } + + let addr = unsafe { addr.as_ref() }.ok_or(SystemError::EFAULT)?; + if len < addr.len()? { + return Err(SystemError::EINVAL); + } + unsafe { + match AddressFamily::try_from(addr.family)? { + AddressFamily::INet => { + let addr_in: SockAddrIn = addr.addr_in; + + let ip: wire::IpAddress = wire::IpAddress::from(wire::Ipv4Address::from_bytes( + &u32::from_be(addr_in.sin_addr).to_be_bytes()[..], + )); + let port = u16::from_be(addr_in.sin_port); + + return Ok(Endpoint::Ip(Some(wire::IpEndpoint::new(ip, port)))); + } + AddressFamily::Packet => { + // TODO: support packet socket + return Err(SystemError::EINVAL); + } + AddressFamily::Netlink => { + // TODO: support netlink socket + return Err(SystemError::EINVAL); + } + AddressFamily::Unix => { + return Err(SystemError::EINVAL); + } + _ => { + return Err(SystemError::EINVAL); + } + } + } + } + + /// @brief 获取地址长度 + pub fn len(&self) -> Result { + let ret = match AddressFamily::try_from(unsafe { self.family })? { + AddressFamily::INet => Ok(core::mem::size_of::()), + AddressFamily::Packet => Ok(core::mem::size_of::()), + AddressFamily::Netlink => Ok(core::mem::size_of::()), + AddressFamily::Unix => Err(SystemError::EINVAL), + _ => Err(SystemError::EINVAL), + }; + + return ret; + } + + /// @brief 把SockAddr的数据写入用户空间 + /// + /// @param addr 用户空间的SockAddr的地址 + /// @param len 要写入的长度 + /// + /// @return 成功返回写入的长度,失败返回错误码 + pub unsafe fn write_to_user( + &self, + addr: *mut SockAddr, + addr_len: *mut u32, + ) -> Result { + // 当用户传入的地址或者长度为空时,直接返回0 + if addr.is_null() || addr_len.is_null() { + return Ok(0); + } + // 检查用户传入的地址是否合法 + if !verify_area( + addr as usize as u64, + core::mem::size_of::() as u64, + ) || !verify_area(addr_len as usize as u64, core::mem::size_of::() as u64) + { + return Err(SystemError::EFAULT); + } + + let to_write = min(self.len()?, *addr_len as usize); + if to_write > 0 { + let buf = core::slice::from_raw_parts_mut(addr as *mut u8, to_write); + buf.copy_from_slice(core::slice::from_raw_parts( + self as *const SockAddr as *const u8, + to_write, + )); + } + *addr_len = self.len()? as u32; + return Ok(to_write); + } +} + +impl From for SockAddr { + fn from(value: Endpoint) -> Self { + match value { + Endpoint::Ip(ip_endpoint) => { + // 未指定地址 + if let None = ip_endpoint { + return SockAddr { + addr_ph: SockAddrPlaceholder { + family: AddressFamily::Unspecified as u16, + data: [0; 14], + }, + }; + } + // 指定了地址 + let ip_endpoint = ip_endpoint.unwrap(); + match ip_endpoint.addr { + wire::IpAddress::Ipv4(ipv4_addr) => { + let addr_in = SockAddrIn { + sin_family: AddressFamily::INet as u16, + sin_port: ip_endpoint.port.to_be(), + sin_addr: u32::from_be_bytes(ipv4_addr.0).to_be(), + sin_zero: [0; 8], + }; + + return SockAddr { addr_in }; + } + _ => { + unimplemented!("not support ipv6"); + } + } + } + + Endpoint::LinkLayer(link_endpoint) => { + let addr_ll = SockAddrLl { + sll_family: AddressFamily::Packet as u16, + sll_protocol: 0, + sll_ifindex: link_endpoint.interface as u32, + sll_hatype: 0, + sll_pkttype: 0, + sll_halen: 0, + sll_addr: [0; 8], + }; + + return SockAddr { addr_ll }; + } // _ => { + // // todo: support other endpoint, like Netlink... + // unimplemented!("not support {value:?}"); + // } + } + } +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct MsgHdr { + /// 指向一个SockAddr结构体的指针 + pub msg_name: *mut SockAddr, + /// SockAddr结构体的大小 + pub msg_namelen: u32, + /// scatter/gather array + pub msg_iov: *mut IoVec, + /// elements in msg_iov + pub msg_iovlen: usize, + /// 辅助数据 + pub msg_control: *mut u8, + /// 辅助数据长度 + pub msg_controllen: usize, + /// 接收到的消息的标志 + pub msg_flags: u32, +} + +#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)] +pub enum PosixIpProtocol { + /// Dummy protocol for TCP. + IP = 0, + /// Internet Control Message Protocol. + ICMP = 1, + /// Internet Group Management Protocol. + IGMP = 2, + /// IPIP tunnels (older KA9Q tunnels use 94). + IPIP = 4, + /// Transmission Control Protocol. + TCP = 6, + /// Exterior Gateway Protocol. + EGP = 8, + /// PUP protocol. + PUP = 12, + /// User Datagram Protocol. + UDP = 17, + /// XNS IDP protocol. + IDP = 22, + /// SO Transport Protocol Class 4. + TP = 29, + /// Datagram Congestion Control Protocol. + DCCP = 33, + /// IPv6-in-IPv4 tunnelling. + IPv6 = 41, + /// RSVP Protocol. + RSVP = 46, + /// Generic Routing Encapsulation. (Cisco GRE) (rfc 1701, 1702) + GRE = 47, + /// Encapsulation Security Payload protocol + ESP = 50, + /// Authentication Header protocol + AH = 51, + /// Multicast Transport Protocol. + MTP = 92, + /// IP option pseudo header for BEET + BEETPH = 94, + /// Encapsulation Header. + ENCAP = 98, + /// Protocol Independent Multicast. + PIM = 103, + /// Compression Header Protocol. + COMP = 108, + /// Stream Control Transport Protocol + SCTP = 132, + /// UDP-Lite protocol (RFC 3828) + UDPLITE = 136, + /// MPLS in IP (RFC 4023) + MPLSINIP = 137, + /// Ethernet-within-IPv6 Encapsulation + ETHERNET = 143, + /// Raw IP packets + RAW = 255, + /// Multipath TCP connection + MPTCP = 262, +} + +impl TryFrom for PosixIpProtocol { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match ::from_u16(value) { + Some(p) => Ok(p), + None => Err(SystemError::EPROTONOSUPPORT), + } + } +} + +impl Into for PosixIpProtocol { + fn into(self) -> u16 { + ::to_u16(&self).unwrap() + } +} + +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)] +pub enum PosixSocketOption { + SO_DEBUG = 1, + SO_REUSEADDR = 2, + SO_TYPE = 3, + SO_ERROR = 4, + SO_DONTROUTE = 5, + SO_BROADCAST = 6, + SO_SNDBUF = 7, + SO_RCVBUF = 8, + SO_SNDBUFFORCE = 32, + SO_RCVBUFFORCE = 33, + SO_KEEPALIVE = 9, + SO_OOBINLINE = 10, + SO_NO_CHECK = 11, + SO_PRIORITY = 12, + SO_LINGER = 13, + SO_BSDCOMPAT = 14, + SO_REUSEPORT = 15, + SO_PASSCRED = 16, + SO_PEERCRED = 17, + SO_RCVLOWAT = 18, + SO_SNDLOWAT = 19, + SO_RCVTIMEO_OLD = 20, + SO_SNDTIMEO_OLD = 21, + + SO_SECURITY_AUTHENTICATION = 22, + SO_SECURITY_ENCRYPTION_TRANSPORT = 23, + SO_SECURITY_ENCRYPTION_NETWORK = 24, + + SO_BINDTODEVICE = 25, + + /// 与SO_GET_FILTER相同 + SO_ATTACH_FILTER = 26, + SO_DETACH_FILTER = 27, + + SO_PEERNAME = 28, + + SO_ACCEPTCONN = 30, + + SO_PEERSEC = 31, + SO_PASSSEC = 34, + + SO_MARK = 36, + + SO_PROTOCOL = 38, + SO_DOMAIN = 39, + + SO_RXQ_OVFL = 40, + + /// 与SCM_WIFI_STATUS相同 + SO_WIFI_STATUS = 41, + SO_PEEK_OFF = 42, + + /* Instruct lower device to use last 4-bytes of skb data as FCS */ + SO_NOFCS = 43, + + SO_LOCK_FILTER = 44, + SO_SELECT_ERR_QUEUE = 45, + SO_BUSY_POLL = 46, + SO_MAX_PACING_RATE = 47, + SO_BPF_EXTENSIONS = 48, + SO_INCOMING_CPU = 49, + SO_ATTACH_BPF = 50, + // SO_DETACH_BPF = SO_DETACH_FILTER, + SO_ATTACH_REUSEPORT_CBPF = 51, + SO_ATTACH_REUSEPORT_EBPF = 52, + + SO_CNX_ADVICE = 53, + SCM_TIMESTAMPING_OPT_STATS = 54, + SO_MEMINFO = 55, + SO_INCOMING_NAPI_ID = 56, + SO_COOKIE = 57, + SCM_TIMESTAMPING_PKTINFO = 58, + SO_PEERGROUPS = 59, + SO_ZEROCOPY = 60, + /// 与SCM_TXTIME相同 + SO_TXTIME = 61, + + SO_BINDTOIFINDEX = 62, + + SO_TIMESTAMP_OLD = 29, + SO_TIMESTAMPNS_OLD = 35, + SO_TIMESTAMPING_OLD = 37, + SO_TIMESTAMP_NEW = 63, + SO_TIMESTAMPNS_NEW = 64, + SO_TIMESTAMPING_NEW = 65, + + SO_RCVTIMEO_NEW = 66, + SO_SNDTIMEO_NEW = 67, + + SO_DETACH_REUSEPORT_BPF = 68, + + SO_PREFER_BUSY_POLL = 69, + SO_BUSY_POLL_BUDGET = 70, + + SO_NETNS_COOKIE = 71, + SO_BUF_LOCK = 72, + SO_RESERVE_MEM = 73, + SO_TXREHASH = 74, + SO_RCVMARK = 75, +} + +impl TryFrom for PosixSocketOption { + type Error = SystemError; + + fn try_from(value: i32) -> Result { + match ::from_i32(value) { + Some(p) => Ok(p), + None => Err(SystemError::EINVAL), + } + } +} + +impl Into for PosixSocketOption { + fn into(self) -> i32 { + ::to_i32(&self).unwrap() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum PosixTcpSocketOptions { + /// Turn off Nagle's algorithm. + NoDelay = 1, + /// Limit MSS. + MaxSegment = 2, + /// Never send partially complete segments. + Cork = 3, + /// Start keeplives after this period. + KeepIdle = 4, + /// Interval between keepalives. + KeepIntvl = 5, + /// Number of keepalives before death. + KeepCnt = 6, + /// Number of SYN retransmits. + Syncnt = 7, + /// Lifetime for orphaned FIN-WAIT-2 state. + Linger2 = 8, + /// Wake up listener only when data arrive. + DeferAccept = 9, + /// Bound advertised window + WindowClamp = 10, + /// Information about this connection. + Info = 11, + /// Block/reenable quick acks. + QuickAck = 12, + /// Congestion control algorithm. + Congestion = 13, + /// TCP MD5 Signature (RFC2385). + Md5Sig = 14, + /// Use linear timeouts for thin streams + ThinLinearTimeouts = 16, + /// Fast retrans. after 1 dupack. + ThinDupack = 17, + /// How long for loss retry before timeout. + UserTimeout = 18, + /// TCP sock is under repair right now. + Repair = 19, + RepairQueue = 20, + QueueSeq = 21, + RepairOptions = 22, + /// Enable FastOpen on listeners + FastOpen = 23, + Timestamp = 24, + /// Limit number of unsent bytes in write queue. + NotSentLowat = 25, + /// Get Congestion Control (optional) info. + CCInfo = 26, + /// Record SYN headers for new connections. + SaveSyn = 27, + /// Get SYN headers recorded for connection. + SavedSyn = 28, + /// Get/set window parameters. + RepairWindow = 29, + /// Attempt FastOpen with connect. + FastOpenConnect = 30, + /// Attach a ULP to a TCP connection. + ULP = 31, + /// TCP MD5 Signature with extensions. + Md5SigExt = 32, + /// Set the key for Fast Open(cookie). + FastOpenKey = 33, + /// Enable TFO without a TFO cookie. + FastOpenNoCookie = 34, + ZeroCopyReceive = 35, + /// Notify bytes available to read as a cmsg on read. + /// 与TCP_CM_INQ相同 + INQ = 36, + /// delay outgoing packets by XX usec + TxDelay = 37, +} + +impl TryFrom for PosixTcpSocketOptions { + type Error = SystemError; + + fn try_from(value: i32) -> Result { + match ::from_i32(value) { + Some(p) => Ok(p), + None => Err(SystemError::EINVAL), + } + } +} + +impl Into for PosixTcpSocketOptions { + fn into(self) -> i32 { + ::to_i32(&self).unwrap() + } +} diff --git a/kernel/src/process/process.rs b/kernel/src/process/process.rs index 3a264510..4f718645 100644 --- a/kernel/src/process/process.rs +++ b/kernel/src/process/process.rs @@ -3,18 +3,20 @@ use core::{ ptr::{null_mut, read_volatile, write_volatile}, }; -use alloc::boxed::Box; +use alloc::{boxed::Box, sync::Arc}; use crate::{ arch::{asm::current::current_pcb, fpu::FpState}, filesystem::vfs::{ file::{File, FileDescriptorVec, FileMode}, - ROOT_INODE, + FileType, ROOT_INODE, }, include::bindings::bindings::{ process_control_block, CLONE_FS, PROC_INTERRUPTIBLE, PROC_RUNNING, PROC_STOPPED, PROC_UNINTERRUPTIBLE, }, + libs::casting::DowncastArc, + net::socket::SocketInode, sched::core::{cpu_executing, sched_enqueue}, smp::core::{smp_get_processor_id, smp_send_reschedule}, syscall::SystemError, @@ -291,6 +293,24 @@ impl process_control_block { pub unsafe fn mark_sleep_uninterruptible(&mut self) { self.state = PROC_UNINTERRUPTIBLE as u64; } + + /// @brief 根据文件描述符序号,获取socket对象的可变引用 + /// + /// @param fd 文件描述符序号 + /// + /// @return Option(&mut Box) socket对象的可变引用. 如果文件描述符不是socket,那么返回None + pub fn get_socket(&self, fd: i32) -> Option> { + let f = self.get_file_mut_by_fd(fd)?; + + if f.file_type() != FileType::Socket { + return None; + } + let socket: Arc = f + .inode() + .downcast_arc::() + .expect("Not a socket inode"); + return Some(socket); + } } // =========== 导出到C的函数,在将来,进程管理模块被完全重构之后,需要删掉他们 BEGIN ============ diff --git a/kernel/src/syscall/syscall.c b/kernel/src/syscall/syscall.c index d47b17af..ed3a5590 100644 --- a/kernel/src/syscall/syscall.c +++ b/kernel/src/syscall/syscall.c @@ -22,6 +22,21 @@ extern uint64_t sys_sigaction(struct pt_regs *regs); extern uint64_t sys_rt_sigreturn(struct pt_regs *regs); extern uint64_t sys_getpid(struct pt_regs *regs); extern uint64_t sys_sched(struct pt_regs *regs); +extern int sys_dup(int oldfd); +extern int sys_dup2(int oldfd, int newfd); +extern uint64_t sys_socket(struct pt_regs *regs); +extern uint64_t sys_setsockopt(struct pt_regs *regs); +extern uint64_t sys_getsockopt(struct pt_regs *regs); +extern uint64_t sys_connect(struct pt_regs *regs); +extern uint64_t sys_bind(struct pt_regs *regs); +extern uint64_t sys_sendto(struct pt_regs *regs); +extern uint64_t sys_recvfrom(struct pt_regs *regs); +extern uint64_t sys_recvmsg(struct pt_regs *regs); +extern uint64_t sys_listen(struct pt_regs *regs); +extern uint64_t sys_shutdown(struct pt_regs *regs); +extern uint64_t sys_accept(struct pt_regs *regs); +extern uint64_t sys_getsockname(struct pt_regs *regs); +extern uint64_t sys_getpeername(struct pt_regs *regs); /** * @brief 关闭文件系统调用 @@ -179,9 +194,9 @@ uint64_t sys_brk(struct pt_regs *regs) // kdebug("sys_brk input= %#010lx , new_brk= %#010lx bytes current_pcb->mm->brk_start=%#018lx // current->end_brk=%#018lx", regs->r8, new_brk, current_pcb->mm->brk_start, current_pcb->mm->brk_end); struct mm_struct *mm = current_pcb->mm; - if (new_brk < mm->brk_start || new_brk> new_brk >= current_pcb->addr_limit) + if (new_brk < mm->brk_start || new_brk > new_brk >= current_pcb->addr_limit) return mm->brk_end; - + if (mm->brk_end == new_brk) return new_brk; @@ -392,9 +407,6 @@ uint64_t sys_pipe(struct pt_regs *regs) extern uint64_t sys_mkdir(struct pt_regs *regs); -extern int sys_dup(int oldfd); -extern int sys_dup2(int oldfd, int newfd); - system_call_t system_call_table[MAX_SYSTEM_CALL_NUM] = { [0] = system_call_not_exists, [1] = sys_put_string, @@ -426,5 +438,18 @@ system_call_t system_call_table[MAX_SYSTEM_CALL_NUM] = { [27] = sys_sched, [28] = sys_dup, [29] = sys_dup2, - [30 ... 255] = system_call_not_exists, + [30] = sys_socket, + [31] = sys_setsockopt, + [32] = sys_getsockopt, + [33] = sys_connect, + [34] = sys_bind, + [35] = sys_sendto, + [36] = sys_recvfrom, + [37] = sys_recvmsg, + [38] = sys_listen, + [39] = sys_shutdown, + [40] = sys_accept, + [41] = sys_getsockname, + [42] = sys_getpeername, + [43 ... 255] = system_call_not_exists, }; diff --git a/kernel/src/syscall/syscall_num.h b/kernel/src/syscall/syscall_num.h index 9fe42d37..f13e877a 100644 --- a/kernel/src/syscall/syscall_num.h +++ b/kernel/src/syscall/syscall_num.h @@ -41,5 +41,19 @@ #define SYS_SCHED 27 // 让系统立即运行调度器(该系统调用不能由运行在Ring3的程序发起) #define SYS_DUP 28 #define SYS_DUP2 29 +#define SYS_SOCKET 30 // 创建一个socket + +#define SYS_SETSOCKOPT 31 // 设置socket的选项 +#define SYS_GETSOCKOPT 32 // 获取socket的选项 +#define SYS_CONNECT 33 // 连接到一个socket +#define SYS_BIND 34 // 绑定一个socket +#define SYS_SENDTO 35 // 向一个socket发送数据 +#define SYS_RECVFROM 36 // 从一个socket接收数据 +#define SYS_RECVMSG 37 // 从一个socket接收消息 +#define SYS_LISTEN 38 // 监听一个socket +#define SYS_SHUTDOWN 39 // 关闭socket +#define SYS_ACCEPT 40 // 接受一个socket连接 +#define SYS_GETSOCKNAME 41 // 获取socket的名字 +#define SYS_GETPEERNAME 42 // 获取socket的对端名字 #define SYS_AHCI_END_REQ 255 // AHCI DMA请求结束end_request的系统调用 \ No newline at end of file diff --git a/kernel/src/time/timer.rs b/kernel/src/time/timer.rs index 731e7e2e..e02f44d8 100644 --- a/kernel/src/time/timer.rs +++ b/kernel/src/time/timer.rs @@ -1,4 +1,4 @@ -use core::sync::atomic::{AtomicBool, Ordering}; +use core::sync::atomic::{compiler_fence, AtomicBool, Ordering}; use alloc::{ boxed::Box, @@ -75,15 +75,21 @@ impl Timer { /// @brief 将定时器插入到定时器链表中 pub fn activate(&self) { - let timer_list = &mut TIMER_LIST.lock(); let inner_guard = self.0.lock(); + let timer_list = &mut TIMER_LIST.lock(); + // 链表为空,则直接插入 if timer_list.is_empty() { // FIXME push_timer + timer_list.push_back(inner_guard.self_ref.upgrade().unwrap()); + + drop(inner_guard); + drop(timer_list); + compiler_fence(Ordering::SeqCst); + return; } - let mut split_pos: usize = 0; for (pos, elt) in timer_list.iter().enumerate() { if elt.0.lock().expire_jiffies > inner_guard.expire_jiffies { @@ -94,6 +100,8 @@ impl Timer { let mut temp_list: LinkedList> = timer_list.split_off(split_pos); timer_list.push_back(inner_guard.self_ref.upgrade().unwrap()); timer_list.append(&mut temp_list); + drop(inner_guard); + drop(timer_list); } #[inline] @@ -147,20 +155,38 @@ impl SoftirqVec for DoTimerSoftirq { // 最多只处理TIMER_RUN_CYCLE_THRESHOLD个计时器 for _ in 0..TIMER_RUN_CYCLE_THRESHOLD { // kdebug!("DoTimerSoftirq run"); - - let timer_list = &mut TIMER_LIST.lock(); + let timer_list = TIMER_LIST.try_lock(); + if timer_list.is_err() { + continue; + } + let mut timer_list = timer_list.unwrap(); if timer_list.is_empty() { break; } - if timer_list.front().unwrap().0.lock().expire_jiffies - <= unsafe { TIMER_JIFFIES as u64 } - { - let timer = timer_list.pop_front().unwrap(); - drop(timer_list); - timer.run(); + let timer_list_front = timer_list.pop_front().unwrap(); + // kdebug!("to lock timer_list_front"); + let mut timer_list_front_guard = None; + for _ in 0..10 { + let x = timer_list_front.0.try_lock(); + if x.is_err() { + continue; + } + timer_list_front_guard = Some(x.unwrap()); } + if timer_list_front_guard.is_none() { + continue; + } + let timer_list_front_guard = timer_list_front_guard.unwrap(); + if timer_list_front_guard.expire_jiffies > unsafe { TIMER_JIFFIES as u64 } { + drop(timer_list_front_guard); + timer_list.push_front(timer_list_front); + break; + } + drop(timer_list_front_guard); + drop(timer_list); + timer_list_front.run(); } self.clear_run(); @@ -293,7 +319,7 @@ pub extern "C" fn rs_timer_next_n_us_jiffies(expire_us: u64) -> u64 { pub extern "C" fn rs_timer_get_first_expire() -> i64 { match timer_get_first_expire() { Ok(v) => return v as i64, - Err(e) => return e.to_posix_errno() as i64, + Err(_) => return 0, } } diff --git a/tools/bootstrap.sh b/tools/bootstrap.sh index ad38a2cb..28a2b79d 100644 --- a/tools/bootstrap.sh +++ b/tools/bootstrap.sh @@ -43,7 +43,7 @@ install_ubuntu_debian_pkg() gnupg \ lsb-release \ llvm-dev libclang-dev clang gcc-multilib \ - gcc build-essential fdisk dosfstools + gcc build-essential fdisk dosfstools dnsmasq bridge-utils iptables if [ -z "$(which docker)" ] && [ -n ${dockerInstall} ]; then echo "正在安装docker..." diff --git a/tools/qemu/ifdown-nat b/tools/qemu/ifdown-nat new file mode 100755 index 00000000..94d00b49 --- /dev/null +++ b/tools/qemu/ifdown-nat @@ -0,0 +1,24 @@ +#!/bin/bash +BRIDGE=dragonos-bridge +if [ -n "$1" ]; then + echo "正在断开接口 $1" + ip link set $1 down + brctl delif "$BRIDGE" $1 + tap=`brctl show | grep natnet | awk '{print $4}'` + if [[ $tap != tap* ]];then + ip link set "$BRIDGE" down + brctl delbr "$BRIDGE" + iptables -t nat -F + kill `ps aux | grep dnsmasq | grep -v grep | awk '{print $2}'` + echo "断开接口 $1 成功" + echo "网桥 $BRIDGE 卸载成功" + echo "dnsmasq 服务停止成功" + exit 0 + else + echo "断开接口 $1 成功" + exit 0 + fi +else + echo "删除错误:未指定接口" + exit 1 +fi diff --git a/tools/qemu/ifup-nat b/tools/qemu/ifup-nat new file mode 100755 index 00000000..67e39e93 --- /dev/null +++ b/tools/qemu/ifup-nat @@ -0,0 +1,85 @@ +#!/bin/bash +# 设置 bridge 名称 +BRIDGE=dragonos-bridge +# 设置网络信息 +NETWORK=192.168.137.0 +NETMASK=255.255.255.0 +GATEWAY=192.168.137.1 +DHCPRANGE=192.168.137.100,192.168.137.200 +# 启用PXE支持的可选参数 +TFTPROOT= +BOOTP= + +function check_bridge() +{ + if brctl show | grep "^$BRIDGE" &> /dev/null; then + return 1 + else + return 0 + fi +} + +function create_bridge() +{ + brctl addbr "$BRIDGE" + brctl stp "$BRIDGE" on + brctl setfd "$BRIDGE" 0 + ifconfig "$BRIDGE" "$GATEWAY" netmask "$NETMASK" up +} + +function enable_ip_forward() +{ + echo 1 > /proc/sys/net/ipv4/ip_forward +} + +function add_filter_rules() +{ + iptables -t nat -A POSTROUTING -s "$NETWORK"/"$NETMASK" \ + ! -d "$NETWORK"/"$NETMASK" -j MASQUERADE +} + +function start_dnsmasq() +{ +# 禁止重复运行dnsmasq + ps -ef | grep "dnsmasq" | grep -v "grep" &> /dev/null + if [ $? -eq 0 ]; then + echo "dnsmasq 已经在运行" + return 1 + fi + dnsmasq \ + --strict-order \ + --except-interface=lo \ + --interface=$BRIDGE \ + --listen-address=$GATEWAY \ + --bind-interfaces \ + --dhcp-range=$DHCPRANGE \ + --conf-file="" \ + --pid-file=/var/run/qemu-dhcp-$BRIDGE.pid \ + --dhcp-leasefile=/var/run/qemu-dhcp-$BRIDGE.leases \ + --dhcp-no-override \ + ${TFTPROOT:+"--enable-tftp"} \ + ${TFTPROOT:+"--tftp-root=$TFTPROOT"} \ + ${BOOTP:+"--dhcp-boot=$BOOTP"} +} + +function setup_bridge_nat() +{ + check_bridge "$BRIDGE" + if [ $? -eq 0 ]; then + create_bridge + fi + enable_ip_forward + add_filter_rules "$BRIDGE" + start_dnsmasq "$BRIDGE" +} + +# 安装前需要检查$1参数 +if [ -n "$1" ]; then + setup_bridge_nat + brctl addif "$BRIDGE" "$1" + ifconfig "$1" 0.0.0.0 up + exit 0 +else + echo "发现错误:没有指定接口" + exit 1 +fi diff --git a/tools/run-qemu.sh b/tools/run-qemu.sh index 7b5a0f63..7594bda5 100644 --- a/tools/run-qemu.sh +++ b/tools/run-qemu.sh @@ -1,9 +1,40 @@ +check_dependencies() +{ + # Check if qemu is installed + if [ -z "$(which qemu-system-x86_64)" ]; then + echo "Please install qemu first!" + exit 1 + fi + + # Check if brctl is installed + if [ -z "$(which brctl)" ]; then + echo "Please install bridge-utils first!" + exit 1 + fi + + # Check if dnsmasq is installed + if [ -z "$(which dnsmasq)" ]; then + echo "Please install dnsmasq first!" + exit 1 + fi + + # Check if iptable is installed + if [ -z "$(which iptables)" ]; then + echo "Please install iptables first!" + exit 1 + fi + +} + +check_dependencies + # 进行启动前检查 flag_can_run=1 ARGS=`getopt -o p -l bios:,display: -- "$@"` eval set -- "${ARGS}" echo "$@" -allflags=$(qemu-system-x86_64 -cpu help | awk '/flags/ {y=1; getline}; y {print}' | tr ' ' '\n' | grep -Ev "^$" | sed -r 's|^|+|' | tr '\n' ',' | sed -r "s|,$||") +allflags= +# allflags=$(qemu-system-x86_64 -cpu help | awk '/flags/ {y=1; getline}; y {print}' | tr ' ' '\n' | grep -Ev "^$" | sed -r 's|^|+|' | tr '\n' ',' | sed -r "s|,$||") ARCH="x86_64" #ARCH="i386" # 请根据自己的需要,在-d 后方加入所需的 trace 事件 @@ -17,7 +48,6 @@ if [ $(uname) == Darwin ]; then qemu_accel=hvf fi - QEMU=qemu-system-x86_64 QEMU_DISK_IMAGE="../bin/disk.img" QEMU_MEMORY="512M" @@ -29,7 +59,10 @@ QEMU_RTC_CLOCK="clock=host,base=localtime" QEMU_SERIAL="file:../serial_opt.txt" QEMU_DRIVE="id=disk,file=${QEMU_DISK_IMAGE},if=none" -QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -nic user,model=virtio-net-pci -usb -device qemu-xhci,id=xhci,p2=8,p3=4 -machine accel=${qemu_accel} -machine q35" + +# ps: 下面这条使用tap的方式,无法dhcp获取到ip,暂时不知道为什么 +# QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -net nic,netdev=nic0 -netdev tap,id=nic0,model=virtio-net-pci,script=qemu/ifup-nat,downscript=qemu/ifdown-nat -usb -device qemu-xhci,id=xhci,p2=8,p3=4 -machine accel=${qemu_accel} -machine q35 " +QEMU_DEVICES="-device ahci,id=ahci -device ide-hd,drive=disk,bus=ahci.0 -nic user,model=virtio-net-pci,hostfwd=tcp::12580-:12580 -usb -device qemu-xhci,id=xhci,p2=8,p3=4 -machine accel=${qemu_accel} -machine q35 " QEMU_ARGUMENT="-d ${QEMU_DISK_IMAGE} -m ${QEMU_MEMORY} -smp ${QEMU_SMP} -boot order=d -monitor ${QEMU_MONITOR} -d ${qemu_trace_std} " diff --git a/user/apps/test_relibc/main.c b/user/apps/test_relibc/main.c index fd129279..c44a5500 100644 --- a/user/apps/test_relibc/main.c +++ b/user/apps/test_relibc/main.c @@ -1,34 +1,242 @@ -/** - * @file main.c - * @author longjin (longjin@RinGoTek.cn) - * @brief 测试signal用的程序 - * @version 0.1 - * @date 2022-12-06 - * - * @copyright Copyright (c) 2022 - * - */ - -/** - * 测试signal的kill命令的方法: - * 1.在DragonOS的控制台输入 exec bin/test_signal.elf & - * 请注意,一定要输入末尾的 '&',否则进程不会后台运行 - * 2.然后kill对应的进程的pid (上一条命令执行后,将会输出这样一行:"[1] 生成的pid") - * - */ - -#include +#include +#include #include #include #include -#include +#include #include +#define CONN_QUEUE_SIZE 20 +#define BUFFER_SIZE 1024 +#define SERVER_PORT 12580 -int main() +int server_sockfd; +int conn; + +void signal_handler(int signo) { - printf("Test Relibc printf!\n"); - printf("Test Relibc printf ok!\n"); + + printf("Server is exiting...\n"); + close(conn); + close(server_sockfd); + exit(0); +} + +static char logo[] = + " ____ ___ ____ \n| _ \\ _ __ __ _ __ _ ___ _ __ / _ \\ / ___| " + "\n| | | || '__| / _` | / _` | / _ \\ | '_ \\ | | | |\\___ \\ \n| |_| || | | (_| || (_| || (_) || | | || |_| | " + "___) |\n|____/ |_| \\__,_| \\__, | \\___/ |_| |_| \\___/ |____/ \n |___/ \n"; + +void tcp_server() +{ + printf("TCP Server is running...\n"); + server_sockfd = socket(AF_INET, SOCK_STREAM, 0); + printf("socket() ok, server_sockfd=%d\n", server_sockfd); + struct sockaddr_in server_sockaddr; + server_sockaddr.sin_family = AF_INET; + server_sockaddr.sin_port = htons(SERVER_PORT); + server_sockaddr.sin_addr.s_addr = htonl(INADDR_ANY); + + if (bind(server_sockfd, (struct sockaddr *)&server_sockaddr, sizeof(server_sockaddr))) + { + perror("Server bind error.\n"); + exit(1); + } + + printf("TCP Server is listening...\n"); + if (listen(server_sockfd, CONN_QUEUE_SIZE) == -1) + { + perror("Server listen error.\n"); + exit(1); + } + + printf("listen() ok\n"); + + char buffer[BUFFER_SIZE]; + struct sockaddr_in client_addr; + socklen_t client_length = sizeof(client_addr); + /* + Await a connection on socket FD. + When a connection arrives, open a new socket to communicate with it, + set *ADDR (which is *ADDR_LEN bytes long) to the address of the connecting + peer and *ADDR_LEN to the address's actual length, and return the + new socket's descriptor, or -1 for errors. + */ + conn = accept(server_sockfd, (struct sockaddr *)&client_addr, &client_length); + printf("Connection established.\n"); + if (conn < 0) + { + printf("Create connection failed, code=%d\n", conn); + exit(1); + } + send(conn, logo, sizeof(logo), 0); + while (1) + { + memset(buffer, 0, sizeof(buffer)); + int len = recv(conn, buffer, sizeof(buffer), 0); + if (len <= 0) + { + printf("Receive data failed! len=%d\n", len); + break; + } + if (strcmp(buffer, "exit\n") == 0) + { + break; + } + + printf("Received: %s\n", buffer); + send(conn, buffer, len, 0); + } + close(conn); + close(server_sockfd); +} + +void udp_server() +{ + printf("UDP Server is running...\n"); + server_sockfd = socket(AF_INET, SOCK_DGRAM, 0); + printf("socket() ok, server_sockfd=%d\n", server_sockfd); + struct sockaddr_in server_sockaddr; + server_sockaddr.sin_family = AF_INET; + server_sockaddr.sin_port = htons(SERVER_PORT); + server_sockaddr.sin_addr.s_addr = htonl(INADDR_ANY); + + if (bind(server_sockfd, (struct sockaddr *)&server_sockaddr, sizeof(server_sockaddr))) + { + perror("Server bind error.\n"); + exit(1); + } + + printf("UDP Server is listening...\n"); + + char buffer[BUFFER_SIZE]; + struct sockaddr_in client_addr; + socklen_t client_length = sizeof(client_addr); + + while (1) + { + memset(buffer, 0, sizeof(buffer)); + int len = recvfrom(server_sockfd, buffer, sizeof(buffer), 0, (struct sockaddr *)&client_addr, &client_length); + if (len <= 0) + { + printf("Receive data failed! len=%d", len); + break; + } + if (strcmp(buffer, "exit\n") == 0) + { + break; + } + + printf("Received: %s", buffer); + sendto(server_sockfd, buffer, len, 0, (struct sockaddr *)&client_addr, client_length); + printf("Send: %s", buffer); + } + close(conn); + close(server_sockfd); +} + +void tcp_client() +{ + printf("Client is running...\n"); + int client_sockfd = socket(AF_INET, SOCK_STREAM, 0); + + struct sockaddr_in server_addr = {0}; + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(12581); + server_addr.sin_addr.s_addr = inet_addr("192.168.199.129"); + printf("to connect\n"); + if (connect(client_sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) + { + perror("Failed to establish connection to server\n"); + exit(1); + } + printf("connected to server\n"); + + char sendbuf[BUFFER_SIZE] = {0}; + char recvbuf[BUFFER_SIZE] = {0}; + + int x = recv(client_sockfd, recvbuf, sizeof(recvbuf), 0); + + fputs(recvbuf, stdout); + + memset(recvbuf, 0, sizeof(recvbuf)); + + while (1) + { + fgets(sendbuf, sizeof(sendbuf), stdin); + sendbuf[0] = 'a'; + + // printf("to send\n"); + send(client_sockfd, sendbuf, strlen(sendbuf), 0); + // printf("send ok\n"); + if (strcmp(sendbuf, "exit\n") == 0) + { + break; + } + + int x = recv(client_sockfd, recvbuf, sizeof(recvbuf), 0); + if (x < 0) + { + printf("recv error, retval=%d\n", x); + break; + } + + fputs(recvbuf, stdout); + + memset(recvbuf, 0, sizeof(recvbuf)); + memset(sendbuf, 0, sizeof(sendbuf)); + } + close(client_sockfd); +} + +void udp_client() +{ + struct sockaddr_in addr; + int sockfd, len = 0; + int addr_len = sizeof(struct sockaddr_in); + char buffer[256]; + + /* 建立socket,注意必须是SOCK_DGRAM */ + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) + { + perror("socket"); + exit(1); + } + + /* 填写sockaddr_in*/ + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(12581); + addr.sin_addr.s_addr = inet_addr("192.168.199.129"); + + printf("to send logo\n"); + sendto(sockfd, logo, sizeof(logo), 0, (struct sockaddr *)&addr, addr_len); + printf("send logo ok\n"); + while (1) + { + bzero(buffer, sizeof(buffer)); + + printf("Please enter a string to send to server: \n"); + + /* 从标准输入设备取得字符串*/ + len = read(STDIN_FILENO, buffer, sizeof(buffer)); + printf("to send: %d\n", len); + /* 将字符串传送给server端*/ + sendto(sockfd, buffer, len, 0, (struct sockaddr *)&addr, addr_len); + + /* 接收server端返回的字符串*/ + len = recvfrom(sockfd, buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &addr_len); + printf("Receive from server: %s\n", buffer); + } return 0; +} +void main() +{ + // signal(SIGKILL, signal_handler); + // signal(SIGINT, signal_handler); + tcp_server(); + // udp_server(); + // tcp_client(); + // udp_client(); } \ No newline at end of file diff --git a/user/libs/libsystem/syscall.h b/user/libs/libsystem/syscall.h index 8a12d45e..268f6df8 100644 --- a/user/libs/libsystem/syscall.h +++ b/user/libs/libsystem/syscall.h @@ -34,6 +34,20 @@ #define SYS_GETPID 26 // 获取当前进程的pid(进程标识符) #define SYS_DUP 28 #define SYS_DUP2 29 +#define SYS_SOCKET 30 // 创建一个socket + +#define SYS_SETSOCKOPT 31 // 设置socket的选项 +#define SYS_GETSOCKOPT 32 // 获取socket的选项 +#define SYS_CONNECT 33 // 连接到一个socket +#define SYS_BIND 34 // 绑定一个socket +#define SYS_SENDTO 35 // 向一个socket发送数据 +#define SYS_RECVFROM 36 // 从一个socket接收数据 +#define SYS_RECVMSG 37 // 从一个socket接收消息 +#define SYS_LISTEN 38 // 监听一个socket +#define SYS_SHUTDOWN 39 // 关闭socket +#define SYS_ACCEPT 40 // 接受一个socket连接 +#define SYS_GETSOCKNAME 41 // 获取socket的名字 +#define SYS_GETPEERNAME 42 // 获取socket的对端名字 /** * @brief 用户态系统调用函数