273 lines
6.6 KiB
Rust

extern crate num_integer;
extern crate num_traits;
use num_integer::Roots;
use num_traits::checked_pow;
use num_traits::{AsPrimitive, PrimInt, Signed};
use std::f64::MANTISSA_DIGITS;
use std::fmt::Debug;
use std::mem;
trait TestInteger: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
impl<T> TestInteger for T where T: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
/// Check that each root is correct
///
/// If `x` is positive, check `rⁿ ≤ x < (r+1)ⁿ`.
/// If `x` is negative, check `(r-1)ⁿ < x ≤ rⁿ`.
fn check<T>(v: &[T], n: u32)
where
T: TestInteger,
{
for i in v {
let rt = i.nth_root(n);
// println!("nth_root({:?}, {}) = {:?}", i, n, rt);
if n == 2 {
assert_eq!(rt, i.sqrt());
} else if n == 3 {
assert_eq!(rt, i.cbrt());
}
if *i >= T::zero() {
let rt1 = rt + T::one();
assert!(rt.pow(n) <= *i);
if let Some(x) = checked_pow(rt1, n as usize) {
assert!(*i < x);
}
} else {
let rt1 = rt - T::one();
assert!(rt < T::zero());
assert!(*i <= rt.pow(n));
if let Some(x) = checked_pow(rt1, n as usize) {
assert!(x < *i);
}
};
}
}
/// Get the maximum value that will round down as `f64` (if any),
/// and its successor that will round up.
///
/// Important because the `std` implementations cast to `f64` to
/// get a close approximation of the roots.
fn mantissa_max<T>() -> Option<(T, T)>
where
T: TestInteger,
{
let bits = if T::min_value().is_zero() {
8 * mem::size_of::<T>()
} else {
8 * mem::size_of::<T>() - 1
};
if bits > MANTISSA_DIGITS as usize {
let rounding_bit = T::one() << (bits - MANTISSA_DIGITS as usize - 1);
let x = T::max_value() - rounding_bit;
let x1 = x + T::one();
let x2 = x1 + T::one();
assert!(x.as_() < x1.as_());
assert_eq!(x1.as_(), x2.as_());
Some((x, x1))
} else {
None
}
}
fn extend<T>(v: &mut Vec<T>, start: T, end: T)
where
T: TestInteger,
{
let mut i = start;
while i < end {
v.push(i);
i = i + T::one();
}
v.push(i);
}
fn extend_shl<T>(v: &mut Vec<T>, start: T, end: T, mask: T)
where
T: TestInteger,
{
let mut i = start;
while i != end {
v.push(i);
i = (i << 1) & mask;
}
}
fn extend_shr<T>(v: &mut Vec<T>, start: T, end: T)
where
T: TestInteger,
{
let mut i = start;
while i != end {
v.push(i);
i = i >> 1;
}
}
fn pos<T>() -> Vec<T>
where
T: TestInteger,
i8: AsPrimitive<T>,
{
let mut v: Vec<T> = vec![];
if mem::size_of::<T>() == 1 {
extend(&mut v, T::zero(), T::max_value());
} else {
extend(&mut v, T::zero(), i8::max_value().as_());
extend(
&mut v,
T::max_value() - i8::max_value().as_(),
T::max_value(),
);
if let Some((i, j)) = mantissa_max::<T>() {
v.push(i);
v.push(j);
}
extend_shl(&mut v, T::max_value(), T::zero(), !T::min_value());
extend_shr(&mut v, T::max_value(), T::zero());
}
v
}
fn neg<T>() -> Vec<T>
where
T: TestInteger + Signed,
i8: AsPrimitive<T>,
{
let mut v: Vec<T> = vec![];
if mem::size_of::<T>() <= 1 {
extend(&mut v, T::min_value(), T::zero());
} else {
extend(&mut v, i8::min_value().as_(), T::zero());
extend(
&mut v,
T::min_value(),
T::min_value() - i8::min_value().as_(),
);
if let Some((i, j)) = mantissa_max::<T>() {
v.push(-i);
v.push(-j);
}
extend_shl(&mut v, -T::one(), T::min_value(), !T::zero());
extend_shr(&mut v, T::min_value(), -T::one());
}
v
}
macro_rules! test_roots {
($I:ident, $U:ident) => {
mod $I {
use check;
use neg;
use num_integer::Roots;
use pos;
use std::mem;
#[test]
#[should_panic]
fn zeroth_root() {
(123 as $I).nth_root(0);
}
#[test]
fn sqrt() {
check(&pos::<$I>(), 2);
}
#[test]
#[should_panic]
fn sqrt_neg() {
(-123 as $I).sqrt();
}
#[test]
fn cbrt() {
check(&pos::<$I>(), 3);
}
#[test]
fn cbrt_neg() {
check(&neg::<$I>(), 3);
}
#[test]
fn nth_root() {
let bits = 8 * mem::size_of::<$I>() as u32 - 1;
let pos = pos::<$I>();
for n in 4..bits {
check(&pos, n);
}
}
#[test]
fn nth_root_neg() {
let bits = 8 * mem::size_of::<$I>() as u32 - 1;
let neg = neg::<$I>();
for n in 2..bits / 2 {
check(&neg, 2 * n + 1);
}
}
#[test]
fn bit_size() {
let bits = 8 * mem::size_of::<$I>() as u32 - 1;
assert_eq!($I::max_value().nth_root(bits - 1), 2);
assert_eq!($I::max_value().nth_root(bits), 1);
assert_eq!($I::min_value().nth_root(bits), -2);
assert_eq!(($I::min_value() + 1).nth_root(bits), -1);
}
}
mod $U {
use check;
use num_integer::Roots;
use pos;
use std::mem;
#[test]
#[should_panic]
fn zeroth_root() {
(123 as $U).nth_root(0);
}
#[test]
fn sqrt() {
check(&pos::<$U>(), 2);
}
#[test]
fn cbrt() {
check(&pos::<$U>(), 3);
}
#[test]
fn nth_root() {
let bits = 8 * mem::size_of::<$I>() as u32 - 1;
let pos = pos::<$I>();
for n in 4..bits {
check(&pos, n);
}
}
#[test]
fn bit_size() {
let bits = 8 * mem::size_of::<$U>() as u32;
assert_eq!($U::max_value().nth_root(bits - 1), 2);
assert_eq!($U::max_value().nth_root(bits), 1);
}
}
};
}
test_roots!(i8, u8);
test_roots!(i16, u16);
test_roots!(i32, u32);
test_roots!(i64, u64);
#[cfg(has_i128)]
test_roots!(i128, u128);
test_roots!(isize, usize);