mirror of
https://github.com/drewcassidy/vector-victor.git
synced 2024-09-01 14:58:35 +00:00
Flesh out LU solving and add more tests
This commit is contained in:
parent
57636dc8dd
commit
1e8399eb41
@ -1,6 +1,7 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
pub trait Index2D: Copy + Debug {
|
pub trait Index2D: Copy + Debug {
|
||||||
|
#[inline(always)]
|
||||||
fn to_1d(self, height: usize, width: usize) -> Option<usize> {
|
fn to_1d(self, height: usize, width: usize) -> Option<usize> {
|
||||||
let (r, c) = self.to_2d(height, width)?;
|
let (r, c) = self.to_2d(height, width)?;
|
||||||
Some(r * width + c)
|
Some(r * width + c)
|
||||||
@ -10,6 +11,7 @@ pub trait Index2D: Copy + Debug {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Index2D for usize {
|
impl Index2D for usize {
|
||||||
|
#[inline(always)]
|
||||||
fn to_2d(self, height: usize, width: usize) -> Option<(usize, usize)> {
|
fn to_2d(self, height: usize, width: usize) -> Option<(usize, usize)> {
|
||||||
match self < (height * width) {
|
match self < (height * width) {
|
||||||
true => Some((self / width, self % width)),
|
true => Some((self / width, self % width)),
|
||||||
@ -19,6 +21,7 @@ impl Index2D for usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Index2D for (usize, usize) {
|
impl Index2D for (usize, usize) {
|
||||||
|
#[inline(always)]
|
||||||
fn to_2d(self, height: usize, width: usize) -> Option<(usize, usize)> {
|
fn to_2d(self, height: usize, width: usize) -> Option<(usize, usize)> {
|
||||||
match self.0 < height && self.1 < width {
|
match self.0 < height && self.1 < width {
|
||||||
true => Some(self),
|
true => Some(self),
|
||||||
|
@ -3,5 +3,6 @@ extern crate core;
|
|||||||
pub mod index;
|
pub mod index;
|
||||||
mod macros;
|
mod macros;
|
||||||
mod matrix;
|
mod matrix;
|
||||||
|
mod util;
|
||||||
|
|
||||||
pub use matrix::{LUSolve, Matrix, Vector};
|
pub use matrix::{LUSolve, Matrix, Vector};
|
||||||
|
227
src/matrix.rs
227
src/matrix.rs
@ -1,12 +1,15 @@
|
|||||||
use crate::impl_matrix_op;
|
use crate::impl_matrix_op;
|
||||||
use crate::index::Index2D;
|
use crate::index::Index2D;
|
||||||
|
use crate::util::{checked_div, checked_inv};
|
||||||
|
|
||||||
use num_traits::real::Real;
|
use num_traits::real::Real;
|
||||||
use num_traits::{Num, NumOps, One, Zero};
|
use num_traits::{Num, NumOps, One, Zero};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::{zip, Flatten, Product, Sum};
|
use std::iter::{zip, Flatten, Product, Sum};
|
||||||
|
use std::mem::swap;
|
||||||
|
|
||||||
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
|
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
|
||||||
|
use std::process::id;
|
||||||
|
|
||||||
/// A 2D array of values which can be operated upon.
|
/// A 2D array of values which can be operated upon.
|
||||||
///
|
///
|
||||||
@ -181,7 +184,7 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
|||||||
Some(&mut self.data[m][n])
|
Some(&mut self.data[m][n])
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a row of the matrix. panics if index is out of bounds
|
/// Returns a row of the matrix. or [None] if index is out of bounds
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
@ -194,12 +197,15 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
|||||||
/// ```
|
/// ```
|
||||||
#[inline]
|
#[inline]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn row(&self, m: usize) -> Option<Vector<T, N>> {
|
pub fn row(&self, m: usize) -> Vector<T, N> {
|
||||||
if m < M {
|
assert!(
|
||||||
Some(Vector::<T, N>::vec(self.data[m]))
|
m < M,
|
||||||
} else {
|
"Row index {} out of bounds for {}x{} matrix",
|
||||||
None
|
m,
|
||||||
}
|
M,
|
||||||
|
N
|
||||||
|
);
|
||||||
|
Vector::<T, N>::vec(self.data[m])
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -211,25 +217,28 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
|||||||
M,
|
M,
|
||||||
N
|
N
|
||||||
);
|
);
|
||||||
for (n, v) in val.elements().enumerate() {
|
for n in 0..N {
|
||||||
self.data[m][n] = *v;
|
self.data[m][n] = val.data[n][0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pivot_row(&mut self, m1: usize, m2: usize) {
|
pub fn pivot_row(&mut self, m1: usize, m2: usize) {
|
||||||
let tmp = self.row(m2).expect("Invalid row index");
|
let tmp = self.row(m2);
|
||||||
self.set_row(m2, &self.row(m1).expect("Invalid row index"));
|
self.set_row(m2, &self.row(m1));
|
||||||
self.set_row(m1, &tmp);
|
self.set_row(m1, &tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn col(&self, n: usize) -> Option<Vector<T, M>> {
|
pub fn col(&self, n: usize) -> Vector<T, M> {
|
||||||
if n < N {
|
assert!(
|
||||||
Some(Vector::<T, M>::vec(self.data.map(|r| r[n])))
|
n < N,
|
||||||
} else {
|
"Column index {} out of bounds for {}x{} matrix",
|
||||||
None
|
n,
|
||||||
}
|
M,
|
||||||
|
N
|
||||||
|
);
|
||||||
|
Vector::<T, M>::vec(self.data.map(|r| r[n]))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -242,25 +251,41 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
|||||||
N
|
N
|
||||||
);
|
);
|
||||||
|
|
||||||
for (m, v) in val.elements().enumerate() {
|
for m in 0..M {
|
||||||
self.data[m][n] = *v;
|
self.data[m][n] = val.data[m][0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pivot_col(&mut self, n1: usize, n2: usize) {
|
pub fn pivot_col(&mut self, n1: usize, n2: usize) {
|
||||||
let tmp = self.col(n2).expect("Invalid column index");
|
let tmp = self.col(n2);
|
||||||
self.set_col(n2, &self.col(n1).expect("Invalid column index"));
|
self.set_col(n2, &self.col(n1));
|
||||||
self.set_col(n1, &tmp);
|
self.set_col(n1, &tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn rows<'a>(&'a self) -> impl Iterator<Item = Vector<T, N>> + 'a {
|
pub fn rows<'a>(&'a self) -> impl Iterator<Item = Vector<T, N>> + 'a {
|
||||||
(0..M).map(|m| self.row(m).expect("invalid row reached while iterating"))
|
(0..M).map(|m| self.row(m))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn cols<'a>(&'a self) -> impl Iterator<Item = Vector<T, M>> + 'a {
|
pub fn cols<'a>(&'a self) -> impl Iterator<Item = Vector<T, M>> + 'a {
|
||||||
(0..N).map(|n| self.col(n).expect("invalid column reached while iterating"))
|
(0..N).map(|n| self.col(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn permute_rows(&self, ms: &Vector<usize, M>) -> Self
|
||||||
|
where
|
||||||
|
T: Default,
|
||||||
|
{
|
||||||
|
Self::from_rows(ms.elements().map(|&m| self.row(m)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn permute_cols(&self, ns: &Vector<usize, N>) -> Self
|
||||||
|
where
|
||||||
|
T: Default,
|
||||||
|
{
|
||||||
|
Self::from_cols(ns.elements().map(|&n| self.col(n)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transpose(&self) -> Matrix<T, N, M>
|
pub fn transpose(&self) -> Matrix<T, N, M>
|
||||||
@ -305,14 +330,7 @@ impl<T: Copy, const N: usize> Vector<T, N> {
|
|||||||
/// Create a vector from a 1D array.
|
/// Create a vector from a 1D array.
|
||||||
/// Note that vectors are always column vectors unless explicitly instantiated as row vectors
|
/// Note that vectors are always column vectors unless explicitly instantiated as row vectors
|
||||||
///
|
///
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `data`: A 1D array of elements to copy into the new vector
|
|
||||||
///
|
|
||||||
/// returns: Vector<T, M>
|
|
||||||
///
|
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
/// # use vector_victor::{Matrix, Vector};
|
/// # use vector_victor::{Matrix, Vector};
|
||||||
/// let my_vector = Vector::vec([1,2,3,4]);
|
/// let my_vector = Vector::vec([1,2,3,4]);
|
||||||
@ -374,8 +392,9 @@ impl<T: Copy, const M: usize, const N: usize> Matrix<T, M, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Square matrix impls
|
// Square matrix implementations
|
||||||
impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
||||||
|
/// Create an identity matrix
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn identity() -> Self
|
pub fn identity() -> Self
|
||||||
where
|
where
|
||||||
@ -388,31 +407,36 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// returns an iterator over the elements along the diagonal of a square matrix
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn diagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
|
pub fn diagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
|
||||||
(0..N).map(|n| self[(n, n)])
|
(0..N).map(|n| self[(n, n)])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns an iterator over the elements directly below the diagonal of a square matrix
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn subdiagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
|
pub fn subdiagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
|
||||||
(0..N - 1).map(|n| self[(n, n + 1)])
|
(0..N - 1).map(|n| self[(n, n + 1)])
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
/// Returns `Some(lu, idx, d)`, or [None] if the matrix is singular.
|
||||||
///
|
///
|
||||||
/// <math xmlns="http://www.w3.org/1998/Math/MathML" alttext="a^{3}" display="block">
|
/// Where:
|
||||||
/// <msup>
|
/// * `lu`: The LU decomposition of `self`. The upper and lower matrices are combined into a single matrix
|
||||||
/// <mi>a</mi>
|
/// * `idx`: The permutation of rows on the original matrix needed to perform the decomposition.
|
||||||
/// <mn>3</mn>
|
/// Each element is the corresponding row index in the original matrix
|
||||||
/// </msup>
|
/// * `d`: The permutation parity of `idx`, either `1` for even or `-1` for odd
|
||||||
/// </math>
|
///
|
||||||
|
/// The resulting tuple (once unwrapped) has the [LUSolve] trait, allowing it to be used for
|
||||||
|
/// solving multiple matrices without having to repeat the LU decomposition process
|
||||||
|
#[must_use]
|
||||||
pub fn lu(&self) -> Option<(Self, Vector<usize, N>, T)>
|
pub fn lu(&self) -> Option<(Self, Vector<usize, N>, T)>
|
||||||
where
|
where
|
||||||
T: Real + Default,
|
T: Real + Default,
|
||||||
{
|
{
|
||||||
// Implementation from Numerical Recipes §2.3
|
// Implementation from Numerical Recipes §2.3
|
||||||
let mut lu = self.clone();
|
let mut lu = self.clone();
|
||||||
let mut idx: Vector<usize, N> = Default::default();
|
let mut idx: Vector<usize, N> = (0..N).collect();
|
||||||
let mut d = T::one();
|
let mut d = T::one();
|
||||||
|
|
||||||
let mut vv: Vector<T, N> = self
|
let mut vv: Vector<T, N> = self
|
||||||
@ -428,7 +452,7 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
|
|
||||||
for k in 0..N {
|
for k in 0..N {
|
||||||
// search for the pivot element and its index
|
// search for the pivot element and its index
|
||||||
let (ipivot, _) = (lu.col(k)? * vv)
|
let (ipivot, _) = (lu.col(k) * vv)
|
||||||
.abs()
|
.abs()
|
||||||
.elements()
|
.elements()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@ -442,11 +466,11 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
// do we need to interchange rows?
|
// do we need to interchange rows?
|
||||||
if k != ipivot {
|
if k != ipivot {
|
||||||
lu.pivot_row(ipivot, k); // yes, we do
|
lu.pivot_row(ipivot, k); // yes, we do
|
||||||
|
idx.pivot_row(ipivot, k);
|
||||||
d = -d; // change parity of d
|
d = -d; // change parity of d
|
||||||
vv[ipivot] = vv[k] //interchange scale factor
|
vv[ipivot] = vv[k] //interchange scale factor
|
||||||
}
|
}
|
||||||
|
|
||||||
idx[k] = ipivot;
|
|
||||||
let pivot = lu[(k, k)];
|
let pivot = lu[(k, k)];
|
||||||
if pivot.abs() < T::epsilon() {
|
if pivot.abs() < T::epsilon() {
|
||||||
// if the pivot is zero, the matrix is singular
|
// if the pivot is zero, the matrix is singular
|
||||||
@ -467,21 +491,33 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
return Some((lu, idx, d));
|
return Some((lu, idx, d));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes the inverse matrix of `self`, or [None] if the matrix cannot be inverted.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn inverse(&self) -> Option<Self>
|
pub fn inverse(&self) -> Option<Self>
|
||||||
where
|
where
|
||||||
T: Real + Default + Sum,
|
T: Real + Default + Sum + Product,
|
||||||
{
|
{
|
||||||
self.solve(&Self::identity())
|
match N {
|
||||||
|
1 => Some(Self::fill(checked_inv(self[0])?)),
|
||||||
|
2 => {
|
||||||
|
let mut result = Self::default();
|
||||||
|
result[(0, 0)] = self[(1, 1)];
|
||||||
|
result[(1, 1)] = self[(0, 0)];
|
||||||
|
result[(1, 0)] = -self[(1, 0)];
|
||||||
|
result[(0, 1)] = -self[(0, 1)];
|
||||||
|
Some(result * checked_inv(self.det())?)
|
||||||
|
}
|
||||||
|
_ => Some(self.lu()?.inverse()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes the determinant of `self`.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn det(&self) -> T
|
pub fn det(&self) -> T
|
||||||
where
|
where
|
||||||
T: Real + Default + Product + Sum,
|
T: Real + Default + Product + Sum,
|
||||||
{
|
{
|
||||||
match N {
|
match N {
|
||||||
0 => T::one(),
|
|
||||||
1 => self[0],
|
1 => self[0],
|
||||||
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
|
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
|
||||||
3 => {
|
3 => {
|
||||||
@ -500,37 +536,52 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
|
|||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// use LU decomposition
|
// use LU decomposition
|
||||||
if let Some((lu, _, d)) = self.lu() {
|
self.lu().map_or(T::zero(), |lu| lu.det())
|
||||||
d * lu.diagonals().product()
|
|
||||||
} else {
|
|
||||||
T::zero()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Solves a system of `Ax = b` using `self` for `A`, or [None] if there is no solution.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>
|
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>
|
||||||
where
|
where
|
||||||
T: Real + Default + Sum,
|
T: Real + Default + Sum + Product,
|
||||||
{
|
{
|
||||||
Some(self.lu()?.solve(b))
|
Some(self.lu()?.solve(b))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait LUSolve<R>: Copy {
|
/// Trait for the result of [Matrix::lu()],
|
||||||
fn solve(&self, rhs: &R) -> R;
|
/// allowing a single LU decomposition to be used to solve multiple equations
|
||||||
|
pub trait LUSolve<T, const N: usize>: Copy
|
||||||
|
where
|
||||||
|
T: Real + Copy,
|
||||||
|
{
|
||||||
|
/// Solves a system of `Ax = b` using an LU decomposition.
|
||||||
|
fn solve<const M: usize>(&self, rhs: &Matrix<T, N, M>) -> Matrix<T, N, M>;
|
||||||
|
|
||||||
|
/// Solves the determinant using an LU decomposition,
|
||||||
|
/// by multiplying the product of the diagonals by the permutation parity
|
||||||
|
fn det(&self) -> T;
|
||||||
|
|
||||||
|
/// Solves the inverse of the matrix that the LU decomposition represents.
|
||||||
|
fn inverse(&self) -> Matrix<T, N, N> {
|
||||||
|
return self.solve(&Matrix::<T, N, N>::identity());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Separate the lu decomposition into L and U matrices, such that `L*U = P*A`.
|
||||||
|
fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Copy, const N: usize, const M: usize> LUSolve<Matrix<T, N, M>>
|
impl<T: Copy, const N: usize> LUSolve<T, N> for (Matrix<T, N, N>, Vector<usize, N>, T)
|
||||||
for (Matrix<T, N, N>, Vector<usize, N>, T)
|
|
||||||
where
|
where
|
||||||
for<'t> T: Real + Default + Sum,
|
T: Real + Default + Sum + Product,
|
||||||
{
|
{
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn solve(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
|
fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
|
||||||
let (lu, idx, _) = self;
|
let (lu, idx, _) = self;
|
||||||
Matrix::<T, N, M>::from_cols(b.cols().map(|mut x| {
|
let bp = b.permute_rows(idx);
|
||||||
|
|
||||||
|
Matrix::from_cols(bp.cols().map(|mut x| {
|
||||||
// Implementation from Numerical Recipes §2.3
|
// Implementation from Numerical Recipes §2.3
|
||||||
|
|
||||||
// When ii is set to a positive value,
|
// When ii is set to a positive value,
|
||||||
@ -538,42 +589,48 @@ where
|
|||||||
let mut ii = 0usize;
|
let mut ii = 0usize;
|
||||||
for i in 0..N {
|
for i in 0..N {
|
||||||
// forward substitution
|
// forward substitution
|
||||||
let ip = idx[i]; // i permuted
|
let mut sum = x[i];
|
||||||
let sum = x[ip];
|
if ii != 0 {
|
||||||
x[ip] = x[i]; // unscramble as we go
|
for j in (ii - 1)..i {
|
||||||
if ii > 0 {
|
sum = sum - (lu[(i, j)] * x[j]);
|
||||||
x[i] = sum
|
|
||||||
- (lu.row(i).expect("Invalid row reached") * x)
|
|
||||||
.elements()
|
|
||||||
.take(i)
|
|
||||||
.skip(ii - 1)
|
|
||||||
.cloned()
|
|
||||||
.sum()
|
|
||||||
} else {
|
|
||||||
x[i] = sum;
|
|
||||||
if sum.abs() > T::epsilon() {
|
|
||||||
ii = i + 1;
|
|
||||||
}
|
}
|
||||||
|
} else if sum.abs() > T::epsilon() {
|
||||||
|
ii = i + 1;
|
||||||
}
|
}
|
||||||
|
x[i] = sum;
|
||||||
}
|
}
|
||||||
for i in (0..(N - 1)).rev() {
|
for i in (0..N).rev() {
|
||||||
// back substitution
|
// back substitution
|
||||||
let sum = x[i]
|
let mut sum = x[i];
|
||||||
- (lu.row(i).expect("Invalid row reached") * x)
|
for j in (i + 1)..N {
|
||||||
.elements()
|
sum = sum - (lu[(i, j)] * x[j]);
|
||||||
.skip(i + 1)
|
}
|
||||||
.cloned()
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
x[i] = sum / lu[(i, i)]
|
x[i] = sum / lu[(i, i)]
|
||||||
}
|
}
|
||||||
x
|
x
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Square matrices
|
fn det(&self) -> T {
|
||||||
impl<T: Copy, const N: usize> Matrix<T, N, N> {}
|
let (lu, _, d) = self;
|
||||||
|
*d * lu.diagonals().product()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>) {
|
||||||
|
let mut l = Matrix::<T, N, N>::identity();
|
||||||
|
let mut u = self.0; // lu
|
||||||
|
|
||||||
|
for m in 1..N {
|
||||||
|
for n in 0..m {
|
||||||
|
// iterate over lower diagonal
|
||||||
|
l[(m, n)] = u[(m, n)];
|
||||||
|
u[(m, n)] = T::zero();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(l, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Index
|
// Index
|
||||||
impl<I, T, const M: usize, const N: usize> Index<I> for Matrix<T, M, N>
|
impl<I, T, const M: usize, const N: usize> Index<I> for Matrix<T, M, N>
|
||||||
@ -583,6 +640,7 @@ where
|
|||||||
{
|
{
|
||||||
type Output = T;
|
type Output = T;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn index(&self, index: I) -> &Self::Output {
|
fn index(&self, index: I) -> &Self::Output {
|
||||||
self.get(index).expect(&*format!(
|
self.get(index).expect(&*format!(
|
||||||
"index {:?} out of range for {}x{} Matrix",
|
"index {:?} out of range for {}x{} Matrix",
|
||||||
@ -597,6 +655,7 @@ where
|
|||||||
I: Index2D,
|
I: Index2D,
|
||||||
T: Copy,
|
T: Copy,
|
||||||
{
|
{
|
||||||
|
#[inline(always)]
|
||||||
fn index_mut(&mut self, index: I) -> &mut Self::Output {
|
fn index_mut(&mut self, index: I) -> &mut Self::Output {
|
||||||
self.get_mut(index).expect(&*format!(
|
self.get_mut(index).expect(&*format!(
|
||||||
"index {:?} out of range for {}x{} Matrix",
|
"index {:?} out of range for {}x{} Matrix",
|
||||||
|
13
src/util.rs
Normal file
13
src/util.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
use num_traits::{Num, NumOps, One, Zero};
|
||||||
|
use std::ops::Div;
|
||||||
|
|
||||||
|
pub fn checked_div<L: Num + Div<R, Output = T>, R: Num + Zero, T>(num: L, den: R) -> Option<T> {
|
||||||
|
if den.is_zero() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
return Some(num / den);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn checked_inv<T: Num + Div<T, Output = T> + Zero + One>(den: T) -> Option<T> {
|
||||||
|
return checked_div(T::one(), den);
|
||||||
|
}
|
83
tests/ops.rs
83
tests/ops.rs
@ -1,9 +1,10 @@
|
|||||||
use generic_parameterize::parameterize;
|
use generic_parameterize::parameterize;
|
||||||
use num_traits::real::Real;
|
use num_traits::real::Real;
|
||||||
|
use num_traits::Zero;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::{Product, Sum};
|
use std::iter::{Product, Sum};
|
||||||
use std::ops;
|
use std::ops;
|
||||||
use vector_victor::{Matrix, Vector};
|
use vector_victor::{LUSolve, Matrix, Vector};
|
||||||
|
|
||||||
#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])]
|
#[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])]
|
||||||
#[test]
|
#[test]
|
||||||
@ -25,40 +26,54 @@ fn test_lu_identity<S: Default + Real + Debug + Product + Sum, const M: usize>()
|
|||||||
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||||
let i = Matrix::<S, M, M>::identity();
|
let i = Matrix::<S, M, M>::identity();
|
||||||
let ones = Vector::<S, M>::fill(S::one());
|
let ones = Vector::<S, M>::fill(S::one());
|
||||||
let (lu, idx, d) = i.lu().expect("Singular matrix encountered");
|
let decomp = i.lu().expect("Singular matrix encountered");
|
||||||
assert_eq!(
|
let (lu, idx, d) = decomp;
|
||||||
lu,
|
assert_eq!(lu, i, "Incorrect LU decomposition");
|
||||||
i,
|
|
||||||
"Incorrect LU decomposition matrix for {m}x{m} identity matrix",
|
|
||||||
m = M
|
|
||||||
);
|
|
||||||
assert!(
|
assert!(
|
||||||
(0..M).eq(idx.elements().cloned()),
|
(0..M).eq(idx.elements().cloned()),
|
||||||
"Incorrect permutation matrix result for {m}x{m} identity matrix",
|
"Incorrect permutation matrix",
|
||||||
m = M
|
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(d, S::one(), "Incorrect permutation parity");
|
||||||
d,
|
assert_eq!(i.det(), S::one());
|
||||||
S::one(),
|
assert_eq!(i.inverse(), Some(i));
|
||||||
"Incorrect permutation parity for {m}x{m} identity matrix",
|
assert_eq!(i.solve(&ones), Some(ones));
|
||||||
m = M
|
assert_eq!(decomp.separate(), (i, i));
|
||||||
);
|
}
|
||||||
assert_eq!(
|
|
||||||
i.det(),
|
#[parameterize(S = (f32, f64), M = [2,3,4])]
|
||||||
S::one(),
|
#[test]
|
||||||
"Incorrect determinant for {m}x{m} identity matrix",
|
fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>() {
|
||||||
m = M
|
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
|
||||||
);
|
let mut a = Matrix::<S, M, M>::zero();
|
||||||
assert_eq!(
|
let ones = Vector::<S, M>::fill(S::one());
|
||||||
i.inverse(),
|
a.set_row(0, &ones);
|
||||||
Some(i),
|
|
||||||
"Incorrect inverse for {m}x{m} identity matrix",
|
assert_eq!(a.lu(), None, "Matrix should be singular");
|
||||||
m = M
|
assert_eq!(a.det(), S::zero());
|
||||||
);
|
assert_eq!(a.inverse(), None);
|
||||||
assert_eq!(
|
assert_eq!(a.solve(&ones), None)
|
||||||
i.solve(&ones),
|
}
|
||||||
Some(ones),
|
|
||||||
"Incorrect solve result for {m}x{m} identity matrix",
|
#[test]
|
||||||
m = M
|
fn test_lu_2x2() {
|
||||||
)
|
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
|
||||||
|
let decomp = a.lu().expect("Singular matrix encountered");
|
||||||
|
let (lu, idx, d) = decomp;
|
||||||
|
// the decomposition is non-unique, due to the combination of lu and idx.
|
||||||
|
// Instead of checking the exact value, we only check the results.
|
||||||
|
// Also check if they produce the same results with both methods, since the
|
||||||
|
// Matrix<> methods use shortcuts the decomposition methods don't
|
||||||
|
|
||||||
|
let (l, u) = decomp.separate();
|
||||||
|
assert_eq!(l.mmul(&u), a.permute_rows(&idx));
|
||||||
|
|
||||||
|
assert_eq!(a.det(), -6.0);
|
||||||
|
assert_eq!(a.det(), decomp.det());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
a.inverse(),
|
||||||
|
Some(Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0))
|
||||||
|
);
|
||||||
|
assert_eq!(a.inverse(), Some(decomp.inverse()));
|
||||||
|
assert_eq!(a.inverse().unwrap().inverse().unwrap(), a)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user