mirror of
https://github.com/drewcassidy/vector-victor.git
synced 2024-09-01 14:58:35 +00:00
Compare commits
2 Commits
df3c2b4ba9
...
543769f691
Author | SHA1 | Date | |
---|---|---|---|
543769f691 | |||
8fcb032b1a |
@ -3,6 +3,7 @@ extern crate core;
|
|||||||
pub mod index;
|
pub mod index;
|
||||||
mod macros;
|
mod macros;
|
||||||
mod matrix;
|
mod matrix;
|
||||||
|
pub mod solve;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
pub use matrix::{LUSolve, Matrix, Vector};
|
pub use matrix::{Matrix, Vector};
|
||||||
|
181
src/matrix.rs
181
src/matrix.rs
@ -7,6 +7,7 @@ use num_traits::{Num, NumOps, One, Zero};
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::{zip, Flatten, Product, Sum};
|
use std::iter::{zip, Flatten, Product, Sum};
|
||||||
|
|
||||||
|
use crate::solve::{LUDecomp, LUSolve};
|
||||||
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
|
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
|
||||||
|
|
||||||
/// A 2D array of values which can be operated upon.
|
/// A 2D array of values which can be operated upon.
|
||||||
@ -398,85 +399,17 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
pub fn subdiagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
|
pub fn subdiagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
|
||||||
(0..N - 1).map(|n| self[(n, n + 1)])
|
(0..N - 1).map(|n| self[(n, n + 1)])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns `Some(lu, idx, d)`, or [None] if the matrix is singular.
|
impl<T, const N: usize> LUSolve<T, N> for Matrix<T, N, N>
|
||||||
///
|
where
|
||||||
/// Where:
|
T: Copy + Default + Real + Sum + Product,
|
||||||
/// * `lu`: The LU decomposition of `self`. The upper and lower matrices are combined into a single matrix
|
{
|
||||||
/// * `idx`: The permutation of rows on the original matrix needed to perform the decomposition.
|
fn lu(&self) -> Option<LUDecomp<T, N>> {
|
||||||
/// Each element is the corresponding row index in the original matrix
|
LUDecomp::decompose(self)
|
||||||
/// * `d`: The permutation parity of `idx`, either `1` for even or `-1` for odd
|
|
||||||
///
|
|
||||||
/// The resulting tuple (once unwrapped) has the [LUSolve] trait, allowing it to be used for
|
|
||||||
/// solving multiple matrices without having to repeat the LU decomposition process
|
|
||||||
#[must_use]
|
|
||||||
pub fn lu(&self) -> Option<(Self, Vector<usize, N>, T)>
|
|
||||||
where
|
|
||||||
T: Real + Default,
|
|
||||||
{
|
|
||||||
// Implementation from Numerical Recipes §2.3
|
|
||||||
let mut lu = self.clone();
|
|
||||||
let mut idx: Vector<usize, N> = (0..N).collect();
|
|
||||||
let mut d = T::one();
|
|
||||||
|
|
||||||
let mut vv: Vector<T, N> = self
|
|
||||||
.rows()
|
|
||||||
.map(|row| {
|
|
||||||
let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?;
|
|
||||||
match m < T::epsilon() {
|
|
||||||
true => None,
|
|
||||||
false => Some(T::one() / m),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Option<_>>()?; // get the inverse maxabs value in each row
|
|
||||||
|
|
||||||
for k in 0..N {
|
|
||||||
// search for the pivot element and its index
|
|
||||||
let (ipivot, _) = (lu.col(k) * vv)
|
|
||||||
.abs()
|
|
||||||
.elements()
|
|
||||||
.enumerate()
|
|
||||||
.skip(k) // below the diagonal
|
|
||||||
.reduce(|(imax, xmax), (i, x)| match x > xmax {
|
|
||||||
// Is the figure of merit for the pivot better than the best so far?
|
|
||||||
true => (i, x),
|
|
||||||
false => (imax, xmax),
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// do we need to interchange rows?
|
|
||||||
if k != ipivot {
|
|
||||||
lu.pivot_row(ipivot, k); // yes, we do
|
|
||||||
idx.pivot_row(ipivot, k);
|
|
||||||
d = -d; // change parity of d
|
|
||||||
vv[ipivot] = vv[k] //interchange scale factor
|
|
||||||
}
|
|
||||||
|
|
||||||
let pivot = lu[(k, k)];
|
|
||||||
if pivot.abs() < T::epsilon() {
|
|
||||||
// if the pivot is zero, the matrix is singular
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
for i in (k + 1)..N {
|
|
||||||
// divide by the pivot element
|
|
||||||
let dpivot = lu[(i, k)] / pivot;
|
|
||||||
lu[(i, k)] = dpivot;
|
|
||||||
for j in (k + 1)..N {
|
|
||||||
// reduce remaining submatrix
|
|
||||||
lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Some((lu, idx, d));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes the inverse matrix of `self`, or [None] if the matrix cannot be inverted.
|
fn inverse(&self) -> Option<Matrix<T, N, N>> {
|
||||||
#[must_use]
|
|
||||||
pub fn inverse(&self) -> Option<Self>
|
|
||||||
where
|
|
||||||
T: Real + Default + Sum + Product,
|
|
||||||
{
|
|
||||||
match N {
|
match N {
|
||||||
1 => Some(Self::fill(checked_inv(self[0])?)),
|
1 => Some(Self::fill(checked_inv(self[0])?)),
|
||||||
2 => {
|
2 => {
|
||||||
@ -491,12 +424,7 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes the determinant of `self`.
|
fn det(&self) -> T {
|
||||||
#[must_use]
|
|
||||||
pub fn det(&self) -> T
|
|
||||||
where
|
|
||||||
T: Real + Default + Product + Sum,
|
|
||||||
{
|
|
||||||
match N {
|
match N {
|
||||||
1 => self[0],
|
1 => self[0],
|
||||||
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
|
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
|
||||||
@ -520,95 +448,6 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Solves a system of `Ax = b` using `self` for `A`, or [None] if there is no solution.
|
|
||||||
#[must_use]
|
|
||||||
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>
|
|
||||||
where
|
|
||||||
T: Real + Default + Sum + Product,
|
|
||||||
{
|
|
||||||
Some(self.lu()?.solve(b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for the result of [Matrix::lu()],
|
|
||||||
/// allowing a single LU decomposition to be used to solve multiple equations
|
|
||||||
pub trait LUSolve<T, const N: usize>: Copy
|
|
||||||
where
|
|
||||||
T: Real + Copy,
|
|
||||||
{
|
|
||||||
/// Solves a system of `Ax = b` using an LU decomposition.
|
|
||||||
fn solve<const M: usize>(&self, rhs: &Matrix<T, N, M>) -> Matrix<T, N, M>;
|
|
||||||
|
|
||||||
/// Solves the determinant using an LU decomposition,
|
|
||||||
/// by multiplying the product of the diagonals by the permutation parity
|
|
||||||
fn det(&self) -> T;
|
|
||||||
|
|
||||||
/// Solves the inverse of the matrix that the LU decomposition represents.
|
|
||||||
fn inverse(&self) -> Matrix<T, N, N> {
|
|
||||||
return self.solve(&Matrix::<T, N, N>::identity());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Separate the lu decomposition into L and U matrices, such that `L*U = P*A`.
|
|
||||||
fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Copy, const N: usize> LUSolve<T, N> for (Matrix<T, N, N>, Vector<usize, N>, T)
|
|
||||||
where
|
|
||||||
T: Real + Default + Sum + Product,
|
|
||||||
{
|
|
||||||
#[must_use]
|
|
||||||
fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
|
|
||||||
let (lu, idx, _) = self;
|
|
||||||
let bp = b.permute_rows(idx);
|
|
||||||
|
|
||||||
Matrix::from_cols(bp.cols().map(|mut x| {
|
|
||||||
// Implementation from Numerical Recipes §2.3
|
|
||||||
// When ii is set to a positive value,
|
|
||||||
// it will become the index of the first nonvanishing element of b
|
|
||||||
let mut ii = 0usize;
|
|
||||||
for i in 0..N {
|
|
||||||
// forward substitution using L
|
|
||||||
let mut sum = x[i];
|
|
||||||
if ii != 0 {
|
|
||||||
for j in (ii - 1)..i {
|
|
||||||
sum = sum - (lu[(i, j)] * x[j]);
|
|
||||||
}
|
|
||||||
} else if sum.abs() > T::epsilon() {
|
|
||||||
ii = i + 1;
|
|
||||||
}
|
|
||||||
x[i] = sum;
|
|
||||||
}
|
|
||||||
for i in (0..N).rev() {
|
|
||||||
// back substitution using U
|
|
||||||
let mut sum = x[i];
|
|
||||||
for j in (i + 1)..N {
|
|
||||||
sum = sum - (lu[(i, j)] * x[j]);
|
|
||||||
}
|
|
||||||
x[i] = sum / lu[(i, i)]
|
|
||||||
}
|
|
||||||
x
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn det(&self) -> T {
|
|
||||||
let (lu, _, d) = self;
|
|
||||||
*d * lu.diagonals().product()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>) {
|
|
||||||
let mut l = Matrix::<T, N, N>::identity();
|
|
||||||
let mut u = self.0; // lu
|
|
||||||
|
|
||||||
for m in 1..N {
|
|
||||||
for n in 0..m {
|
|
||||||
// iterate over lower diagonal
|
|
||||||
l[(m, n)] = u[(m, n)];
|
|
||||||
u[(m, n)] = T::zero();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(l, u)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Index
|
// Index
|
||||||
|
153
src/solve.rs
Normal file
153
src/solve.rs
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
use crate::util::checked_inv;
|
||||||
|
use crate::Matrix;
|
||||||
|
use crate::Vector;
|
||||||
|
use num_traits::real::Real;
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
use std::iter::{Product, Sum};
|
||||||
|
use std::ops::Index;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||||
|
pub struct LUDecomp<T: Copy, const N: usize> {
|
||||||
|
pub lu: Matrix<T, N, N>,
|
||||||
|
pub idx: Vector<usize, N>,
|
||||||
|
pub parity: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Copy + Default, const N: usize> LUDecomp<T, N>
|
||||||
|
where
|
||||||
|
T: Real + Default + Sum + Product,
|
||||||
|
{
|
||||||
|
#[must_use]
|
||||||
|
pub fn decompose(m: &Matrix<T, N, N>) -> Option<Self> {
|
||||||
|
// Implementation from Numerical Recipes §2.3
|
||||||
|
let mut lu = m.clone();
|
||||||
|
let mut idx: Vector<usize, N> = (0..N).collect();
|
||||||
|
let mut parity = T::one();
|
||||||
|
|
||||||
|
let mut vv: Vector<T, N> = m
|
||||||
|
.rows()
|
||||||
|
.map(|row| {
|
||||||
|
let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?;
|
||||||
|
match m < T::epsilon() {
|
||||||
|
true => None,
|
||||||
|
false => Some(T::one() / m),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Option<_>>()?; // get the inverse maxabs value in each row
|
||||||
|
|
||||||
|
for k in 0..N {
|
||||||
|
// search for the pivot element and its index
|
||||||
|
let (ipivot, _) = (lu.col(k) * vv)
|
||||||
|
.abs()
|
||||||
|
.elements()
|
||||||
|
.enumerate()
|
||||||
|
.skip(k) // below the diagonal
|
||||||
|
.reduce(|(imax, xmax), (i, x)| match x > xmax {
|
||||||
|
// Is the figure of merit for the pivot better than the best so far?
|
||||||
|
true => (i, x),
|
||||||
|
false => (imax, xmax),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// do we need to interchange rows?
|
||||||
|
if k != ipivot {
|
||||||
|
lu.pivot_row(ipivot, k); // yes, we do
|
||||||
|
idx.pivot_row(ipivot, k);
|
||||||
|
parity = -parity; // change parity of d
|
||||||
|
vv[ipivot] = vv[k] //interchange scale factor
|
||||||
|
}
|
||||||
|
|
||||||
|
let pivot = lu[(k, k)];
|
||||||
|
if pivot.abs() < T::epsilon() {
|
||||||
|
// if the pivot is zero, the matrix is singular
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
for i in (k + 1)..N {
|
||||||
|
// divide by the pivot element
|
||||||
|
let dpivot = lu[(i, k)] / pivot;
|
||||||
|
lu[(i, k)] = dpivot;
|
||||||
|
for j in (k + 1)..N {
|
||||||
|
// reduce remaining submatrix
|
||||||
|
lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some(Self { lu, idx, parity });
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
|
||||||
|
let b_permuted = b.permute_rows(&self.idx);
|
||||||
|
|
||||||
|
Matrix::from_cols(b_permuted.cols().map(|mut x| {
|
||||||
|
// Implementation from Numerical Recipes §2.3
|
||||||
|
// When ii is set to a positive value,
|
||||||
|
// it will become the index of the first nonvanishing element of b
|
||||||
|
let mut ii = 0usize;
|
||||||
|
for i in 0..N {
|
||||||
|
// forward substitution using L
|
||||||
|
let mut sum = x[i];
|
||||||
|
if ii != 0 {
|
||||||
|
for j in (ii - 1)..i {
|
||||||
|
sum = sum - (self.lu[(i, j)] * x[j]);
|
||||||
|
}
|
||||||
|
} else if sum.abs() > T::epsilon() {
|
||||||
|
ii = i + 1;
|
||||||
|
}
|
||||||
|
x[i] = sum;
|
||||||
|
}
|
||||||
|
for i in (0..N).rev() {
|
||||||
|
// back substitution using U
|
||||||
|
let mut sum = x[i];
|
||||||
|
for j in (i + 1)..N {
|
||||||
|
sum = sum - (self.lu[(i, j)] * x[j]);
|
||||||
|
}
|
||||||
|
x[i] = sum / self.lu[(i, i)]
|
||||||
|
}
|
||||||
|
x
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn det(&self) -> T {
|
||||||
|
self.parity * self.lu.diagonals().product()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inverse(&self) -> Matrix<T, N, N> {
|
||||||
|
return self.solve(&Matrix::<T, N, N>::identity());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>) {
|
||||||
|
let mut l = Matrix::<T, N, N>::identity();
|
||||||
|
let mut u = self.lu; // lu
|
||||||
|
|
||||||
|
for m in 1..N {
|
||||||
|
for n in 0..m {
|
||||||
|
// iterate over lower diagonal
|
||||||
|
l[(m, n)] = u[(m, n)];
|
||||||
|
u[(m, n)] = T::zero();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(l, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LUSolve<T, const N: usize>
|
||||||
|
where
|
||||||
|
T: Copy + Default + Real + Product + Sum,
|
||||||
|
{
|
||||||
|
#[must_use]
|
||||||
|
fn lu(&self) -> Option<LUDecomp<T, N>>;
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
fn inverse(&self) -> Option<Matrix<T, N, N>>;
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
fn det(&self) -> T;
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>> {
|
||||||
|
Some(self.lu()?.solve(b))
|
||||||
|
}
|
||||||
|
}
|
57
tests/common/mod.rs
Normal file
57
tests/common/mod.rs
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
use num_traits::Float;
|
||||||
|
use std::iter::zip;
|
||||||
|
use vector_victor::Matrix;
|
||||||
|
|
||||||
|
pub trait Approx: PartialEq {
|
||||||
|
fn approx(left: &Self, right: &Self) -> bool {
|
||||||
|
left == right
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) }
|
||||||
|
multi_impl!(Approx for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
|
||||||
|
|
||||||
|
impl Approx for f32 {
|
||||||
|
fn approx(left: &f32, right: &f32) -> bool {
|
||||||
|
f32::abs(left - right) <= f32::epsilon()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Approx for f64 {
|
||||||
|
fn approx(left: &f64, right: &f64) -> bool {
|
||||||
|
f64::abs(left - right) <= f32::epsilon() as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Copy + Approx, const M: usize, const N: usize> Approx for Matrix<T, M, N> {
|
||||||
|
fn approx(left: &Self, right: &Self) -> bool {
|
||||||
|
zip(left.elements(), right.elements()).all(|(l, r)| T::approx(l, r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn approx<T: Approx>(left: &T, right: &T) -> bool {
|
||||||
|
T::approx(left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! assert_approx {
|
||||||
|
($left:expr, $right:expr $(,)?) => {
|
||||||
|
match (&$left, &$right) {
|
||||||
|
(_left_val, _right_val) => {
|
||||||
|
assert_approx!($left, $right, "Difference is less than epsilon")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
($left:expr, $right:expr, $($arg:tt)+) => {
|
||||||
|
match (&$left, &$right) {
|
||||||
|
(left_val, right_val) => {
|
||||||
|
pub fn approx<T: Approx>(left: &T, right: &T) -> bool {
|
||||||
|
T::approx(left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !approx(left_val, right_val){
|
||||||
|
assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
137
tests/ops.rs
137
tests/ops.rs
@ -1,61 +1,16 @@
|
|||||||
|
#[macro_use]
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use common::Approx;
|
||||||
use generic_parameterize::parameterize;
|
use generic_parameterize::parameterize;
|
||||||
use num_traits::real::Real;
|
use num_traits::real::Real;
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::{zip, Product, Sum};
|
use std::iter::{zip, Product, Sum};
|
||||||
use std::ops;
|
use std::ops;
|
||||||
use vector_victor::{LUSolve, Matrix, Vector};
|
use vector_victor::{Matrix, Vector};
|
||||||
|
|
||||||
macro_rules! scalar_eq {
|
#[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])]
|
||||||
($left:expr, $right:expr $(,)?) => {
|
|
||||||
match (&$left, &$right) {
|
|
||||||
(_left_val, _right_val) => {
|
|
||||||
scalar_eq!($left, $right, "Difference is less than epsilon")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
($left:expr, $right:expr, $($arg:tt)+) => {
|
|
||||||
match (&$left, &$right) {
|
|
||||||
(left_val, right_val) => {
|
|
||||||
let epsilon = f32::epsilon() as f64;
|
|
||||||
let lf : f64 = (*left_val).into();
|
|
||||||
let rf : f64 = (*right_val).into();
|
|
||||||
let diff : f64 = (lf - rf).abs();
|
|
||||||
if diff >= epsilon {
|
|
||||||
assert_eq!(left_val, right_val, $($arg)+) // done this way to get nice errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! matrix_eq {
|
|
||||||
($left:expr, $right:expr $(,)?) => {
|
|
||||||
match (&$left, &$right) {
|
|
||||||
(_left_val, _right_val) => {
|
|
||||||
matrix_eq!($left, $right, "Difference is less than epsilon")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
($left:expr, $right:expr, $($arg:tt)+) => {
|
|
||||||
match (&$left, &$right) {
|
|
||||||
(left_val, right_val) => {
|
|
||||||
let epsilon = f32::epsilon() as f64;
|
|
||||||
for (l, r) in zip(left_val.elements(), right_val.elements()) {
|
|
||||||
let lf : f64 = (*l).into();
|
|
||||||
let rf : f64 = (*r).into();
|
|
||||||
let diff : f64 = (lf - rf).abs();
|
|
||||||
if diff >= epsilon {
|
|
||||||
assert_eq!($left, $right, $($arg)+) // done this way to get nice errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add<S: Copy + From<u16> + PartialEq + Debug, const M: usize, const N: usize>()
|
fn test_add<S: Copy + From<u16> + PartialEq + Debug, const M: usize, const N: usize>()
|
||||||
where
|
where
|
||||||
@ -64,85 +19,7 @@ where
|
|||||||
let a = Matrix::<S, M, N>::fill(S::from(1));
|
let a = Matrix::<S, M, N>::fill(S::from(1));
|
||||||
let b = Matrix::<S, M, N>::fill(S::from(3));
|
let b = Matrix::<S, M, N>::fill(S::from(3));
|
||||||
let c: Matrix<S, M, N> = a + b;
|
let c: Matrix<S, M, N> = a + b;
|
||||||
for (i, ci) in c.elements().enumerate() {
|
for (_, ci) in c.elements().enumerate() {
|
||||||
assert_eq!(*ci, S::from(4));
|
assert_eq!(*ci, S::from(4));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
|
|
||||||
#[test]
|
|
||||||
fn test_lu_identity<S: Default + Real + Debug + Product + Sum + Into<f64>, const M: usize>() {
|
|
||||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
|
||||||
let i = Matrix::<S, M, M>::identity();
|
|
||||||
let ones = Vector::<S, M>::fill(S::one());
|
|
||||||
let decomp = i.lu().expect("Singular matrix encountered");
|
|
||||||
let (lu, idx, d) = decomp;
|
|
||||||
assert_eq!(lu, i, "Incorrect LU decomposition");
|
|
||||||
assert!(
|
|
||||||
(0..M).eq(idx.elements().cloned()),
|
|
||||||
"Incorrect permutation matrix",
|
|
||||||
);
|
|
||||||
scalar_eq!(d, S::one(), "Incorrect permutation parity");
|
|
||||||
scalar_eq!(i.det(), S::one());
|
|
||||||
assert_eq!(i.inverse(), Some(i));
|
|
||||||
assert_eq!(i.solve(&ones), Some(ones));
|
|
||||||
assert_eq!(decomp.separate(), (i, i));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[parameterize(S = (f32, f64), M = [2,3,4])]
|
|
||||||
#[test]
|
|
||||||
fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>() {
|
|
||||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
|
||||||
let mut a = Matrix::<S, M, M>::zero();
|
|
||||||
let ones = Vector::<S, M>::fill(S::one());
|
|
||||||
a.set_row(0, &ones);
|
|
||||||
|
|
||||||
assert_eq!(a.lu(), None, "Matrix should be singular");
|
|
||||||
assert_eq!(a.det(), S::zero());
|
|
||||||
assert_eq!(a.inverse(), None);
|
|
||||||
assert_eq!(a.solve(&ones), None)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_lu_2x2() {
|
|
||||||
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
|
|
||||||
let decomp = a.lu().expect("Singular matrix encountered");
|
|
||||||
let (_lu, idx, _d) = decomp;
|
|
||||||
// the decomposition is non-unique, due to the combination of lu and idx.
|
|
||||||
// Instead of checking the exact value, we only check the results.
|
|
||||||
// Also check if they produce the same results with both methods, since the
|
|
||||||
// Matrix<> methods use shortcuts the decomposition methods don't
|
|
||||||
|
|
||||||
let (l, u) = decomp.separate();
|
|
||||||
matrix_eq!(l.mmul(&u), a.permute_rows(&idx));
|
|
||||||
|
|
||||||
scalar_eq!(a.det(), -6.0);
|
|
||||||
scalar_eq!(a.det(), decomp.det());
|
|
||||||
|
|
||||||
matrix_eq!(
|
|
||||||
a.inverse().unwrap(),
|
|
||||||
Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)
|
|
||||||
);
|
|
||||||
matrix_eq!(a.inverse().unwrap(), decomp.inverse());
|
|
||||||
matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_lu_3x3() {
|
|
||||||
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
|
|
||||||
let decomp = a.lu().expect("Singular matrix encountered");
|
|
||||||
let (_lu, idx, _d) = decomp;
|
|
||||||
|
|
||||||
let (l, u) = decomp.separate();
|
|
||||||
matrix_eq!(l.mmul(&u), a.permute_rows(&idx));
|
|
||||||
|
|
||||||
scalar_eq!(a.det(), 3.0);
|
|
||||||
scalar_eq!(a.det(), decomp.det());
|
|
||||||
|
|
||||||
matrix_eq!(
|
|
||||||
a.inverse().unwrap(),
|
|
||||||
Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0)
|
|
||||||
);
|
|
||||||
matrix_eq!(a.inverse().unwrap(), decomp.inverse());
|
|
||||||
matrix_eq!(a.inverse().unwrap().inverse().unwrap(), a)
|
|
||||||
}
|
|
||||||
|
116
tests/solve.rs
Normal file
116
tests/solve.rs
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#[macro_use]
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use common::Approx;
|
||||||
|
use generic_parameterize::parameterize;
|
||||||
|
use num_traits::real::Real;
|
||||||
|
use num_traits::Zero;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::iter::{zip, Product, Sum};
|
||||||
|
use vector_victor::solve::{LUDecomp, LUSolve};
|
||||||
|
use vector_victor::{Matrix, Vector};
|
||||||
|
|
||||||
|
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
|
||||||
|
#[test]
|
||||||
|
/// The LU decomposition of the identity matrix should produce
|
||||||
|
/// the identity matrix with no permutations and parity 1
|
||||||
|
fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M: usize>() {
|
||||||
|
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||||
|
let i = Matrix::<S, M, M>::identity();
|
||||||
|
let ones = Vector::<S, M>::fill(S::one());
|
||||||
|
let decomp = i.lu().expect("Singular matrix encountered");
|
||||||
|
let LUDecomp { lu, idx, parity } = decomp;
|
||||||
|
assert_eq!(lu, i, "Incorrect LU decomposition");
|
||||||
|
assert!(
|
||||||
|
(0..M).eq(idx.elements().cloned()),
|
||||||
|
"Incorrect permutation matrix",
|
||||||
|
);
|
||||||
|
assert_approx!(parity, S::one(), "Incorrect permutation parity");
|
||||||
|
|
||||||
|
// Check determinant calculation which uses LU decomposition
|
||||||
|
assert_approx!(
|
||||||
|
i.det(),
|
||||||
|
S::one(),
|
||||||
|
"Identity matrix should have determinant of 1"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check inverse calculation with uses LU decomposition
|
||||||
|
assert_eq!(
|
||||||
|
i.inverse(),
|
||||||
|
Some(i),
|
||||||
|
"Identity matrix should be its own inverse"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
i.solve(&ones),
|
||||||
|
Some(ones),
|
||||||
|
"Failed to solve using identity matrix"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check triangle separation
|
||||||
|
assert_eq!(decomp.separate(), (i, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[parameterize(S = (f32, f64), M = [2,3,4])]
|
||||||
|
#[test]
|
||||||
|
/// The LU decomposition of any singular matrix should be `None`
|
||||||
|
fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>() {
|
||||||
|
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||||
|
let mut a = Matrix::<S, M, M>::zero();
|
||||||
|
let ones = Vector::<S, M>::fill(S::one());
|
||||||
|
a.set_row(0, &ones);
|
||||||
|
|
||||||
|
assert_eq!(a.lu(), None, "Matrix should be singular");
|
||||||
|
assert_eq!(
|
||||||
|
a.det(),
|
||||||
|
S::zero(),
|
||||||
|
"Singular matrix should have determinant of zero"
|
||||||
|
);
|
||||||
|
assert_eq!(a.inverse(), None, "Singular matrix should have no inverse");
|
||||||
|
assert_eq!(
|
||||||
|
a.solve(&ones),
|
||||||
|
None,
|
||||||
|
"Singular matrix should not be solvable"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lu_2x2() {
|
||||||
|
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
|
||||||
|
let decomp = a.lu().expect("Singular matrix encountered");
|
||||||
|
// the decomposition is non-unique, due to the combination of lu and idx.
|
||||||
|
// Instead of checking the exact value, we only check the results.
|
||||||
|
// Also check if they produce the same results with both methods, since the
|
||||||
|
// Matrix<> methods use shortcuts the decomposition methods don't
|
||||||
|
|
||||||
|
let (l, u) = decomp.separate();
|
||||||
|
assert_approx!(l.mmul(&u), a.permute_rows(&decomp.idx));
|
||||||
|
|
||||||
|
assert_approx!(a.det(), -6.0);
|
||||||
|
assert_approx!(a.det(), decomp.det());
|
||||||
|
|
||||||
|
assert_approx!(
|
||||||
|
a.inverse().unwrap(),
|
||||||
|
Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)
|
||||||
|
);
|
||||||
|
assert_approx!(a.inverse().unwrap(), decomp.inverse());
|
||||||
|
assert_approx!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lu_3x3() {
|
||||||
|
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
|
||||||
|
let decomp = a.lu().expect("Singular matrix encountered");
|
||||||
|
|
||||||
|
let (l, u) = decomp.separate();
|
||||||
|
assert_approx!(l.mmul(&u), a.permute_rows(&decomp.idx));
|
||||||
|
|
||||||
|
assert_approx!(a.det(), 3.0);
|
||||||
|
assert_approx!(a.det(), decomp.det());
|
||||||
|
|
||||||
|
assert_approx!(
|
||||||
|
a.inverse().unwrap(),
|
||||||
|
Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0)
|
||||||
|
);
|
||||||
|
assert_approx!(a.inverse().unwrap(), decomp.inverse());
|
||||||
|
assert_approx!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user