Separate LU decomposition into its own file where other solving stuff will live

This commit is contained in:
Andrew Cassidy 2023-05-06 01:34:31 -07:00
parent 8fcb032b1a
commit 543769f691
5 changed files with 172 additions and 180 deletions

View File

@ -3,6 +3,7 @@ extern crate core;
pub mod index;
mod macros;
mod matrix;
pub mod solve;
mod util;
pub use matrix::{LUSolve, Matrix, Vector};
pub use matrix::{Matrix, Vector};

View File

@ -7,6 +7,7 @@ use num_traits::{Num, NumOps, One, Zero};
use std::fmt::Debug;
use std::iter::{zip, Flatten, Product, Sum};
use crate::solve::{LUDecomp, LUSolve};
use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg};
/// A 2D array of values which can be operated upon.
@ -398,85 +399,17 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
pub fn subdiagonals<'s>(&'s self) -> impl Iterator<Item = T> + 's {
(0..N - 1).map(|n| self[(n, n + 1)])
}
}
/// Returns `Some(lu, idx, d)`, or [None] if the matrix is singular.
///
/// Where:
/// * `lu`: The LU decomposition of `self`. The upper and lower matrices are combined into a single matrix
/// * `idx`: The permutation of rows on the original matrix needed to perform the decomposition.
/// Each element is the corresponding row index in the original matrix
/// * `d`: The permutation parity of `idx`, either `1` for even or `-1` for odd
///
/// The resulting tuple (once unwrapped) has the [LUSolve] trait, allowing it to be used for
/// solving multiple matrices without having to repeat the LU decomposition process
#[must_use]
pub fn lu(&self) -> Option<(Self, Vector<usize, N>, T)>
where
T: Real + Default,
{
// Implementation from Numerical Recipes §2.3
let mut lu = self.clone();
let mut idx: Vector<usize, N> = (0..N).collect();
let mut d = T::one();
let mut vv: Vector<T, N> = self
.rows()
.map(|row| {
let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?;
match m < T::epsilon() {
true => None,
false => Some(T::one() / m),
}
})
.collect::<Option<_>>()?; // get the inverse maxabs value in each row
for k in 0..N {
// search for the pivot element and its index
let (ipivot, _) = (lu.col(k) * vv)
.abs()
.elements()
.enumerate()
.skip(k) // below the diagonal
.reduce(|(imax, xmax), (i, x)| match x > xmax {
// Is the figure of merit for the pivot better than the best so far?
true => (i, x),
false => (imax, xmax),
})?;
// do we need to interchange rows?
if k != ipivot {
lu.pivot_row(ipivot, k); // yes, we do
idx.pivot_row(ipivot, k);
d = -d; // change parity of d
vv[ipivot] = vv[k] //interchange scale factor
}
let pivot = lu[(k, k)];
if pivot.abs() < T::epsilon() {
// if the pivot is zero, the matrix is singular
return None;
};
for i in (k + 1)..N {
// divide by the pivot element
let dpivot = lu[(i, k)] / pivot;
lu[(i, k)] = dpivot;
for j in (k + 1)..N {
// reduce remaining submatrix
lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]);
}
}
}
return Some((lu, idx, d));
impl<T, const N: usize> LUSolve<T, N> for Matrix<T, N, N>
where
T: Copy + Default + Real + Sum + Product,
{
fn lu(&self) -> Option<LUDecomp<T, N>> {
LUDecomp::decompose(self)
}
/// Computes the inverse matrix of `self`, or [None] if the matrix cannot be inverted.
#[must_use]
pub fn inverse(&self) -> Option<Self>
where
T: Real + Default + Sum + Product,
{
fn inverse(&self) -> Option<Matrix<T, N, N>> {
match N {
1 => Some(Self::fill(checked_inv(self[0])?)),
2 => {
@ -491,12 +424,7 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
}
}
/// Computes the determinant of `self`.
#[must_use]
pub fn det(&self) -> T
where
T: Real + Default + Product + Sum,
{
fn det(&self) -> T {
match N {
1 => self[0],
2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]),
@ -520,95 +448,6 @@ impl<T: Copy, const N: usize> Matrix<T, N, N> {
}
}
}
/// Solves a system of `Ax = b` using `self` for `A`, or [None] if there is no solution.
#[must_use]
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>
where
T: Real + Default + Sum + Product,
{
Some(self.lu()?.solve(b))
}
}
/// Trait for the result of [Matrix::lu()],
/// allowing a single LU decomposition to be used to solve multiple equations
pub trait LUSolve<T, const N: usize>: Copy
where
T: Real + Copy,
{
/// Solves a system of `Ax = b` using an LU decomposition.
fn solve<const M: usize>(&self, rhs: &Matrix<T, N, M>) -> Matrix<T, N, M>;
/// Solves the determinant using an LU decomposition,
/// by multiplying the product of the diagonals by the permutation parity
fn det(&self) -> T;
/// Solves the inverse of the matrix that the LU decomposition represents.
fn inverse(&self) -> Matrix<T, N, N> {
return self.solve(&Matrix::<T, N, N>::identity());
}
/// Separate the lu decomposition into L and U matrices, such that `L*U = P*A`.
fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>);
}
impl<T: Copy, const N: usize> LUSolve<T, N> for (Matrix<T, N, N>, Vector<usize, N>, T)
where
T: Real + Default + Sum + Product,
{
#[must_use]
fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
let (lu, idx, _) = self;
let bp = b.permute_rows(idx);
Matrix::from_cols(bp.cols().map(|mut x| {
// Implementation from Numerical Recipes §2.3
// When ii is set to a positive value,
// it will become the index of the first nonvanishing element of b
let mut ii = 0usize;
for i in 0..N {
// forward substitution using L
let mut sum = x[i];
if ii != 0 {
for j in (ii - 1)..i {
sum = sum - (lu[(i, j)] * x[j]);
}
} else if sum.abs() > T::epsilon() {
ii = i + 1;
}
x[i] = sum;
}
for i in (0..N).rev() {
// back substitution using U
let mut sum = x[i];
for j in (i + 1)..N {
sum = sum - (lu[(i, j)] * x[j]);
}
x[i] = sum / lu[(i, i)]
}
x
}))
}
fn det(&self) -> T {
let (lu, _, d) = self;
*d * lu.diagonals().product()
}
fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>) {
let mut l = Matrix::<T, N, N>::identity();
let mut u = self.0; // lu
for m in 1..N {
for n in 0..m {
// iterate over lower diagonal
l[(m, n)] = u[(m, n)];
u[(m, n)] = T::zero();
}
}
(l, u)
}
}
// Index

