3x3 matrix tests

This commit is contained in:
Andrew Cassidy 2022-12-04 20:24:57 -08:00
parent 1e8399eb41
commit 3989c5a8ec
2 changed files with 88 additions and 39 deletions

View File

@ -1,15 +1,13 @@
use crate::impl_matrix_op;
use crate::index::Index2D;
use crate::util::{checked_div, checked_inv};
use crate::util::checked_inv;
use num_traits::real::Real;
use num_traits::{Num, NumOps, One, Zero};
use std::fmt::Debug;
use std::iter::{zip, Flatten, Product, Sum};
use std::mem::swap;
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
use std::process::id;
/// A 2D array of values which can be operated upon.
///
@ -306,23 +304,6 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
})
.collect()
}
// pub fn mmul<const P: usize, R, O>(&self, rhs: &Matrix<R, P, N>) -> Matrix<T, P, M>
// where
// R: Num,
// T: Scalar + Mul<R, Output = T>,
// Vector<T, N>: Dot<Vector<R, M>, Output = T>,
// {
// let mut result: Matrix<T, P, M> = Zero::zero();
//
// for (m, a) in self.rows().enumerate() {
// for (n, b) in rhs.cols().enumerate() {
// // result[(m, n)] = a.dot(b)
// }
// }
//
// return result;
// }
}
// 1D vector implementations
@ -366,11 +347,11 @@ impl<T: Copy> Vector<T, 3> {
])
}
pub fn cross_l<R: Copy>(&self, rhs: &Vector<R, 3>) -> Self
pub fn cross_l<R: Copy>(&self, rhs: &Vector<R, 3>) -> Vector<R, 3>
where
T: NumOps<R> + NumOps + Neg<Output = T>,
R: NumOps<T> + NumOps,
{
-self.cross_r(rhs)
rhs.cross_r(self)
}
}
@ -583,12 +564,11 @@ where
Matrix::from_cols(bp.cols().map(|mut x| {
// Implementation from Numerical Recipes §2.3
// When ii is set to a positive value,
// it will become the index of the first nonvanishing element of b
let mut ii = 0usize;
for i in 0..N {
// forward substitution
// forward substitution using L
let mut sum = x[i];
if ii != 0 {
for j in (ii - 1)..i {
@ -600,7 +580,7 @@ where
x[i] = sum;
}
for i in (0..N).rev() {
// back substitution
// back substitution using U
let mut sum = x[i];
for j in (i + 1)..N {
sum = sum - (lu[(i, j)] * x[j]);

View File

@ -2,10 +2,59 @@ use generic_parameterize::parameterize;
use num_traits::real::Real;
use num_traits::Zero;
use std::fmt::Debug;
use std::iter::{Product, Sum};
use std::iter::{zip, Product, Sum};
use std::ops;
use vector_victor::{LUSolve, Matrix, Vector};
macro_rules! scalar_eq {
($left:expr, $right:expr $(,)?) => {
match (&$left, &$right) {
(_left_val, _right_val) => {
scalar_eq!($left, $right, "Difference is less than epsilon")
}
}
};
($left:expr, $right:expr, $($arg:tt)+) => {
match (&$left, &$right) {
(left_val, right_val) => {
let epsilon = f32::epsilon() as f64;
let lf : f64 = (*left_val).into();
let rf : f64 = (*right_val).into();
let diff : f64 = (lf - rf).abs();
if diff >= epsilon {
assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors
}
}
}
};
}
macro_rules! matrix_eq {
($left:expr, $right:expr $(,)?) => {
match (&$left, &$right) {
(_left_val, _right_val) => {
matrix_eq!($left, $right, "Difference is less than epsilon")
}
}
};
($left:expr, $right:expr, $($arg:tt)+) => {
match (&$left, &$right) {
(left_val, right_val) => {
let epsilon = f32::epsilon() as f64;
for (l, r) in zip(left_val.elements(), right_val.elements()) {
let lf : f64 = (*l).into();
let rf : f64 = (*r).into();
let diff : f64 = (lf - rf).abs();
if diff >= epsilon {
assert_eq!($left, $right, $($arg)+) // done this way to get nice errors
}
}
}
}
};
}
#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])]
#[test]
fn test_add<S: Copy + From<u16> + PartialEq + Debug, const M: usize, const N: usize>()
@ -22,7 +71,7 @@ where
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
#[test]
fn test_lu_identity<S: Default + Real + Debug + Product + Sum, const M: usize>() {
fn test_lu_identity<S: Default + Real + Debug + Product + Sum + Into<f64>, const M: usize>() {
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
let i = Matrix::<S, M, M>::identity();
let ones = Vector::<S, M>::fill(S::one());
@ -33,8 +82,8 @@ fn test_lu_identity<S: Default + Real + Debug + Product + Sum, const M: usize>()
(0..M).eq(idx.elements().cloned()),
"Incorrect permutation matrix",
);
assert_eq!(d, S::one(), "Incorrect permutation parity");
assert_eq!(i.det(), S::one());
scalar_eq!(d, S::one(), "Incorrect permutation parity");
scalar_eq!(i.det(), S::one());
assert_eq!(i.inverse(), Some(i));
assert_eq!(i.solve(&ones), Some(ones));
assert_eq!(decomp.separate(), (i, i));
@ -58,22 +107,42 @@ fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>()
fn test_lu_2x2() {
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
let decomp = a.lu().expect("Singular matrix encountered");
let (lu, idx, d) = decomp;
let (_lu, idx, _d) = decomp;
// the decomposition is non-unique, due to the combination of lu and idx.
// Instead of checking the exact value, we only check the results.
// Also check if they produce the same results with both methods, since the
// Matrix<> methods use shortcuts the decomposition methods don't
let (l, u) = decomp.separate();
assert_eq!(l.mmul(&u), a.permute_rows(&idx));
matrix_eq!(l.mmul(&u), a.permute_rows(&idx));
assert_eq!(a.det(), -6.0);
assert_eq!(a.det(), decomp.det());
scalar_eq!(a.det(), -6.0);
scalar_eq!(a.det(), decomp.det());
assert_eq!(
a.inverse(),
Some(Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0))
matrix_eq!(
a.inverse().unwrap(),
Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)
);
assert_eq!(a.inverse(), Some(decomp.inverse()));
assert_eq!(a.inverse().unwrap().inverse().unwrap(), a)
matrix_eq!(a.inverse().unwrap(), decomp.inverse());
matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a)
}
#[test]
fn test_lu_3x3() {
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
let decomp = a.lu().expect("Singular matrix encountered");
let (_lu, idx, _d) = decomp;
let (l, u) = decomp.separate();
matrix_eq!(l.mmul(&u), a.permute_rows(&idx));
scalar_eq!(a.det(), 3.0);
scalar_eq!(a.det(), decomp.det());
matrix_eq!(
a.inverse().unwrap(),
Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0)
);
matrix_eq!(a.inverse().unwrap(), decomp.inverse());
matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a)
}