From 543769f6912ab5ea93ec4b3aa37522076fcddb35 Mon Sep 17 00:00:00 2001 From: Andrew Cassidy Date: Sat, 6 May 2023 01:34:31 -0700 Subject: [PATCH] Separate LU decomposition into its own file where other solving stuff will live --- src/lib.rs | 3 +- src/matrix.rs | 181 ++------------------------- src/solve.rs | 153 ++++++++++++++++++++++ tests/ops.rs | 2 +- tests/{decomposition.rs => solve.rs} | 13 +- 5 files changed, 172 insertions(+), 180 deletions(-) create mode 100644 src/solve.rs rename tests/{decomposition.rs => solve.rs} (91%) diff --git a/src/lib.rs b/src/lib.rs index 514bf19..5538c77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ extern crate core; pub mod index; mod macros; mod matrix; +pub mod solve; mod util; -pub use matrix::{LUSolve, Matrix, Vector}; +pub use matrix::{Matrix, Vector}; diff --git a/src/matrix.rs b/src/matrix.rs index d8d5ead..105d714 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -7,6 +7,7 @@ use num_traits::{Num, NumOps, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; +use crate::solve::{LUDecomp, LUSolve}; use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg}; /// A 2D array of values which can be operated upon. @@ -398,85 +399,17 @@ impl Matrix { pub fn subdiagonals<'s>(&'s self) -> impl Iterator + 's { (0..N - 1).map(|n| self[(n, n + 1)]) } +} - /// Returns `Some(lu, idx, d)`, or [None] if the matrix is singular. - /// - /// Where: - /// * `lu`: The LU decomposition of `self`. The upper and lower matrices are combined into a single matrix - /// * `idx`: The permutation of rows on the original matrix needed to perform the decomposition. - /// Each element is the corresponding row index in the original matrix - /// * `d`: The permutation parity of `idx`, either `1` for even or `-1` for odd - /// - /// The resulting tuple (once unwrapped) has the [LUSolve] trait, allowing it to be used for - /// solving multiple matrices without having to repeat the LU decomposition process - #[must_use] - pub fn lu(&self) -> Option<(Self, Vector, T)> - where - T: Real + Default, - { - // Implementation from Numerical Recipes §2.3 - let mut lu = self.clone(); - let mut idx: Vector = (0..N).collect(); - let mut d = T::one(); - - let mut vv: Vector = self - .rows() - .map(|row| { - let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?; - match m < T::epsilon() { - true => None, - false => Some(T::one() / m), - } - }) - .collect::>()?; // get the inverse maxabs value in each row - - for k in 0..N { - // search for the pivot element and its index - let (ipivot, _) = (lu.col(k) * vv) - .abs() - .elements() - .enumerate() - .skip(k) // below the diagonal - .reduce(|(imax, xmax), (i, x)| match x > xmax { - // Is the figure of merit for the pivot better than the best so far? - true => (i, x), - false => (imax, xmax), - })?; - - // do we need to interchange rows? - if k != ipivot { - lu.pivot_row(ipivot, k); // yes, we do - idx.pivot_row(ipivot, k); - d = -d; // change parity of d - vv[ipivot] = vv[k] //interchange scale factor - } - - let pivot = lu[(k, k)]; - if pivot.abs() < T::epsilon() { - // if the pivot is zero, the matrix is singular - return None; - }; - - for i in (k + 1)..N { - // divide by the pivot element - let dpivot = lu[(i, k)] / pivot; - lu[(i, k)] = dpivot; - for j in (k + 1)..N { - // reduce remaining submatrix - lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]); - } - } - } - - return Some((lu, idx, d)); +impl LUSolve for Matrix +where + T: Copy + Default + Real + Sum + Product, +{ + fn lu(&self) -> Option> { + LUDecomp::decompose(self) } - /// Computes the inverse matrix of `self`, or [None] if the matrix cannot be inverted. - #[must_use] - pub fn inverse(&self) -> Option - where - T: Real + Default + Sum + Product, - { + fn inverse(&self) -> Option> { match N { 1 => Some(Self::fill(checked_inv(self[0])?)), 2 => { @@ -491,12 +424,7 @@ impl Matrix { } } - /// Computes the determinant of `self`. - #[must_use] - pub fn det(&self) -> T - where - T: Real + Default + Product + Sum, - { + fn det(&self) -> T { match N { 1 => self[0], 2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]), @@ -520,95 +448,6 @@ impl Matrix { } } } - /// Solves a system of `Ax = b` using `self` for `A`, or [None] if there is no solution. - #[must_use] - pub fn solve(&self, b: &Matrix) -> Option> - where - T: Real + Default + Sum + Product, - { - Some(self.lu()?.solve(b)) - } -} - -/// Trait for the result of [Matrix::lu()], -/// allowing a single LU decomposition to be used to solve multiple equations -pub trait LUSolve: Copy -where - T: Real + Copy, -{ - /// Solves a system of `Ax = b` using an LU decomposition. - fn solve(&self, rhs: &Matrix) -> Matrix; - - /// Solves the determinant using an LU decomposition, - /// by multiplying the product of the diagonals by the permutation parity - fn det(&self) -> T; - - /// Solves the inverse of the matrix that the LU decomposition represents. - fn inverse(&self) -> Matrix { - return self.solve(&Matrix::::identity()); - } - - /// Separate the lu decomposition into L and U matrices, such that `L*U = P*A`. - fn separate(&self) -> (Matrix, Matrix); -} - -impl LUSolve for (Matrix, Vector, T) -where - T: Real + Default + Sum + Product, -{ - #[must_use] - fn solve(&self, b: &Matrix) -> Matrix { - let (lu, idx, _) = self; - let bp = b.permute_rows(idx); - - Matrix::from_cols(bp.cols().map(|mut x| { - // Implementation from Numerical Recipes §2.3 - // When ii is set to a positive value, - // it will become the index of the first nonvanishing element of b - let mut ii = 0usize; - for i in 0..N { - // forward substitution using L - let mut sum = x[i]; - if ii != 0 { - for j in (ii - 1)..i { - sum = sum - (lu[(i, j)] * x[j]); - } - } else if sum.abs() > T::epsilon() { - ii = i + 1; - } - x[i] = sum; - } - for i in (0..N).rev() { - // back substitution using U - let mut sum = x[i]; - for j in (i + 1)..N { - sum = sum - (lu[(i, j)] * x[j]); - } - x[i] = sum / lu[(i, i)] - } - x - })) - } - - fn det(&self) -> T { - let (lu, _, d) = self; - *d * lu.diagonals().product() - } - - fn separate(&self) -> (Matrix, Matrix) { - let mut l = Matrix::::identity(); - let mut u = self.0; // lu - - for m in 1..N { - for n in 0..m { - // iterate over lower diagonal - l[(m, n)] = u[(m, n)]; - u[(m, n)] = T::zero(); - } - } - - (l, u) - } } // Index diff --git a/src/solve.rs b/src/solve.rs new file mode 100644 index 0000000..39a6c47 --- /dev/null +++ b/src/solve.rs @@ -0,0 +1,153 @@ +use crate::util::checked_inv; +use crate::Matrix; +use crate::Vector; +use num_traits::real::Real; +use num_traits::{One, Zero}; +use std::iter::{Product, Sum}; +use std::ops::Index; + +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct LUDecomp { + pub lu: Matrix, + pub idx: Vector, + pub parity: T, +} + +impl LUDecomp +where + T: Real + Default + Sum + Product, +{ + #[must_use] + pub fn decompose(m: &Matrix) -> Option { + // Implementation from Numerical Recipes §2.3 + let mut lu = m.clone(); + let mut idx: Vector = (0..N).collect(); + let mut parity = T::one(); + + let mut vv: Vector = m + .rows() + .map(|row| { + let m = row.elements().cloned().reduce(|acc, x| acc.max(x.abs()))?; + match m < T::epsilon() { + true => None, + false => Some(T::one() / m), + } + }) + .collect::>()?; // get the inverse maxabs value in each row + + for k in 0..N { + // search for the pivot element and its index + let (ipivot, _) = (lu.col(k) * vv) + .abs() + .elements() + .enumerate() + .skip(k) // below the diagonal + .reduce(|(imax, xmax), (i, x)| match x > xmax { + // Is the figure of merit for the pivot better than the best so far? + true => (i, x), + false => (imax, xmax), + })?; + + // do we need to interchange rows? + if k != ipivot { + lu.pivot_row(ipivot, k); // yes, we do + idx.pivot_row(ipivot, k); + parity = -parity; // change parity of d + vv[ipivot] = vv[k] //interchange scale factor + } + + let pivot = lu[(k, k)]; + if pivot.abs() < T::epsilon() { + // if the pivot is zero, the matrix is singular + return None; + }; + + for i in (k + 1)..N { + // divide by the pivot element + let dpivot = lu[(i, k)] / pivot; + lu[(i, k)] = dpivot; + for j in (k + 1)..N { + // reduce remaining submatrix + lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]); + } + } + } + + return Some(Self { lu, idx, parity }); + } + + #[must_use] + pub fn solve(&self, b: &Matrix) -> Matrix { + let b_permuted = b.permute_rows(&self.idx); + + Matrix::from_cols(b_permuted.cols().map(|mut x| { + // Implementation from Numerical Recipes §2.3 + // When ii is set to a positive value, + // it will become the index of the first nonvanishing element of b + let mut ii = 0usize; + for i in 0..N { + // forward substitution using L + let mut sum = x[i]; + if ii != 0 { + for j in (ii - 1)..i { + sum = sum - (self.lu[(i, j)] * x[j]); + } + } else if sum.abs() > T::epsilon() { + ii = i + 1; + } + x[i] = sum; + } + for i in (0..N).rev() { + // back substitution using U + let mut sum = x[i]; + for j in (i + 1)..N { + sum = sum - (self.lu[(i, j)] * x[j]); + } + x[i] = sum / self.lu[(i, i)] + } + x + })) + } + + pub fn det(&self) -> T { + self.parity * self.lu.diagonals().product() + } + + pub fn inverse(&self) -> Matrix { + return self.solve(&Matrix::::identity()); + } + + pub fn separate(&self) -> (Matrix, Matrix) { + let mut l = Matrix::::identity(); + let mut u = self.lu; // lu + + for m in 1..N { + for n in 0..m { + // iterate over lower diagonal + l[(m, n)] = u[(m, n)]; + u[(m, n)] = T::zero(); + } + } + + (l, u) + } +} + +pub trait LUSolve +where + T: Copy + Default + Real + Product + Sum, +{ + #[must_use] + fn lu(&self) -> Option>; + + #[must_use] + fn inverse(&self) -> Option>; + + #[must_use] + fn det(&self) -> T; + + #[must_use] + fn solve(&self, b: &Matrix) -> Option> { + Some(self.lu()?.solve(b)) + } +} diff --git a/tests/ops.rs b/tests/ops.rs index 140c69a..180e8e4 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -8,7 +8,7 @@ use num_traits::Zero; use std::fmt::Debug; use std::iter::{zip, Product, Sum}; use std::ops; -use vector_victor::{LUSolve, Matrix, Vector}; +use vector_victor::{Matrix, Vector}; #[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])] #[test] diff --git a/tests/decomposition.rs b/tests/solve.rs similarity index 91% rename from tests/decomposition.rs rename to tests/solve.rs index 5f42744..35bee25 100644 --- a/tests/decomposition.rs +++ b/tests/solve.rs @@ -7,7 +7,8 @@ use num_traits::real::Real; use num_traits::Zero; use std::fmt::Debug; use std::iter::{zip, Product, Sum}; -use vector_victor::{LUSolve, Matrix, Vector}; +use vector_victor::solve::{LUDecomp, LUSolve}; +use vector_victor::{Matrix, Vector}; #[parameterize(S = (f32, f64), M = [1,2,3,4])] #[test] @@ -18,13 +19,13 @@ fn test_lu_identity::identity(); let ones = Vector::::fill(S::one()); let decomp = i.lu().expect("Singular matrix encountered"); - let (lu, idx, d) = decomp; + let LUDecomp { lu, idx, parity } = decomp; assert_eq!(lu, i, "Incorrect LU decomposition"); assert!( (0..M).eq(idx.elements().cloned()), "Incorrect permutation matrix", ); - assert_approx!(d, S::one(), "Incorrect permutation parity"); + assert_approx!(parity, S::one(), "Incorrect permutation parity"); // Check determinant calculation which uses LU decomposition assert_approx!( @@ -76,14 +77,13 @@ fn test_lu_singular() fn test_lu_2x2() { let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); let decomp = a.lu().expect("Singular matrix encountered"); - let (_lu, idx, _d) = decomp; // the decomposition is non-unique, due to the combination of lu and idx. // Instead of checking the exact value, we only check the results. // Also check if they produce the same results with both methods, since the // Matrix<> methods use shortcuts the decomposition methods don't let (l, u) = decomp.separate(); - assert_approx!(l.mmul(&u), a.permute_rows(&idx)); + assert_approx!(l.mmul(&u), a.permute_rows(&decomp.idx)); assert_approx!(a.det(), -6.0); assert_approx!(a.det(), decomp.det()); @@ -100,10 +100,9 @@ fn test_lu_2x2() { fn test_lu_3x3() { let a = Matrix::new([[1.0, -5.0, 8.0], [1.0, -2.0, 1.0], [2.0, -1.0, -4.0]]); let decomp = a.lu().expect("Singular matrix encountered"); - let (_lu, idx, _d) = decomp; let (l, u) = decomp.separate(); - assert_approx!(l.mmul(&u), a.permute_rows(&idx)); + assert_approx!(l.mmul(&u), a.permute_rows(&decomp.idx)); assert_approx!(a.det(), 3.0); assert_approx!(a.det(), decomp.det());