diff --git a/Cargo.toml b/Cargo.toml index fe8c8ce..65df4db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" [dependencies] generic_parameterize = "0.1.0" +num-traits = "0.2.15" diff --git a/src/lib.rs b/src/lib.rs index 83c6a21..8cdc4bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,5 +3,6 @@ extern crate core; pub mod index; mod macros; mod matrix; +mod matrix_traits; pub use matrix::{Matrix, Scalar, Vector}; diff --git a/src/macros/ops.rs b/src/macros/ops.rs index 434ff88..9613382 100644 --- a/src/macros/ops.rs +++ b/src/macros/ops.rs @@ -102,7 +102,7 @@ macro_rules! _impl_op_m_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $out:ty) => { impl ::std::ops::$ops_trait for $lhs where - L: ::std::ops::$ops_trait + Scalar, + L: ::std::ops::$ops_trait + Copy, { type Output = $out; @@ -125,8 +125,8 @@ macro_rules! _impl_op_mm_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { impl ::std::ops::$ops_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Scalar, - R: Scalar, + L: ::std::ops::$ops_trait + Copy, + R: Copy, { type Output = $out; @@ -149,8 +149,8 @@ macro_rules! _impl_opassign_mm_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { impl ::std::ops::$ops_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Scalar, - R: Scalar, + L: ::std::ops::$ops_trait + Copy, + R: Copy, { #[inline(always)] fn $ops_fn(&mut self, other: $rhs) { @@ -169,8 +169,8 @@ macro_rules! _impl_op_ms_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { impl ::std::ops::$ops_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Scalar, - R: Scalar, + L: ::std::ops::$ops_trait + Copy, + R: Copy + Num, { type Output = $out; @@ -193,8 +193,8 @@ macro_rules! _impl_opassign_ms_internal { ($ops_trait:ident, $ops_fn:ident, $lhs:ty, $rhs:ty, $out:ty) => { impl ::std::ops::$ops_trait<$rhs> for $lhs where - L: ::std::ops::$ops_trait + Scalar, - R: Scalar, + L: ::std::ops::$ops_trait + Copy, + R: Copy + Num, { #[inline(always)] fn $ops_fn(&mut self, r: $rhs) { diff --git a/src/matrix.rs b/src/matrix.rs index 7c78fb3..cbcb3bf 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,8 +1,11 @@ use crate::impl_matrix_op; use crate::index::Index2D; +use crate::matrix_traits::Mult; +use num_traits::{Num, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; -use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut, MulAssign}; +use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub}; +use std::process::Output; /// A Scalar that a [Matrix] can be made up of. /// @@ -25,15 +28,29 @@ where #[derive(Debug, Copy, Clone, PartialEq)] pub struct Matrix where - T: Scalar, + T: Copy, { - data: [[T; N]; M], + data: [[T; N]; M], // Column-Major order } /// An alias for a [Matrix] with a single column pub type Vector = Matrix; -impl Matrix { +pub trait Dot { + type Output; + #[must_use] + fn dot(&self, rhs: &R) -> Output; +} + +pub trait Cross { + #[must_use] + fn cross_r(&self, rhs: &R) -> Self; + #[must_use] + fn cross_l(&self, rhs: &R) -> Self; +} + +// Simple access functions that only require T be copyable +impl Matrix { /// Generate a new matrix from a 2D Array /// /// # Arguments @@ -95,8 +112,8 @@ impl Matrix { #[must_use] pub fn from_rows(iter: I) -> Self where - Self: Default, I: IntoIterator>, + Self: Default, { let mut result = Self::default(); for (m, row) in iter.into_iter().enumerate().take(M) { @@ -124,8 +141,8 @@ impl Matrix { #[must_use] pub fn from_cols(iter: I) -> Self where - Self: Default, I: IntoIterator>, + Self: Default, { let mut result = Self::default(); for (n, col) in iter.into_iter().enumerate().take(N) { @@ -143,17 +160,17 @@ impl Matrix { /// assert!(vec![1,2,3,4].iter().eq(my_matrix.elements())) /// ``` #[must_use] - pub fn elements<'a>(&'a self) -> impl Iterator + 'a { + pub fn elements<'a>(&'a self) -> impl Iterator + 'a { self.data.iter().flatten() } /// Returns a mutable iterator over the elements of the matrix in row-major order. #[must_use] - pub fn elements_mut<'a>(&'a mut self) -> impl Iterator + 'a { + pub fn elements_mut<'a>(&'a mut self) -> impl Iterator + 'a { self.data.iter_mut().flatten() } - /// Returns a reference to the element at that position in the matrix or `None` if out of bounds. + /// Returns a reference to the element at that position in the matrix, or `None` if out of bounds. /// /// # Examples /// @@ -163,6 +180,11 @@ impl Matrix { /// /// // element at index 2 is the same as the element at (row 1, column 0). /// assert_eq!(my_matrix.get(2), my_matrix.get((1,0))); + /// + /// // my_matrix.get() is equivalent to my_matrix[], + /// // but returns an Option instead of panicking + /// assert_eq!(my_matrix.get(2), Some(&my_matrix[2])); + /// /// // index 4 is out of range, so get(4) returns None. /// assert_eq!(my_matrix.get(4), None); /// ``` @@ -173,6 +195,7 @@ impl Matrix { Some(&self.data[m][n]) } + /// Returns a mutable reference to the element at that position in the matrix, or `None` if out of bounds. #[inline] #[must_use] pub fn get_mut(&mut self, index: impl Index2D) -> Option<&mut T> { @@ -180,14 +203,28 @@ impl Matrix { Some(&mut self.data[m][n]) } + /// Returns a row of the matrix. panics if index is out of bounds + /// + /// # Examples + /// + /// ``` + /// # use vector_victor::{Matrix, Vector}; + /// let my_matrix = Matrix::new([[1,2],[3,4]]); + /// + /// // row at index 1 + /// assert_eq!(my_matrix.row(1), Vector::vec([3,4])); + /// ``` #[inline] #[must_use] - pub fn row(&self, m: usize) -> Option> { - if m < M { - Some(Vector::::vec(self.data[m])) - } else { - None - } + pub fn row(&self, m: usize) -> Vector { + assert!( + m < M, + "Row index {} out of bounds for {}x{} matrix", + m, + M, + N + ); + Vector::::vec(self.data[m]) } #[inline] @@ -231,17 +268,41 @@ impl Matrix { #[must_use] pub fn rows<'a>(&'a self) -> impl Iterator> + 'a { - (0..M).map(|m| self.row(m).expect("invalid row reached while iterating")) + (0..M).map(|m| self.row(m)) } #[must_use] pub fn cols<'a>(&'a self) -> impl Iterator> + 'a { (0..N).map(|n| self.col(n).expect("invalid column reached while iterating")) } + + pub fn transpose(&self) -> Matrix + where + Matrix: Default, + { + Matrix::::from_rows(self.cols()) + } + + // pub fn mmul(&self, rhs: &Matrix) -> Matrix + // where + // R: Num, + // T: Scalar + Mul, + // Vector: Dot, Output = T>, + // { + // let mut result: Matrix = Zero::zero(); + // + // for (m, a) in self.rows().enumerate() { + // for (n, b) in rhs.cols().enumerate() { + // // result[(m, n)] = a.dot(b) + // } + // } + // + // return result; + // } } // 1D vector implementations -impl Matrix { +impl Vector { /// Create a vector from a 1D array. /// Note that vectors are always column vectors unless explicitly instantiated as row vectors /// @@ -249,7 +310,7 @@ impl Matrix { /// /// * `data`: A 1D array of elements to copy into the new vector /// - /// returns: Matrix + /// returns: Vector /// /// # Examples /// @@ -260,17 +321,49 @@ impl Matrix { /// assert_eq!(my_vector, Matrix::new([[1],[2],[3],[4]])); /// ``` pub fn vec(data: [T; M]) -> Self { - return Matrix:: { - data: data.map(|e| [e; 1]), + return Vector:: { + data: data.map(|e| [e]), }; } } +impl Dot> for Vector +where + for<'a> Output: Sum<&'a T>, + for<'b> &'b Self: Mul<&'b Vector, Output = Self>, +{ + type Output = T; + fn dot(&self, rhs: &Matrix) -> Output { + (self * rhs).elements().sum::() + } +} + +impl Vector { + pub fn cross_r(&self, rhs: Vector) -> Self + where + T: Mul + Sub, + { + Self::vec([ + (self[1] * rhs[2]) - (self[2] * rhs[1]), + (self[2] * rhs[0]) - (self[0] * rhs[2]), + (self[0] * rhs[1]) - (self[1] * rhs[0]), + ]) + } + + pub fn cross_l(&self, rhs: Vector) -> Self + where + T: Mul + Sub, + Self: Neg, + { + -self.cross_r(rhs) + } +} + // Index impl Index for Matrix where I: Index2D, - T: Scalar, + T: Copy, { type Output = T; @@ -295,13 +388,28 @@ where )) } } - // Default -impl Default for Matrix { +impl Default for Matrix { fn default() -> Self { - Matrix { - data: [[T::default(); N]; M], - } + Matrix::new([[T::default(); N]; M]) + } +} + +// Zero +impl Zero for Matrix { + fn zero() -> Self { + Matrix::new([[T::zero(); N]; M]) + } + + fn is_zero(&self) -> bool { + self.elements().all(|e| e.is_zero()) + } +} + +// One +impl One for Matrix { + fn one() -> Self { + Matrix::new([[T::one(); N]; M]) } } @@ -363,9 +471,12 @@ where } } -impl Sum for Matrix { +impl Sum for Matrix +where + Self: Zero + AddAssign, +{ fn sum>(iter: I) -> Self { - let mut sum = Self::default(); + let mut sum = Self::zero(); for m in iter { sum += m; @@ -375,9 +486,12 @@ impl Sum for Matrix Product for Matrix { +impl Product for Matrix +where + Self: One + MulAssign, +{ fn product>(iter: I) -> Self { - let mut prod = Self::default(); + let mut prod = Self::one(); for m in iter { prod *= m; diff --git a/src/matrix_traits.rs b/src/matrix_traits.rs new file mode 100644 index 0000000..fb42659 --- /dev/null +++ b/src/matrix_traits.rs @@ -0,0 +1,22 @@ +use num_traits::Pow; + +pub trait Dot { + type Output; + fn dot(&self, other: &RHS) -> >::Output; +} + +pub trait Cross { + type Output; + fn cross(&self, other: &RHS) -> >::Output; +} + +pub trait Mult { + type Output; + fn mult(&self, other: &RHS) -> >::Output; +} + +pub trait Magnitude> { + fn sqrmag(&self) -> T; + fn mag(&self) -> >::Output; + fn norm(&self) -> Self; +}