mirror of
https://github.com/drewcassidy/vector-victor.git
synced 2024-09-01 14:58:35 +00:00
Fix Solve and add unit tests on an identity matrix
This commit is contained in:
@ -2,11 +2,11 @@ use crate::impl_matrix_op;
|
||||
use crate::index::Index2D;
|
||||
|
||||
use num_traits::real::Real;
|
||||
use num_traits::{Num, NumOps, One, Signed, Zero};
|
||||
use num_traits::{Num, NumOps, One, Zero};
|
||||
use std::fmt::Debug;
|
||||
use std::iter::{zip, Flatten, Product, Sum};
|
||||
|
||||
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub};
|
||||
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
|
||||
|
||||
/// A 2D array of values which can be operated upon.
|
||||
///
|
||||
@ -23,38 +23,6 @@ where
|
||||
/// An alias for a [Matrix] with a single column
|
||||
pub type Vector<T, const N: usize> = Matrix<T, N, 1>;
|
||||
|
||||
pub trait MatrixLike {
|
||||
type Scalar: Copy;
|
||||
const WIDTH: usize;
|
||||
const HEIGHT: usize;
|
||||
}
|
||||
impl<T: Copy, const M: usize, const N: usize> MatrixLike for Matrix<T, M, N> {
|
||||
type Scalar = T;
|
||||
const WIDTH: usize = N;
|
||||
const HEIGHT: usize = M;
|
||||
}
|
||||
|
||||
pub trait SquareMatrix {}
|
||||
impl<T: Copy, const M: usize> SquareMatrix for Matrix<T, M, M> {}
|
||||
|
||||
// pub trait Dot<R>: MatrixLike {
|
||||
// #[must_use]
|
||||
// fn dot(&self, rhs: &R) -> <Self as MatrixLike>::Scalar;
|
||||
// }
|
||||
//
|
||||
// pub trait Cross<R> {
|
||||
// #[must_use]
|
||||
// fn cross_r(&self, rhs: &R) -> Self;
|
||||
// #[must_use]
|
||||
// fn cross_l(&self, rhs: &R) -> Self;
|
||||
// }
|
||||
//
|
||||
// pub trait MMul<R> {
|
||||
// type Output;
|
||||
// #[must_use]
|
||||
// fn mmul(&self, rhs: &R) -> Self::Output;
|
||||
// }
|
||||
|
||||
// Simple access functions that only require T be copyable
|
||||
impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
||||
/// Generate a new matrix from a 2D Array
|
||||
@ -73,6 +41,8 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn new(data: [[T; N]; M]) -> Self {
|
||||
assert!(M > 0, "Matrix must have at least 1 row");
|
||||
assert!(N > 0, "Matrix must have at least 1 column");
|
||||
Matrix::<T, M, N> { data }
|
||||
}
|
||||
|
||||
@ -94,6 +64,8 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn fill(scalar: T) -> Matrix<T, M, N> {
|
||||
assert!(M > 0, "Matrix must have at least 1 row");
|
||||
assert!(N > 0, "Matrix must have at least 1 column");
|
||||
Matrix::<T, M, N> {
|
||||
data: [[scalar; N]; M],
|
||||
}
|
||||
@ -329,7 +301,7 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
||||
}
|
||||
|
||||
// 1D vector implementations
|
||||
impl<T: Copy, const M: usize> Vector<T, M> {
|
||||
impl<T: Copy, const N: usize> Vector<T, N> {
|
||||
/// Create a vector from a 1D array.
|
||||
/// Note that vectors are always column vectors unless explicitly instantiated as row vectors
|
||||
///
|
||||
@ -347,8 +319,9 @@ impl<T: Copy, const M: usize> Vector<T, M> {
|
||||
/// // is equivalent to
|
||||
/// assert_eq!(my_vector, Matrix::new([[1],[2],[3],[4]]));
|
||||
/// ```
|
||||
pub fn vec(data: [T; M]) -> Self {
|
||||
return Vector::<T, M> {
|
||||
pub fn vec(data: [T; N]) -> Self {
|
||||
assert!(N > 0, "Vector must have at least 1 element");
|
||||
return Vector::<T, N> {
|
||||
data: data.map(|e| [e]),
|
||||
};
|
||||
}
|
||||
@ -508,7 +481,7 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
||||
T: Real + Default + Product + Sum,
|
||||
{
|
||||
match N {
|
||||
0 => T::zero(),
|
||||
0 => T::one(),
|
||||
1 => self[0],
|
||||
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
|
||||
3 => {
|
||||
@ -572,6 +545,7 @@ where
|
||||
x[i] = sum
|
||||
- (lu.row(i).expect("Invalid row reached") * x)
|
||||
.elements()
|
||||
.take(i)
|
||||
.skip(ii - 1)
|
||||
.cloned()
|
||||
.sum()
|
||||
@ -587,7 +561,7 @@ where
|
||||
let sum = x[i]
|
||||
- (lu.row(i).expect("Invalid row reached") * x)
|
||||
.elements()
|
||||
.skip(ii - 1)
|
||||
.skip(i + 1)
|
||||
.cloned()
|
||||
.sum();
|
||||
|
||||
|
Reference in New Issue
Block a user