Files
asterinas/kernel/libs/aster-util/src/coeff.rs
2024-08-23 23:37:50 +08:00

142 lines
4.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// SPDX-License-Identifier: MPL-2.0
//! This module provides an abstraction `Coeff` to server for efficient and accurate calculation
//! of fraction multiplication.
use core::ops::Mul;
/// A `Coeff` is used to do a fraction multiplication operation with an unsigned integer.
/// It can achieve accurate and efficient calculation and avoid numeric overflow at the same time.
///
/// # Example
///
/// Let's say we want to multiply a fraction (23456 / 56789) with the target integer `a`,
/// which will be no larger than `1_000_000_000`, we can use the following code snippets
/// to get an accurate result.
///
/// ```
/// let a = input();
/// let coeff = Coeff::new(23456, 56789, 1_000_000_000);
/// let result = coeff * a;
/// ```
///
/// # How it works
/// `Coeff` is used in the calculation of a fraction value multiplied by an integer.
/// Here is a simple example of such calculation:
///
/// ```rust
/// let result = (a / b) * c;
/// ```
///
/// In this equation, `a`, `b`, `c` and `result` are all integers. To acquire a more precise result, we will
/// generally calculate `a * c` first and then divide the multiplication result with `b`.
/// However, this simple calculation above has two complications:
/// - The calculation of `a * c` may overflow if they are too large.
/// - The division operation is much more expensive than integer multiplication, which can easily create performance bottlenecks.
///
/// `Coeff` is implemented to address these two issues. It can be used to replace the fraction in this calculation.
/// For example, a `Coeff` generated from (a / b) can modify the calculation above to ensure that:
///
/// ```
/// coeff * c ~= (a / b) * c
/// ```
///
/// In principle, `Coeff` actually turns the multiplication and division into a combination of multiplication and bit operation.
/// When creating a `Coeff`, it needs to know the numerator and denominator of the represented fraction
/// and the max multiplier it will be multiplied by. Then, a `mult` and a `shift` will be chosen to achieve the replacement of calculation.
/// Taking the previous calculation as an example again, `coeff * c` will turn into `mult * c >> shift`, ensuring that:
///
/// ```
/// mult * c >> shift ~= (a / b) * c
/// ```
///
/// and
///
/// `mult * c` will not result in numeric overflow (i.e., `mult * c` will stay below MAX_U64).
///
/// This is how `Coeff` achieves accuracy and efficiency at the same time.
#[derive(Debug, Copy, Clone)]
pub struct Coeff {
mult: u32,
shift: u32,
max_multiplier: u64,
}
impl Coeff {
/// Create a new coeff, which is essentially equivalent to `numerator` / `denominator`) when being multiplied to an integer;
/// Here users should make sure the multiplied integer should not be larger than `max_multiplier`.
pub fn new(numerator: u64, denominator: u64, max_multiplier: u64) -> Self {
let mut shift_acc: u32 = 32;
// Too large `max_multiplier` will make the generated coeff imprecise
debug_assert!(max_multiplier < (1 << 40));
let mut tmp = max_multiplier >> 32;
// Counts the number of 0 in front of the `max_multiplier`.
// `shift_acc` indicates the maximum number of bits `mult` can have.
while tmp > 0 {
tmp >>= 1;
shift_acc -= 1;
}
// Try the `shift` from 32 to 0.
let mut shift = 32;
let mut mult = 0;
while shift > 0 {
mult = numerator << shift;
mult += denominator / 2;
mult /= denominator;
if (mult >> shift_acc) == 0 {
break;
}
shift -= 1;
}
Self {
mult: mult as u32,
shift,
max_multiplier,
}
}
/// Return the `mult` of the Coeff.
/// Only used for the VdsoData and will be removed in the future.
pub fn mult(&self) -> u32 {
self.mult
}
/// Return the `shift` of the Coeff.
/// Only used for the VdsoData and will be removed in the future.
pub fn shift(&self) -> u32 {
self.shift
}
}
impl Mul<u64> for Coeff {
type Output = u64;
fn mul(self, rhs: u64) -> Self::Output {
debug_assert!(rhs <= self.max_multiplier);
(rhs * self.mult as u64) >> self.shift
}
}
impl Mul<u32> for Coeff {
type Output = u32;
fn mul(self, rhs: u32) -> Self::Output {
debug_assert!(rhs as u64 <= self.max_multiplier);
((rhs as u64 * self.mult as u64) >> self.shift) as u32
}
}
#[cfg(ktest)]
mod test {
use ostd::prelude::*;
use super::*;
#[ktest]
fn calculation() {
let coeff = Coeff::new(23456, 56789, 1_000_000_000);
assert!(coeff * 0_u64 == 0);
assert!(coeff * 100_u64 == 100 * 23456 / 56789);
assert!(coeff * 1_000_000_000_u64 == 1_000_000_000 * 23456 / 56789);
}
}