Add support for Rcu to store an Either

This commit is contained in:
Ruihan Li 2025-04-11 10:11:08 +08:00 committed by Tate, Hongliang Tian
parent a1f81df263
commit de69fd6c31
5 changed files with 233 additions and 5 deletions

View File

@ -43,7 +43,7 @@ pub mod task;
pub mod timer;
pub mod trap;
pub mod user;
pub(crate) mod util;
pub mod util;
use core::sync::atomic::{AtomicBool, Ordering};

View File

@ -0,0 +1,162 @@
// SPDX-License-Identifier: MPL-2.0
use core::{marker::PhantomData, ptr::NonNull};
use super::NonNullPtr;
use crate::util::Either;
// If both `L` and `R` have at least one alignment bit (i.e., their alignments are at least 2), we
// can use the alignment bit to indicate whether a pointer is `L` or `R`, so it's possible to
// implement `NonNullPtr` for `Either<L, R>`.
unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
type Target = PhantomData<Self>;
type Ref<'a>
= Either<L::Ref<'a>, R::Ref<'a>>
where
Self: 'a;
const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS)
.checked_sub(1)
.expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer");
fn into_raw(self) -> NonNull<Self::Target> {
match self {
Self::Left(left) => left.into_raw().cast(),
Self::Right(right) => right
.into_raw()
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
.cast(),
}
}
unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self {
// SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
// guarantees that `real_ptr` is a non-null pointer.
let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
if is_right == 0 {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
// safety requirements are upheld by the caller.
Either::Left(unsafe { L::from_raw(real_ptr.cast()) })
} else {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
// safety requirements are upheld by the caller.
Either::Right(unsafe { R::from_raw(real_ptr.cast()) })
}
}
unsafe fn raw_as_ref<'a>(raw: NonNull<Self::Target>) -> Self::Ref<'a> {
// SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
// guarantees that `real_ptr` is a non-null pointer.
let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
if is_right == 0 {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
// safety requirements are upheld by the caller.
Either::Left(unsafe { L::raw_as_ref(real_ptr.cast()) })
} else {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
// safety requirements are upheld by the caller.
Either::Right(unsafe { R::raw_as_ref(real_ptr.cast()) })
}
}
fn ref_as_raw(ptr_ref: Self::Ref<'_>) -> NonNull<Self::Target> {
match ptr_ref {
Either::Left(left) => L::ref_as_raw(left).cast(),
Either::Right(right) => R::ref_as_raw(right)
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
.cast(),
}
}
}
// A `min` implementation for use in constant evaluation.
const fn min(a: u32, b: u32) -> u32 {
if a < b {
a
} else {
b
}
}
/// # Safety
///
/// The caller must ensure that removing the bits from the non-null pointer will result in another
/// non-null pointer.
unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
use core::num::NonZeroUsize;
let removed_bits = ptr.addr().get() & bits;
let result_ptr = ptr.map_addr(|addr|
// SAFETY: The safety is upheld by the caller.
unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) });
(removed_bits, result_ptr)
}
#[cfg(ktest)]
mod test {
use alloc::{boxed::Box, sync::Arc};
use super::*;
use crate::{prelude::ktest, sync::RcuOption};
type Either32 = Either<Arc<u32>, Box<u32>>;
type Either16 = Either<Arc<u32>, Box<u16>>;
#[ktest]
fn alignment() {
assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
}
#[ktest]
fn left_pointer() {
let val: Either16 = Either::Left(Arc::new(123));
let ptr = NonNullPtr::into_raw(val);
assert_eq!(ptr.addr().get() & 1, 0);
let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
assert!(matches!(ref_, Either::Left(ref r) if ***r == 123));
let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
assert_eq!(ptr, ptr2);
let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
assert!(matches!(val, Either::Left(ref r) if **r == 123));
drop(val);
}
#[ktest]
fn right_pointer() {
let val: Either16 = Either::Right(Box::new(456));
let ptr = NonNullPtr::into_raw(val);
assert_eq!(ptr.addr().get() & 1, 1);
let ref_ = unsafe { <Either16 as NonNullPtr>::raw_as_ref(ptr) };
assert!(matches!(ref_, Either::Right(ref r) if ***r == 456));
let ptr2 = <Either16 as NonNullPtr>::ref_as_raw(ref_);
assert_eq!(ptr, ptr2);
let val = unsafe { <Either16 as NonNullPtr>::from_raw(ptr) };
assert!(matches!(val, Either::Right(ref r) if **r == 456));
drop(val);
}
#[ktest]
fn rcu_store_load() {
let rcu: RcuOption<Either32> = RcuOption::new_none();
assert!(rcu.read().get().is_none());
rcu.update(Some(Either::Left(Arc::new(888))));
assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888));
rcu.update(Some(Either::Right(Box::new(999))));
assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999));
}
}

