diff --git a/framework/aster-frame/src/sync/wait.rs b/framework/aster-frame/src/sync/wait.rs index 74e924a2a..1eacc07b3 100644 --- a/framework/aster-frame/src/sync/wait.rs +++ b/framework/aster-frame/src/sync/wait.rs @@ -1,7 +1,10 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::{collections::VecDeque, sync::Arc}; -use core::time::Duration; +use core::{ + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; use super::SpinLock; use crate::{ @@ -16,13 +19,13 @@ use crate::{ /// Other threads may invoke the `wake`-family methods of a wait queue to /// wake up one or many waiter threads. pub struct WaitQueue { - waiters: SpinLock>>, + wakers: SpinLock>>, } impl WaitQueue { pub const fn new() -> Self { WaitQueue { - waiters: SpinLock::new(VecDeque::new()), + wakers: SpinLock::new(VecDeque::new()), } } @@ -61,7 +64,7 @@ impl WaitQueue { return Some(res); } - let waiter = Arc::new(Waiter::new()); + let (waiter, waker) = Waiter::new_pair(); let timer_callback = timeout.map(|timeout| { let remaining_ticks = { @@ -76,16 +79,16 @@ impl WaitQueue { (timeout.as_millis() as u64 + ms_per_tick - 1) / ms_per_tick }; - add_timeout_list(remaining_ticks, waiter.clone(), |timer_call_back| { - let waiter = timer_call_back - .data() - .downcast_ref::>() - .unwrap(); - waiter.wake_up(); + add_timeout_list(remaining_ticks, waker.clone(), |timer_call_back| { + let waker = timer_call_back.data().downcast_ref::>().unwrap(); + waker.wake_up(); }) }); loop { + // Enqueue the waker before checking `cond()` to avoid races + self.enqueue(waker.clone()); + if let Some(res) = cond() { if let Some(timer_callback) = timer_callback { timer_callback.cancel(); @@ -97,19 +100,20 @@ impl WaitQueue { if let Some(ref timer_callback) = timer_callback && timer_callback.is_expired() { + // Drop the waiter and check again to avoid missing a wake event + drop(waiter); return cond(); } - self.enqueue(&waiter); waiter.wait(); } } /// Wake up one waiting thread. pub fn wake_one(&self) { - while let Some(waiter) = self.waiters.lock_irq_disabled().pop_front() { + while let Some(waker) = self.wakers.lock_irq_disabled().pop_front() { // Avoid holding lock when calling `wake_up` - if waiter.wake_up() { + if waker.wake_up() { return; } } @@ -117,63 +121,124 @@ impl WaitQueue { /// Wake up all waiting threads. pub fn wake_all(&self) { - while let Some(waiter) = self.waiters.lock_irq_disabled().pop_front() { + while let Some(waker) = self.wakers.lock_irq_disabled().pop_front() { // Avoid holding lock when calling `wake_up` - waiter.wake_up(); + waker.wake_up(); } } pub fn is_empty(&self) -> bool { - self.waiters.lock_irq_disabled().is_empty() + self.wakers.lock_irq_disabled().is_empty() } - // Enqueue a waiter into current waitqueue. If waiter is exclusive, add to the back of waitqueue. - // Otherwise, add to the front of waitqueue - fn enqueue(&self, waiter: &Arc) { - self.waiters.lock_irq_disabled().push_back(waiter.clone()); + fn enqueue(&self, waker: Arc) { + self.wakers.lock_irq_disabled().push_back(waker); } } +/// A waiter that can put the current thread to sleep until it is woken up by the associated +/// [`Waker`]. +/// +/// By definition, a waiter belongs to the current thread, so it cannot be sent to another thread +/// and its reference cannot be shared between threads. struct Waiter { - /// The `Task` held by the waiter. + waker: Arc, +} + +impl !Send for Waiter {} +impl !Sync for Waiter {} + +/// A waker that can wake up the associated [`Waiter`]. +/// +/// A waker can be created by calling [`Waiter::new`]. This method creates an `Arc` that can +/// be used across different threads. +struct Waker { + has_woken: AtomicBool, task: Arc, } impl Waiter { - pub fn new() -> Self { - Waiter { + /// Creates a waiter and its associated [`Waker`]. + pub fn new_pair() -> (Self, Arc) { + let waker = Arc::new(Waker { + has_woken: AtomicBool::new(false), task: current_task().unwrap(), - } + }); + let waiter = Self { + waker: waker.clone(), + }; + (waiter, waker) } - /// Wait until being woken up + /// Waits until the waiter is woken up by calling [`Waker::wake_up`] on the associated + /// [`Waker`]. + /// + /// This method returns immediately if the waiter has been woken since the end of the last call + /// to this method (or since the waiter was created, if this method has not been called + /// before). Otherwise, it puts the current thread to sleep until the waiter is woken up. pub fn wait(&self) { - debug_assert_eq!( - self.task.inner_exclusive_access().task_status, - TaskStatus::Runnable - ); - self.task.inner_exclusive_access().task_status = TaskStatus::Sleeping; - while self.task.inner_exclusive_access().task_status == TaskStatus::Sleeping { - schedule(); - } - } - - /// Wake up a waiting task. - /// If the task is waiting before being woken, return true; - /// Otherwise return false. - pub fn wake_up(&self) -> bool { - let mut task = self.task.inner_exclusive_access(); - if task.task_status == TaskStatus::Sleeping { - task.task_status = TaskStatus::Runnable; - - // Avoid holding lock when doing `add_task` - drop(task); - - add_task(self.task.clone()); - - true - } else { - false - } + self.waker.do_wait(); + } +} + +impl Drop for Waiter { + fn drop(&mut self) { + // When dropping the waiter, we need to close the waker to ensure that if someone wants to + // wake up the waiter afterwards, they will perform a no-op. + self.waker.close(); + } +} + +impl Waker { + /// Wakes up the associated [`Waiter`]. + /// + /// This method returns `true` if the waiter is woken by this call. It returns `false` if the + /// waiter has already been woken by a previous call to the method, or if the waiter has been + /// dropped. + /// + /// Note that if this method returns `true`, it implies that the wake event will be properly + /// delivered, _or_ that the waiter will be dropped after being woken. It's up to the caller to + /// handle the latter case properly to avoid missing the wake event. + pub fn wake_up(&self) -> bool { + if self.has_woken.swap(true, Ordering::AcqRel) { + return false; + } + + let mut task = self.task.inner_exclusive_access(); + match task.task_status { + TaskStatus::Sleepy => { + task.task_status = TaskStatus::Runnable; + } + TaskStatus::Sleeping => { + task.task_status = TaskStatus::Runnable; + + // Avoid holding the lock when doing `add_task` + drop(task); + add_task(self.task.clone()); + } + _ => (), + } + + true + } + + fn do_wait(&self) { + while !self.has_woken.load(Ordering::Acquire) { + let mut task = self.task.inner_exclusive_access(); + // After holding the lock, check again to avoid races + if self.has_woken.load(Ordering::Acquire) { + break; + } + task.task_status = TaskStatus::Sleepy; + drop(task); + + schedule(); + } + + self.has_woken.store(false, Ordering::Release); + } + + fn close(&self) { + self.has_woken.store(true, Ordering::Release); } } diff --git a/framework/aster-frame/src/task/processor.rs b/framework/aster-frame/src/task/processor.rs index db972c1c1..adc6b1ed8 100644 --- a/framework/aster-frame/src/task/processor.rs +++ b/framework/aster-frame/src/task/processor.rs @@ -91,26 +91,33 @@ fn switch_to_task(next_task: Arc) { "Calling schedule() while holding {} locks", PREEMPT_COUNT.num_locks() ); - //GLOBAL_SCHEDULER.lock_irq_disabled().enqueue(next_task); - //return; } - let current_task_option = current_task(); - let next_task_cx_ptr = &next_task.inner_ctx() as *const TaskContext; - let current_task: Arc; - let current_task_cx_ptr = match current_task_option { + + let current_task_cx_ptr = match current_task() { None => PROCESSOR.lock().get_idle_task_cx_ptr(), Some(current_task) => { - if current_task.status() == TaskStatus::Runnable { - GLOBAL_SCHEDULER - .lock_irq_disabled() - .enqueue(current_task.clone()); + let mut task = current_task.inner_exclusive_access(); + + // FIXME: `task.ctx` should be put in a separate `UnsafeCell`, not as a part of + // `TaskInner`. Otherwise, it violates the sematics of `SpinLock` and Rust's memory + // model which requires that mutable references must be exclusive. + let cx_ptr = &mut task.ctx as *mut TaskContext; + + debug_assert_ne!(task.task_status, TaskStatus::Sleeping); + if task.task_status == TaskStatus::Runnable { + drop(task); + GLOBAL_SCHEDULER.lock_irq_disabled().enqueue(current_task); + } else if task.task_status == TaskStatus::Sleepy { + task.task_status = TaskStatus::Sleeping; } - &mut current_task.inner_exclusive_access().ctx as *mut TaskContext + + cx_ptr } }; - // change the current task to the next task + let next_task_cx_ptr = &next_task.inner_ctx() as *const TaskContext; + // change the current task to the next task PROCESSOR.lock().current = Some(next_task.clone()); unsafe { context_switch(current_task_cx_ptr, next_task_cx_ptr); diff --git a/framework/aster-frame/src/task/task.rs b/framework/aster-frame/src/task/task.rs index 660d2513d..1171858ce 100644 --- a/framework/aster-frame/src/task/task.rs +++ b/framework/aster-frame/src/task/task.rs @@ -190,7 +190,9 @@ impl Task { pub enum TaskStatus { /// The task is runnable. Runnable, - /// The task is sleeping. + /// The task is running in the foreground but will sleep when it goes to the background. + Sleepy, + /// The task is sleeping in the background. Sleeping, /// The task has exited. Exited,