From 57636dc8dd5dfdfe98bf2fafc54218488a51a461 Mon Sep 17 00:00:00 2001 From: Andrew Cassidy Date: Sun, 27 Nov 2022 21:20:15 -0800 Subject: [PATCH] Fix Solve and add unit tests on an identity matrix --- src/matrix.rs | 52 +++++++++++++-------------------------------------- tests/ops.rs | 51 ++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index f546353..20b149e 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -2,11 +2,11 @@ use crate::impl_matrix_op; use crate::index::Index2D; use num_traits::real::Real; -use num_traits::{Num, NumOps, One, Signed, Zero}; +use num_traits::{Num, NumOps, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; -use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub}; +use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg}; /// A 2D array of values which can be operated upon. /// @@ -23,38 +23,6 @@ where /// An alias for a [Matrix] with a single column pub type Vector = Matrix; -pub trait MatrixLike { - type Scalar: Copy; - const WIDTH: usize; - const HEIGHT: usize; -} -impl MatrixLike for Matrix { - type Scalar = T; - const WIDTH: usize = N; - const HEIGHT: usize = M; -} - -pub trait SquareMatrix {} -impl SquareMatrix for Matrix {} - -// pub trait Dot: MatrixLike { -// #[must_use] -// fn dot(&self, rhs: &R) -> ::Scalar; -// } -// -// pub trait Cross { -// #[must_use] -// fn cross_r(&self, rhs: &R) -> Self; -// #[must_use] -// fn cross_l(&self, rhs: &R) -> Self; -// } -// -// pub trait MMul { -// type Output; -// #[must_use] -// fn mmul(&self, rhs: &R) -> Self::Output; -// } - // Simple access functions that only require T be copyable impl Matrix { /// Generate a new matrix from a 2D Array @@ -73,6 +41,8 @@ impl Matrix { /// ``` #[must_use] pub fn new(data: [[T; N]; M]) -> Self { + assert!(M > 0, "Matrix must have at least 1 row"); + assert!(N > 0, "Matrix must have at least 1 column"); Matrix:: { data } } @@ -94,6 +64,8 @@ impl Matrix { /// ``` #[must_use] pub fn fill(scalar: T) -> Matrix { + assert!(M > 0, "Matrix must have at least 1 row"); + assert!(N > 0, "Matrix must have at least 1 column"); Matrix:: { data: [[scalar; N]; M], } @@ -329,7 +301,7 @@ impl Matrix { } // 1D vector implementations -impl Vector { +impl Vector { /// Create a vector from a 1D array. /// Note that vectors are always column vectors unless explicitly instantiated as row vectors /// @@ -347,8 +319,9 @@ impl Vector { /// // is equivalent to /// assert_eq!(my_vector, Matrix::new([[1],[2],[3],[4]])); /// ``` - pub fn vec(data: [T; M]) -> Self { - return Vector:: { + pub fn vec(data: [T; N]) -> Self { + assert!(N > 0, "Vector must have at least 1 element"); + return Vector:: { data: data.map(|e| [e]), }; } @@ -508,7 +481,7 @@ impl Matrix { T: Real + Default + Product + Sum, { match N { - 0 => T::zero(), + 0 => T::one(), 1 => self[0], 2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]), 3 => { @@ -572,6 +545,7 @@ where x[i] = sum - (lu.row(i).expect("Invalid row reached") * x) .elements() + .take(i) .skip(ii - 1) .cloned() .sum() @@ -587,7 +561,7 @@ where let sum = x[i] - (lu.row(i).expect("Invalid row reached") * x) .elements() - .skip(ii - 1) + .skip(i + 1) .cloned() .sum(); diff --git a/tests/ops.rs b/tests/ops.rs index e41a05b..0fc4908 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -1,10 +1,9 @@ use generic_parameterize::parameterize; -use std::convert::identity; +use num_traits::real::Real; use std::fmt::Debug; +use std::iter::{Product, Sum}; use std::ops; -use std::thread::sleep; -use std::time::Duration; -use vector_victor::Matrix; +use vector_victor::{Matrix, Vector}; #[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])] #[test] @@ -20,10 +19,46 @@ where } } +#[parameterize(S = (f32, f64), M = [1,2,3,4])] #[test] -fn test_lu() { +fn test_lu_identity() { // let a: Matrix = Matrix::::identity(); - let a = Matrix::new([[1.0, 2.0], [3.0, 4.0]]); - let (lu, _idx, _d) = a.lu().expect("What"); - println!("{:?}", lu); + let i = Matrix::::identity(); + let ones = Vector::::fill(S::one()); + let (lu, idx, d) = i.lu().expect("Singular matrix encountered"); + assert_eq!( + lu, + i, + "Incorrect LU decomposition matrix for {m}x{m} identity matrix", + m = M + ); + assert!( + (0..M).eq(idx.elements().cloned()), + "Incorrect permutation matrix result for {m}x{m} identity matrix", + m = M + ); + assert_eq!( + d, + S::one(), + "Incorrect permutation parity for {m}x{m} identity matrix", + m = M + ); + assert_eq!( + i.det(), + S::one(), + "Incorrect determinant for {m}x{m} identity matrix", + m = M + ); + assert_eq!( + i.inverse(), + Some(i), + "Incorrect inverse for {m}x{m} identity matrix", + m = M + ); + assert_eq!( + i.solve(&ones), + Some(ones), + "Incorrect solve result for {m}x{m} identity matrix", + m = M + ) }