diff --git a/kernel/aster-nix/src/process/signal/pauser.rs b/kernel/aster-nix/src/process/signal/pauser.rs index f02ae5346..b64fe7db5 100644 --- a/kernel/aster-nix/src/process/signal/pauser.rs +++ b/kernel/aster-nix/src/process/signal/pauser.rs @@ -11,7 +11,11 @@ use ostd::sync::WaitQueue; use super::{sig_mask::SigMask, SigEvents, SigEventsFilter}; use crate::{ - events::Observer, prelude::*, process::posix_thread::PosixThreadExt, time::wait::WaitTimeout, + events::Observer, + prelude::*, + process::posix_thread::{PosixThread, PosixThreadExt}, + thread::Thread, + time::wait::WaitTimeout, }; /// A `Pauser` allows pausing the execution of the current thread until certain conditions are reached. @@ -120,66 +124,30 @@ impl Pauser { F: FnMut() -> Option, { let current_thread = current_thread!(); - let posix_thread = current_thread.as_posix_thread().unwrap(); + let sig_queue_waiter = + SigObserverRegistrar::new(¤t_thread, self.sig_mask, self.clone()); - // Block `self.sig_mask` - let (old_mask, filter) = { - let mut current_mask = posix_thread.sig_mask().lock(); - let old_mask = *current_mask; - - let new_mask = { - current_mask.block(self.sig_mask.as_u64()); - *current_mask - }; - - (old_mask, SigEventsFilter::new(new_mask)) - }; - - // Register observer on sigqueue - let observer = SigQueueObserver::new(self.clone()); - let weak_observer = Arc::downgrade(&observer) as Weak>; - posix_thread.register_sigqueue_observer(weak_observer.clone(), filter); - - // Some signal may come before we register observer, so we do another check here. - if posix_thread.has_pending() { - observer.set_interrupted(); - } - - enum Res { - Ok(R), - Interrupted, - } - - let cond = { - let cloned_observer = observer.clone(); - move || { - if let Some(res) = cond() { - return Some(Res::Ok(res)); - } - - if cloned_observer.is_interrupted() { - return Some(Res::Interrupted); - } - - None + let cond = || { + if let Some(res) = cond() { + return Some(Ok(res)); } + + if sig_queue_waiter.is_interrupted() { + return Some(Err(Error::with_message( + Errno::EINTR, + "the current thread is interrupted by a signal", + ))); + } + + None }; - let res = if let Some(timeout) = timeout { + if let Some(timeout) = timeout { self.wait_queue .wait_until_or_timeout(cond, timeout) - .ok_or_else(|| Error::with_message(Errno::ETIME, "timeout is reached")) + .ok_or_else(|| Error::with_message(Errno::ETIME, "the time limit is reached"))? } else { - Ok(self.wait_queue.wait_until(cond)) - }; - - // Restore the state - posix_thread.unregiser_sigqueue_observer(&weak_observer); - posix_thread.sig_mask().lock().set(old_mask.as_u64()); - - match res? { - Res::Ok(r) => Ok(r), - Res::Interrupted => return_errno_with_message!(Errno::EINTR, "interrupted by signal"), + self.wait_queue.wait_until(cond) } } @@ -194,6 +162,78 @@ impl Pauser { } } +enum SigObserverRegistrar<'a> { + // A POSIX thread may be interrupted by a signal if the signal is not masked. + PosixThread { + thread: &'a PosixThread, + old_mask: SigMask, + observer: Arc, + }, + // A kernel thread ignores all signals. It is not necessary to wait for them. + KernelThread, +} + +impl<'a> SigObserverRegistrar<'a> { + fn new(current_thread: &'a Arc, sig_mask: SigMask, pauser: Arc) -> Self { + let Some(thread) = current_thread.as_posix_thread() else { + return Self::KernelThread; + }; + + // Block `sig_mask`. + let (old_mask, filter) = { + let mut locked_mask = thread.sig_mask().lock(); + + let old_mask = *locked_mask; + let new_mask = { + locked_mask.block(sig_mask.as_u64()); + *locked_mask + }; + + (old_mask, SigEventsFilter::new(new_mask)) + }; + + // Register `SigQueueObserver`. + let observer = SigQueueObserver::new(pauser); + thread.register_sigqueue_observer(Arc::downgrade(&observer) as _, filter); + + // Check pending signals after registering the observer to avoid race conditions. + if thread.has_pending() { + observer.set_interrupted(); + } + + Self::PosixThread { + thread, + old_mask, + observer, + } + } + + fn is_interrupted(&self) -> bool { + match self { + Self::PosixThread { observer, .. } => observer.is_interrupted(), + Self::KernelThread => false, + } + } +} + +impl<'a> Drop for SigObserverRegistrar<'a> { + fn drop(&mut self) { + let Self::PosixThread { + thread, + old_mask, + observer, + } = self + else { + return; + }; + + // Restore the state, assuming no one else can modify the current thread's signal mask + // during the pause. + thread.unregiser_sigqueue_observer(&(Arc::downgrade(observer) as _)); + thread.sig_mask().lock().set(old_mask.as_u64()); + } +} + struct SigQueueObserver { is_interrupted: AtomicBool, pauser: Arc, @@ -222,3 +262,38 @@ impl Observer for SigQueueObserver { self.pauser.wait_queue.wake_all(); } } + +#[cfg(ktest)] +mod test { + use ostd::prelude::*; + + use super::*; + use crate::thread::{ + kernel_thread::{KernelThreadExt, ThreadOptions}, + Thread, + }; + + #[ktest] + fn test_pauser() { + let pauser = Pauser::new(); + let pauser_cloned = pauser.clone(); + + let boolean = Arc::new(AtomicBool::new(false)); + let boolean_cloned = boolean.clone(); + + let thread1 = Thread::spawn_kernel_thread(ThreadOptions::new(move || { + pauser + .pause_until(|| boolean.load(Ordering::Relaxed).then_some(())) + .unwrap(); + })); + + let thread2 = Thread::spawn_kernel_thread(ThreadOptions::new(move || { + Thread::yield_now(); + boolean_cloned.store(true, Ordering::Relaxed); + pauser_cloned.resume_all(); + })); + + thread1.join(); + thread2.join(); + } +}