From de69fd6c313d5a3da887724e9b985304ab961595 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Fri, 11 Apr 2025 10:11:08 +0800 Subject: [PATCH] Add support for `Rcu` to store an `Either` --- ostd/src/lib.rs | 2 +- ostd/src/sync/rcu/non_null/either.rs | 162 ++++++++++++++++++ .../sync/rcu/{non_null.rs => non_null/mod.rs} | 19 +- ostd/src/util/either.rs | 44 +++++ ostd/src/util/mod.rs | 11 +- 5 files changed, 233 insertions(+), 5 deletions(-) create mode 100644 ostd/src/sync/rcu/non_null/either.rs rename ostd/src/sync/rcu/{non_null.rs => non_null/mod.rs} (92%) create mode 100644 ostd/src/util/either.rs diff --git a/ostd/src/lib.rs b/ostd/src/lib.rs index 8727d1ef..a4affc65 100644 --- a/ostd/src/lib.rs +++ b/ostd/src/lib.rs @@ -43,7 +43,7 @@ pub mod task; pub mod timer; pub mod trap; pub mod user; -pub(crate) mod util; +pub mod util; use core::sync::atomic::{AtomicBool, Ordering}; diff --git a/ostd/src/sync/rcu/non_null/either.rs b/ostd/src/sync/rcu/non_null/either.rs new file mode 100644 index 00000000..bb0326cd --- /dev/null +++ b/ostd/src/sync/rcu/non_null/either.rs @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::{marker::PhantomData, ptr::NonNull}; + +use super::NonNullPtr; +use crate::util::Either; + +// If both `L` and `R` have at least one alignment bit (i.e., their alignments are at least 2), we +// can use the alignment bit to indicate whether a pointer is `L` or `R`, so it's possible to +// implement `NonNullPtr` for `Either`. +unsafe impl NonNullPtr for Either { + type Target = PhantomData; + + type Ref<'a> + = Either, R::Ref<'a>> + where + Self: 'a; + + const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS) + .checked_sub(1) + .expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer"); + + fn into_raw(self) -> NonNull { + match self { + Self::Left(left) => left.into_raw().cast(), + Self::Right(right) => right + .into_raw() + .map_addr(|addr| addr | (1 << Self::ALIGN_BITS)) + .cast(), + } + } + + unsafe fn from_raw(ptr: NonNull) -> Self { + // SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which + // guarantees that `real_ptr` is a non-null pointer. + let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) }; + + if is_right == 0 { + // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other + // safety requirements are upheld by the caller. + Either::Left(unsafe { L::from_raw(real_ptr.cast()) }) + } else { + // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other + // safety requirements are upheld by the caller. + Either::Right(unsafe { R::from_raw(real_ptr.cast()) }) + } + } + + unsafe fn raw_as_ref<'a>(raw: NonNull) -> Self::Ref<'a> { + // SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which + // guarantees that `real_ptr` is a non-null pointer. + let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) }; + + if is_right == 0 { + // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other + // safety requirements are upheld by the caller. + Either::Left(unsafe { L::raw_as_ref(real_ptr.cast()) }) + } else { + // SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other + // safety requirements are upheld by the caller. + Either::Right(unsafe { R::raw_as_ref(real_ptr.cast()) }) + } + } + + fn ref_as_raw(ptr_ref: Self::Ref<'_>) -> NonNull { + match ptr_ref { + Either::Left(left) => L::ref_as_raw(left).cast(), + Either::Right(right) => R::ref_as_raw(right) + .map_addr(|addr| addr | (1 << Self::ALIGN_BITS)) + .cast(), + } + } +} + +// A `min` implementation for use in constant evaluation. +const fn min(a: u32, b: u32) -> u32 { + if a < b { + a + } else { + b + } +} + +/// # Safety +/// +/// The caller must ensure that removing the bits from the non-null pointer will result in another +/// non-null pointer. +unsafe fn remove_bits(ptr: NonNull, bits: usize) -> (usize, NonNull) { + use core::num::NonZeroUsize; + + let removed_bits = ptr.addr().get() & bits; + let result_ptr = ptr.map_addr(|addr| + // SAFETY: The safety is upheld by the caller. + unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) }); + + (removed_bits, result_ptr) +} + +#[cfg(ktest)] +mod test { + use alloc::{boxed::Box, sync::Arc}; + + use super::*; + use crate::{prelude::ktest, sync::RcuOption}; + + type Either32 = Either, Box>; + type Either16 = Either, Box>; + + #[ktest] + fn alignment() { + assert_eq!(::ALIGN_BITS, 1); + assert_eq!(::ALIGN_BITS, 0); + } + + #[ktest] + fn left_pointer() { + let val: Either16 = Either::Left(Arc::new(123)); + + let ptr = NonNullPtr::into_raw(val); + assert_eq!(ptr.addr().get() & 1, 0); + + let ref_ = unsafe { ::raw_as_ref(ptr) }; + assert!(matches!(ref_, Either::Left(ref r) if ***r == 123)); + + let ptr2 = ::ref_as_raw(ref_); + assert_eq!(ptr, ptr2); + + let val = unsafe { ::from_raw(ptr) }; + assert!(matches!(val, Either::Left(ref r) if **r == 123)); + drop(val); + } + + #[ktest] + fn right_pointer() { + let val: Either16 = Either::Right(Box::new(456)); + + let ptr = NonNullPtr::into_raw(val); + assert_eq!(ptr.addr().get() & 1, 1); + + let ref_ = unsafe { ::raw_as_ref(ptr) }; + assert!(matches!(ref_, Either::Right(ref r) if ***r == 456)); + + let ptr2 = ::ref_as_raw(ref_); + assert_eq!(ptr, ptr2); + + let val = unsafe { ::from_raw(ptr) }; + assert!(matches!(val, Either::Right(ref r) if **r == 456)); + drop(val); + } + + #[ktest] + fn rcu_store_load() { + let rcu: RcuOption = RcuOption::new_none(); + assert!(rcu.read().get().is_none()); + + rcu.update(Some(Either::Left(Arc::new(888)))); + assert!(matches!(rcu.read().get().unwrap(), Either::Left(r) if **r == 888)); + + rcu.update(Some(Either::Right(Box::new(999)))); + assert!(matches!(rcu.read().get().unwrap(), Either::Right(r) if **r == 999)); + } +} diff --git a/ostd/src/sync/rcu/non_null.rs b/ostd/src/sync/rcu/non_null/mod.rs similarity index 92% rename from ostd/src/sync/rcu/non_null.rs rename to ostd/src/sync/rcu/non_null/mod.rs index 22a2c04b..7ed0d773 100644 --- a/ostd/src/sync/rcu/non_null.rs +++ b/ostd/src/sync/rcu/non_null/mod.rs @@ -3,6 +3,8 @@ //! This module provides a trait and some auxiliary types to help abstract and //! work with non-null pointers. +mod either; + use alloc::sync::Weak; use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref, ptr::NonNull}; @@ -27,14 +29,21 @@ pub unsafe trait NonNullPtr: Send + 'static { type Target; /// A type that behaves just like a shared reference to the `NonNullPtr`. - type Ref<'a>: Deref + type Ref<'a> where Self: 'a; + /// The power of two of the pointer alignment. + const ALIGN_BITS: u32; + /// Converts to a raw pointer. /// /// Each call to `into_raw` must be paired with a call to `from_raw` /// in order to avoid memory leakage. + /// + /// The lower [`Self::ALIGN_BITS`] of the raw pointer is guaranteed to + /// be zero. In other words, the pointer is guaranteed to be aligned to + /// `1 << Self::ALIGN_BITS`. fn into_raw(self) -> NonNull; /// Converts back from a raw pointer. @@ -101,6 +110,8 @@ unsafe impl NonNullPtr for Box { where Self: 'a; + const ALIGN_BITS: u32 = core::mem::align_of::().trailing_zeros(); + fn into_raw(self) -> NonNull { let ptr = Box::into_raw(self); @@ -161,6 +172,8 @@ unsafe impl NonNullPtr for Arc { where Self: 'a; + const ALIGN_BITS: u32 = core::mem::align_of::().trailing_zeros(); + fn into_raw(self) -> NonNull { let ptr = Arc::into_raw(self).cast_mut(); @@ -213,6 +226,10 @@ unsafe impl NonNullPtr for Weak { where Self: 'a; + // The alignment of `Weak` is 1 instead of `align_of::()`. + // This is because `Weak::new()` uses a dangling pointer that is _not_ aligned. + const ALIGN_BITS: u32 = 0; + fn into_raw(self) -> NonNull { let ptr = Weak::into_raw(self).cast_mut(); diff --git a/ostd/src/util/either.rs b/ostd/src/util/either.rs new file mode 100644 index 00000000..7afd0d7e --- /dev/null +++ b/ostd/src/util/either.rs @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MPL-2.0 + +/// A type containing either a [`Left`] value `L` or a [`Right`] value `R`. +/// +/// [`Left`]: Self::Left +/// [`Right`]: Self::Right +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Either { + /// Contains the left value + Left(L), + /// Contains the right value + Right(R), +} + +impl Either { + /// Converts to the left value, if any. + pub fn left(self) -> Option { + match self { + Self::Left(left) => Some(left), + Self::Right(_) => None, + } + } + + /// Converts to the right value, if any. + pub fn right(self) -> Option { + match self { + Self::Left(_) => None, + Self::Right(right) => Some(right), + } + } + + /// Returns true if the left value is present. + pub fn is_left(&self) -> bool { + matches!(self, Self::Left(_)) + } + + /// Returns true if the right value is present. + pub fn is_right(&self) -> bool { + matches!(self, Self::Right(_)) + } + + // TODO: Add other utility methods (e.g. `as_ref`, `as_mut`) as needed. + // As a good reference, check what methods `Result` provides. +} diff --git a/ostd/src/util/mod.rs b/ostd/src/util/mod.rs index bbda5005..bea61104 100644 --- a/ostd/src/util/mod.rs +++ b/ostd/src/util/mod.rs @@ -1,6 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 +//! Utility types and methods. + +mod either; mod macros; -pub mod marker; -pub mod ops; -pub mod range_alloc; +pub(crate) mod marker; +pub(crate) mod ops; +pub(crate) mod range_alloc; + +pub use either::Either;