153
src/solve.rs Normal file
View File

@ -0,0 +1,153 @@
use crate::util::checked_inv;
use crate::Matrix;
use crate::Vector;
use num_traits::real::Real;
use num_traits::{One, Zero};
use std::iter::{Product, Sum};
use std::ops::Index;
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct LUDecomp<T: Copy, const N: usize> {
pub lu: Matrix<T, N, N>,
pub idx: Vector<usize, N>,
pub parity: T,
}
impl<T: Copy + Default, const N: usize> LUDecomp<T, N>
where
T: Real + Default + Sum + Product,
{
#[must_use]
pub fn decompose(m: &Matrix<T, N, N>) -> Option<Self> {
// Implementation from Numerical Recipes §2.3
let mut lu = m.clone();
let mut idx: Vector<usize, N> = (0..N).collect();
let mut parity = T::one();
let mut vv: Vector<T, N> = m
.rows()
.map(|row| {
let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?;
match m < T::epsilon() {
true => None,
false => Some(T::one() / m),
}
})
.collect::<Option<_>>()?; // get the inverse maxabs value in each row
for k in 0..N {
// search for the pivot element and its index
let (ipivot, _) = (lu.col(k) * vv)
.abs()
.elements()
.enumerate()
.skip(k) // below the diagonal
.reduce(|(imax, xmax), (i, x)| match x > xmax {
// Is the figure of merit for the pivot better than the best so far?
true => (i, x),
false => (imax, xmax),
})?;
// do we need to interchange rows?
if k != ipivot {
lu.pivot_row(ipivot, k); // yes, we do
idx.pivot_row(ipivot, k);
parity = -parity; // change parity of d
vv[ipivot] = vv[k] //interchange scale factor
}
let pivot = lu[(k, k)];
if pivot.abs() < T::epsilon() {
// if the pivot is zero, the matrix is singular
return None;
};
for i in (k + 1)..N {
// divide by the pivot element
let dpivot = lu[(i, k)] / pivot;
lu[(i, k)] = dpivot;
for j in (k + 1)..N {
// reduce remaining submatrix
lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]);
}
}
}
return Some(Self { lu, idx, parity });
}
#[must_use]
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
let b_permuted = b.permute_rows(&self.idx);
Matrix::from_cols(b_permuted.cols().map(|mut x| {
// Implementation from Numerical Recipes §2.3
// When ii is set to a positive value,
// it will become the index of the first nonvanishing element of b
let mut ii = 0usize;
for i in 0..N {
// forward substitution using L
let mut sum = x[i];
if ii != 0 {
for j in (ii - 1)..i {
sum = sum - (self.lu[(i, j)] * x[j]);
}
} else if sum.abs() > T::epsilon() {
ii = i + 1;
}
x[i] = sum;
}
for i in (0..N).rev() {
// back substitution using U
let mut sum = x[i];
for j in (i + 1)..N {
sum = sum - (self.lu[(i, j)] * x[j]);
}
x[i] = sum / self.lu[(i, i)]
}
x
}))
}
pub fn det(&self) -> T {
self.parity * self.lu.diagonals().product()
}
pub fn inverse(&self) -> Matrix<T, N, N> {
return self.solve(&Matrix::<T, N, N>::identity());
}
pub fn separate(&self) -> (Matrix<T, N, N>, Matrix<T, N, N>) {
let mut l = Matrix::<T, N, N>::identity();
let mut u = self.lu; // lu
for m in 1..N {
for n in 0..m {
// iterate over lower diagonal
l[(m, n)] = u[(m, n)];
u[(m, n)] = T::zero();
}
}
(l, u)
}
}
pub trait LUSolve<T, const N: usize>
where
T: Copy + Default + Real + Product + Sum,
{
#[must_use]
fn lu(&self) -> Option<LUDecomp<T, N>>;
#[must_use]
fn inverse(&self) -> Option<Matrix<T, N, N>>;
#[must_use]
fn det(&self) -> T;
#[must_use]
fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>> {
Some(self.lu()?.solve(b))
}
}

