diff --git a/Cargo.toml b/Cargo.toml index fe1632e..be1ac90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,12 +6,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -generic_parameterize = "0.1.0" +generic_parameterize = "0.2.0" itertools = "0.10.5" num-traits = "0.2.15" +[dev-dependencies] +impls = "1.0.3" + [package.metadata.docs.rs] rustdoc-args = [ "--html-in-header", "doc/mathjax.html", -] \ No newline at end of file +] diff --git a/src/ops.rs b/src/ops.rs index 4f02c45..59cc9d1 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -5,118 +5,88 @@ use num_traits::Num; #[doc(hidden)] macro_rules! impl_matrix_op { (neg) => { - _impl_op_m_internal_ex!(Neg, neg); + _impl_op_unary_ex!(Neg::neg); }; (!) => { - _impl_op_m_internal_ex!(Not, not); + _impl_op_unary_ex!(Not::not); }; (+) => { - _impl_op_mm_internal_ex!(Add, add); - _impl_opassign_mm_internal_ex!(AddAssign, add_assign); + _impl_op_binary_ex!(Add::add, AddAssign::add_assign); }; (-) => { - _impl_op_mm_internal_ex!(Sub, sub); - _impl_opassign_mm_internal_ex!(SubAssign, sub_assign); + _impl_op_binary_ex!(Sub::sub, SubAssign::sub_assign); }; (*) => { - _impl_op_mm_internal_ex!(Mul, mul); - _impl_op_ms_internal_ex!(Mul, mul); - _impl_opassign_mm_internal_ex!(MulAssign, mul_assign); - _impl_opassign_ms_internal_ex!(MulAssign, mul_assign); + _impl_op_binary_ex!(Mul::mul, MulAssign::mul_assign); }; (/) => { - _impl_op_mm_internal_ex!(Div, div); - _impl_op_ms_internal_ex!(Div, div); - _impl_opassign_mm_internal_ex!(DivAssign, div_assign); - _impl_opassign_ms_internal_ex!(DivAssign, div_assign); + _impl_op_binary_ex!(Div::div, DivAssign::div_assign); }; (%) => { - _impl_op_mm_internal_ex!(Rem, rem); - _impl_op_ms_internal_ex!(Rem, rem); - _impl_opassign_mm_internal_ex!(RemAssign, rem_assign); - _impl_opassign_ms_internal_ex!(RemAssign, rem_assign); + _impl_op_binary_ex!(Rem::rem, RemAssign::rem_assign); }; (&) => { - _impl_op_mm_internal_ex!(BitAnd, bitand); - _impl_opassign_mm_internal_ex!(BitAndAssign, bitand_assign); + _impl_op_binary_ex!(BitAnd::bitand, BitAndAssign::bitand_assign); }; (|) => { - _impl_op_mm_internal_ex!(BitOr, bitor); - _impl_opassign_mm_internal_ex!(BitOrAssign, bitor_assign); + _impl_op_binary_ex!(BitOr::bitor, BitOrAssign::bitor_assign); }; (^) => { - _impl_op_mm_internal_ex!(BitXor, bitxor); - _impl_opassign_mm_internal_ex!(BitXorAssign, bitxor_assign); + _impl_op_binary_ex!(BitXor::bitxor, BitXorAssign::bitxor_assign); }; (<<) => { - _impl_op_ms_internal_ex!(Shl, shl); - _impl_opassign_ms_internal_ex!(ShlAssign, shl_assign); + _impl_op_binary_ex!(Shl::shl, ShlAssign::shl_assign); }; (>>) => { - _impl_op_ms_internal_ex!(Shr, shr); - _impl_opassign_ms_internal_ex!(ShrAssign, shr_assign); + _impl_op_binary_ex!(Shr::shr, ShrAssign::shr_assign); }; } #[doc(hidden)] #[macro_export] -macro_rules! _impl_op_m_internal_ex { - ($ops_trait:ident, $ops_fn:ident) => { - _impl_op_m_internal!($ops_trait, $ops_fn, Matrix, Matrix); - _impl_op_m_internal!($ops_trait, $ops_fn, &Matrix, Matrix); +macro_rules! _impl_op_unary_ex { + ($op_trait:ident::$op_fn:ident) => { + _impl_op_m_internal!($op_trait, $op_fn, Matrix, Matrix); + _impl_op_m_internal!($op_trait, $op_fn, &Matrix, Matrix); } } #[doc(hidden)] #[macro_export] -macro_rules! _impl_op_mm_internal_ex { - ($ops_trait:ident, $ops_fn:ident) => { - _impl_op_mm_internal!($ops_trait, $ops_fn, Matrix, Matrix, Matrix); - _impl_op_mm_internal!($ops_trait, $ops_fn, &Matrix, Matrix, Matrix); - _impl_op_mm_internal!($ops_trait, $ops_fn, Matrix, &Matrix, Matrix); - _impl_op_mm_internal!($ops_trait, $ops_fn, &Matrix, &Matrix, Matrix); - } -} +macro_rules! _impl_op_binary_ex { + ($op_trait:ident::$op_fn:ident, $op_assign_trait:ident::$op_assign_fn:ident) => { + _impl_op_mm_internal!($op_trait, $op_fn, Matrix, Matrix, Matrix); + _impl_op_mm_internal!($op_trait, $op_fn, &Matrix, Matrix, Matrix); + _impl_op_mm_internal!($op_trait, $op_fn, Matrix, &Matrix, Matrix); + _impl_op_mm_internal!($op_trait, $op_fn, &Matrix, &Matrix, Matrix); -#[doc(hidden)] -#[macro_export] -macro_rules! _impl_opassign_mm_internal_ex { - ($ops_trait:ident, $ops_fn:ident) => { - _impl_opassign_mm_internal!($ops_trait, $ops_fn, Matrix, Matrix, Matrix); - _impl_opassign_mm_internal!($ops_trait, $ops_fn, Matrix, &Matrix, Matrix); - } -} + _impl_op_ms_internal!($op_trait, $op_fn, Matrix, R, Matrix); + _impl_op_ms_internal!($op_trait, $op_fn, &Matrix, R, Matrix); -#[doc(hidden)] -macro_rules! _impl_op_ms_internal_ex { - ($ops_trait:ident, $ops_fn:ident) => { - _impl_op_ms_internal!($ops_trait, $ops_fn, Matrix, R, Matrix); - _impl_op_ms_internal!($ops_trait, $ops_fn, &Matrix, R, Matrix); - } -} + _impl_opassign_mm_internal!($op_assign_trait, $op_assign_fn, Matrix, Matrix, Matrix); + _impl_opassign_mm_internal!($op_assign_trait, $op_assign_fn, Matrix, &Matrix, Matrix); + + _impl_opassign_ms_internal!($op_assign_trait, $op_assign_fn, Matrix, R, Matrix); -#[doc(hidden)] -macro_rules! _impl_opassign_ms_internal_ex { - ($ops_trait:ident, $ops_fn:ident) => { - _impl_opassign_ms_internal!($ops_trait, $ops_fn, Matrix, R, Matrix); } } #[doc(hidden)] macro_rules! _impl_op_m_internal { - ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $out:ty) => { - impl ::std::ops::$ops_trait for $lhs + ($op_trait:ident, $op_fn:ident, $lhs:ty, $out:ty) => { + impl ::std::ops::$op_trait for $lhs where - L: ::std::ops::$ops_trait + Copy, + L: ::std::ops::$op_trait + Copy, { type Output = $out; #[inline(always)] - fn $ops_fn(self) -> Self::Output { + fn $op_fn(self) -> Self::Output { let mut result = self.clone(); + // we arnt using iterators because they dont seem to always vectorize correctly for m in 0..M { for n in 0..N { - result.data[m][n] = self.data[m][n].$ops_fn(); + result.data[m][n] = self.data[m][n].$op_fn(); } } result @@ -127,20 +97,20 @@ macro_rules! _impl_op_m_internal { #[doc(hidden)] macro_rules! _impl_op_mm_internal { - ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { - impl ::std::ops::$ops_trait<$rhs> for $lhs + ($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { + impl ::std::ops::$op_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Copy, + L: ::std::ops::$op_trait + Copy, R: Copy, { type Output = $out; #[inline(always)] - fn $ops_fn(self, other: $rhs) -> Self::Output { + fn $op_fn(self, other: $rhs) -> Self::Output { let mut result = self.clone(); for m in 0..M { for n in 0..N { - result.data[m][n] = self.data[m][n].$ops_fn(other.data[m][n]); + result.data[m][n] = self.data[m][n].$op_fn(other.data[m][n]); } } result @@ -151,17 +121,17 @@ macro_rules! _impl_op_mm_internal { #[doc(hidden)] macro_rules! _impl_opassign_mm_internal { - ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { - impl ::std::ops::$ops_trait<$rhs> for $lhs + ($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { + impl ::std::ops::$op_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Copy, + L: ::std::ops::$op_trait + Copy, R: Copy, { #[inline(always)] - fn $ops_fn(&mut self, other: $rhs) { + fn $op_fn(&mut self, other: $rhs) { for m in 0..M { for n in 0..N { - self.data[m][n].$ops_fn(other.data[m][n]); + self.data[m][n].$op_fn(other.data[m][n]); } } } @@ -171,20 +141,20 @@ macro_rules! _impl_opassign_mm_internal { #[doc(hidden)] macro_rules! _impl_op_ms_internal { - ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { - impl ::std::ops::$ops_trait<$rhs> for $lhs + ($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { + impl ::std::ops::$op_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Copy, + L: ::std::ops::$op_trait + Copy, R: Copy + Num, { type Output = $out; #[inline(always)] - fn $ops_fn(self, other: $rhs) -> Self::Output { + fn $op_fn(self, other: $rhs) -> Self::Output { let mut result = self.clone(); for m in 0..M { for n in 0..N { - result.data[m][n] = self.data[m][n].$ops_fn(other); + result.data[m][n] = self.data[m][n].$op_fn(other); } } result @@ -195,17 +165,17 @@ macro_rules! _impl_op_ms_internal { #[doc(hidden)] macro_rules! _impl_opassign_ms_internal { - ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { - impl ::std::ops::$ops_trait<$rhs> for $lhs + ($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { + impl ::std::ops::$op_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Copy, + L: ::std::ops::$op_trait + Copy, R: Copy + Num, { #[inline(always)] - fn $ops_fn(&mut self, r: $rhs) { + fn $op_fn(&mut self, r: $rhs) { for m in 0..M { for n in 0..N { - self.data[m][n].$ops_fn(r); + self.data[m][n].$op_fn(r); } } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index dadbb35..a5e031f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,5 +1,7 @@ -use num_traits::Float; use std::iter::zip; + +use num_traits::{Float, NumCast, NumOps}; + use vector_victor::Matrix; pub trait Approx: PartialEq { @@ -9,7 +11,7 @@ pub trait Approx: PartialEq { } macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) } -multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize); +multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, bool); impl Approx for f32 { fn approx(left: &f32, right: &f32) -> bool { @@ -29,6 +31,10 @@ impl Approx for Matrix(left: &T, right: &T) -> bool { + T::approx(left, right) +} + #[macro_export] macro_rules! assert_approx { ($left:expr, $right:expr $(,)?) => { @@ -41,14 +47,20 @@ macro_rules! assert_approx { ($left:expr, $right:expr, $($arg:tt)+) => { match (&$left, &$right) { (left_val, right_val) => { - pub fn approx(left: &T, right: &T) -> bool { - T::approx(left, right) - } - if !approx(left_val, right_val){ + + if !common::approx(left_val, right_val){ assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors } } } }; } + +pub fn step(start: U, step: U) -> impl Iterator +where + T: NumCast, + U: NumOps + NumCast + Copy, +{ + (0usize..).map_while(move |i| T::from((U::from(i)? + start) * step)) +} diff --git a/tests/decompose.rs b/tests/decompose.rs index 3a07d56..7ad8a77 100644 --- a/tests/decompose.rs +++ b/tests/decompose.rs @@ -11,7 +11,7 @@ use std::fmt::Debug; use vector_victor::decompose::{LUDecompose, LUDecomposition, Parity}; use vector_victor::{Matrix, Vector}; -#[parameterize(S = (f32, f64), M = [1,2,3,4])] +#[parameterize(S = (f32,), M = [1,2,3,4], fmt="{fn}_{M}x{M}")] #[test] /// The LU decomposition of the identity matrix should produce /// the identity matrix with no permutations and parity 1 @@ -55,7 +55,7 @@ where assert_eq!(decomp.separate(), (i, i)); } -#[parameterize(S = (f32, f64), M = [2,3,4])] +#[parameterize(S = (f32,), M = [2,3,4], fmt="{fn}_{M}x{M}")] #[test] /// The LU decomposition of any singular matrix should be `None` fn test_lu_singular() diff --git a/tests/ops.rs b/tests/ops.rs index 5ccd12e..ed4198c 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -1,21 +1,146 @@ #[macro_use] mod common; +use crate::common::{step, Approx}; use generic_parameterize::parameterize; +use num_traits::{NumAssign, NumCast}; use std::fmt::Debug; -use std::ops; +use std::iter::zip; +use std::ops::*; use vector_victor::Matrix; -#[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])] +#[parameterize(S = (i32, f32), M = [1,4], N = [1,4], fmt = "{fn}_{S}_{M}x{N}")] #[test] -fn test_add + PartialEq + Debug, const M: usize, const N: usize>() -where - Matrix: ops::Add>, -{ - let a = Matrix::::fill(S::from(1)); - let b = Matrix::::fill(S::from(3)); - let c: Matrix = a + b; - for (_, ci) in c.elements().enumerate() { - assert_eq!(*ci, S::from(4)); - } +fn neg< + S: Copy + NumCast + NumAssign + Approx + Default + Debug + Neg, + const M: usize, + const N: usize, +>() { + let a: Matrix = step(-2, 2).collect(); + let expected: Matrix = a.elements().map(|&a| -a).collect(); + + assert_approx!(-a, expected, "Incorrect value for negation"); +} + +#[parameterize(S = (i32, u32), M = [1,4], N = [1,4], fmt = "{fn}_{S}_{M}x{N}")] +#[test] +fn not< + S: Copy + NumCast + NumAssign + Approx + Default + Debug + Not, + const M: usize, + const N: usize, +>() { + let a: Matrix = step(-2, 2).collect(); + let expected: Matrix = a.elements().map(|&a| !a).collect(); + + assert_approx!(!a, expected, "Incorrect value for inversion"); +} + +#[parameterize(M = [1,4], N = [1,4], fmt="{fn}_{M}x{N}")] +#[test] +fn not_bool() { + let a: Matrix = [true, true, false].iter().cycle().copied().collect(); + let expected: Matrix = [false, false, true].iter().cycle().copied().collect(); + + assert_approx!(!a, expected, "Incorrect value for inversion"); +} + +macro_rules! test_op { + {$op_trait:ident::$op_fn:ident, $op_assign_trait:ident::$op_assign_fn:ident, + $op_name:literal, $t:ty} => { + #[parameterize(S = $t, M = [1,4], N = [1,4], fmt="{fn}_{S}_{M}x{N}")] + #[test] + fn $op_fn< + S: Copy + NumCast + NumAssign + Approx + Default + Debug + + $op_trait, + const M: usize, + const N: usize, + >() { + let a: Matrix = step(2, 3).collect(); + let b: Matrix = step(1, 2).collect(); + let expected: Matrix = zip(a, b).map(|(aa, bb)| $op_trait::$op_fn(aa, bb)).collect(); + + assert_approx!($op_trait::$op_fn(a, b), expected, "Incorrect value for {}", $op_name); + assert_approx!($op_trait::$op_fn(a, &b), expected, "Incorrect value for {}", $op_name); + assert_approx!($op_trait::$op_fn(&a, b), expected, "Incorrect value for {}", $op_name); + assert_approx!($op_trait::$op_fn(&a, &b), expected, "Incorrect value for {}", $op_name); + + let s: S = S::from(2).unwrap(); + let expected: Matrix = a.elements().map(|&aa| $op_trait::$op_fn(aa, s)).collect(); + + assert_approx!($op_trait::$op_fn(a, s), expected, "Incorrect value for {} by scalar", $op_name); + assert_approx!($op_trait::$op_fn(&a, s), expected, "Incorrect value for {} by scalar", $op_name); + + } + + #[parameterize(S = $t, M = [1,4], N = [1,4])] + #[test] + fn $op_assign_fn< + S: Copy + NumCast + NumAssign + Approx + Default + Debug + + $op_trait + $op_assign_trait, + const M: usize, + const N: usize, + >() { + let a: Matrix = step(2, 3).collect(); + let b: Matrix = step(1, 2).collect(); + let expected: Matrix = zip(a, b).map(|(aa, bb)| $op_trait::$op_fn(aa, bb)).collect(); + + let mut c = a; + $op_assign_trait::$op_assign_fn(&mut c, b); + assert_approx!(c, expected, "Incorrect value for {}-assignment", $op_name); + + let mut c = a; + $op_assign_trait::$op_assign_fn(&mut c, &b); + assert_approx!(c, expected, "Incorrect value for {}-assignment", $op_name); + + let s: S = S::from(2).unwrap(); + let expected: Matrix = a.elements().map(|&aa| $op_trait::$op_fn(aa, s)).collect(); + + let mut c = a; + $op_assign_trait::$op_assign_fn(&mut c, s); + assert_approx!(c, expected, "Incorrect value for {}-assignment by scalar", $op_name); + } + }; } + +test_op!(Add::add, AddAssign::add_assign, "addition", (i32, u32, f32)); + +test_op!( + Sub::sub, + SubAssign::sub_assign, + "subtraction", + (i32, u32, f32) +); + +test_op!( + Mul::mul, + MulAssign::mul_assign, + "multiplication", + (i32, u32, f32) +); +test_op!(Div::div, DivAssign::div_assign, "division", (i32, u32, f32)); + +test_op!( + Rem::rem, + RemAssign::rem_assign, + "remainder", + (i32, u32, f32) +); + +test_op!(BitOr::bitor, BitOrAssign::bitor_assign, "or", (i32, u32)); + +test_op!( + BitAnd::bitand, + BitAndAssign::bitand_assign, + "and", + (i32, u32) +); + +test_op!( + BitXor::bitxor, + BitXorAssign::bitxor_assign, + "xor", + (i32, u32) +); + +test_op!(Shl::shl, ShlAssign::shl_assign, "shift-left", (usize,)); +test_op!(Shr::shr, ShrAssign::shr_assign, "shift-right", (usize,));