diff --git a/src/lib.rs b/src/lib.rs index 487e657..3747c73 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,4 +4,4 @@ pub mod index; mod macros; mod matrix; -pub use matrix::{Matrix, Vector}; +pub use matrix::{LUSolve, Matrix, Vector}; diff --git a/src/macros/ops.rs b/src/macros/ops.rs index 9613382..18f9707 100644 --- a/src/macros/ops.rs +++ b/src/macros/ops.rs @@ -56,6 +56,7 @@ macro_rules! impl_matrix_op { }; } +#[doc(hidden)] #[macro_export] macro_rules! _impl_op_m_internal_ex { ($ops_trait:ident, $ops_fn:ident) => { @@ -64,6 +65,7 @@ macro_rules! _impl_op_m_internal_ex { } } +#[doc(hidden)] #[macro_export] macro_rules! _impl_op_mm_internal_ex { ($ops_trait:ident, $ops_fn:ident) => { @@ -74,6 +76,7 @@ macro_rules! _impl_op_mm_internal_ex { } } +#[doc(hidden)] #[macro_export] macro_rules! _impl_opassign_mm_internal_ex { ($ops_trait:ident, $ops_fn:ident) => { @@ -82,6 +85,7 @@ macro_rules! _impl_opassign_mm_internal_ex { } } +#[doc(hidden)] #[macro_export] macro_rules! _impl_op_ms_internal_ex { ($ops_trait:ident, $ops_fn:ident) => { @@ -90,6 +94,7 @@ macro_rules! _impl_op_ms_internal_ex { } } +#[doc(hidden)] #[macro_export] macro_rules! _impl_opassign_ms_internal_ex { ($ops_trait:ident, $ops_fn:ident) => { @@ -97,6 +102,7 @@ macro_rules! _impl_opassign_ms_internal_ex { } } +#[doc(hidden)] #[macro_export] macro_rules! _impl_op_m_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $out:ty) => { @@ -120,6 +126,7 @@ macro_rules! _impl_op_m_internal { }; } +#[doc(hidden)] #[macro_export] macro_rules! _impl_op_mm_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { @@ -144,6 +151,7 @@ macro_rules! _impl_op_mm_internal { }; } +#[doc(hidden)] #[macro_export] macro_rules! _impl_opassign_mm_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { @@ -164,6 +172,7 @@ macro_rules! _impl_opassign_mm_internal { }; } +#[doc(hidden)] #[macro_export] macro_rules! _impl_op_ms_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { @@ -188,6 +197,7 @@ macro_rules! _impl_op_ms_internal { }; } +#[doc(hidden)] #[macro_export] macro_rules! _impl_opassign_ms_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { diff --git a/src/matrix.rs b/src/matrix.rs index 8929da4..f546353 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -8,20 +8,6 @@ use std::iter::{zip, Flatten, Product, Sum}; use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub}; -/// A Scalar that a [Matrix] can be made up of. -/// -/// This trait has no associated functions and can be implemented on any type that is [Default] and -/// [Copy] and has a static lifetime. -// pub trait Scalar: Default + Copy + 'static {} -// macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) } -// multi_impl!(Scalar for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64); -// impl Scalar for &'static T -// where -// T: Scalar, -// &'static T: Default, -// { -// } - /// A 2D array of values which can be operated upon. /// /// Matrices have a fixed size known at compile time, and must be made up of types that implement @@ -415,12 +401,43 @@ impl Matrix { } } -// Matrix Decomposition and related functions -impl Matrix -where - T: Real + Default, -{ - pub fn lu(&self) -> Option<(Self, Vector, T)> { +// Square matrix impls +impl Matrix { + #[must_use] + pub fn identity() -> Self + where + T: Zero + One, + { + let mut result = Self::zero(); + for i in 0..N { + result[(i, i)] = T::one(); + } + return result; + } + + #[must_use] + pub fn diagonals<'s>(&'s self) -> impl Iterator + 's { + (0..N).map(|n| self[(n, n)]) + } + + #[must_use] + pub fn subdiagonals<'s>(&'s self) -> impl Iterator + 's { + (0..N - 1).map(|n| self[(n, n + 1)]) + } + + #[must_use] + /// + /// + /// + /// a + /// 3 + /// + /// + pub fn lu(&self) -> Option<(Self, Vector, T)> + where + T: Real + Default, + { + // Implementation from Numerical Recipes §2.3 let mut lu = self.clone(); let mut idx: Vector = Default::default(); let mut d = T::one(); @@ -438,7 +455,7 @@ where for k in 0..N { // search for the pivot element and its index - let ipivot = (lu.col(k)? * vv) + let (ipivot, _) = (lu.col(k)? * vv) .abs() .elements() .enumerate() @@ -447,8 +464,7 @@ where // Is the figure of merit for the pivot better than the best so far? true => (i, x), false => (imax, xmax), - })? - .0; + })?; // do we need to interchange rows? if k != ipivot { @@ -456,8 +472,8 @@ where d = -d; // change parity of d vv[ipivot] = vv[k] //interchange scale factor } - idx[k] = ipivot; + idx[k] = ipivot; let pivot = lu[(k, k)]; if pivot.abs() < T::epsilon() { // if the pivot is zero, the matrix is singular @@ -470,7 +486,7 @@ where lu[(i, k)] = dpivot; for j in (k + 1)..N { // reduce remaining submatrix - lu[(i, j)] = lu[(i, j)] - dpivot * lu[(k, j)]; + lu[(i, j)] = lu[(i, j)] - (dpivot * lu[(k, j)]); } } } @@ -478,36 +494,112 @@ where return Some((lu, idx, d)); } - fn inverse(&self) -> Option { - todo!() + #[must_use] + pub fn inverse(&self) -> Option + where + T: Real + Default + Sum, + { + self.solve(&Self::identity()) } - // fn det(&self) -> Self::Scalar { - // todo!() - // } + #[must_use] + pub fn det(&self) -> T + where + T: Real + Default + Product + Sum, + { + match N { + 0 => T::zero(), + 1 => self[0], + 2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]), + 3 => { + // use rule of Sarrus + (0..N) // starting column + .map(|i| { + let dn = (0..N) + .map(|j| -> T { self[(j, (j + i) % N)] }) + .product::(); + let up = (0..N) + .map(|j| -> T { self[(N - j - 1, (j + i) % N)] }) + .product::(); + dn - up + }) + .sum::() + } + _ => { + // use LU decomposition + if let Some((lu, _, d)) = self.lu() { + d * lu.diagonals().product() + } else { + T::zero() + } + } + } + } + + #[must_use] + pub fn solve(&self, b: &Matrix) -> Option> + where + T: Real + Default + Sum, + { + Some(self.lu()?.solve(b)) + } +} + +pub trait LUSolve: Copy { + fn solve(&self, rhs: &R) -> R; +} + +impl LUSolve> + for (Matrix, Vector, T) +where + for<'t> T: Real + Default + Sum, +{ + #[must_use] + fn solve(&self, b: &Matrix) -> Matrix { + let (lu, idx, _) = self; + Matrix::::from_cols(b.cols().map(|mut x| { + // Implementation from Numerical Recipes §2.3 + + // When ii is set to a positive value, + // it will become the index of the first nonvanishing element of b + let mut ii = 0usize; + for i in 0..N { + // forward substitution + let ip = idx[i]; // i permuted + let sum = x[ip]; + x[ip] = x[i]; // unscramble as we go + if ii > 0 { + x[i] = sum + - (lu.row(i).expect("Invalid row reached") * x) + .elements() + .skip(ii - 1) + .cloned() + .sum() + } else { + x[i] = sum; + if sum.abs() > T::epsilon() { + ii = i + 1; + } + } + } + for i in (0..(N - 1)).rev() { + // back substitution + let sum = x[i] + - (lu.row(i).expect("Invalid row reached") * x) + .elements() + .skip(ii - 1) + .cloned() + .sum(); + + x[i] = sum / lu[(i, i)] + } + x + })) + } } // Square matrices -impl Matrix { - pub fn identity() -> Self - where - T: Zero + One, - { - let mut result = Self::zero(); - for i in 0..N { - result[(i, i)] = T::one(); - } - return result; - } - - pub fn diagonals<'s>(&'s self) -> impl Iterator + 's { - (0..N).map(|n| self[(n, n)]) - } - - pub fn subdiagonals<'s>(&'s self) -> impl Iterator + 's { - (0..N - 1).map(|n| self[(n, n + 1)]) - } -} +impl Matrix {} // Index impl Index for Matrix