Implement a safe and race-free Waiter

This commit is contained in:
Ruihan Li
2024-04-11 15:11:17 +08:00
committed by Tate, Hongliang Tian
parent dac41e9a2f
commit 4851204059
3 changed files with 139 additions and 65 deletions

View File

@ -1,7 +1,10 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::{collections::VecDeque, sync::Arc};
use core::time::Duration;
use core::{
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use super::SpinLock;
use crate::{
@ -16,13 +19,13 @@ use crate::{
/// Other threads may invoke the `wake`-family methods of a wait queue to
/// wake up one or many waiter threads.
pub struct WaitQueue {
waiters: SpinLock<VecDeque<Arc<Waiter>>>,
wakers: SpinLock<VecDeque<Arc<Waker>>>,
}
impl WaitQueue {
pub const fn new() -> Self {
WaitQueue {
waiters: SpinLock::new(VecDeque::new()),
wakers: SpinLock::new(VecDeque::new()),
}
}
@ -61,7 +64,7 @@ impl WaitQueue {
return Some(res);
}
let waiter = Arc::new(Waiter::new());
let (waiter, waker) = Waiter::new_pair();
let timer_callback = timeout.map(|timeout| {
let remaining_ticks = {
@ -76,16 +79,16 @@ impl WaitQueue {
(timeout.as_millis() as u64 + ms_per_tick - 1) / ms_per_tick
};
add_timeout_list(remaining_ticks, waiter.clone(), |timer_call_back| {
let waiter = timer_call_back
.data()
.downcast_ref::<Arc<Waiter>>()
.unwrap();
waiter.wake_up();
add_timeout_list(remaining_ticks, waker.clone(), |timer_call_back| {
let waker = timer_call_back.data().downcast_ref::<Arc<Waker>>().unwrap();
waker.wake_up();
})
});
loop {
// Enqueue the waker before checking `cond()` to avoid races
self.enqueue(waker.clone());
if let Some(res) = cond() {
if let Some(timer_callback) = timer_callback {
timer_callback.cancel();
@ -97,19 +100,20 @@ impl WaitQueue {
if let Some(ref timer_callback) = timer_callback
&& timer_callback.is_expired()
{
// Drop the waiter and check again to avoid missing a wake event
drop(waiter);
return cond();
}
self.enqueue(&waiter);
waiter.wait();
}
}
/// Wake up one waiting thread.
pub fn wake_one(&self) {
while let Some(waiter) = self.waiters.lock_irq_disabled().pop_front() {
while let Some(waker) = self.wakers.lock_irq_disabled().pop_front() {
// Avoid holding lock when calling `wake_up`
if waiter.wake_up() {
if waker.wake_up() {
return;
}
}
@ -117,63 +121,124 @@ impl WaitQueue {
/// Wake up all waiting threads.
pub fn wake_all(&self) {
while let Some(waiter) = self.waiters.lock_irq_disabled().pop_front() {
while let Some(waker) = self.wakers.lock_irq_disabled().pop_front() {
// Avoid holding lock when calling `wake_up`
waiter.wake_up();
waker.wake_up();
}
}
pub fn is_empty(&self) -> bool {
self.waiters.lock_irq_disabled().is_empty()
self.wakers.lock_irq_disabled().is_empty()
}
// Enqueue a waiter into current waitqueue. If waiter is exclusive, add to the back of waitqueue.
// Otherwise, add to the front of waitqueue
fn enqueue(&self, waiter: &Arc<Waiter>) {
self.waiters.lock_irq_disabled().push_back(waiter.clone());
fn enqueue(&self, waker: Arc<Waker>) {
self.wakers.lock_irq_disabled().push_back(waker);
}
}
/// A waiter that can put the current thread to sleep until it is woken up by the associated
/// [`Waker`].
///
/// By definition, a waiter belongs to the current thread, so it cannot be sent to another thread
/// and its reference cannot be shared between threads.
struct Waiter {
/// The `Task` held by the waiter.
waker: Arc<Waker>,
}
impl !Send for Waiter {}
impl !Sync for Waiter {}
/// A waker that can wake up the associated [`Waiter`].
///
/// A waker can be created by calling [`Waiter::new`]. This method creates an `Arc<Waker>` that can
/// be used across different threads.
struct Waker {
has_woken: AtomicBool,
task: Arc<Task>,
}
impl Waiter {
pub fn new() -> Self {
Waiter {
/// Creates a waiter and its associated [`Waker`].
pub fn new_pair() -> (Self, Arc<Waker>) {
let waker = Arc::new(Waker {
has_woken: AtomicBool::new(false),
task: current_task().unwrap(),
}
});
let waiter = Self {
waker: waker.clone(),
};
(waiter, waker)
}
/// Wait until being woken up
/// Waits until the waiter is woken up by calling [`Waker::wake_up`] on the associated
/// [`Waker`].
///
/// This method returns immediately if the waiter has been woken since the end of the last call
/// to this method (or since the waiter was created, if this method has not been called
/// before). Otherwise, it puts the current thread to sleep until the waiter is woken up.
pub fn wait(&self) {
debug_assert_eq!(
self.task.inner_exclusive_access().task_status,
TaskStatus::Runnable
);
self.task.inner_exclusive_access().task_status = TaskStatus::Sleeping;
while self.task.inner_exclusive_access().task_status == TaskStatus::Sleeping {
schedule();
self.waker.do_wait();
}
}
/// Wake up a waiting task.
/// If the task is waiting before being woken, return true;
/// Otherwise return false.
impl Drop for Waiter {
fn drop(&mut self) {
// When dropping the waiter, we need to close the waker to ensure that if someone wants to
// wake up the waiter afterwards, they will perform a no-op.
self.waker.close();
}
}
impl Waker {
/// Wakes up the associated [`Waiter`].
///
/// This method returns `true` if the waiter is woken by this call. It returns `false` if the
/// waiter has already been woken by a previous call to the method, or if the waiter has been
/// dropped.
///
/// Note that if this method returns `true`, it implies that the wake event will be properly
/// delivered, _or_ that the waiter will be dropped after being woken. It's up to the caller to
/// handle the latter case properly to avoid missing the wake event.
pub fn wake_up(&self) -> bool {
if self.has_woken.swap(true, Ordering::AcqRel) {
return false;
}
let mut task = self.task.inner_exclusive_access();
if task.task_status == TaskStatus::Sleeping {
match task.task_status {
TaskStatus::Sleepy => {
task.task_status = TaskStatus::Runnable;
}
TaskStatus::Sleeping => {
task.task_status = TaskStatus::Runnable;
// Avoid holding lock when doing `add_task`
// Avoid holding the lock when doing `add_task`
drop(task);
add_task(self.task.clone());
}
_ => (),
}
true
} else {
false
}
}
fn do_wait(&self) {
while !self.has_woken.load(Ordering::Acquire) {
let mut task = self.task.inner_exclusive_access();
// After holding the lock, check again to avoid races
if self.has_woken.load(Ordering::Acquire) {
break;
}
task.task_status = TaskStatus::Sleepy;
drop(task);
schedule();
}
self.has_woken.store(false, Ordering::Release);
}
fn close(&self) {
self.has_woken.store(true, Ordering::Release);
}
}

View File

@ -91,26 +91,33 @@ fn switch_to_task(next_task: Arc<Task>) {
"Calling schedule() while holding {} locks",
PREEMPT_COUNT.num_locks()
);
//GLOBAL_SCHEDULER.lock_irq_disabled().enqueue(next_task);
//return;
}
let current_task_option = current_task();
let next_task_cx_ptr = &next_task.inner_ctx() as *const TaskContext;
let current_task: Arc<Task>;
let current_task_cx_ptr = match current_task_option {
let current_task_cx_ptr = match current_task() {
None => PROCESSOR.lock().get_idle_task_cx_ptr(),
Some(current_task) => {
if current_task.status() == TaskStatus::Runnable {
GLOBAL_SCHEDULER
.lock_irq_disabled()
.enqueue(current_task.clone());
let mut task = current_task.inner_exclusive_access();
// FIXME: `task.ctx` should be put in a separate `UnsafeCell`, not as a part of
// `TaskInner`. Otherwise, it violates the sematics of `SpinLock` and Rust's memory
// model which requires that mutable references must be exclusive.
let cx_ptr = &mut task.ctx as *mut TaskContext;
debug_assert_ne!(task.task_status, TaskStatus::Sleeping);
if task.task_status == TaskStatus::Runnable {
drop(task);
GLOBAL_SCHEDULER.lock_irq_disabled().enqueue(current_task);
} else if task.task_status == TaskStatus::Sleepy {
task.task_status = TaskStatus::Sleeping;
}
&mut current_task.inner_exclusive_access().ctx as *mut TaskContext
cx_ptr
}
};
// change the current task to the next task
let next_task_cx_ptr = &next_task.inner_ctx() as *const TaskContext;
// change the current task to the next task
PROCESSOR.lock().current = Some(next_task.clone());
unsafe {
context_switch(current_task_cx_ptr, next_task_cx_ptr);

View File

@ -190,7 +190,9 @@ impl Task {
pub enum TaskStatus {
/// The task is runnable.
Runnable,
/// The task is sleeping.
/// The task is running in the foreground but will sleep when it goes to the background.
Sleepy,
/// The task is sleeping in the background.
Sleeping,
/// The task has exited.
Exited,