diff --git a/src/matrix.rs b/src/matrix.rs index cf87f8a..0835623 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,15 +1,13 @@ use crate::impl_matrix_op; use crate::index::Index2D; -use crate::util::{checked_div, checked_inv}; +use crate::util::checked_inv; use num_traits::real::Real; use num_traits::{Num, NumOps, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; -use std::mem::swap; use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg}; -use std::process::id; /// A 2D array of values which can be operated upon. /// @@ -306,23 +304,6 @@ impl Matrix { }) .collect() } - - // pub fn mmul(&self, rhs: &Matrix) -> Matrix - // where - // R: Num, - // T: Scalar + Mul, - // Vector: Dot, Output = T>, - // { - // let mut result: Matrix = Zero::zero(); - // - // for (m, a) in self.rows().enumerate() { - // for (n, b) in rhs.cols().enumerate() { - // // result[(m, n)] = a.dot(b) - // } - // } - // - // return result; - // } } // 1D vector implementations @@ -366,11 +347,11 @@ impl Vector { ]) } - pub fn cross_l(&self, rhs: &Vector) -> Self + pub fn cross_l(&self, rhs: &Vector) -> Vector where - T: NumOps + NumOps + Neg, + R: NumOps + NumOps, { - -self.cross_r(rhs) + rhs.cross_r(self) } } @@ -583,12 +564,11 @@ where Matrix::from_cols(bp.cols().map(|mut x| { // Implementation from Numerical Recipes ยง2.3 - // When ii is set to a positive value, // it will become the index of the first nonvanishing element of b let mut ii = 0usize; for i in 0..N { - // forward substitution + // forward substitution using L let mut sum = x[i]; if ii != 0 { for j in (ii - 1)..i { @@ -600,7 +580,7 @@ where x[i] = sum; } for i in (0..N).rev() { - // back substitution + // back substitution using U let mut sum = x[i]; for j in (i + 1)..N { sum = sum - (lu[(i, j)] * x[j]); diff --git a/tests/ops.rs b/tests/ops.rs index 89dc2fc..219864a 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -2,10 +2,59 @@ use generic_parameterize::parameterize; use num_traits::real::Real; use num_traits::Zero; use std::fmt::Debug; -use std::iter::{Product, Sum}; +use std::iter::{zip, Product, Sum}; use std::ops; use vector_victor::{LUSolve, Matrix, Vector}; +macro_rules! scalar_eq { + ($left:expr, $right:expr $(,)?) => { + match (&$left, &$right) { + (_left_val, _right_val) => { + scalar_eq!($left, $right, "Difference is less than epsilon") + } + } + }; + ($left:expr, $right:expr, $($arg:tt)+) => { + match (&$left, &$right) { + (left_val, right_val) => { + let epsilon = f32::epsilon() as f64; + let lf : f64 = (*left_val).into(); + let rf : f64 = (*right_val).into(); + let diff : f64 = (lf - rf).abs(); + if diff >= epsilon { + assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors + } + } + } + }; +} + +macro_rules! matrix_eq { + ($left:expr, $right:expr $(,)?) => { + match (&$left, &$right) { + (_left_val, _right_val) => { + matrix_eq!($left, $right, "Difference is less than epsilon") + } + } + }; + ($left:expr, $right:expr, $($arg:tt)+) => { + match (&$left, &$right) { + (left_val, right_val) => { + let epsilon = f32::epsilon() as f64; + for (l, r) in zip(left_val.elements(), right_val.elements()) { + let lf : f64 = (*l).into(); + let rf : f64 = (*r).into(); + let diff : f64 = (lf - rf).abs(); + if diff >= epsilon { + assert_eq!($left, $right, $($arg)+) // done this way to get nice errors + } + } + + } + } + }; +} + #[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])] #[test] fn test_add + PartialEq + Debug, const M: usize, const N: usize>() @@ -22,7 +71,7 @@ where #[parameterize(S = (f32, f64), M = [1,2,3,4])] #[test] -fn test_lu_identity() { +fn test_lu_identity, const M: usize>() { // let a: Matrix = Matrix::::identity(); let i = Matrix::::identity(); let ones = Vector::::fill(S::one()); @@ -33,8 +82,8 @@ fn test_lu_identity() (0..M).eq(idx.elements().cloned()), "Incorrect permutation matrix", ); - assert_eq!(d, S::one(), "Incorrect permutation parity"); - assert_eq!(i.det(), S::one()); + scalar_eq!(d, S::one(), "Incorrect permutation parity"); + scalar_eq!(i.det(), S::one()); assert_eq!(i.inverse(), Some(i)); assert_eq!(i.solve(&ones), Some(ones)); assert_eq!(decomp.separate(), (i, i)); @@ -58,22 +107,42 @@ fn test_lu_singular() fn test_lu_2x2() { let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); let decomp = a.lu().expect("Singular matrix encountered"); - let (lu, idx, d) = decomp; + let (_lu, idx, _d) = decomp; // the decomposition is non-unique, due to the combination of lu and idx. // Instead of checking the exact value, we only check the results. // Also check if they produce the same results with both methods, since the // Matrix<> methods use shortcuts the decomposition methods don't let (l, u) = decomp.separate(); - assert_eq!(l.mmul(&u), a.permute_rows(&idx)); + matrix_eq!(l.mmul(&u), a.permute_rows(&idx)); + + scalar_eq!(a.det(), -6.0); + scalar_eq!(a.det(), decomp.det()); + + matrix_eq!( + a.inverse().unwrap(), + Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0) + ); + matrix_eq!(a.inverse().unwrap(), decomp.inverse()); + matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a) +} + +#[test] +fn test_lu_3x3() { + let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]); + let decomp = a.lu().expect("Singular matrix encountered"); + let (_lu, idx, _d) = decomp; + + let (l, u) = decomp.separate(); + matrix_eq!(l.mmul(&u), a.permute_rows(&idx)); - assert_eq!(a.det(), -6.0); - assert_eq!(a.det(), decomp.det()); + scalar_eq!(a.det(), 3.0); + scalar_eq!(a.det(), decomp.det()); - assert_eq!( - a.inverse(), - Some(Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)) + matrix_eq!( + a.inverse().unwrap(), + Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0) ); - assert_eq!(a.inverse(), Some(decomp.inverse())); - assert_eq!(a.inverse().unwrap().inverse().unwrap(), a) + matrix_eq!(a.inverse().unwrap(), decomp.inverse()); + matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a) }