View File

@ -3,6 +3,8 @@
//! This module provides a trait and some auxiliary types to help abstract and
//! work with non-null pointers.
mod either;
use alloc::sync::Weak;
use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref, ptr::NonNull};
@ -27,14 +29,21 @@ pub unsafe trait NonNullPtr: Send + 'static {
type Target;
/// A type that behaves just like a shared reference to the `NonNullPtr`.
type Ref<'a>: Deref<Target = Self>
type Ref<'a>
where
Self: 'a;
/// The power of two of the pointer alignment.
const ALIGN_BITS: u32;
/// Converts to a raw pointer.
///
/// Each call to `into_raw` must be paired with a call to `from_raw`
/// in order to avoid memory leakage.
///
/// The lower [`Self::ALIGN_BITS`] of the raw pointer is guaranteed to
/// be zero. In other words, the pointer is guaranteed to be aligned to
/// `1 << Self::ALIGN_BITS`.
fn into_raw(self) -> NonNull<Self::Target>;
/// Converts back from a raw pointer.
@ -101,6 +110,8 @@ unsafe impl<T: Send + 'static> NonNullPtr for Box<T> {
where
Self: 'a;
const ALIGN_BITS: u32 = core::mem::align_of::<T>().trailing_zeros();
fn into_raw(self) -> NonNull<Self::Target> {
let ptr = Box::into_raw(self);
@ -161,6 +172,8 @@ unsafe impl<T: Send + Sync + 'static> NonNullPtr for Arc<T> {
where
Self: 'a;
const ALIGN_BITS: u32 = core::mem::align_of::<T>().trailing_zeros();
fn into_raw(self) -> NonNull<Self::Target> {
let ptr = Arc::into_raw(self).cast_mut();
@ -213,6 +226,10 @@ unsafe impl<T: Send + Sync + 'static> NonNullPtr for Weak<T> {
where
Self: 'a;
// The alignment of `Weak<T>` is 1 instead of `align_of::<T>()`.
// This is because `Weak::new()` uses a dangling pointer that is _not_ aligned.
const ALIGN_BITS: u32 = 0;
fn into_raw(self) -> NonNull<Self::Target> {
let ptr = Weak::into_raw(self).cast_mut();

44
ostd/src/util/either.rs Normal file
View File

@ -0,0 +1,44 @@
// SPDX-License-Identifier: MPL-2.0
/// A type containing either a [`Left`] value `L` or a [`Right`] value `R`.
///
/// [`Left`]: Self::Left
/// [`Right`]: Self::Right
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Either<L, R> {
/// Contains the left value
Left(L),
/// Contains the right value
Right(R),
}
impl<L, R> Either<L, R> {
/// Converts to the left value, if any.
pub fn left(self) -> Option<L> {
match self {
Self::Left(left) => Some(left),
Self::Right(_) => None,
}
}
/// Converts to the right value, if any.
pub fn right(self) -> Option<R> {
match self {
Self::Left(_) => None,
Self::Right(right) => Some(right),
}
}
/// Returns true if the left value is present.
pub fn is_left(&self) -> bool {
matches!(self, Self::Left(_))
}
/// Returns true if the right value is present.
pub fn is_right(&self) -> bool {
matches!(self, Self::Right(_))
}
// TODO: Add other utility methods (e.g. `as_ref`, `as_mut`) as needed.
// As a good reference, check what methods `Result` provides.
}

View File

@ -1,6 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
//! Utility types and methods.
mod either;
mod macros;
pub mod marker;
pub mod ops;
pub mod range_alloc;
pub(crate) mod marker;
pub(crate) mod ops;
pub(crate) mod range_alloc;
pub use either::Either;