diff --git a/kernel/src/filesystem/poll.rs b/kernel/src/filesystem/poll.rs index e5a96d9a..66cba369 100644 --- a/kernel/src/filesystem/poll.rs +++ b/kernel/src/filesystem/poll.rs @@ -1,12 +1,18 @@ use core::ffi::c_int; use crate::{ - ipc::signal::{RestartBlock, RestartBlockData, RestartFn}, + arch::ipc::signal::SigSet, + ipc::signal::{ + restore_saved_sigmask_unless, set_user_sigmask, RestartBlock, RestartBlockData, RestartFn, + }, mm::VirtAddr, net::event_poll::{EPollCtlOption, EPollEvent, EPollEventType, EventPoll}, process::ProcessManager, - syscall::{user_access::UserBufferWriter, Syscall}, - time::{Duration, Instant}, + syscall::{ + user_access::{UserBufferReader, UserBufferWriter}, + Syscall, + }, + time::{Duration, Instant, PosixTimeSpec}, }; use super::vfs::file::{File, FileMode}; @@ -32,11 +38,15 @@ impl<'a> PollAdapter<'a> { } fn add_pollfds(&self) -> Result<(), SystemError> { - for pollfd in self.poll_fds.iter() { + for (i, pollfd) in self.poll_fds.iter().enumerate() { + if pollfd.fd < 0 { + continue; + } let mut epoll_event = EPollEvent::default(); let poll_flags = PollFlags::from_bits_truncate(pollfd.events); let ep_events: EPollEventType = poll_flags.into(); epoll_event.set_events(ep_events.bits()); + epoll_event.set_data(i as u64); EventPoll::epoll_ctl_with_epfile( self.ep_file.clone(), @@ -64,8 +74,13 @@ impl<'a> PollAdapter<'a> { remain_timeout, )?; - for (i, event) in epoll_events.iter().enumerate() { - self.poll_fds[i].revents = (event.events() & 0xffff) as u16; + for event in epoll_events.iter() { + let index = event.data() as usize; + if index >= self.poll_fds.len() { + log::warn!("poll_all_fds: Invalid index in epoll event: {}", index); + continue; + } + self.poll_fds[index].revents = (event.events() & 0xffff) as u16; } Ok(events) @@ -74,13 +89,14 @@ impl<'a> PollAdapter<'a> { impl Syscall { /// https://code.dragonos.org.cn/xref/linux-6.6.21/fs/select.c#1068 + #[inline(never)] pub fn poll(pollfd_ptr: usize, nfds: u32, timeout_ms: i32) -> Result { let pollfd_ptr = VirtAddr::new(pollfd_ptr); let len = nfds as usize * core::mem::size_of::(); let mut timeout: Option = None; if timeout_ms >= 0 { - timeout = poll_select_set_timeout(timeout_ms); + timeout = poll_select_set_timeout(timeout_ms as u64); } let mut poll_fds_writer = UserBufferWriter::new(pollfd_ptr.as_ptr::(), len, true)?; let mut r = do_sys_poll(poll_fds_writer.buffer(0)?, timeout); @@ -92,15 +108,58 @@ impl Syscall { return r; } -} -/// 计算超时的时刻 -fn poll_select_set_timeout(timeout_ms: i32) -> Option { - if timeout_ms == 0 { - return None; + /// 参考 https://code.dragonos.org.cn/xref/linux-6.1.9/fs/select.c#1101 + #[inline(never)] + pub fn ppoll( + pollfd_ptr: usize, + nfds: u32, + timespec_ptr: usize, + sigmask_ptr: usize, + ) -> Result { + let mut timeout_ts: Option = None; + let mut sigmask: Option = None; + let pollfd_ptr = VirtAddr::new(pollfd_ptr); + let pollfds_len = nfds as usize * core::mem::size_of::(); + let mut poll_fds_writer = + UserBufferWriter::new(pollfd_ptr.as_ptr::(), pollfds_len, true)?; + let poll_fds = poll_fds_writer.buffer(0)?; + if sigmask_ptr != 0 { + let sigmask_reader = + UserBufferReader::new(sigmask_ptr as *const SigSet, size_of::(), true)?; + sigmask = Some(*sigmask_reader.read_one_from_user(0)?); + } + + if timespec_ptr != 0 { + let tsreader = UserBufferReader::new( + timespec_ptr as *const PosixTimeSpec, + size_of::(), + true, + )?; + let ts: PosixTimeSpec = *tsreader.read_one_from_user(0)?; + let timeout_ms = ts.tv_sec * 1000 + ts.tv_nsec / 1_000_000; + + if timeout_ms >= 0 { + timeout_ts = + Some(poll_select_set_timeout(timeout_ms as u64).ok_or(SystemError::EINVAL)?); + } + } + + if let Some(mut sigmask) = sigmask { + set_user_sigmask(&mut sigmask); + } + // log::debug!( + // "ppoll: poll_fds: {:?}, nfds: {}, timeout_ts: {:?},sigmask: {:?}", + // poll_fds, + // nfds, + // timeout_ts, + // sigmask + // ); + + let r: Result = do_sys_poll(poll_fds, timeout_ts); + + return poll_select_finish(timeout_ts, timespec_ptr, PollTimeType::TimeSpec, r); } - - Some(Instant::now() + Duration::from_millis(timeout_ms as u64)) } fn do_sys_poll(poll_fds: &mut [PollFd], timeout: Option) -> Result { @@ -115,6 +174,75 @@ fn do_sys_poll(poll_fds: &mut [PollFd], timeout: Option) -> Result Option { + if timeout_ms == 0 { + return None; + } + + Some(Instant::now() + Duration::from_millis(timeout_ms)) +} + +/// 参考 https://code.dragonos.org.cn/xref/linux-6.1.9/fs/select.c#298 +fn poll_select_finish( + end_time: Option, + user_time_ptr: usize, + poll_time_type: PollTimeType, + mut result: Result, +) -> Result { + restore_saved_sigmask_unless(result == Err(SystemError::ERESTARTNOHAND)); + + if user_time_ptr == 0 { + return result; + } + + // todo: 处理sticky timeouts + + if end_time.is_none() { + return result; + } + + let end_time = end_time.unwrap(); + + // no update for zero timeout + if end_time.total_millis() <= 0 { + return result; + } + + let ts = Instant::now(); + let duration = end_time.saturating_sub(ts); + let rts: PosixTimeSpec = duration.into(); + + match poll_time_type { + PollTimeType::TimeSpec => { + let mut tswriter = UserBufferWriter::new( + user_time_ptr as *mut PosixTimeSpec, + size_of::(), + true, + )?; + if tswriter.copy_one_to_user(&rts, 0).is_err() { + return result; + } + } + _ => todo!(), + } + + if result == Err(SystemError::ERESTARTNOHAND) { + result = result.map_err(|_| SystemError::EINTR); + } + + return result; +} + +#[allow(unused)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PollTimeType { + TimeVal, + OldTimeVal, + TimeSpec, + OldTimeSpec, +} + bitflags! { pub struct PollFlags: u16 { const POLLIN = 0x0001; diff --git a/kernel/src/ipc/signal.rs b/kernel/src/ipc/signal.rs index 1f5ee768..ce84408b 100644 --- a/kernel/src/ipc/signal.rs +++ b/kernel/src/ipc/signal.rs @@ -419,6 +419,16 @@ pub fn restore_saved_sigmask() { } } +pub fn restore_saved_sigmask_unless(interrupted: bool) { + if interrupted { + if !ProcessManager::current_pcb().has_pending_signal_fast() { + log::warn!("restore_saved_sigmask_unless: interrupted, but has NO pending signal"); + } + } else { + restore_saved_sigmask(); + } +} + /// 刷新指定进程的sighand的sigaction,将满足条件的sigaction恢复为默认状态。 /// 除非某个信号被设置为忽略且 `force_default` 为 `false`,否则都不会将其恢复。 /// diff --git a/kernel/src/ipc/signal_types.rs b/kernel/src/ipc/signal_types.rs index befeda02..124c3e5d 100644 --- a/kernel/src/ipc/signal_types.rs +++ b/kernel/src/ipc/signal_types.rs @@ -78,7 +78,11 @@ impl SignalStruct { let mut r = Self { inner: Box::::default(), }; - let sig_ign = Sigaction::default(); + let mut sig_ign = Sigaction::default(); + // 收到忽略的信号,重启系统调用 + // todo: 看看linux哪些 + sig_ign.flags_mut().insert(SigFlags::SA_RESTART); + r.inner.handlers[Signal::SIGCHLD as usize - 1] = sig_ign; r.inner.handlers[Signal::SIGURG as usize - 1] = sig_ign; r.inner.handlers[Signal::SIGWINCH as usize - 1] = sig_ign; diff --git a/kernel/src/ipc/syscall.rs b/kernel/src/ipc/syscall.rs index 90d4ddc4..d121821e 100644 --- a/kernel/src/ipc/syscall.rs +++ b/kernel/src/ipc/syscall.rs @@ -562,4 +562,28 @@ impl Syscall { return Ok(0); } } + + #[inline(never)] + pub fn rt_sigpending(user_sigset_ptr: usize, sigsetsize: usize) -> Result { + if sigsetsize != size_of::() { + return Err(SystemError::EINVAL); + } + + let mut user_buffer_writer = + UserBufferWriter::new(user_sigset_ptr as *mut SigSet, size_of::(), true)?; + + let pcb = ProcessManager::current_pcb(); + let siginfo_guard = pcb.sig_info_irqsave(); + let pending_set = siginfo_guard.sig_pending().signal(); + let shared_pending_set = siginfo_guard.sig_shared_pending().signal(); + let blocked_set = *siginfo_guard.sig_blocked(); + drop(siginfo_guard); + + let mut result = pending_set.union(shared_pending_set); + result = result.difference(blocked_set); + + user_buffer_writer.copy_one_to_user(&result, 0)?; + + Ok(0) + } } diff --git a/kernel/src/net/event_poll/mod.rs b/kernel/src/net/event_poll/mod.rs index 2536d11b..ba495fc6 100644 --- a/kernel/src/net/event_poll/mod.rs +++ b/kernel/src/net/event_poll/mod.rs @@ -531,8 +531,10 @@ impl EventPoll { continue; } - // 如果有未处理的信号则返回错误 - if current_pcb.has_pending_signal_fast() { + // 如果有未处理且未被屏蔽的信号则返回错误 + if current_pcb.has_pending_signal_fast() + && current_pcb.has_pending_not_masked_signal() + { return Err(SystemError::ERESTARTSYS); } @@ -858,6 +860,14 @@ impl EPollEvent { pub fn events(&self) -> u32 { self.events } + + pub fn set_data(&mut self, data: u64) { + self.data = data; + } + + pub fn data(&self) -> u64 { + self.data + } } /// ## epoll_ctl函数的参数 diff --git a/kernel/src/process/mod.rs b/kernel/src/process/mod.rs index d7f80057..8b46ee5e 100644 --- a/kernel/src/process/mod.rs +++ b/kernel/src/process/mod.rs @@ -1071,6 +1071,24 @@ impl ProcessControlBlock { self.flags.get().contains(ProcessFlags::HAS_PENDING_SIGNAL) } + /// 检查当前进程是否有未被阻塞的待处理信号。 + /// + /// 注:该函数较慢,因此需要与 has_pending_signal_fast 一起使用。 + pub fn has_pending_not_masked_signal(&self) -> bool { + let sig_info = self.sig_info_irqsave(); + let blocked: SigSet = *sig_info.sig_blocked(); + let mut pending: SigSet = sig_info.sig_pending().signal(); + drop(sig_info); + pending.remove(blocked); + // log::debug!( + // "pending and not masked:{:?}, masked: {:?}", + // pending, + // blocked + // ); + let has_not_masked = !pending.is_empty(); + return has_not_masked; + } + pub fn sig_struct(&self) -> SpinLockGuard { self.sig_struct.lock_irqsave() } diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index 3d264694..3d052dbe 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -883,10 +883,7 @@ impl Syscall { Self::poll(fds, nfds, timeout) } - SYS_PPOLL => { - log::warn!("SYS_PPOLL has not yet been implemented"); - Ok(0) - } + SYS_PPOLL => Self::ppoll(args[0], args[1] as u32, args[2], args[3]), SYS_SETPGID => { warn!("SYS_SETPGID has not yet been implemented"); @@ -1233,6 +1230,7 @@ impl Syscall { } SYS_SETRLIMIT => Ok(0), SYS_RESTART_SYSCALL => Self::restart_syscall(), + SYS_RT_SIGPENDING => Self::rt_sigpending(args[0], args[1]), _ => panic!("Unsupported syscall ID: {}", syscall_num), }; diff --git a/kernel/src/time/mod.rs b/kernel/src/time/mod.rs index 12947310..214bf118 100644 --- a/kernel/src/time/mod.rs +++ b/kernel/src/time/mod.rs @@ -288,6 +288,23 @@ impl Instant { let micros_diff = self.micros - earlier.micros; Some(Duration::from_micros(micros_diff as u64)) } + + /// Saturating subtraction. Computes `self - other`, returning [`Instant::ZERO`] if the result would be negative. + /// + /// # Arguments + /// + /// * `other` - The `Instant` to subtract from `self`. + /// + /// # Returns + /// + /// The duration between `self` and `other`, or [`Instant::ZERO`] if `other` is later than `self`. + pub fn saturating_sub(self, other: Instant) -> Duration { + if self.micros >= other.micros { + Duration::from_micros((self.micros - other.micros) as u64) + } else { + Duration::ZERO + } + } } impl fmt::Display for Instant { diff --git a/user/apps/test_poll/.gitignore b/user/apps/test_poll/.gitignore index 96903813..36fe7a1d 100644 --- a/user/apps/test_poll/.gitignore +++ b/user/apps/test_poll/.gitignore @@ -1 +1,3 @@ test_poll +test_ppoll +*.o \ No newline at end of file diff --git a/user/apps/test_poll/Makefile b/user/apps/test_poll/Makefile index 6604e069..437f0ef2 100644 --- a/user/apps/test_poll/Makefile +++ b/user/apps/test_poll/Makefile @@ -8,14 +8,17 @@ BIN_NAME=test_poll CC=$(CROSS_COMPILE)gcc .PHONY: all -all: main.c +all: main.c ppoll.c $(CC) -static -o $(BIN_NAME) main.c + $(CC) -static -o test_ppoll ppoll.c .PHONY: install clean install: all mv $(BIN_NAME) $(DADK_CURRENT_BUILD_DIR)/$(BIN_NAME) + mv test_ppoll $(DADK_CURRENT_BUILD_DIR)/test_ppoll clean: rm $(BIN_NAME) *.o + rm test_ppoll fmt: diff --git a/user/apps/test_poll/ppoll.c b/user/apps/test_poll/ppoll.c new file mode 100644 index 00000000..0300bdd1 --- /dev/null +++ b/user/apps/test_poll/ppoll.c @@ -0,0 +1,148 @@ +#include +#define _GNU_SOURCE + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define RED "\x1B[31m" +#define GREEN "\x1B[32m" +#define RESET "\x1B[0m" + +// 测试用例1:基本功能测试(管道I/O) +void test_basic_functionality() { + int pipefd[2]; + struct pollfd fds[1]; + struct timespec timeout = {5, 0}; // 5秒超时 + + printf("=== Test 1: Basic functionality test ===\n"); + + // 创建管道 + if (pipe(pipefd) == -1) { + perror("pipe creation failed"); + exit(EXIT_FAILURE); + } + + // 设置监听读端管道 + fds[0].fd = pipefd[0]; + fds[0].events = POLLIN; + + printf("Test scenario 1: Wait with no data (should timeout)\n"); + int ret = ppoll(fds, 1, &timeout, NULL); + if (ret == 0) { + printf(GREEN "Test passed: Correct timeout\n" RESET); + } else { + printf(RED "Test failed: Return value %d\n" RESET, ret); + } + + // 向管道写入数据 + const char *msg = "test data"; + write(pipefd[1], msg, strlen(msg)); + + printf( + "\nTest scenario 2: Should return immediately when data is available\n"); + timeout.tv_sec = 5; + ret = ppoll(fds, 1, &timeout, NULL); + if (ret > 0 && (fds[0].revents & POLLIN)) { + printf(GREEN "Test passed: Data detected\n" RESET); + } else { + printf(RED "Test failed: Return value %d, revents %d\n" RESET, ret, + fds[0].revents); + } + + close(pipefd[0]); + close(pipefd[1]); +} + +// 测试用例2:信号屏蔽测试 +void test_signal_handling() { + printf("\n=== Test 2: Signal handling test ===\n"); + sigset_t mask, orig_mask; + struct timespec timeout = {5, 0}; + struct pollfd fds[1]; + + fds[0].fd = -1; + fds[0].events = 0; + + // 设置信号屏蔽 + sigemptyset(&mask); + sigaddset(&mask, SIGUSR1); + // 阻塞SIGUSR1,并保存原来的信号掩码 + if (sigprocmask(SIG_BLOCK, &mask, &orig_mask)) { + perror("sigprocmask"); + exit(EXIT_FAILURE); + } + + printf("Test scenario: Signal should not interrupt when masked\n"); + pid_t pid = fork(); + if (pid == 0) { // 子进程 + sleep(2); // 等待父进程进入ppoll + kill(getppid(), SIGUSR1); + exit(0); + } + + int ret = ppoll(fds, 1, &timeout, &mask); + + if (ret == 0) { + printf(GREEN "Test passed: Completed full 5 second wait\n" RESET); + } else { + printf(RED "Test failed: Premature return %d\n" RESET, errno); + } + + waitpid(pid, NULL, 0); + + // 检查并消费挂起的SIGUSR1信号 + sigset_t pending; + sigpending(&pending); + if (sigismember(&pending, SIGUSR1)) { + int sig; + sigwait(&mask, &sig); // 主动消费信号 + + printf("Consumed pending SIGUSR1 signal\n"); + } + // 恢复原来的信号掩码 + sigprocmask(SIG_SETMASK, &orig_mask, NULL); +} + +// 测试用例3:精确超时测试 +void test_timeout_accuracy() { + printf("\n=== Test 3: Timeout accuracy test ===\n"); + struct timespec start, end, timeout = {0, 500000000}; + struct pollfd fds[1]; + fds[0].fd = -1; + fds[0].events = 0; + + clock_gettime(CLOCK_MONOTONIC, &start); + int ret = ppoll(fds, 1, &timeout, NULL); + clock_gettime(CLOCK_MONOTONIC, &end); + + long elapsed = (end.tv_sec - start.tv_sec) * 1000000 + + (end.tv_nsec - start.tv_nsec) / 1000; + + printf("Expected timeout: 500ms, Actual elapsed: %.3fms\n", elapsed / 1000.0); + if (labs(elapsed - 500000) < 50000) { // 允许±50ms误差 + printf(GREEN "Test passed: Timeout within acceptable range\n" RESET); + } else { + printf(RED "Test failed: Timeout deviation too large\n" RESET); + } +} + +int main() { + // 设置非阻塞标准输入 + fcntl(STDIN_FILENO, F_SETFL, O_NONBLOCK); + + test_basic_functionality(); + test_signal_handling(); + test_timeout_accuracy(); + + return 0; +}