diff --git a/src/solve.rs b/src/decompose.rs similarity index 70% rename from src/solve.rs rename to src/decompose.rs index 39a6c47..2cf12c0 100644 --- a/src/solve.rs +++ b/src/decompose.rs @@ -2,18 +2,16 @@ use crate::util::checked_inv; use crate::Matrix; use crate::Vector; use num_traits::real::Real; -use num_traits::{One, Zero}; use std::iter::{Product, Sum}; -use std::ops::Index; #[derive(Copy, Clone, Debug, PartialEq)] -pub struct LUDecomp { +pub struct LUDecomposition { pub lu: Matrix, pub idx: Vector, pub parity: T, } -impl LUDecomp +impl LUDecomposition where T: Real + Default + Sum + Product, { @@ -133,12 +131,12 @@ where } } -pub trait LUSolve +pub trait LUDecomposable where T: Copy + Default + Real + Product + Sum, { #[must_use] - fn lu(&self) -> Option>; + fn lu(&self) -> Option>; #[must_use] fn inverse(&self) -> Option>; @@ -147,6 +145,57 @@ where fn det(&self) -> T; #[must_use] + fn solve(&self, b: &Matrix) -> Option>; +} + +impl LUDecomposable for Matrix +where + T: Copy + Default + Real + Sum + Product, +{ + fn lu(&self) -> Option> { + LUDecomposition::decompose(self) + } + + fn inverse(&self) -> Option> { + match N { + 1 => Some(Self::fill(checked_inv(self[0])?)), + 2 => { + let mut result = Self::default(); + result[(0, 0)] = self[(1, 1)]; + result[(1, 1)] = self[(0, 0)]; + result[(1, 0)] = -self[(1, 0)]; + result[(0, 1)] = -self[(0, 1)]; + Some(result * checked_inv(self.det())?) + } + _ => Some(self.lu()?.inverse()), + } + } + + fn det(&self) -> T { + match N { + 1 => self[0], + 2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]), + 3 => { + // use rule of Sarrus + (0..N) // starting column + .map(|i| { + let dn = (0..N) + .map(|j| -> T { self[(j, (j + i) % N)] }) + .product::(); + let up = (0..N) + .map(|j| -> T { self[(N - j - 1, (j + i) % N)] }) + .product::(); + dn - up + }) + .sum::() + } + _ => { + // use LU decomposition + self.lu().map_or(T::zero(), |lu| lu.det()) + } + } + } + fn solve(&self, b: &Matrix) -> Option> { Some(self.lu()?.solve(b)) } diff --git a/src/lib.rs b/src/lib.rs index 5538c77..30bbab5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ extern crate core; +pub mod decompose; pub mod index; mod macros; mod matrix; -pub mod solve; mod util; pub use matrix::{Matrix, Vector}; diff --git a/src/matrix.rs b/src/matrix.rs index 105d714..5f97e3a 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,13 +1,10 @@ use crate::impl_matrix_op; use crate::index::Index2D; -use crate::util::checked_inv; -use num_traits::real::Real; use num_traits::{Num, NumOps, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; -use crate::solve::{LUDecomp, LUSolve}; use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg}; /// A 2D array of values which can be operated upon. @@ -401,55 +398,6 @@ impl Matrix { } } -impl LUSolve for Matrix -where - T: Copy + Default + Real + Sum + Product, -{ - fn lu(&self) -> Option> { - LUDecomp::decompose(self) - } - - fn inverse(&self) -> Option> { - match N { - 1 => Some(Self::fill(checked_inv(self[0])?)), - 2 => { - let mut result = Self::default(); - result[(0, 0)] = self[(1, 1)]; - result[(1, 1)] = self[(0, 0)]; - result[(1, 0)] = -self[(1, 0)]; - result[(0, 1)] = -self[(0, 1)]; - Some(result * checked_inv(self.det())?) - } - _ => Some(self.lu()?.inverse()), - } - } - - fn det(&self) -> T { - match N { - 1 => self[0], - 2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]), - 3 => { - // use rule of Sarrus - (0..N) // starting column - .map(|i| { - let dn = (0..N) - .map(|j| -> T { self[(j, (j + i) % N)] }) - .product::(); - let up = (0..N) - .map(|j| -> T { self[(N - j - 1, (j + i) % N)] }) - .product::(); - dn - up - }) - .sum::() - } - _ => { - // use LU decomposition - self.lu().map_or(T::zero(), |lu| lu.det()) - } - } - } -} - // Index impl Index for Matrix where diff --git a/src/util.rs b/src/util.rs index 2db0147..5649688 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,4 @@ -use num_traits::{Num, NumOps, One, Zero}; +use num_traits::{Num, One, Zero}; use std::ops::Div; pub fn checked_div, R: Num + Zero, T>(num: L, den: R) -> Option { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index f301ac3..936c7da 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -29,10 +29,6 @@ impl Approx for Matrix(left: &T, right: &T) -> bool { - T::approx(left, right) -} - macro_rules! assert_approx { ($left:expr, $right:expr $(,)?) => { match (&$left, &$right) { diff --git a/tests/solve.rs b/tests/decompose.rs similarity index 96% rename from tests/solve.rs rename to tests/decompose.rs index 35bee25..e159e50 100644 --- a/tests/solve.rs +++ b/tests/decompose.rs @@ -6,8 +6,8 @@ use generic_parameterize::parameterize; use num_traits::real::Real; use num_traits::Zero; use std::fmt::Debug; -use std::iter::{zip, Product, Sum}; -use vector_victor::solve::{LUDecomp, LUSolve}; +use std::iter::{Product, Sum}; +use vector_victor::decompose::{LUDecomposable, LUDecomposition}; use vector_victor::{Matrix, Vector}; #[parameterize(S = (f32, f64), M = [1,2,3,4])] @@ -19,7 +19,7 @@ fn test_lu_identity::identity(); let ones = Vector::::fill(S::one()); let decomp = i.lu().expect("Singular matrix encountered"); - let LUDecomp { lu, idx, parity } = decomp; + let LUDecomposition { lu, idx, parity } = decomp; assert_eq!(lu, i, "Incorrect LU decomposition"); assert!( (0..M).eq(idx.elements().cloned()), diff --git a/tests/ops.rs b/tests/ops.rs index 180e8e4..5ccd12e 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -1,14 +1,10 @@ #[macro_use] mod common; -use common::Approx; use generic_parameterize::parameterize; -use num_traits::real::Real; -use num_traits::Zero; use std::fmt::Debug; -use std::iter::{zip, Product, Sum}; use std::ops; -use vector_victor::{Matrix, Vector}; +use vector_victor::Matrix; #[parameterize(S = (i32, f32, f64, u32), M = [1,4], N = [1,4])] #[test]