use generic_parameterize::parameterize; use num_traits::real::Real; use num_traits::Zero; use std::fmt::Debug; use std::iter::{Product, Sum}; use std::ops; use vector_victor::{LUSolve, Matrix, Vector}; #[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])] #[test] fn test_add + PartialEq + Debug, const M: usize, const N: usize>() where Matrix: ops::Add>, { let a = Matrix::::fill(S::from(1)); let b = Matrix::::fill(S::from(3)); let c: Matrix = a + b; for (i, ci) in c.elements().enumerate() { assert_eq!(*ci, S::from(4)); } } #[parameterize(S = (f32, f64), M = [1,2,3,4])] #[test] fn test_lu_identity() { // let a: Matrix = Matrix::::identity(); let i = Matrix::::identity(); let ones = Vector::::fill(S::one()); let decomp = i.lu().expect("Singular matrix encountered"); let (lu, idx, d) = decomp; assert_eq!(lu, i, "Incorrect LU decomposition"); assert!( (0..M).eq(idx.elements().cloned()), "Incorrect permutation matrix", ); assert_eq!(d, S::one(), "Incorrect permutation parity"); assert_eq!(i.det(), S::one()); assert_eq!(i.inverse(), Some(i)); assert_eq!(i.solve(&ones), Some(ones)); assert_eq!(decomp.separate(), (i, i)); } #[parameterize(S = (f32, f64), M = [2,3,4])] #[test] fn test_lu_singular() { // let a: Matrix = Matrix::::identity(); let mut a = Matrix::::zero(); let ones = Vector::::fill(S::one()); a.set_row(0, &ones); assert_eq!(a.lu(), None, "Matrix should be singular"); assert_eq!(a.det(), S::zero()); assert_eq!(a.inverse(), None); assert_eq!(a.solve(&ones), None) } #[test] fn test_lu_2x2() { let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); let decomp = a.lu().expect("Singular matrix encountered"); let (lu, idx, d) = decomp; // the decomposition is non-unique, due to the combination of lu and idx. // Instead of checking the exact value, we only check the results. // Also check if they produce the same results with both methods, since the // Matrix<> methods use shortcuts the decomposition methods don't let (l, u) = decomp.separate(); assert_eq!(l.mmul(&u), a.permute_rows(&idx)); assert_eq!(a.det(), -6.0); assert_eq!(a.det(), decomp.det()); assert_eq!( a.inverse(), Some(Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)) ); assert_eq!(a.inverse(), Some(decomp.inverse())); assert_eq!(a.inverse().unwrap().inverse().unwrap(), a) }