Refactor ops.

all binary operations now support operating by a scalar
This commit is contained in:
Andrew Cassidy 2023-06-14 00:04:24 -07:00
parent b05a8172c0
commit b29bcc867d
5 changed files with 216 additions and 106 deletions

View File

@ -6,12 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
generic_parameterize = "0.1.0"
generic_parameterize = "0.2.0"
itertools = "0.10.5"
num-traits = "0.2.15"
[dev-dependencies]
impls = "1.0.3"
[package.metadata.docs.rs]
rustdoc-args = [
"--html-in-header",
"doc/mathjax.html",
]
]

View File

@ -5,118 +5,88 @@ use num_traits::Num;
#[doc(hidden)]
macro_rules! impl_matrix_op {
(neg) => {
_impl_op_m_internal_ex!(Neg, neg);
_impl_op_unary_ex!(Neg::neg);
};
(!) => {
_impl_op_m_internal_ex!(Not, not);
_impl_op_unary_ex!(Not::not);
};
(+) => {
_impl_op_mm_internal_ex!(Add, add);
_impl_opassign_mm_internal_ex!(AddAssign, add_assign);
_impl_op_binary_ex!(Add::add, AddAssign::add_assign);
};
(-) => {
_impl_op_mm_internal_ex!(Sub, sub);
_impl_opassign_mm_internal_ex!(SubAssign, sub_assign);
_impl_op_binary_ex!(Sub::sub, SubAssign::sub_assign);
};
(*) => {
_impl_op_mm_internal_ex!(Mul, mul);
_impl_op_ms_internal_ex!(Mul, mul);
_impl_opassign_mm_internal_ex!(MulAssign, mul_assign);
_impl_opassign_ms_internal_ex!(MulAssign, mul_assign);
_impl_op_binary_ex!(Mul::mul, MulAssign::mul_assign);
};
(/) => {
_impl_op_mm_internal_ex!(Div, div);
_impl_op_ms_internal_ex!(Div, div);
_impl_opassign_mm_internal_ex!(DivAssign, div_assign);
_impl_opassign_ms_internal_ex!(DivAssign, div_assign);
_impl_op_binary_ex!(Div::div, DivAssign::div_assign);
};
(%) => {
_impl_op_mm_internal_ex!(Rem, rem);
_impl_op_ms_internal_ex!(Rem, rem);
_impl_opassign_mm_internal_ex!(RemAssign, rem_assign);
_impl_opassign_ms_internal_ex!(RemAssign, rem_assign);
_impl_op_binary_ex!(Rem::rem, RemAssign::rem_assign);
};
(&) => {
_impl_op_mm_internal_ex!(BitAnd, bitand);
_impl_opassign_mm_internal_ex!(BitAndAssign, bitand_assign);
_impl_op_binary_ex!(BitAnd::bitand, BitAndAssign::bitand_assign);
};
(|) => {
_impl_op_mm_internal_ex!(BitOr, bitor);
_impl_opassign_mm_internal_ex!(BitOrAssign, bitor_assign);
_impl_op_binary_ex!(BitOr::bitor, BitOrAssign::bitor_assign);
};
(^) => {
_impl_op_mm_internal_ex!(BitXor, bitxor);
_impl_opassign_mm_internal_ex!(BitXorAssign, bitxor_assign);
_impl_op_binary_ex!(BitXor::bitxor, BitXorAssign::bitxor_assign);
};
(<<) => {
_impl_op_ms_internal_ex!(Shl, shl);
_impl_opassign_ms_internal_ex!(ShlAssign, shl_assign);
_impl_op_binary_ex!(Shl::shl, ShlAssign::shl_assign);
};
(>>) => {
_impl_op_ms_internal_ex!(Shr, shr);
_impl_opassign_ms_internal_ex!(ShrAssign, shr_assign);
_impl_op_binary_ex!(Shr::shr, ShrAssign::shr_assign);
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! _impl_op_m_internal_ex {
($ops_trait:ident, $ops_fn:ident) => {
_impl_op_m_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, Matrix<L,M,N>);
_impl_op_m_internal!($ops_trait, $ops_fn, &Matrix<L,M,N>, Matrix<L,M,N>);
macro_rules! _impl_op_unary_ex {
($op_trait:ident::$op_fn:ident) => {
_impl_op_m_internal!($op_trait, $op_fn, Matrix<L,M,N>, Matrix<L,M,N>);
_impl_op_m_internal!($op_trait, $op_fn, &Matrix<L,M,N>, Matrix<L,M,N>);
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! _impl_op_mm_internal_ex {
($ops_trait:ident, $ops_fn:ident) => {
_impl_op_mm_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, Matrix<R,M,N>, Matrix<L,M,N>);
_impl_op_mm_internal!($ops_trait, $ops_fn, &Matrix<L,M,N>, Matrix<R,M,N>, Matrix<L,M,N>);
_impl_op_mm_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, &Matrix<R,M,N>, Matrix<L,M,N>);
_impl_op_mm_internal!($ops_trait, $ops_fn, &Matrix<L,M,N>, &Matrix<R,M,N>, Matrix<L,M,N>);
}
}
macro_rules! _impl_op_binary_ex {
($op_trait:ident::$op_fn:ident, $op_assign_trait:ident::$op_assign_fn:ident) => {
_impl_op_mm_internal!($op_trait, $op_fn, Matrix<L,M,N>, Matrix<R,M,N>, Matrix<L,M,N>);
_impl_op_mm_internal!($op_trait, $op_fn, &Matrix<L,M,N>, Matrix<R,M,N>, Matrix<L,M,N>);
_impl_op_mm_internal!($op_trait, $op_fn, Matrix<L,M,N>, &Matrix<R,M,N>, Matrix<L,M,N>);
_impl_op_mm_internal!($op_trait, $op_fn, &Matrix<L,M,N>, &Matrix<R,M,N>, Matrix<L,M,N>);
#[doc(hidden)]
#[macro_export]
macro_rules! _impl_opassign_mm_internal_ex {
($ops_trait:ident, $ops_fn:ident) => {
_impl_opassign_mm_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, Matrix<R,M,N>, Matrix<L,M,N>);
_impl_opassign_mm_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, &Matrix<R,M,N>, Matrix<L,M,N>);
}
}
_impl_op_ms_internal!($op_trait, $op_fn, Matrix<L,M,N>, R, Matrix<L,M,N>);
_impl_op_ms_internal!($op_trait, $op_fn, &Matrix<L,M,N>, R, Matrix<L,M,N>);
#[doc(hidden)]
macro_rules! _impl_op_ms_internal_ex {
($ops_trait:ident, $ops_fn:ident) => {
_impl_op_ms_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, R, Matrix<L,M,N>);
_impl_op_ms_internal!($ops_trait, $ops_fn, &Matrix<L,M,N>, R, Matrix<L,M,N>);
}
}
_impl_opassign_mm_internal!($op_assign_trait, $op_assign_fn, Matrix<L,M,N>, Matrix<R,M,N>, Matrix<L,M,N>);
_impl_opassign_mm_internal!($op_assign_trait, $op_assign_fn, Matrix<L,M,N>, &Matrix<R,M,N>, Matrix<L,M,N>);
_impl_opassign_ms_internal!($op_assign_trait, $op_assign_fn, Matrix<L,M,N>, R, Matrix<L,M,N>);
#[doc(hidden)]
macro_rules! _impl_opassign_ms_internal_ex {
($ops_trait:ident, $ops_fn:ident) => {
_impl_opassign_ms_internal!($ops_trait, $ops_fn, Matrix<L,M,N>, R, Matrix<L,M,N>);
}
}
#[doc(hidden)]
macro_rules! _impl_op_m_internal {
($ops_trait:ident, $ops_fn:ident, $lhs:ty, $out:ty) => {
impl<L, const M: usize, const N: usize> ::std::ops::$ops_trait for $lhs
($op_trait:ident, $op_fn:ident, $lhs:ty, $out:ty) => {
impl<L, const M: usize, const N: usize> ::std::ops::$op_trait for $lhs
where
L: ::std::ops::$ops_trait<Output = L> + Copy,
L: ::std::ops::$op_trait<Output = L> + Copy,
{
type Output = $out;
#[inline(always)]
fn $ops_fn(self) -> Self::Output {
fn $op_fn(self) -> Self::Output {
let mut result = self.clone();
// we arnt using iterators because they dont seem to always vectorize correctly
for m in 0..M {
for n in 0..N {
result.data[m][n] = self.data[m][n].$ops_fn();
result.data[m][n] = self.data[m][n].$op_fn();
}
}
result
@ -127,20 +97,20 @@ macro_rules! _impl_op_m_internal {
#[doc(hidden)]
macro_rules! _impl_op_mm_internal {
($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$ops_trait<$rhs> for $lhs
($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$op_trait<$rhs> for $lhs
where
L: ::std::ops::$ops_trait<R, Output = L> + Copy,
L: ::std::ops::$op_trait<R, Output = L> + Copy,
R: Copy,
{
type Output = $out;
#[inline(always)]
fn $ops_fn(self, other: $rhs) -> Self::Output {
fn $op_fn(self, other: $rhs) -> Self::Output {
let mut result = self.clone();
for m in 0..M {
for n in 0..N {
result.data[m][n] = self.data[m][n].$ops_fn(other.data[m][n]);
result.data[m][n] = self.data[m][n].$op_fn(other.data[m][n]);
}
}
result
@ -151,17 +121,17 @@ macro_rules! _impl_op_mm_internal {
#[doc(hidden)]
macro_rules! _impl_opassign_mm_internal {
($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$ops_trait<$rhs> for $lhs
($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$op_trait<$rhs> for $lhs
where
L: ::std::ops::$ops_trait<R> + Copy,
L: ::std::ops::$op_trait<R> + Copy,
R: Copy,
{
#[inline(always)]
fn $ops_fn(&mut self, other: $rhs) {
fn $op_fn(&mut self, other: $rhs) {
for m in 0..M {
for n in 0..N {
self.data[m][n].$ops_fn(other.data[m][n]);
self.data[m][n].$op_fn(other.data[m][n]);
}
}
}
@ -171,20 +141,20 @@ macro_rules! _impl_opassign_mm_internal {
#[doc(hidden)]
macro_rules! _impl_op_ms_internal {
($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$ops_trait<$rhs> for $lhs
($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$op_trait<$rhs> for $lhs
where
L: ::std::ops::$ops_trait<R, Output = L> + Copy,
L: ::std::ops::$op_trait<R, Output = L> + Copy,
R: Copy + Num,
{
type Output = $out;
#[inline(always)]
fn $ops_fn(self, other: $rhs) -> Self::Output {
fn $op_fn(self, other: $rhs) -> Self::Output {
let mut result = self.clone();
for m in 0..M {
for n in 0..N {
result.data[m][n] = self.data[m][n].$ops_fn(other);
result.data[m][n] = self.data[m][n].$op_fn(other);
}
}
result
@ -195,17 +165,17 @@ macro_rules! _impl_op_ms_internal {
#[doc(hidden)]
macro_rules! _impl_opassign_ms_internal {
($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$ops_trait<$rhs> for $lhs
($op_trait:ident, $op_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => {
impl<L, R, const M: usize, const N: usize> ::std::ops::$op_trait<$rhs> for $lhs
where
L: ::std::ops::$ops_trait<R> + Copy,
L: ::std::ops::$op_trait<R> + Copy,
R: Copy + Num,
{
#[inline(always)]
fn $ops_fn(&mut self, r: $rhs) {
fn $op_fn(&mut self, r: $rhs) {
for m in 0..M {
for n in 0..N {
self.data[m][n].$ops_fn(r);
self.data[m][n].$op_fn(r);
}
}
}

View File

@ -1,5 +1,7 @@
use num_traits::Float;
use std::iter::zip;
use num_traits::{Float, NumCast, NumOps};
use vector_victor::Matrix;
pub trait Approx: PartialEq {
@ -9,7 +11,7 @@ pub trait Approx: PartialEq {
}
macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) }
multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, bool);
impl Approx for f32 {
fn approx(left: &f32, right: &f32) -> bool {
@ -29,6 +31,10 @@ impl<T: Copy + Approx, const M: usize, const N: usize> Approx for Matrix<T, M, N
}
}
pub fn approx<T: Approx>(left: &T, right: &T) -> bool {
T::approx(left, right)
}
#[macro_export]
macro_rules! assert_approx {
($left:expr, $right:expr $(,)?) => {
@ -41,14 +47,20 @@ macro_rules! assert_approx {
($left:expr, $right:expr, $($arg:tt)+) => {
match (&$left, &$right) {
(left_val, right_val) => {
pub fn approx<T: Approx>(left: &T, right: &T) -> bool {
T::approx(left, right)
}
if !approx(left_val, right_val){
if !common::approx(left_val, right_val){
assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors
}
}
}
};
}
pub fn step<T, U>(start: U, step: U) -> impl Iterator<Item = T>
where
T: NumCast,
U: NumOps + NumCast + Copy,
{
(0usize..).map_while(move |i| T::from((U::from(i)? + start) * step))
}

View File

@ -11,7 +11,7 @@ use std::fmt::Debug;
use vector_victor::decompose::{LUDecompose, LUDecomposition, Parity};
use vector_victor::{Matrix, Vector};
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
#[parameterize(S = (f32,), M = [1,2,3,4], fmt="{fn}_{M}x{M}")]
#[test]
/// The LU decomposition of the identity matrix should produce
/// the identity matrix with no permutations and parity 1
@ -55,7 +55,7 @@ where
assert_eq!(decomp.separate(), (i, i));
}
#[parameterize(S = (f32, f64), M = [2,3,4])]
#[parameterize(S = (f32,), M = [2,3,4], fmt="{fn}_{M}x{M}")]
#[test]
/// The LU decomposition of any singular matrix should be `None`
fn test_lu_singular<S, const M: usize>()

View File

@ -1,21 +1,146 @@
#[macro_use]
mod common;
use crate::common::{step, Approx};
use generic_parameterize::parameterize;
use num_traits::{NumAssign, NumCast};
use std::fmt::Debug;
use std::ops;
use std::iter::zip;
use std::ops::*;
use vector_victor::Matrix;
#[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])]
#[parameterize(S = (i32, f32), M = [1,4], N = [1,4], fmt = "{fn}_{S}_{M}x{N}")]
#[test]
fn test_add<S: Copy + From<u16> + PartialEq + Debug, const M: usize, const N: usize>()
where
Matrix<S, M, N>: ops::Add<Output = Matrix<S, M, N>>,
{
let a = Matrix::<S, M, N>::fill(S::from(1));
let b = Matrix::<S, M, N>::fill(S::from(3));
let c: Matrix<S, M, N> = a + b;
for (_, ci) in c.elements().enumerate() {
assert_eq!(*ci, S::from(4));
}
fn neg<
S: Copy + NumCast + NumAssign + Approx + Default + Debug + Neg<Output = S>,
const M: usize,
const N: usize,
>() {
let a: Matrix<S, M, N> = step(-2, 2).collect();
let expected: Matrix<S, M, N> = a.elements().map(|&a| -a).collect();
assert_approx!(-a, expected, "Incorrect value for negation");
}
#[parameterize(S = (i32, u32), M = [1,4], N = [1,4], fmt = "{fn}_{S}_{M}x{N}")]
#[test]
fn not<
S: Copy + NumCast + NumAssign + Approx + Default + Debug + Not<Output = S>,
const M: usize,
const N: usize,
>() {
let a: Matrix<S, M, N> = step(-2, 2).collect();
let expected: Matrix<S, M, N> = a.elements().map(|&a| !a).collect();
assert_approx!(!a, expected, "Incorrect value for inversion");
}
#[parameterize(M = [1,4], N = [1,4], fmt="{fn}_{M}x{N}")]
#[test]
fn not_bool<const M: usize, const N: usize>() {
let a: Matrix<bool, M, N> = [true, true, false].iter().cycle().copied().collect();
let expected: Matrix<bool, M, N> = [false, false, true].iter().cycle().copied().collect();
assert_approx!(!a, expected, "Incorrect value for inversion");
}
macro_rules! test_op {
{$op_trait:ident::$op_fn:ident, $op_assign_trait:ident::$op_assign_fn:ident,
$op_name:literal, $t:ty} => {
#[parameterize(S = $t, M = [1,4], N = [1,4], fmt="{fn}_{S}_{M}x{N}")]
#[test]
fn $op_fn<
S: Copy + NumCast + NumAssign + Approx + Default + Debug
+ $op_trait<S, Output=S>,
const M: usize,
const N: usize,
>() {
let a: Matrix<S, M, N> = step(2, 3).collect();
let b: Matrix<S, M, N> = step(1, 2).collect();
let expected: Matrix<S, M, N> = zip(a, b).map(|(aa, bb)| $op_trait::$op_fn(aa, bb)).collect();
assert_approx!($op_trait::$op_fn(a, b), expected, "Incorrect value for {}", $op_name);
assert_approx!($op_trait::$op_fn(a, &b), expected, "Incorrect value for {}", $op_name);
assert_approx!($op_trait::$op_fn(&a, b), expected, "Incorrect value for {}", $op_name);
assert_approx!($op_trait::$op_fn(&a, &b), expected, "Incorrect value for {}", $op_name);
let s: S = S::from(2).unwrap();
let expected: Matrix<S, M, N> = a.elements().map(|&aa| $op_trait::$op_fn(aa, s)).collect();
assert_approx!($op_trait::$op_fn(a, s), expected, "Incorrect value for {} by scalar", $op_name);
assert_approx!($op_trait::$op_fn(&a, s), expected, "Incorrect value for {} by scalar", $op_name);
}
#[parameterize(S = $t, M = [1,4], N = [1,4])]
#[test]
fn $op_assign_fn<
S: Copy + NumCast + NumAssign + Approx + Default + Debug
+ $op_trait<S, Output=S> + $op_assign_trait<S>,
const M: usize,
const N: usize,
>() {
let a: Matrix<S, M, N> = step(2, 3).collect();
let b: Matrix<S, M, N> = step(1, 2).collect();
let expected: Matrix<S, M, N> = zip(a, b).map(|(aa, bb)| $op_trait::$op_fn(aa, bb)).collect();
let mut c = a;
$op_assign_trait::$op_assign_fn(&mut c, b);
assert_approx!(c, expected, "Incorrect value for {}-assignment", $op_name);
let mut c = a;
$op_assign_trait::$op_assign_fn(&mut c, &b);
assert_approx!(c, expected, "Incorrect value for {}-assignment", $op_name);
let s: S = S::from(2).unwrap();
let expected: Matrix<S, M, N> = a.elements().map(|&aa| $op_trait::$op_fn(aa, s)).collect();
let mut c = a;
$op_assign_trait::$op_assign_fn(&mut c, s);
assert_approx!(c, expected, "Incorrect value for {}-assignment by scalar", $op_name);
}
};
}
test_op!(Add::add, AddAssign::add_assign, "addition", (i32, u32, f32));
test_op!(
Sub::sub,
SubAssign::sub_assign,
"subtraction",
(i32, u32, f32)
);
test_op!(
Mul::mul,
MulAssign::mul_assign,
"multiplication",
(i32, u32, f32)
);
test_op!(Div::div, DivAssign::div_assign, "division", (i32, u32, f32));
test_op!(
Rem::rem,
RemAssign::rem_assign,
"remainder",
(i32, u32, f32)
);
test_op!(BitOr::bitor, BitOrAssign::bitor_assign, "or", (i32, u32));
test_op!(
BitAnd::bitand,
BitAndAssign::bitand_assign,
"and",
(i32, u32)
);
test_op!(
BitXor::bitxor,
BitXorAssign::bitxor_assign,
"xor",
(i32, u32)
);
test_op!(Shl::shl, ShlAssign::shl_assign, "shift-left", (usize,));
test_op!(Shr::shr, ShrAssign::shr_assign, "shift-right", (usize,));