diff --git a/src/index.rs b/src/index.rs index 079d60a..6a2aeda 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; pub trait Index2D: Copy + Debug { + #[inline(always)] fn to_1d(self, height: usize, width: usize) -> Option { let (r, c) = self.to_2d(height, width)?; Some(r * width + c) @@ -10,6 +11,7 @@ pub trait Index2D: Copy + Debug { } impl Index2D for usize { + #[inline(always)] fn to_2d(self, height: usize, width: usize) -> Option<(usize, usize)> { match self < (height * width) { true => Some((self / width, self % width)), @@ -19,6 +21,7 @@ impl Index2D for usize { } impl Index2D for (usize, usize) { + #[inline(always)] fn to_2d(self, height: usize, width: usize) -> Option<(usize, usize)> { match self.0 < height && self.1 < width { true => Some(self), diff --git a/src/lib.rs b/src/lib.rs index 3747c73..514bf19 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,5 +3,6 @@ extern crate core; pub mod index; mod macros; mod matrix; +mod util; pub use matrix::{LUSolve, Matrix, Vector}; diff --git a/src/matrix.rs b/src/matrix.rs index 20b149e..cf87f8a 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,12 +1,15 @@ use crate::impl_matrix_op; use crate::index::Index2D; +use crate::util::{checked_div, checked_inv}; use num_traits::real::Real; use num_traits::{Num, NumOps, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; +use std::mem::swap; use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg}; +use std::process::id; /// A 2D array of values which can be operated upon. /// @@ -181,7 +184,7 @@ impl Matrix { Some(&mut self.data[m][n]) } - /// Returns a row of the matrix. panics if index is out of bounds + /// Returns a row of the matrix. or [None] if index is out of bounds /// /// # Examples /// @@ -194,12 +197,15 @@ impl Matrix { /// ``` #[inline] #[must_use] - pub fn row(&self, m: usize) -> Option> { - if m < M { - Some(Vector::::vec(self.data[m])) - } else { - None - } + pub fn row(&self, m: usize) -> Vector { + assert!( + m < M, + "Row index {} out of bounds for {}x{} matrix", + m, + M, + N + ); + Vector::::vec(self.data[m]) } #[inline] @@ -211,25 +217,28 @@ impl Matrix { M, N ); - for (n, v) in val.elements().enumerate() { - self.data[m][n] = *v; + for n in 0..N { + self.data[m][n] = val.data[n][0]; } } pub fn pivot_row(&mut self, m1: usize, m2: usize) { - let tmp = self.row(m2).expect("Invalid row index"); - self.set_row(m2, &self.row(m1).expect("Invalid row index")); + let tmp = self.row(m2); + self.set_row(m2, &self.row(m1)); self.set_row(m1, &tmp); } #[inline] #[must_use] - pub fn col(&self, n: usize) -> Option> { - if n < N { - Some(Vector::::vec(self.data.map(|r| r[n]))) - } else { - None - } + pub fn col(&self, n: usize) -> Vector { + assert!( + n < N, + "Column index {} out of bounds for {}x{} matrix", + n, + M, + N + ); + Vector::::vec(self.data.map(|r| r[n])) } #[inline] @@ -242,25 +251,41 @@ impl Matrix { N ); - for (m, v) in val.elements().enumerate() { - self.data[m][n] = *v; + for m in 0..M { + self.data[m][n] = val.data[m][0]; } } pub fn pivot_col(&mut self, n1: usize, n2: usize) { - let tmp = self.col(n2).expect("Invalid column index"); - self.set_col(n2, &self.col(n1).expect("Invalid column index")); + let tmp = self.col(n2); + self.set_col(n2, &self.col(n1)); self.set_col(n1, &tmp); } #[must_use] pub fn rows<'a>(&'a self) -> impl Iterator> + 'a { - (0..M).map(|m| self.row(m).expect("invalid row reached while iterating")) + (0..M).map(|m| self.row(m)) } #[must_use] pub fn cols<'a>(&'a self) -> impl Iterator> + 'a { - (0..N).map(|n| self.col(n).expect("invalid column reached while iterating")) + (0..N).map(|n| self.col(n)) + } + + #[must_use] + pub fn permute_rows(&self, ms: &Vector) -> Self + where + T: Default, + { + Self::from_rows(ms.elements().map(|&m| self.row(m))) + } + + #[must_use] + pub fn permute_cols(&self, ns: &Vector) -> Self + where + T: Default, + { + Self::from_cols(ns.elements().map(|&n| self.col(n))) } pub fn transpose(&self) -> Matrix @@ -305,14 +330,7 @@ impl Vector { /// Create a vector from a 1D array. /// Note that vectors are always column vectors unless explicitly instantiated as row vectors /// - /// # Arguments - /// - /// * `data`: A 1D array of elements to copy into the new vector - /// - /// returns: Vector - /// /// # Examples - /// /// ``` /// # use vector_victor::{Matrix, Vector}; /// let my_vector = Vector::vec([1,2,3,4]); @@ -374,8 +392,9 @@ impl Matrix { } } -// Square matrix impls +// Square matrix implementations impl Matrix { + /// Create an identity matrix #[must_use] pub fn identity() -> Self where @@ -388,31 +407,36 @@ impl Matrix { return result; } + /// returns an iterator over the elements along the diagonal of a square matrix #[must_use] pub fn diagonals<'s>(&'s self) -> impl Iterator + 's { (0..N).map(|n| self[(n, n)]) } + /// Returns an iterator over the elements directly below the diagonal of a square matrix #[must_use] pub fn subdiagonals<'s>(&'s self) -> impl Iterator + 's { (0..N - 1).map(|n| self[(n, n + 1)]) } - #[must_use] + /// Returns `Some(lu, idx, d)`, or [None] if the matrix is singular. + /// + /// Where: + /// * `lu`: The LU decomposition of `self`. The upper and lower matrices are combined into a single matrix + /// * `idx`: The permutation of rows on the original matrix needed to perform the decomposition. + /// Each element is the corresponding row index in the original matrix + /// * `d`: The permutation parity of `idx`, either `1` for even or `-1` for odd /// - /// - /// - /// a - /// 3 - /// - /// + /// The resulting tuple (once unwrapped) has the [LUSolve] trait, allowing it to be used for + /// solving multiple matrices without having to repeat the LU decomposition process + #[must_use] pub fn lu(&self) -> Option<(Self, Vector, T)> where T: Real + Default, { // Implementation from Numerical Recipes §2.3 let mut lu = self.clone(); - let mut idx: Vector = Default::default(); + let mut idx: Vector = (0..N).collect(); let mut d = T::one(); let mut vv: Vector = self @@ -428,7 +452,7 @@ impl Matrix { for k in 0..N { // search for the pivot element and its index - let (ipivot, _) = (lu.col(k)? * vv) + let (ipivot, _) = (lu.col(k) * vv) .abs() .elements() .enumerate() @@ -442,11 +466,11 @@ impl Matrix { // do we need to interchange rows? if k != ipivot { lu.pivot_row(ipivot, k); // yes, we do + idx.pivot_row(ipivot, k); d = -d; // change parity of d vv[ipivot] = vv[k] //interchange scale factor } - idx[k] = ipivot; let pivot = lu[(k, k)]; if pivot.abs() < T::epsilon() { // if the pivot is zero, the matrix is singular @@ -467,21 +491,33 @@ impl Matrix { return Some((lu, idx, d)); } + /// Computes the inverse matrix of `self`, or [None] if the matrix cannot be inverted. #[must_use] pub fn inverse(&self) -> Option where - T: Real + Default + Sum, + T: Real + Default + Sum + Product, { - self.solve(&Self::identity()) + match N { + 1 => Some(Self::fill(checked_inv(self[0])?)), + 2 => { + let mut result = Self::default(); + result[(0, 0)] = self[(1, 1)]; + result[(1, 1)] = self[(0, 0)]; + result[(1, 0)] = -self[(1, 0)]; + result[(0, 1)] = -self[(0, 1)]; + Some(result * checked_inv(self.det())?) + } + _ => Some(self.lu()?.inverse()), + } } + /// Computes the determinant of `self`. #[must_use] pub fn det(&self) -> T where T: Real + Default + Product + Sum, { match N { - 0 => T::one(), 1 => self[0], 2 => (self[(0, 0)] * self[(1, 1)]) - (self[(0, 1)] * self[(1, 0)]), 3 => { @@ -500,37 +536,52 @@ impl Matrix { } _ => { // use LU decomposition - if let Some((lu, _, d)) = self.lu() { - d * lu.diagonals().product() - } else { - T::zero() - } + self.lu().map_or(T::zero(), |lu| lu.det()) } } } - + /// Solves a system of `Ax = b` using `self` for `A`, or [None] if there is no solution. #[must_use] pub fn solve(&self, b: &Matrix) -> Option> where - T: Real + Default + Sum, + T: Real + Default + Sum + Product, { Some(self.lu()?.solve(b)) } } -pub trait LUSolve: Copy { - fn solve(&self, rhs: &R) -> R; +/// Trait for the result of [Matrix::lu()], +/// allowing a single LU decomposition to be used to solve multiple equations +pub trait LUSolve: Copy +where + T: Real + Copy, +{ + /// Solves a system of `Ax = b` using an LU decomposition. + fn solve(&self, rhs: &Matrix) -> Matrix; + + /// Solves the determinant using an LU decomposition, + /// by multiplying the product of the diagonals by the permutation parity + fn det(&self) -> T; + + /// Solves the inverse of the matrix that the LU decomposition represents. + fn inverse(&self) -> Matrix { + return self.solve(&Matrix::::identity()); + } + + /// Separate the lu decomposition into L and U matrices, such that `L*U = P*A`. + fn separate(&self) -> (Matrix, Matrix); } -impl LUSolve> - for (Matrix, Vector, T) +impl LUSolve for (Matrix, Vector, T) where - for<'t> T: Real + Default + Sum, + T: Real + Default + Sum + Product, { #[must_use] - fn solve(&self, b: &Matrix) -> Matrix { + fn solve(&self, b: &Matrix) -> Matrix { let (lu, idx, _) = self; - Matrix::::from_cols(b.cols().map(|mut x| { + let bp = b.permute_rows(idx); + + Matrix::from_cols(bp.cols().map(|mut x| { // Implementation from Numerical Recipes §2.3 // When ii is set to a positive value, @@ -538,42 +589,48 @@ where let mut ii = 0usize; for i in 0..N { // forward substitution - let ip = idx[i]; // i permuted - let sum = x[ip]; - x[ip] = x[i]; // unscramble as we go - if ii > 0 { - x[i] = sum - - (lu.row(i).expect("Invalid row reached") * x) - .elements() - .take(i) - .skip(ii - 1) - .cloned() - .sum() - } else { - x[i] = sum; - if sum.abs() > T::epsilon() { - ii = i + 1; + let mut sum = x[i]; + if ii != 0 { + for j in (ii - 1)..i { + sum = sum - (lu[(i, j)] * x[j]); } + } else if sum.abs() > T::epsilon() { + ii = i + 1; } + x[i] = sum; } - for i in (0..(N - 1)).rev() { + for i in (0..N).rev() { // back substitution - let sum = x[i] - - (lu.row(i).expect("Invalid row reached") * x) - .elements() - .skip(i + 1) - .cloned() - .sum(); - + let mut sum = x[i]; + for j in (i + 1)..N { + sum = sum - (lu[(i, j)] * x[j]); + } x[i] = sum / lu[(i, i)] } x })) } -} -// Square matrices -impl Matrix {} + fn det(&self) -> T { + let (lu, _, d) = self; + *d * lu.diagonals().product() + } + + fn separate(&self) -> (Matrix, Matrix) { + let mut l = Matrix::::identity(); + let mut u = self.0; // lu + + for m in 1..N { + for n in 0..m { + // iterate over lower diagonal + l[(m, n)] = u[(m, n)]; + u[(m, n)] = T::zero(); + } + } + + (l, u) + } +} // Index impl Index for Matrix @@ -583,6 +640,7 @@ where { type Output = T; + #[inline(always)] fn index(&self, index: I) -> &Self::Output { self.get(index).expect(&*format!( "index {:?} out of range for {}x{} Matrix", @@ -597,6 +655,7 @@ where I: Index2D, T: Copy, { + #[inline(always)] fn index_mut(&mut self, index: I) -> &mut Self::Output { self.get_mut(index).expect(&*format!( "index {:?} out of range for {}x{} Matrix", diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..2db0147 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,13 @@ +use num_traits::{Num, NumOps, One, Zero}; +use std::ops::Div; + +pub fn checked_div, R: Num + Zero, T>(num: L, den: R) -> Option { + if den.is_zero() { + return None; + } + return Some(num / den); +} + +pub fn checked_inv + Zero + One>(den: T) -> Option { + return checked_div(T::one(), den); +} diff --git a/tests/ops.rs b/tests/ops.rs index 0fc4908..89dc2fc 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -1,9 +1,10 @@ use generic_parameterize::parameterize; use num_traits::real::Real; +use num_traits::Zero; use std::fmt::Debug; use std::iter::{Product, Sum}; use std::ops; -use vector_victor::{Matrix, Vector}; +use vector_victor::{LUSolve, Matrix, Vector}; #[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])] #[test] @@ -25,40 +26,54 @@ fn test_lu_identity() // let a: Matrix = Matrix::::identity(); let i = Matrix::::identity(); let ones = Vector::::fill(S::one()); - let (lu, idx, d) = i.lu().expect("Singular matrix encountered"); - assert_eq!( - lu, - i, - "Incorrect LU decomposition matrix for {m}x{m} identity matrix", - m = M - ); + let decomp = i.lu().expect("Singular matrix encountered"); + let (lu, idx, d) = decomp; + assert_eq!(lu, i, "Incorrect LU decomposition"); assert!( (0..M).eq(idx.elements().cloned()), - "Incorrect permutation matrix result for {m}x{m} identity matrix", - m = M - ); - assert_eq!( - d, - S::one(), - "Incorrect permutation parity for {m}x{m} identity matrix", - m = M - ); - assert_eq!( - i.det(), - S::one(), - "Incorrect determinant for {m}x{m} identity matrix", - m = M + "Incorrect permutation matrix", ); + assert_eq!(d, S::one(), "Incorrect permutation parity"); + assert_eq!(i.det(), S::one()); + assert_eq!(i.inverse(), Some(i)); + assert_eq!(i.solve(&ones), Some(ones)); + assert_eq!(decomp.separate(), (i, i)); +} + +#[parameterize(S = (f32, f64), M = [2,3,4])] +#[test] +fn test_lu_singular() { + // let a: Matrix = Matrix::::identity(); + let mut a = Matrix::::zero(); + let ones = Vector::::fill(S::one()); + a.set_row(0, &ones); + + assert_eq!(a.lu(), None, "Matrix should be singular"); + assert_eq!(a.det(), S::zero()); + assert_eq!(a.inverse(), None); + assert_eq!(a.solve(&ones), None) +} + +#[test] +fn test_lu_2x2() { + let a = Matrix::new([[1.0, 2.0], [3.0, 0.0]]); + let decomp = a.lu().expect("Singular matrix encountered"); + let (lu, idx, d) = decomp; + // the decomposition is non-unique, due to the combination of lu and idx. + // Instead of checking the exact value, we only check the results. + // Also check if they produce the same results with both methods, since the + // Matrix<> methods use shortcuts the decomposition methods don't + + let (l, u) = decomp.separate(); + assert_eq!(l.mmul(&u), a.permute_rows(&idx)); + + assert_eq!(a.det(), -6.0); + assert_eq!(a.det(), decomp.det()); + assert_eq!( - i.inverse(), - Some(i), - "Incorrect inverse for {m}x{m} identity matrix", - m = M + a.inverse(), + Some(Matrix::new([[0.0, 2.0], [3.0, -1.0]]) * (1.0 / 6.0)) ); - assert_eq!( - i.solve(&ones), - Some(ones), - "Incorrect solve result for {m}x{m} identity matrix", - m = M - ) + assert_eq!(a.inverse(), Some(decomp.inverse())); + assert_eq!(a.inverse().unwrap().inverse().unwrap(), a) }