mirror of
https://github.com/drewcassidy/vector-victor.git
synced 2024-09-01 14:58:35 +00:00
Seperate LU tests into their own file
This commit is contained in:
parent
df3c2b4ba9
commit
8fcb032b1a
57
tests/common/mod.rs
Normal file
57
tests/common/mod.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use num_traits::Float;
|
||||
use std::iter::zip;
|
||||
use vector_victor::Matrix;
|
||||
|
||||
pub trait Approx: PartialEq {
|
||||
fn approx(left: &Self, right: &Self) -> bool {
|
||||
left == right
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) }
|
||||
multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
|
||||
|
||||
impl Approx for f32 {
|
||||
fn approx(left: &f32, right: &f32) -> bool {
|
||||
f32::abs(left - right) <= f32::epsilon()
|
||||
}
|
||||
}
|
||||
|
||||
impl Approx for f64 {
|
||||
fn approx(left: &f64, right: &f64) -> bool {
|
||||
f64::abs(left - right) <= f32::epsilon() as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy + Approx, const M: usize, const N: usize> Approx for Matrix<T, M, N> {
|
||||
fn approx(left: &Self, right: &Self) -> bool {
|
||||
zip(left.elements(), right.elements()).all(|(l, r)| T::approx(l, r))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn approx<T: Approx>(left: &T, right: &T) -> bool {
|
||||
T::approx(left, right)
|
||||
}
|
||||
|
||||
macro_rules! assert_approx {
|
||||
($left:expr, $right:expr $(,)?) => {
|
||||
match (&$left, &$right) {
|
||||
(_left_val, _right_val) => {
|
||||
assert_approx!($left, $right, "Difference is less than epsilon")
|
||||
}
|
||||
}
|
||||
};
|
||||
($left:expr, $right:expr, $($arg:tt)+) => {
|
||||
match (&$left, &$right) {
|
||||
(left_val, right_val) => {
|
||||
pub fn approx<T: Approx>(left: &T, right: &T) -> bool {
|
||||
T::approx(left, right)
|
||||
}
|
||||
|
||||
if !approx(left_val, right_val){
|
||||
assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
117
tests/decomposition.rs
Normal file
117
tests/decomposition.rs
Normal file
@ -0,0 +1,117 @@
|
||||
#[macro_use]
|
||||
mod common;
|
||||
|
||||
use common::Approx;
|
||||
use generic_parameterize::parameterize;
|
||||
use num_traits::real::Real;
|
||||
use num_traits::Zero;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::{zip, Product, Sum};
|
||||
use vector_victor::{LUSolve, Matrix, Vector};
|
||||
|
||||
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
|
||||
#[test]
|
||||
/// The LU decomposition of the identity matrix should produce
|
||||
/// the identity matrix with no permutations and parity 1
|
||||
fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M: usize>() {
|
||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||
let i = Matrix::<S, M, M>::identity();
|
||||
let ones = Vector::<S, M>::fill(S::one());
|
||||
let decomp = i.lu().expect("Singular matrix encountered");
|
||||
let (lu, idx, d) = decomp;
|
||||
assert_eq!(lu, i, "Incorrect LU decomposition");
|
||||
assert!(
|
||||
(0..M).eq(idx.elements().cloned()),
|
||||
"Incorrect permutation matrix",
|
||||
);
|
||||
assert_approx!(d, S::one(), "Incorrect permutation parity");
|
||||
|
||||
// Check determinant calculation which uses LU decomposition
|
||||
assert_approx!(
|
||||
i.det(),
|
||||
S::one(),
|
||||
"Identity matrix should have determinant of 1"
|
||||
);
|
||||
|
||||
// Check inverse calculation with uses LU decomposition
|
||||
assert_eq!(
|
||||
i.inverse(),
|
||||
Some(i),
|
||||
"Identity matrix should be its own inverse"
|
||||
);
|
||||
assert_eq!(
|
||||
i.solve(&ones),
|
||||
Some(ones),
|
||||
"Failed to solve using identity matrix"
|
||||
);
|
||||
|
||||
// Check triangle separation
|
||||
assert_eq!(decomp.separate(), (i, i));
|
||||
}
|
||||
|
||||
#[parameterize(S = (f32, f64), M = [2,3,4])]
|
||||
#[test]
|
||||
/// The LU decomposition of any singular matrix should be `None`
|
||||
fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>() {
|
||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||
let mut a = Matrix::<S, M, M>::zero();
|
||||
let ones = Vector::<S, M>::fill(S::one());
|
||||
a.set_row(0, &ones);
|
||||
|
||||
assert_eq!(a.lu(), None, "Matrix should be singular");
|
||||
assert_eq!(
|
||||
a.det(),
|
||||
S::zero(),
|
||||
"Singular matrix should have determinant of zero"
|
||||
);
|
||||
assert_eq!(a.inverse(), None, "Singular matrix should have no inverse");
|
||||
assert_eq!(
|
||||
a.solve(&ones),
|
||||
None,
|
||||
"Singular matrix should not be solvable"
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_2x2() {
|
||||
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
|
||||
let decomp = a.lu().expect("Singular matrix encountered");
|
||||
let (_lu, idx, _d) = decomp;
|
||||
// the decomposition is non-unique, due to the combination of lu and idx.
|
||||
// Instead of checking the exact value, we only check the results.
|
||||
// Also check if they produce the same results with both methods, since the
|
||||
// Matrix<> methods use shortcuts the decomposition methods don't
|
||||
|
||||
let (l, u) = decomp.separate();
|
||||
assert_approx!(l.mmul(&u), a.permute_rows(&idx));
|
||||
|
||||
assert_approx!(a.det(), -6.0);
|
||||
assert_approx!(a.det(), decomp.det());
|
||||
|
||||
assert_approx!(
|
||||
a.inverse().unwrap(),
|
||||
Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)
|
||||
);
|
||||
assert_approx!(a.inverse().unwrap(), decomp.inverse());
|
||||
assert_approx!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_3x3() {
|
||||
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
|
||||
let decomp = a.lu().expect("Singular matrix encountered");
|
||||
let (_lu, idx, _d) = decomp;
|
||||
|
||||
let (l, u) = decomp.separate();
|
||||
assert_approx!(l.mmul(&u), a.permute_rows(&idx));
|
||||
|
||||
assert_approx!(a.det(), 3.0);
|
||||
assert_approx!(a.det(), decomp.det());
|
||||
|
||||
assert_approx!(
|
||||
a.inverse().unwrap(),
|
||||
Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0)
|
||||
);
|
||||
assert_approx!(a.inverse().unwrap(), decomp.inverse());
|
||||
assert_approx!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||
}
|
135
tests/ops.rs
135
tests/ops.rs
@ -1,3 +1,7 @@
|
||||
#[macro_use]
|
||||
mod common;
|
||||
|
||||
use common::Approx;
|
||||
use generic_parameterize::parameterize;
|
||||
use num_traits::real::Real;
|
||||
use num_traits::Zero;
|
||||
@ -6,56 +10,7 @@ use std::iter::{zip, Product, Sum};
|
||||
use std::ops;
|
||||
use vector_victor::{LUSolve, Matrix, Vector};
|
||||
|
||||
macro_rules! scalar_eq {
|
||||
($left:expr, $right:expr $(,)?) => {
|
||||
match (&$left, &$right) {
|
||||
(_left_val, _right_val) => {
|
||||
scalar_eq!($left, $right, "Difference is less than epsilon")
|
||||
}
|
||||
}
|
||||
};
|
||||
($left:expr, $right:expr, $($arg:tt)+) => {
|
||||
match (&$left, &$right) {
|
||||
(left_val, right_val) => {
|
||||
let epsilon = f32::epsilon() as f64;
|
||||
let lf : f64 = (*left_val).into();
|
||||
let rf : f64 = (*right_val).into();
|
||||
let diff : f64 = (lf - rf).abs();
|
||||
if diff >= epsilon {
|
||||
assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! matrix_eq {
|
||||
($left:expr, $right:expr $(,)?) => {
|
||||
match (&$left, &$right) {
|
||||
(_left_val, _right_val) => {
|
||||
matrix_eq!($left, $right, "Difference is less than epsilon")
|
||||
}
|
||||
}
|
||||
};
|
||||
($left:expr, $right:expr, $($arg:tt)+) => {
|
||||
match (&$left, &$right) {
|
||||
(left_val, right_val) => {
|
||||
let epsilon = f32::epsilon() as f64;
|
||||
for (l, r) in zip(left_val.elements(), right_val.elements()) {
|
||||
let lf : f64 = (*l).into();
|
||||
let rf : f64 = (*r).into();
|
||||
let diff : f64 = (lf - rf).abs();
|
||||
if diff >= epsilon {
|
||||
assert_eq!($left, $right, $($arg)+) // done this way to get nice errors
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])]
|
||||
#[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])]
|
||||
#[test]
|
||||
fn test_add<S: Copy + From<u16> + PartialEq + Debug, const M: usize, const N: usize>()
|
||||
where
|
||||
@ -64,85 +19,7 @@ where
|
||||
let a = Matrix::<S, M, N>::fill(S::from(1));
|
||||
let b = Matrix::<S, M, N>::fill(S::from(3));
|
||||
let c: Matrix<S, M, N> = a + b;
|
||||
for (i, ci) in c.elements().enumerate() {
|
||||
for (_, ci) in c.elements().enumerate() {
|
||||
assert_eq!(*ci, S::from(4));
|
||||
}
|
||||
}
|
||||
|
||||
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
|
||||
#[test]
|
||||
fn test_lu_identity<S: Default + Real + Debug + Product + Sum + Into<f64>, const M: usize>() {
|
||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||
let i = Matrix::<S, M, M>::identity();
|
||||
let ones = Vector::<S, M>::fill(S::one());
|
||||
let decomp = i.lu().expect("Singular matrix encountered");
|
||||
let (lu, idx, d) = decomp;
|
||||
assert_eq!(lu, i, "Incorrect LU decomposition");
|
||||
assert!(
|
||||
(0..M).eq(idx.elements().cloned()),
|
||||
"Incorrect permutation matrix",
|
||||
);
|
||||
scalar_eq!(d, S::one(), "Incorrect permutation parity");
|
||||
scalar_eq!(i.det(), S::one());
|
||||
assert_eq!(i.inverse(), Some(i));
|
||||
assert_eq!(i.solve(&ones), Some(ones));
|
||||
assert_eq!(decomp.separate(), (i, i));
|
||||
}
|
||||
|
||||
#[parameterize(S = (f32, f64), M = [2,3,4])]
|
||||
#[test]
|
||||
fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>() {
|
||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||
let mut a = Matrix::<S, M, M>::zero();
|
||||
let ones = Vector::<S, M>::fill(S::one());
|
||||
a.set_row(0, &ones);
|
||||
|
||||
assert_eq!(a.lu(), None, "Matrix should be singular");
|
||||
assert_eq!(a.det(), S::zero());
|
||||
assert_eq!(a.inverse(), None);
|
||||
assert_eq!(a.solve(&ones), None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_2x2() {
|
||||
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
|
||||
let decomp = a.lu().expect("Singular matrix encountered");
|
||||
let (_lu, idx, _d) = decomp;
|
||||
// the decomposition is non-unique, due to the combination of lu and idx.
|
||||
// Instead of checking the exact value, we only check the results.
|
||||
// Also check if they produce the same results with both methods, since the
|
||||
// Matrix<> methods use shortcuts the decomposition methods don't
|
||||
|
||||
let (l, u) = decomp.separate();
|
||||
matrix_eq!(l.mmul(&u), a.permute_rows(&idx));
|
||||
|
||||
scalar_eq!(a.det(), -6.0);
|
||||
scalar_eq!(a.det(), decomp.det());
|
||||
|
||||
matrix_eq!(
|
||||
a.inverse().unwrap(),
|
||||
Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)
|
||||
);
|
||||
matrix_eq!(a.inverse().unwrap(), decomp.inverse());
|
||||
matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lu_3x3() {
|
||||
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
|
||||
let decomp = a.lu().expect("Singular matrix encountered");
|
||||
let (_lu, idx, _d) = decomp;
|
||||
|
||||
let (l, u) = decomp.separate();
|
||||
matrix_eq!(l.mmul(&u), a.permute_rows(&idx));
|
||||
|
||||
scalar_eq!(a.det(), 3.0);
|
||||
scalar_eq!(a.det(), decomp.det());
|
||||
|
||||
matrix_eq!(
|
||||
a.inverse().unwrap(),
|
||||
Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0)
|
||||
);
|
||||
matrix_eq!(a.inverse().unwrap(), decomp.inverse());
|
||||
matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user