Use UnsafeCell to store UserContext

This commit is contained in:
Ruihan Li
2024-05-09 01:16:54 +08:00
committed by Tate, Hongliang Tian
parent 5189f889a3
commit a215cb54d9
2 changed files with 25 additions and 19 deletions

View File

@ -97,12 +97,9 @@ fn switch_to_task(next_task: Arc<Task>) {
let current_task_cx_ptr = match current_task() { let current_task_cx_ptr = match current_task() {
None => PROCESSOR.lock().get_idle_task_cx_ptr(), None => PROCESSOR.lock().get_idle_task_cx_ptr(),
Some(current_task) => { Some(current_task) => {
let mut task = current_task.inner_exclusive_access(); let cx_ptr = current_task.ctx().get();
// FIXME: `task.ctx` should be put in a separate `UnsafeCell`, not as a part of let mut task = current_task.inner_exclusive_access();
// `TaskInner`. Otherwise, it violates the sematics of `SpinLock` and Rust's memory
// model which requires that mutable references must be exclusive.
let cx_ptr = &mut task.ctx as *mut TaskContext;
debug_assert_ne!(task.task_status, TaskStatus::Sleeping); debug_assert_ne!(task.task_status, TaskStatus::Sleeping);
if task.task_status == TaskStatus::Runnable { if task.task_status == TaskStatus::Runnable {
@ -116,10 +113,15 @@ fn switch_to_task(next_task: Arc<Task>) {
} }
}; };
let next_task_cx_ptr = &next_task.inner_ctx() as *const TaskContext; let next_task_cx_ptr = next_task.ctx().get().cast_const();
// change the current task to the next task // change the current task to the next task
PROCESSOR.lock().current = Some(next_task.clone()); PROCESSOR.lock().current = Some(next_task.clone());
// SAFETY:
// 1. `ctx` is only used in `schedule()`. We have exclusive access to both the current task
// context and the next task context.
// 2. The next task context is a valid task context.
unsafe { unsafe {
context_switch(current_task_cx_ptr, next_task_cx_ptr); context_switch(current_task_cx_ptr, next_task_cx_ptr);
} }

View File

@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::cell::UnsafeCell;
use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink}; use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink};
use super::{ use super::{
@ -116,21 +118,24 @@ pub struct Task {
data: Box<dyn Any + Send + Sync>, data: Box<dyn Any + Send + Sync>,
user_space: Option<Arc<UserSpace>>, user_space: Option<Arc<UserSpace>>,
task_inner: SpinLock<TaskInner>, task_inner: SpinLock<TaskInner>,
exit_code: usize, ctx: UnsafeCell<TaskContext>,
/// kernel stack, note that the top is SyscallFrame/TrapFrame /// kernel stack, note that the top is SyscallFrame/TrapFrame
kstack: KernelStack, kstack: KernelStack,
link: LinkedListAtomicLink, link: LinkedListAtomicLink,
priority: Priority, priority: Priority,
// TODO:: add multiprocessor support // TODO: add multiprocessor support
cpu_affinity: CpuSet, cpu_affinity: CpuSet,
} }
// TaskAdapter struct is implemented for building relationships between doubly linked list and Task struct // TaskAdapter struct is implemented for building relationships between doubly linked list and Task struct
intrusive_adapter!(pub TaskAdapter = Arc<Task>: Task { link: LinkedListAtomicLink }); intrusive_adapter!(pub TaskAdapter = Arc<Task>: Task { link: LinkedListAtomicLink });
// SAFETY: `UnsafeCell<TaskContext>` is not `Sync`. However, we only use it in `schedule()` where
// we have exclusive access to the field.
unsafe impl Sync for Task {}
pub(crate) struct TaskInner { pub(crate) struct TaskInner {
pub task_status: TaskStatus, pub task_status: TaskStatus,
pub ctx: TaskContext,
} }
impl Task { impl Task {
@ -144,9 +149,8 @@ impl Task {
self.task_inner.lock_irq_disabled() self.task_inner.lock_irq_disabled()
} }
/// get inner pub(super) fn ctx(&self) -> &UnsafeCell<TaskContext> {
pub(crate) fn inner_ctx(&self) -> TaskContext { &self.ctx
self.task_inner.lock_irq_disabled().ctx
} }
/// Yields execution so that another task may be scheduled. /// Yields execution so that another task may be scheduled.
@ -273,32 +277,32 @@ impl TaskOptions {
current_task.func.call(()); current_task.func.call(());
current_task.exit(); current_task.exit();
} }
let result = Task {
let mut result = Task {
func: self.func.unwrap(), func: self.func.unwrap(),
data: self.data.unwrap(), data: self.data.unwrap(),
user_space: self.user_space, user_space: self.user_space,
task_inner: SpinLock::new(TaskInner { task_inner: SpinLock::new(TaskInner {
task_status: TaskStatus::Runnable, task_status: TaskStatus::Runnable,
ctx: TaskContext::default(),
}), }),
exit_code: 0, ctx: UnsafeCell::new(TaskContext::default()),
kstack: KernelStack::new_with_guard_page()?, kstack: KernelStack::new_with_guard_page()?,
link: LinkedListAtomicLink::new(), link: LinkedListAtomicLink::new(),
priority: self.priority, priority: self.priority,
cpu_affinity: self.cpu_affinity, cpu_affinity: self.cpu_affinity,
}; };
result.task_inner.lock().task_status = TaskStatus::Runnable; let ctx = result.ctx.get_mut();
result.task_inner.lock().ctx.rip = kernel_task_entry as usize; ctx.rip = kernel_task_entry as usize;
// We should reserve space for the return address in the stack, otherwise // We should reserve space for the return address in the stack, otherwise
// we will write across the page boundary due to the implementation of // we will write across the page boundary due to the implementation of
// the context switch. // the context switch.
//
// According to the System V AMD64 ABI, the stack pointer should be aligned // According to the System V AMD64 ABI, the stack pointer should be aligned
// to at least 16 bytes. And a larger alignment is needed if larger arguments // to at least 16 bytes. And a larger alignment is needed if larger arguments
// are passed to the function. The `kernel_task_entry` function does not // are passed to the function. The `kernel_task_entry` function does not
// have any arguments, so we only need to align the stack pointer to 16 bytes. // have any arguments, so we only need to align the stack pointer to 16 bytes.
result.task_inner.lock().ctx.regs.rsp = ctx.regs.rsp = (crate::vm::paddr_to_vaddr(result.kstack.end_paddr() - 16)) as u64;
(crate::vm::paddr_to_vaddr(result.kstack.end_paddr() - 16)) as u64;
Ok(Arc::new(result)) Ok(Arc::new(result))
} }