Fix Solve and add unit tests on an identity matrix

This commit is contained in:
Andrew Cassidy 2022-11-27 21:20:15 -08:00
parent 2ba8d9d323
commit 57636dc8dd
2 changed files with 56 additions and 47 deletions

View File

@ -2,11 +2,11 @@ use crate::impl_matrix_op;
use crate::index::Index2D;
use num_traits::real::Real;
use num_traits::{Num, NumOps, One, Signed, Zero};
use num_traits::{Num, NumOps, One, Zero};
use std::fmt::Debug;
use std::iter::{zip, Flatten, Product, Sum};
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub};
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
/// A 2D array of values which can be operated upon.
///
@ -23,38 +23,6 @@ where
/// An alias for a [Matrix] with a single column
pub type Vector<T, const N: usize> = Matrix<T, N, 1>;
pub trait MatrixLike {
type Scalar: Copy;
const WIDTH: usize;
const HEIGHT: usize;
}
impl<T: Copy, const M: usize, const N: usize> MatrixLike for Matrix<T, M, N> {
type Scalar = T;
const WIDTH: usize = N;
const HEIGHT: usize = M;
}
pub trait SquareMatrix {}
impl<T: Copy, const M: usize> SquareMatrix for Matrix<T, M, M> {}
// pub trait Dot<R>: MatrixLike {
// #[must_use]
// fn dot(&self, rhs: &R) -> <Self as MatrixLike>::Scalar;
// }
//
// pub trait Cross<R> {
// #[must_use]
// fn cross_r(&self, rhs: &R) -> Self;
// #[must_use]
// fn cross_l(&self, rhs: &R) -> Self;
// }
//
// pub trait MMul<R> {
// type Output;
// #[must_use]
// fn mmul(&self, rhs: &R) -> Self::Output;
// }
// Simple access functions that only require T be copyable
impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
/// Generate a new matrix from a 2D Array
@ -73,6 +41,8 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
/// ```
#[must_use]
pub fn new(data: [[T; N]; M]) -> Self {
assert!(M > 0, "Matrix must have at least 1 row");
assert!(N > 0, "Matrix must have at least 1 column");
Matrix::<T, M, N> { data }
}
@ -94,6 +64,8 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
/// ```
#[must_use]
pub fn fill(scalar: T) -> Matrix<T, M, N> {
assert!(M > 0, "Matrix must have at least 1 row");
assert!(N > 0, "Matrix must have at least 1 column");
Matrix::<T, M, N> {
data: [[scalar; N]; M],
}
@ -329,7 +301,7 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
}
// 1D vector implementations
impl<T: Copy, const M: usize> Vector<T, M> {
impl<T: Copy, const N: usize> Vector<T, N> {
/// Create a vector from a 1D array.
/// Note that vectors are always column vectors unless explicitly instantiated as row vectors
///
@ -347,8 +319,9 @@ impl<T: Copy, const M: usize> Vector<T, M> {
/// // is equivalent to
/// assert_eq!(my_vector, Matrix::new([[1],[2],[3],[4]]));
/// ```
pub fn vec(data: [T; M]) -> Self {
return Vector::<T, M> {
pub fn vec(data: [T; N]) -> Self {
assert!(N > 0, "Vector must have at least 1 element");
return Vector::<T, N> {
data: data.map(|e| [e]),
};
}
@ -508,7 +481,7 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
T: Real + Default + Product + Sum,
{
match N {
0 => T::zero(),
0 => T::one(),
1 => self[0],
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
3 => {
@ -572,6 +545,7 @@ where
x[i] = sum
- (lu.row(i).expect("Invalid row reached") * x)
.elements()
.take(i)
.skip(ii - 1)
.cloned()
.sum()
@ -587,7 +561,7 @@ where
let sum = x[i]
- (lu.row(i).expect("Invalid row reached") * x)
.elements()
.skip(ii - 1)
.skip(i + 1)
.cloned()
.sum();

View File

@ -1,10 +1,9 @@
use generic_parameterize::parameterize;
use std::convert::identity;
use num_traits::real::Real;
use std::fmt::Debug;
use std::iter::{Product, Sum};
use std::ops;
use std::thread::sleep;
use std::time::Duration;
use vector_victor::Matrix;
use vector_victor::{Matrix, Vector};
#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])]
#[test]
@ -20,10 +19,46 @@ where
}
}
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
#[test]
fn test_lu() {
fn test_lu_identity<S: Default + Real + Debug + Product + Sum, const M: usize>() {
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
let a = Matrix::new([[1.0, 2.0], [3.0, 4.0]]);
let (lu, _idx, _d) = a.lu().expect("What");
println!("{:?}", lu);
let i = Matrix::<S, M, M>::identity();
let ones = Vector::<S, M>::fill(S::one());
let (lu, idx, d) = i.lu().expect("Singular matrix encountered");
assert_eq!(
lu,
i,
"Incorrect LU decomposition matrix for {m}x{m} identity matrix",
m = M
);
assert!(
(0..M).eq(idx.elements().cloned()),
"Incorrect permutation matrix result for {m}x{m} identity matrix",
m = M
);
assert_eq!(
d,
S::one(),
"Incorrect permutation parity for {m}x{m} identity matrix",
m = M
);
assert_eq!(
i.det(),
S::one(),
"Incorrect determinant for {m}x{m} identity matrix",
m = M
);
assert_eq!(
i.inverse(),
Some(i),
"Incorrect inverse for {m}x{m} identity matrix",
m = M
);
assert_eq!(
i.solve(&ones),
Some(ones),
"Incorrect solve result for {m}x{m} identity matrix",
m = M
)
}