Make Pauser work in kernel threads

This commit is contained in:
Ruihan Li
2024-07-06 18:20:02 +08:00
committed by Tate, Hongliang Tian
parent 9e5f3123e1
commit 8f64a1cb90

View File

@ -11,7 +11,11 @@ use ostd::sync::WaitQueue;
use super::{sig_mask::SigMask, SigEvents, SigEventsFilter}; use super::{sig_mask::SigMask, SigEvents, SigEventsFilter};
use crate::{ use crate::{
events::Observer, prelude::*, process::posix_thread::PosixThreadExt, time::wait::WaitTimeout, events::Observer,
prelude::*,
process::posix_thread::{PosixThread, PosixThreadExt},
thread::Thread,
time::wait::WaitTimeout,
}; };
/// A `Pauser` allows pausing the execution of the current thread until certain conditions are reached. /// A `Pauser` allows pausing the execution of the current thread until certain conditions are reached.
@ -120,66 +124,30 @@ impl Pauser {
F: FnMut() -> Option<R>, F: FnMut() -> Option<R>,
{ {
let current_thread = current_thread!(); let current_thread = current_thread!();
let posix_thread = current_thread.as_posix_thread().unwrap(); let sig_queue_waiter =
SigObserverRegistrar::new(&current_thread, self.sig_mask, self.clone());
// Block `self.sig_mask` let cond = || {
let (old_mask, filter) = { if let Some(res) = cond() {
let mut current_mask = posix_thread.sig_mask().lock(); return Some(Ok(res));
let old_mask = *current_mask;
let new_mask = {
current_mask.block(self.sig_mask.as_u64());
*current_mask
};
(old_mask, SigEventsFilter::new(new_mask))
};
// Register observer on sigqueue
let observer = SigQueueObserver::new(self.clone());
let weak_observer = Arc::downgrade(&observer) as Weak<dyn Observer<SigEvents>>;
posix_thread.register_sigqueue_observer(weak_observer.clone(), filter);
// Some signal may come before we register observer, so we do another check here.
if posix_thread.has_pending() {
observer.set_interrupted();
}
enum Res<R> {
Ok(R),
Interrupted,
}
let cond = {
let cloned_observer = observer.clone();
move || {
if let Some(res) = cond() {
return Some(Res::Ok(res));
}
if cloned_observer.is_interrupted() {
return Some(Res::Interrupted);
}
None
} }
if sig_queue_waiter.is_interrupted() {
return Some(Err(Error::with_message(
Errno::EINTR,
"the current thread is interrupted by a signal",
)));
}
None
}; };
let res = if let Some(timeout) = timeout { if let Some(timeout) = timeout {
self.wait_queue self.wait_queue
.wait_until_or_timeout(cond, timeout) .wait_until_or_timeout(cond, timeout)
.ok_or_else(|| Error::with_message(Errno::ETIME, "timeout is reached")) .ok_or_else(|| Error::with_message(Errno::ETIME, "the time limit is reached"))?
} else { } else {
Ok(self.wait_queue.wait_until(cond)) self.wait_queue.wait_until(cond)
};
// Restore the state
posix_thread.unregiser_sigqueue_observer(&weak_observer);
posix_thread.sig_mask().lock().set(old_mask.as_u64());
match res? {
Res::Ok(r) => Ok(r),
Res::Interrupted => return_errno_with_message!(Errno::EINTR, "interrupted by signal"),
} }
} }
@ -194,6 +162,78 @@ impl Pauser {
} }
} }
enum SigObserverRegistrar<'a> {
// A POSIX thread may be interrupted by a signal if the signal is not masked.
PosixThread {
thread: &'a PosixThread,
old_mask: SigMask,
observer: Arc<SigQueueObserver>,
},
// A kernel thread ignores all signals. It is not necessary to wait for them.
KernelThread,
}
impl<'a> SigObserverRegistrar<'a> {
fn new(current_thread: &'a Arc<Thread>, sig_mask: SigMask, pauser: Arc<Pauser>) -> Self {
let Some(thread) = current_thread.as_posix_thread() else {
return Self::KernelThread;
};
// Block `sig_mask`.
let (old_mask, filter) = {
let mut locked_mask = thread.sig_mask().lock();
let old_mask = *locked_mask;
let new_mask = {
locked_mask.block(sig_mask.as_u64());
*locked_mask
};
(old_mask, SigEventsFilter::new(new_mask))
};
// Register `SigQueueObserver`.
let observer = SigQueueObserver::new(pauser);
thread.register_sigqueue_observer(Arc::downgrade(&observer) as _, filter);
// Check pending signals after registering the observer to avoid race conditions.
if thread.has_pending() {
observer.set_interrupted();
}
Self::PosixThread {
thread,
old_mask,
observer,
}
}
fn is_interrupted(&self) -> bool {
match self {
Self::PosixThread { observer, .. } => observer.is_interrupted(),
Self::KernelThread => false,
}
}
}
impl<'a> Drop for SigObserverRegistrar<'a> {
fn drop(&mut self) {
let Self::PosixThread {
thread,
old_mask,
observer,
} = self
else {
return;
};
// Restore the state, assuming no one else can modify the current thread's signal mask
// during the pause.
thread.unregiser_sigqueue_observer(&(Arc::downgrade(observer) as _));
thread.sig_mask().lock().set(old_mask.as_u64());
}
}
struct SigQueueObserver { struct SigQueueObserver {
is_interrupted: AtomicBool, is_interrupted: AtomicBool,
pauser: Arc<Pauser>, pauser: Arc<Pauser>,
@ -222,3 +262,38 @@ impl Observer<SigEvents> for SigQueueObserver {
self.pauser.wait_queue.wake_all(); self.pauser.wait_queue.wake_all();
} }
} }
#[cfg(ktest)]
mod test {
use ostd::prelude::*;
use super::*;
use crate::thread::{
kernel_thread::{KernelThreadExt, ThreadOptions},
Thread,
};
#[ktest]
fn test_pauser() {
let pauser = Pauser::new();
let pauser_cloned = pauser.clone();
let boolean = Arc::new(AtomicBool::new(false));
let boolean_cloned = boolean.clone();
let thread1 = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
pauser
.pause_until(|| boolean.load(Ordering::Relaxed).then_some(()))
.unwrap();
}));
let thread2 = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
Thread::yield_now();
boolean_cloned.store(true, Ordering::Relaxed);
pauser_cloned.resume_all();
}));
thread1.join();
thread2.join();
}
}