From 8fcb032b1abe87dceebafaa2b755a2e211034fdc Mon Sep 17 00:00:00 2001 From: Andrew Cassidy Date: Sat, 6 May 2023 00:39:06 -0700 Subject: [PATCH] Seperate LU tests into their own file --- tests/common/mod.rs | 57 +++++++++++++++++ tests/decomposition.rs | 117 +++++++++++++++++++++++++++++++++++ tests/ops.rs | 135 ++--------------------------------------- 3 files changed, 180 insertions(+), 129 deletions(-) create mode 100644 tests/common/mod.rs create mode 100644 tests/decomposition.rs diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..f301ac3 --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,57 @@ +use num_traits::Float; +use std::iter::zip; +use vector_victor::Matrix; + +pub trait Approx: PartialEq { + fn approx(left: &Self, right: &Self) -> bool { + left == right + } +} + +macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) } +multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize); + +impl Approx for f32 { + fn approx(left: &f32, right: &f32) -> bool { + f32::abs(left - right) <= f32::epsilon() + } +} + +impl Approx for f64 { + fn approx(left: &f64, right: &f64) -> bool { + f64::abs(left - right) <= f32::epsilon() as f64 + } +} + +impl Approx for Matrix { + fn approx(left: &Self, right: &Self) -> bool { + zip(left.elements(), right.elements()).all(|(l, r)| T::approx(l, r)) + } +} + +pub fn approx(left: &T, right: &T) -> bool { + T::approx(left, right) +} + +macro_rules! assert_approx { + ($left:expr, $right:expr $(,)?) => { + match (&$left, &$right) { + (_left_val, _right_val) => { + assert_approx!($left, $right, "Difference is less than epsilon") + } + } + }; + ($left:expr, $right:expr, $($arg:tt)+) => { + match (&$left, &$right) { + (left_val, right_val) => { + pub fn approx(left: &T, right: &T) -> bool { + T::approx(left, right) + } + + if !approx(left_val, right_val){ + assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors + } + } + } + }; +} diff --git a/tests/decomposition.rs b/tests/decomposition.rs new file mode 100644 index 0000000..5f42744 --- /dev/null +++ b/tests/decomposition.rs @@ -0,0 +1,117 @@ +#[macro_use] +mod common; + +use common::Approx; +use generic_parameterize::parameterize; +use num_traits::real::Real; +use num_traits::Zero; +use std::fmt::Debug; +use std::iter::{zip, Product, Sum}; +use vector_victor::{LUSolve, Matrix, Vector}; + +#[parameterize(S = (f32, f64), M = [1,2,3,4])] +#[test] +/// The LU decomposition of the identity matrix should produce +/// the identity matrix with no permutations and parity 1 +fn test_lu_identity() { + // let a: Matrix = Matrix::::identity(); + let i = Matrix::::identity(); + let ones = Vector::::fill(S::one()); + let decomp = i.lu().expect("Singular matrix encountered"); + let (lu, idx, d) = decomp; + assert_eq!(lu, i, "Incorrect LU decomposition"); + assert!( + (0..M).eq(idx.elements().cloned()), + "Incorrect permutation matrix", + ); + assert_approx!(d, S::one(), "Incorrect permutation parity"); + + // Check determinant calculation which uses LU decomposition + assert_approx!( + i.det(), + S::one(), + "Identity matrix should have determinant of 1" + ); + + // Check inverse calculation with uses LU decomposition + assert_eq!( + i.inverse(), + Some(i), + "Identity matrix should be its own inverse" + ); + assert_eq!( + i.solve(&ones), + Some(ones), + "Failed to solve using identity matrix" + ); + + // Check triangle separation + assert_eq!(decomp.separate(), (i, i)); +} + +#[parameterize(S = (f32, f64), M = [2,3,4])] +#[test] +/// The LU decomposition of any singular matrix should be `None` +fn test_lu_singular() { + // let a: Matrix = Matrix::::identity(); + let mut a = Matrix::::zero(); + let ones = Vector::::fill(S::one()); + a.set_row(0, &ones); + + assert_eq!(a.lu(), None, "Matrix should be singular"); + assert_eq!( + a.det(), + S::zero(), + "Singular matrix should have determinant of zero" + ); + assert_eq!(a.inverse(), None, "Singular matrix should have no inverse"); + assert_eq!( + a.solve(&ones), + None, + "Singular matrix should not be solvable" + ) +} + +#[test] +fn test_lu_2x2() { + let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); + let decomp = a.lu().expect("Singular matrix encountered"); + let (_lu, idx, _d) = decomp; + // the decomposition is non-unique, due to the combination of lu and idx. + // Instead of checking the exact value, we only check the results. + // Also check if they produce the same results with both methods, since the + // Matrix<> methods use shortcuts the decomposition methods don't + + let (l, u) = decomp.separate(); + assert_approx!(l.mmul(&u), a.permute_rows(&idx)); + + assert_approx!(a.det(), -6.0); + assert_approx!(a.det(), decomp.det()); + + assert_approx!( + a.inverse().unwrap(), + Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0) + ); + assert_approx!(a.inverse().unwrap(), decomp.inverse()); + assert_approx!(a.inverse().unwrap().inverse().unwrap(), a) +} + +#[test] +fn test_lu_3x3() { + let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]); + let decomp = a.lu().expect("Singular matrix encountered"); + let (_lu, idx, _d) = decomp; + + let (l, u) = decomp.separate(); + assert_approx!(l.mmul(&u), a.permute_rows(&idx)); + + assert_approx!(a.det(), 3.0); + assert_approx!(a.det(), decomp.det()); + + assert_approx!( + a.inverse().unwrap(), + Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0) + ); + assert_approx!(a.inverse().unwrap(), decomp.inverse()); + assert_approx!(a.inverse().unwrap().inverse().unwrap(), a) +} diff --git a/tests/ops.rs b/tests/ops.rs index 219864a..140c69a 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -1,3 +1,7 @@ +#[macro_use] +mod common; + +use common::Approx; use generic_parameterize::parameterize; use num_traits::real::Real; use num_traits::Zero; @@ -6,56 +10,7 @@ use std::iter::{zip, Product, Sum}; use std::ops; use vector_victor::{LUSolve, Matrix, Vector}; -macro_rules! scalar_eq { - ($left:expr, $right:expr $(,)?) => { - match (&$left, &$right) { - (_left_val, _right_val) => { - scalar_eq!($left, $right, "Difference is less than epsilon") - } - } - }; - ($left:expr, $right:expr, $($arg:tt)+) => { - match (&$left, &$right) { - (left_val, right_val) => { - let epsilon = f32::epsilon() as f64; - let lf : f64 = (*left_val).into(); - let rf : f64 = (*right_val).into(); - let diff : f64 = (lf - rf).abs(); - if diff >= epsilon { - assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors - } - } - } - }; -} - -macro_rules! matrix_eq { - ($left:expr, $right:expr $(,)?) => { - match (&$left, &$right) { - (_left_val, _right_val) => { - matrix_eq!($left, $right, "Difference is less than epsilon") - } - } - }; - ($left:expr, $right:expr, $($arg:tt)+) => { - match (&$left, &$right) { - (left_val, right_val) => { - let epsilon = f32::epsilon() as f64; - for (l, r) in zip(left_val.elements(), right_val.elements()) { - let lf : f64 = (*l).into(); - let rf : f64 = (*r).into(); - let diff : f64 = (lf - rf).abs(); - if diff >= epsilon { - assert_eq!($left, $right, $($arg)+) // done this way to get nice errors - } - } - - } - } - }; -} - -#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])] +#[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])] #[test] fn test_add + PartialEq + Debug, const M: usize, const N: usize>() where @@ -64,85 +19,7 @@ where let a = Matrix::::fill(S::from(1)); let b = Matrix::::fill(S::from(3)); let c: Matrix = a + b; - for (i, ci) in c.elements().enumerate() { + for (_, ci) in c.elements().enumerate() { assert_eq!(*ci, S::from(4)); } } - -#[parameterize(S = (f32, f64), M = [1,2,3,4])] -#[test] -fn test_lu_identity, const M: usize>() { - // let a: Matrix = Matrix::::identity(); - let i = Matrix::::identity(); - let ones = Vector::::fill(S::one()); - let decomp = i.lu().expect("Singular matrix encountered"); - let (lu, idx, d) = decomp; - assert_eq!(lu, i, "Incorrect LU decomposition"); - assert!( - (0..M).eq(idx.elements().cloned()), - "Incorrect permutation matrix", - ); - scalar_eq!(d, S::one(), "Incorrect permutation parity"); - scalar_eq!(i.det(), S::one()); - assert_eq!(i.inverse(), Some(i)); - assert_eq!(i.solve(&ones), Some(ones)); - assert_eq!(decomp.separate(), (i, i)); -} - -#[parameterize(S = (f32, f64), M = [2,3,4])] -#[test] -fn test_lu_singular() { - // let a: Matrix = Matrix::::identity(); - let mut a = Matrix::::zero(); - let ones = Vector::::fill(S::one()); - a.set_row(0, &ones); - - assert_eq!(a.lu(), None, "Matrix should be singular"); - assert_eq!(a.det(), S::zero()); - assert_eq!(a.inverse(), None); - assert_eq!(a.solve(&ones), None) -} - -#[test] -fn test_lu_2x2() { - let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); - let decomp = a.lu().expect("Singular matrix encountered"); - let (_lu, idx, _d) = decomp; - // the decomposition is non-unique, due to the combination of lu and idx. - // Instead of checking the exact value, we only check the results. - // Also check if they produce the same results with both methods, since the - // Matrix<> methods use shortcuts the decomposition methods don't - - let (l, u) = decomp.separate(); - matrix_eq!(l.mmul(&u), a.permute_rows(&idx)); - - scalar_eq!(a.det(), -6.0); - scalar_eq!(a.det(), decomp.det()); - - matrix_eq!( - a.inverse().unwrap(), - Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0) - ); - matrix_eq!(a.inverse().unwrap(), decomp.inverse()); - matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a) -} - -#[test] -fn test_lu_3x3() { - let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]); - let decomp = a.lu().expect("Singular matrix encountered"); - let (_lu, idx, _d) = decomp; - - let (l, u) = decomp.separate(); - matrix_eq!(l.mmul(&u), a.permute_rows(&idx)); - - scalar_eq!(a.det(), 3.0); - scalar_eq!(a.det(), decomp.det()); - - matrix_eq!( - a.inverse().unwrap(), - Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0) - ); - matrix_eq!(a.inverse().unwrap(), decomp.inverse()); - matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a) -}