Refactor LUDecompose slightly

This commit is contained in:
Andrew Cassidy 2023-05-21 20:55:37 -07:00
parent bc1b3f199d
commit 2b303892f7
2 changed files with 70 additions and 67 deletions

View File

@ -1,6 +1,7 @@
use crate::util::checked_inv; use crate::util::checked_inv;
use crate::{Matrix, Vector}; use crate::{Matrix, Vector};
use num_traits::real::Real; use num_traits::real::Real;
use num_traits::Signed;
use std::iter::{Product, Sum}; use std::iter::{Product, Sum};
use std::ops::{Mul, Neg, Not}; use std::ops::{Mul, Neg, Not};
@ -37,7 +38,7 @@ impl Not for Parity {
} }
} }
/// The result of the [LU decomposition](LUDecomposable::lu) of a matrix. /// The result of the [LU decomposition](LUDecompose::lu) of a matrix.
/// ///
/// This struct provides a convenient way to reuse one LU decomposition to solve multiple /// This struct provides a convenient way to reuse one LU decomposition to solve multiple
/// matrix equations. You likely do not need to worry about its contents. /// matrix equations. You likely do not need to worry about its contents.
@ -46,26 +47,26 @@ impl Not for Parity {
/// on wikipedia for more information /// on wikipedia for more information
#[derive(Copy, Clone, Debug, PartialEq)] #[derive(Copy, Clone, Debug, PartialEq)]
pub struct LUDecomposition<T: Copy, const N: usize> { pub struct LUDecomposition<T: Copy, const N: usize> {
/// The $L$ and $U$ matrices combined into one /// The $bbL$ and $bbU$ matrices combined into one
/// ///
/// for example if /// for example if
/// ///
/// $ U = [[u_{11}, u_{12}, cdots, u_{1n} ], /// $ bbU = [[u_{11}, u_{12}, cdots, u_{1n} ],
/// [0, u_{22}, cdots, u_{2n} ], /// [0, u_{22}, cdots, u_{2n} ],
/// [vdots, vdots, ddots, vdots ], /// [vdots, vdots, ddots, vdots ],
/// [0, 0, cdots, u_{mn} ]] $ /// [0, 0, cdots, u_{mn} ]] $
/// and /// and
/// $ L = [[1, 0, cdots, 0 ], /// $ bbL = [[1, 0, cdots, 0 ],
/// [l_{21}, 1, cdots, 0 ], /// [l_{21}, 1, cdots, 0 ],
/// [vdots, vdots, ddots, vdots ], /// [vdots, vdots, ddots, vdots ],
/// [l_{m1}, l_{m2}, cdots, 1 ]] $, /// [l_{m1}, l_{m2}, cdots, 1 ]] $,
/// then /// then
/// $ LU = [[u_{11}, u_{12}, cdots, u_{1n} ], /// $ bb{LU} = [[u_{11}, u_{12}, cdots, u_{1n} ],
/// [l_{21}, u_{22}, cdots, u_{2n} ], /// [l_{21}, u_{22}, cdots, u_{2n} ],
/// [vdots, vdots, ddots, vdots ], /// [vdots, vdots, ddots, vdots ],
/// [l_{m1}, l_{m2}, cdots, u_{mn} ]] $ /// [l_{m1}, l_{m2}, cdots, u_{mn} ]] $
/// ///
/// note that the diagonals of the $L$ matrix are always 1, so no information is lost /// note that the diagonals of the $bbL$ matrix are always 1, so no information is lost
pub lu: Matrix<T, N, N>, pub lu: Matrix<T, N, N>,
/// The indices of the permutation matrix $P$, such that $PxxA$ = $LxxU$ /// The indices of the permutation matrix $P$, such that $PxxA$ = $LxxU$
@ -79,13 +80,10 @@ pub struct LUDecomposition<T: Copy, const N: usize> {
pub parity: Parity, pub parity: Parity,
} }
impl<T: Copy + Default, const N: usize> LUDecomposition<T, N> impl<T: Copy + Default + Real, const N: usize> LUDecomposition<T, N> {
where /// Solve for $x$ in $bbM xx x = b$, where $bbM$ is the original matrix this is a decomposition of.
T: Real + Default + Sum + Product,
{
/// Solve for $x$ in $M xx x = b$, where $M$ is the original matrix this is a decomposition of.
/// ///
/// This is equivalent to [`LUDecomposable::solve`] while allowing the LU decomposition /// This is equivalent to [`LUDecompose::solve`] while allowing the LU decomposition
/// to be reused /// to be reused
#[must_use] #[must_use]
pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> { pub fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Matrix<T, N, M> {
@ -123,17 +121,17 @@ where
/// Calculate the determinant $|M|$ of the matrix $M$. /// Calculate the determinant $|M|$ of the matrix $M$.
/// If the matrix is singular, the determinant is 0. /// If the matrix is singular, the determinant is 0.
/// ///
/// This is equivalent to [`LUDecomposable::det`] while allowing the LU decomposition /// This is equivalent to [`LUDecompose::det`] while allowing the LU decomposition
/// to be reused /// to be reused
pub fn det(&self) -> T { pub fn det(&self) -> T {
self.parity * self.lu.diagonals().product() self.parity * self.lu.diagonals().fold(T::one(), T::mul)
} }
/// Calculate the inverse of the original matrix, such that $MxxM^{-1} = I$ /// Calculate the inverse of the original matrix, such that $bbM xx bbM^{-1} = bbI$
/// ///
/// This is equivalent to [`Matrix::inverse`] while allowing the LU decomposition to be reused /// This is equivalent to [`Matrix::inv`] while allowing the LU decomposition to be reused
#[must_use] #[must_use]
pub fn inverse(&self) -> Matrix<T, N, N> { pub fn inv(&self) -> Matrix<T, N, N> {
return self.solve(&Matrix::<T, N, N>::identity()); return self.solve(&Matrix::<T, N, N>::identity());
} }
@ -160,17 +158,14 @@ where
/// ///
/// See [LU decomposition](https://en.wikipedia.org/wiki/LU_decomposition) /// See [LU decomposition](https://en.wikipedia.org/wiki/LU_decomposition)
/// on wikipedia for more information /// on wikipedia for more information
pub trait LUDecomposable<T, const N: usize> pub trait LUDecompose<T: Copy, const N: usize> {
where
T: Copy + Default + Real + Product + Sum,
{
/// return this matrix's [`LUDecomposition`], or [`None`] if the matrix is singular. /// return this matrix's [`LUDecomposition`], or [`None`] if the matrix is singular.
/// This can be used to solve for multiple results /// This can be used to solve for multiple results
/// ///
/// ``` /// ```
/// # use vector_victor::decompose::LUDecomposable; /// # use vector_victor::decompose::LUDecompose;
/// # use vector_victor::{Matrix, Vector}; /// # use vector_victor::{Matrix, Vector};
/// let m = Matrix::new([[1.0,3.0],[2.0,4.0]]); /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
/// let lu = m.lu().expect("Cannot decompose a signular matrix"); /// let lu = m.lu().expect("Cannot decompose a signular matrix");
/// ///
/// let b = Vector::vec([7.0,10.0]); /// let b = Vector::vec([7.0,10.0]);
@ -183,34 +178,35 @@ where
#[must_use] #[must_use]
fn lu(&self) -> Option<LUDecomposition<T, N>>; fn lu(&self) -> Option<LUDecomposition<T, N>>;
/// Calculate the inverse of the matrix, such that $MxxM^{-1} = I$, or [`None`] if the matrix is singular. /// Calculate the inverse of the matrix, such that $bbMxxbbM^{-1} = bbI$,
/// or [`None`] if the matrix is singular.
/// ///
/// ``` /// ```
/// # use vector_victor::decompose::LUDecomposable; /// # use vector_victor::decompose::LUDecompose;
/// # use vector_victor::Matrix; /// # use vector_victor::Matrix;
/// let m = Matrix::new([[1.0,3.0],[2.0,4.0]]); /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
/// let mi = m.inverse().expect("Cannot invert a singular matrix"); /// let mi = m.inv().expect("Cannot invert a singular matrix");
/// ///
/// assert_eq!(mi, Matrix::new([[-2.0, 1.5],[1.0, -0.5]]), "unexpected inverse matrix"); /// assert_eq!(mi, Matrix::mat([[-2.0, 1.5],[1.0, -0.5]]), "unexpected inverse matrix");
/// ///
/// // multiplying a matrix by its inverse yields the identity matrix /// // multiplying a matrix by its inverse yields the identity matrix
/// assert_eq!(m.mmul(&mi), Matrix::identity()) /// assert_eq!(m.mmul(&mi), Matrix::identity())
/// ``` /// ```
#[must_use] #[must_use]
fn inverse(&self) -> Option<Matrix<T, N, N>>; fn inv(&self) -> Option<Matrix<T, N, N>>;
/// Calculate the determinant $|M|$ of the matrix $M$. /// Calculate the determinant $|M|$ of the matrix $M$.
/// If the matrix is singular, the determinant is 0 /// If the matrix is singular, the determinant is 0
#[must_use] #[must_use]
fn det(&self) -> T; fn det(&self) -> T;
/// Solve for $x$ in $M xx x = b$ /// Solve for $x$ in $bbM xx x = b$
/// ///
/// ``` /// ```
/// # use vector_victor::decompose::LUDecomposable; /// # use vector_victor::decompose::LUDecompose;
/// # use vector_victor::{Matrix, Vector}; /// # use vector_victor::{Matrix, Vector};
/// ///
/// let m = Matrix::new([[1.0,3.0],[2.0,4.0]]); /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
/// let b = Vector::vec([7.0,10.0]); /// let b = Vector::vec([7.0,10.0]);
/// let x = m.solve(&b).expect("Cannot solve a singular matrix"); /// let x = m.solve(&b).expect("Cannot solve a singular matrix");
/// ///
@ -219,26 +215,26 @@ where
/// ``` /// ```
/// ///
/// $x$ does not need to be a column-vector, it can also be a 2D matrix. For example, /// $x$ does not need to be a column-vector, it can also be a 2D matrix. For example,
/// the following is another way to calculate the [inverse](LUDecomposable::inverse()) by solving for the identity matrix $I$. /// the following is another way to calculate the [inverse](LUDecompose::inv()) by solving for the identity matrix $I$.
/// ///
/// ``` /// ```
/// # use vector_victor::decompose::LUDecomposable; /// # use vector_victor::decompose::LUDecompose;
/// # use vector_victor::{Matrix, Vector}; /// # use vector_victor::{Matrix, Vector};
/// ///
/// let m = Matrix::new([[1.0,3.0],[2.0,4.0]]); /// let m = Matrix::mat([[1.0,3.0],[2.0,4.0]]);
/// let i = Matrix::<f64,2,2>::identity(); /// let i = Matrix::<f64,2,2>::identity();
/// let mi = m.solve(&i).expect("Cannot solve a singular matrix"); /// let mi = m.solve(&i).expect("Cannot solve a singular matrix");
/// ///
/// assert_eq!(mi, Matrix::new([[-2.0, 1.5],[1.0, -0.5]])); /// assert_eq!(mi, Matrix::mat([[-2.0, 1.5],[1.0, -0.5]]));
/// assert_eq!(m.mmul(&mi), i, "M x M^-1 = I"); /// assert_eq!(m.mmul(&mi), i, "M x M^-1 = I");
/// ``` /// ```
#[must_use] #[must_use]
fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>; fn solve<const M: usize>(&self, b: &Matrix<T, N, M>) -> Option<Matrix<T, N, M>>;
} }
impl<T, const N: usize> LUDecomposable<T, N> for Matrix<T, N, N> impl<T, const N: usize> LUDecompose<T, N> for Matrix<T, N, N>
where where
T: Copy + Default + Real + Sum + Product, T: Copy + Default + Real + Sum + Product + Signed,
{ {
fn lu(&self) -> Option<LUDecomposition<T, N>> { fn lu(&self) -> Option<LUDecomposition<T, N>> {
// Implementation from Numerical Recipes §2.3 // Implementation from Numerical Recipes §2.3
@ -300,7 +296,7 @@ where
return Some(LUDecomposition { lu, idx, parity }); return Some(LUDecomposition { lu, idx, parity });
} }
fn inverse(&self) -> Option<Matrix<T, N, N>> { fn inv(&self) -> Option<Matrix<T, N, N>> {
match N { match N {
1 => Some(Self::fill(checked_inv(self[0])?)), 1 => Some(Self::fill(checked_inv(self[0])?)),
2 => { 2 => {
@ -311,7 +307,7 @@ where
result[(0, 1)] = -self[(0, 1)]; result[(0, 1)] = -self[(0, 1)];
Some(result * checked_inv(self.det())?) Some(result * checked_inv(self.det())?)
} }
_ => Some(self.lu()?.inverse()), _ => Some(self.lu()?.inv()),
} }
} }

View File

@ -4,18 +4,21 @@ mod common;
use common::Approx; use common::Approx;
use generic_parameterize::parameterize; use generic_parameterize::parameterize;
use num_traits::real::Real; use num_traits::real::Real;
use num_traits::Zero; use num_traits::{Float, One, Signed, Zero};
use std::fmt::Debug; use std::fmt::Debug;
use std::iter::{Product, Sum}; use std::iter::{Product, Sum};
use vector_victor::decompose::Parity::Even; use vector_victor::decompose::{LUDecompose, LUDecomposition, Parity};
use vector_victor::decompose::{LUDecomposable, LUDecomposition};
use vector_victor::{Matrix, Vector}; use vector_victor::{Matrix, Vector};
#[parameterize(S = (f32, f64), M = [1,2,3,4])] #[parameterize(S = (f32, f64), M = [1,2,3,4])]
#[test] #[test]
/// The LU decomposition of the identity matrix should produce /// The LU decomposition of the identity matrix should produce
/// the identity matrix with no permutations and parity 1 /// the identity matrix with no permutations and parity 1
fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M: usize>() { fn test_lu_identity<S, const M: usize>()
where
Matrix<S, M, M>: LUDecompose<S, M>,
S: Copy + Real + Debug + Approx + Default,
{
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity(); // let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
let i = Matrix::<S, M, M>::identity(); let i = Matrix::<S, M, M>::identity();
let ones = Vector::<S, M>::fill(S::one()); let ones = Vector::<S, M>::fill(S::one());
@ -26,7 +29,7 @@ fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M:
(0..M).eq(idx.elements().cloned()), (0..M).eq(idx.elements().cloned()),
"Incorrect permutation matrix", "Incorrect permutation matrix",
); );
assert_eq!(parity, Even, "Incorrect permutation parity"); assert_eq!(parity, Parity::Even, "Incorrect permutation parity");
// Check determinant calculation which uses LU decomposition // Check determinant calculation which uses LU decomposition
assert_approx!( assert_approx!(
@ -37,7 +40,7 @@ fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M:
// Check inverse calculation with uses LU decomposition // Check inverse calculation with uses LU decomposition
assert_eq!( assert_eq!(
i.inverse(), i.inv(),
Some(i), Some(i),
"Identity matrix should be its own inverse" "Identity matrix should be its own inverse"
); );
@ -54,7 +57,11 @@ fn test_lu_identity<S: Default + Approx + Real + Debug + Product + Sum, const M:
#[parameterize(S = (f32, f64), M = [2,3,4])] #[parameterize(S = (f32, f64), M = [2,3,4])]
#[test] #[test]
/// The LU decomposition of any singular matrix should be `None` /// The LU decomposition of any singular matrix should be `None`
fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>() { fn test_lu_singular<S, const M: usize>()
where
Matrix<S, M, M>: LUDecompose<S, M>,
S: Copy + Real + Debug + Approx + Default,
{
// let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity(); // let a: Matrix<f32, 3, 3> = Matrix::<f32, 3, 3>::identity();
let mut a = Matrix::<S, M, M>::zero(); let mut a = Matrix::<S, M, M>::zero();
let ones = Vector::<S, M>::fill(S::one()); let ones = Vector::<S, M>::fill(S::one());
@ -66,7 +73,7 @@ fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>()
S::zero(), S::zero(),
"Singular matrix should have determinant of zero" "Singular matrix should have determinant of zero"
); );
assert_eq!(a.inverse(), None, "Singular matrix should have no inverse"); assert_eq!(a.inv(), None, "Singular matrix should have no inverse");
assert_eq!( assert_eq!(
a.solve(&ones), a.solve(&ones),
None, None,
@ -76,7 +83,7 @@ fn test_lu_singular<S: Default + Real + Debug + Product + Sum, const M: usize>()
#[test] #[test]
fn test_lu_2x2() { fn test_lu_2x2() {
let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); let a = Matrix::mat([[1.0, 2.0], [3.0, 0.0]]);
let decomp = a.lu().expect("Singular matrix encountered"); let decomp = a.lu().expect("Singular matrix encountered");
// the decomposition is non-unique, due to the combination of lu and idx. // the decomposition is non-unique, due to the combination of lu and idx.
// Instead of checking the exact value, we only check the results. // Instead of checking the exact value, we only check the results.
@ -90,16 +97,16 @@ fn test_lu_2x2() {
assert_approx!(a.det(), decomp.det()); assert_approx!(a.det(), decomp.det());
assert_approx!( assert_approx!(
a.inverse().unwrap(), a.inv().unwrap(),
Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0) Matrix::mat([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)
); );
assert_approx!(a.inverse().unwrap(), decomp.inverse()); assert_approx!(a.inv().unwrap(), decomp.inv());
assert_approx!(a.inverse().unwrap().inverse().unwrap(), a) assert_approx!(a.inv().unwrap().inv().unwrap(), a)
} }
#[test] #[test]
fn test_lu_3x3() { fn test_lu_3x3() {
let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]); let a = Matrix::mat([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]);
let decomp = a.lu().expect("Singular matrix encountered"); let decomp = a.lu().expect("Singular matrix encountered");
let (l, u) = decomp.separate(); let (l, u) = decomp.separate();
@ -109,9 +116,9 @@ fn test_lu_3x3() {
assert_approx!(a.det(), decomp.det()); assert_approx!(a.det(), decomp.det());
assert_approx!( assert_approx!(
a.inverse().unwrap(), a.inv().unwrap(),
Matrix::new([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0) Matrix::mat([[9.0, -28.0, 11.0], [6.0, -20.0, 7.0], [3.0, -9.0, 3.0]]) * (1.0 / 3.0)
); );
assert_approx!(a.inverse().unwrap(), decomp.inverse()); assert_approx!(a.inv().unwrap(), decomp.inv());
assert_approx!(a.inverse().unwrap().inverse().unwrap(), a) assert_approx!(a.inv().unwrap().inv().unwrap(), a)
} }