View File

@ -8,7 +8,7 @@ use num_traits::Zero;
use std::fmt::Debug;
use std::iter::{zip, Product, Sum};
use std::ops;
use vector_victor::{LUSolve, Matrix, Vector};
use vector_victor::{Matrix, Vector};
#[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])]
#[test]

View File

@ -7,7 +7,8 @@ use num_traits::real::Real;
use num_traits::Zero;
use std::fmt::Debug;
use std::iter::{zip, Product, Sum};
use vector_victor::{LUSolve, Matrix, Vector};
use vector_victor::solve::{LUDecomp, LUSolve};
use vector_victor::{Matrix, Vector};
#[parameterize(S = (f32, f64), M = [1,2,3,4])]
#[test]
@ -18,13 +19,13 @@ fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M:
let i = Matrix::<S, M, M>::identity();
let ones = Vector::<S, M>::fill(S::one());
let decomp = i.lu().expect("Singular matrix encountered");
let (lu, idx, d) = decomp;
let LUDecomp { lu, idx, parity } = decomp;
assert_eq!(lu, i, "Incorrect LU decomposition");
assert!(
(0..M).eq(idx.elements().cloned()),
"Incorrect permutation matrix",
);
assert_approx!(d, S::one(), "Incorrect permutation parity");
assert_approx!(parity, S::one(), "Incorrect permutation parity");
// Check determinant calculation which uses LU decomposition
assert_approx!(
@ -76,14 +77,13 @@ fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>()
fn test_lu_2x2() {
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]);
let decomp = a.lu().expect("Singular matrix encountered");
let (_lu, idx, _d) = decomp;
// the decomposition is non-unique, due to the combination of lu and idx.
// Instead of checking the exact value, we only check the results.
// Also check if they produce the same results with both methods, since the
// Matrix<> methods use shortcuts the decomposition methods don't
let (l, u) = decomp.separate();
assert_approx!(l.mmul(&u), a.permute_rows(&idx));
assert_approx!(l.mmul(&u), a.permute_rows(&decomp.idx));
assert_approx!(a.det(), -6.0);
assert_approx!(a.det(), decomp.det());
@ -100,10 +100,9 @@ fn test_lu_2x2() {
fn test_lu_3x3() {
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
let decomp = a.lu().expect("Singular matrix encountered");
let (_lu, idx, _d) = decomp;
let (l, u) = decomp.separate();
assert_approx!(l.mmul(&u), a.permute_rows(&idx));
assert_approx!(l.mmul(&u), a.permute_rows(&decomp.idx));
assert_approx!(a.det(), 3.0);
assert_approx!(a.det(), decomp.det());