mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-25 18:33:24 +00:00
Make Pauser
work in kernel threads
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
9e5f3123e1
commit
8f64a1cb90
@ -11,7 +11,11 @@ use ostd::sync::WaitQueue;
|
|||||||
|
|
||||||
use super::{sig_mask::SigMask, SigEvents, SigEventsFilter};
|
use super::{sig_mask::SigMask, SigEvents, SigEventsFilter};
|
||||||
use crate::{
|
use crate::{
|
||||||
events::Observer, prelude::*, process::posix_thread::PosixThreadExt, time::wait::WaitTimeout,
|
events::Observer,
|
||||||
|
prelude::*,
|
||||||
|
process::posix_thread::{PosixThread, PosixThreadExt},
|
||||||
|
thread::Thread,
|
||||||
|
time::wait::WaitTimeout,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A `Pauser` allows pausing the execution of the current thread until certain conditions are reached.
|
/// A `Pauser` allows pausing the execution of the current thread until certain conditions are reached.
|
||||||
@ -120,66 +124,30 @@ impl Pauser {
|
|||||||
F: FnMut() -> Option<R>,
|
F: FnMut() -> Option<R>,
|
||||||
{
|
{
|
||||||
let current_thread = current_thread!();
|
let current_thread = current_thread!();
|
||||||
let posix_thread = current_thread.as_posix_thread().unwrap();
|
let sig_queue_waiter =
|
||||||
|
SigObserverRegistrar::new(¤t_thread, self.sig_mask, self.clone());
|
||||||
|
|
||||||
// Block `self.sig_mask`
|
let cond = || {
|
||||||
let (old_mask, filter) = {
|
if let Some(res) = cond() {
|
||||||
let mut current_mask = posix_thread.sig_mask().lock();
|
return Some(Ok(res));
|
||||||
let old_mask = *current_mask;
|
|
||||||
|
|
||||||
let new_mask = {
|
|
||||||
current_mask.block(self.sig_mask.as_u64());
|
|
||||||
*current_mask
|
|
||||||
};
|
|
||||||
|
|
||||||
(old_mask, SigEventsFilter::new(new_mask))
|
|
||||||
};
|
|
||||||
|
|
||||||
// Register observer on sigqueue
|
|
||||||
let observer = SigQueueObserver::new(self.clone());
|
|
||||||
let weak_observer = Arc::downgrade(&observer) as Weak<dyn Observer<SigEvents>>;
|
|
||||||
posix_thread.register_sigqueue_observer(weak_observer.clone(), filter);
|
|
||||||
|
|
||||||
// Some signal may come before we register observer, so we do another check here.
|
|
||||||
if posix_thread.has_pending() {
|
|
||||||
observer.set_interrupted();
|
|
||||||
}
|
|
||||||
|
|
||||||
enum Res<R> {
|
|
||||||
Ok(R),
|
|
||||||
Interrupted,
|
|
||||||
}
|
|
||||||
|
|
||||||
let cond = {
|
|
||||||
let cloned_observer = observer.clone();
|
|
||||||
move || {
|
|
||||||
if let Some(res) = cond() {
|
|
||||||
return Some(Res::Ok(res));
|
|
||||||
}
|
|
||||||
|
|
||||||
if cloned_observer.is_interrupted() {
|
|
||||||
return Some(Res::Interrupted);
|
|
||||||
}
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sig_queue_waiter.is_interrupted() {
|
||||||
|
return Some(Err(Error::with_message(
|
||||||
|
Errno::EINTR,
|
||||||
|
"the current thread is interrupted by a signal",
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = if let Some(timeout) = timeout {
|
if let Some(timeout) = timeout {
|
||||||
self.wait_queue
|
self.wait_queue
|
||||||
.wait_until_or_timeout(cond, timeout)
|
.wait_until_or_timeout(cond, timeout)
|
||||||
.ok_or_else(|| Error::with_message(Errno::ETIME, "timeout is reached"))
|
.ok_or_else(|| Error::with_message(Errno::ETIME, "the time limit is reached"))?
|
||||||
} else {
|
} else {
|
||||||
Ok(self.wait_queue.wait_until(cond))
|
self.wait_queue.wait_until(cond)
|
||||||
};
|
|
||||||
|
|
||||||
// Restore the state
|
|
||||||
posix_thread.unregiser_sigqueue_observer(&weak_observer);
|
|
||||||
posix_thread.sig_mask().lock().set(old_mask.as_u64());
|
|
||||||
|
|
||||||
match res? {
|
|
||||||
Res::Ok(r) => Ok(r),
|
|
||||||
Res::Interrupted => return_errno_with_message!(Errno::EINTR, "interrupted by signal"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,6 +162,78 @@ impl Pauser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum SigObserverRegistrar<'a> {
|
||||||
|
// A POSIX thread may be interrupted by a signal if the signal is not masked.
|
||||||
|
PosixThread {
|
||||||
|
thread: &'a PosixThread,
|
||||||
|
old_mask: SigMask,
|
||||||
|
observer: Arc<SigQueueObserver>,
|
||||||
|
},
|
||||||
|
// A kernel thread ignores all signals. It is not necessary to wait for them.
|
||||||
|
KernelThread,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SigObserverRegistrar<'a> {
|
||||||
|
fn new(current_thread: &'a Arc<Thread>, sig_mask: SigMask, pauser: Arc<Pauser>) -> Self {
|
||||||
|
let Some(thread) = current_thread.as_posix_thread() else {
|
||||||
|
return Self::KernelThread;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Block `sig_mask`.
|
||||||
|
let (old_mask, filter) = {
|
||||||
|
let mut locked_mask = thread.sig_mask().lock();
|
||||||
|
|
||||||
|
let old_mask = *locked_mask;
|
||||||
|
let new_mask = {
|
||||||
|
locked_mask.block(sig_mask.as_u64());
|
||||||
|
*locked_mask
|
||||||
|
};
|
||||||
|
|
||||||
|
(old_mask, SigEventsFilter::new(new_mask))
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register `SigQueueObserver`.
|
||||||
|
let observer = SigQueueObserver::new(pauser);
|
||||||
|
thread.register_sigqueue_observer(Arc::downgrade(&observer) as _, filter);
|
||||||
|
|
||||||
|
// Check pending signals after registering the observer to avoid race conditions.
|
||||||
|
if thread.has_pending() {
|
||||||
|
observer.set_interrupted();
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::PosixThread {
|
||||||
|
thread,
|
||||||
|
old_mask,
|
||||||
|
observer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_interrupted(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::PosixThread { observer, .. } => observer.is_interrupted(),
|
||||||
|
Self::KernelThread => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Drop for SigObserverRegistrar<'a> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let Self::PosixThread {
|
||||||
|
thread,
|
||||||
|
old_mask,
|
||||||
|
observer,
|
||||||
|
} = self
|
||||||
|
else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Restore the state, assuming no one else can modify the current thread's signal mask
|
||||||
|
// during the pause.
|
||||||
|
thread.unregiser_sigqueue_observer(&(Arc::downgrade(observer) as _));
|
||||||
|
thread.sig_mask().lock().set(old_mask.as_u64());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct SigQueueObserver {
|
struct SigQueueObserver {
|
||||||
is_interrupted: AtomicBool,
|
is_interrupted: AtomicBool,
|
||||||
pauser: Arc<Pauser>,
|
pauser: Arc<Pauser>,
|
||||||
@ -222,3 +262,38 @@ impl Observer<SigEvents> for SigQueueObserver {
|
|||||||
self.pauser.wait_queue.wake_all();
|
self.pauser.wait_queue.wake_all();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(ktest)]
|
||||||
|
mod test {
|
||||||
|
use ostd::prelude::*;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::thread::{
|
||||||
|
kernel_thread::{KernelThreadExt, ThreadOptions},
|
||||||
|
Thread,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[ktest]
|
||||||
|
fn test_pauser() {
|
||||||
|
let pauser = Pauser::new();
|
||||||
|
let pauser_cloned = pauser.clone();
|
||||||
|
|
||||||
|
let boolean = Arc::new(AtomicBool::new(false));
|
||||||
|
let boolean_cloned = boolean.clone();
|
||||||
|
|
||||||
|
let thread1 = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
|
||||||
|
pauser
|
||||||
|
.pause_until(|| boolean.load(Ordering::Relaxed).then_some(()))
|
||||||
|
.unwrap();
|
||||||
|
}));
|
||||||
|
|
||||||
|
let thread2 = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
|
||||||
|
Thread::yield_now();
|
||||||
|
boolean_cloned.store(true, Ordering::Relaxed);
|
||||||
|
pauser_cloned.resume_all();
|
||||||
|
}));
|
||||||
|
|
||||||
|
thread1.join();
|
||||||
|
thread2.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user