From 2829c524873b2a8ebebba4fbdc4fc7809ee48383 Mon Sep 17 00:00:00 2001 From: Andrew Cassidy Date: Wed, 16 Nov 2022 22:53:53 -0800 Subject: [PATCH] Yeet `Scalar` A better solution may be needed for matrix*scalar ops than using `num` though --- src/lib.rs | 2 +- src/matrix.rs | 104 ++++++++++++++++++++++++++++++-------------------- tests/ops.rs | 4 +- 3 files changed, 66 insertions(+), 44 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8cdc4bb..f0ac1a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,4 +5,4 @@ mod macros; mod matrix; mod matrix_traits; -pub use matrix::{Matrix, Scalar, Vector}; +pub use matrix::{Matrix, Vector}; diff --git a/src/matrix.rs b/src/matrix.rs index cbcb3bf..b12013c 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,25 +1,23 @@ use crate::impl_matrix_op; use crate::index::Index2D; -use crate::matrix_traits::Mult; use num_traits::{Num, One, Zero}; use std::fmt::Debug; use std::iter::{zip, Flatten, Product, Sum}; -use std::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub}; -use std::process::Output; +use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut, Mul, MulAssign, Neg, Sub}; /// A Scalar that a [Matrix] can be made up of. /// /// This trait has no associated functions and can be implemented on any type that is [Default] and /// [Copy] and has a static lifetime. -pub trait Scalar: Default + Copy + 'static {} -macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) } -multi_impl!(Scalar for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64); -impl Scalar for &'static T -where - T: Scalar, - &'static T: Default, -{ -} +// pub trait Scalar: Default + Copy + 'static {} +// macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) } +// multi_impl!(Scalar for i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64); +// impl Scalar for &'static T +// where +// T: Scalar, +// &'static T: Default, +// { +// } /// A 2D array of values which can be operated upon. /// @@ -30,7 +28,7 @@ pub struct Matrix where T: Copy, { - data: [[T; N]; M], // Column-Major order + data: [[T; N]; M], // Row-Major order } /// An alias for a [Matrix] with a single column @@ -39,7 +37,7 @@ pub type Vector = Matrix; pub trait Dot { type Output; #[must_use] - fn dot(&self, rhs: &R) -> Output; + fn dot(&self, rhs: &R) -> Self::Output; } pub trait Cross { @@ -49,6 +47,12 @@ pub trait Cross { fn cross_l(&self, rhs: &R) -> Self; } +pub trait MMul { + type Output; + #[must_use] + fn mmul(&self, rhs: &R) -> Self::Output; +} + // Simple access functions that only require T be copyable impl Matrix { /// Generate a new matrix from a 2D Array @@ -327,22 +331,23 @@ impl Vector { } } -impl Dot> for Vector +impl Dot> for Vector where - for<'a> Output: Sum<&'a T>, + for<'a> T: Sum<&'a T>, for<'b> &'b Self: Mul<&'b Vector, Output = Self>, { type Output = T; - fn dot(&self, rhs: &Matrix) -> Output { - (self * rhs).elements().sum::() + fn dot(&self, rhs: &Matrix) -> Self::Output { + (self * rhs).elements().sum::() } } -impl Vector { - pub fn cross_r(&self, rhs: Vector) -> Self - where - T: Mul + Sub, - { +impl Cross> for Vector +where + T: Mul + Sub, + Self: Neg, +{ + fn cross_r(&self, rhs: &Vector) -> Self { Self::vec([ (self[1] * rhs[2]) - (self[2] * rhs[1]), (self[2] * rhs[0]) - (self[0] * rhs[2]), @@ -350,15 +355,32 @@ impl Vector { ]) } - pub fn cross_l(&self, rhs: Vector) -> Self - where - T: Mul + Sub, - Self: Neg, - { + fn cross_l(&self, rhs: &Vector) -> Self { -self.cross_r(rhs) } } +impl MMul> + for Matrix +where + T: Default, + Vector: Dot, Output = T>, +{ + type Output = Matrix; + + fn mmul(&self, rhs: &Matrix) -> Self::Output { + let mut result = Self::Output::default(); + + for (m, a) in self.rows().enumerate() { + for (n, b) in rhs.cols().enumerate() { + result[(m, n)] = a.dot(&b) + } + } + + return result; + } +} + // Index impl Index for Matrix where @@ -379,7 +401,7 @@ where impl IndexMut for Matrix where I: Index2D, - T: Scalar, + T: Copy, { fn index_mut(&mut self, index: I) -> &mut Self::Output { self.get_mut(index).expect(&*format!( @@ -391,14 +413,14 @@ where // Default impl Default for Matrix { fn default() -> Self { - Matrix::new([[T::default(); N]; M]) + Matrix::fill(T::default()) } } // Zero impl Zero for Matrix { fn zero() -> Self { - Matrix::new([[T::zero(); N]; M]) + Matrix::fill(T::zero()) } fn is_zero(&self) -> bool { @@ -409,30 +431,30 @@ impl Zero for Matrix { // One impl One for Matrix { fn one() -> Self { - Matrix::new([[T::one(); N]; M]) + Matrix::fill(T::one()) } } -impl From<[[T; N]; M]> for Matrix { +impl From<[[T; N]; M]> for Matrix { fn from(data: [[T; N]; M]) -> Self { Self::new(data) } } -impl From<[T; M]> for Vector { +impl From<[T; M]> for Vector { fn from(data: [T; M]) -> Self { Self::vec(data) } } -impl From for Matrix { +impl From for Matrix { fn from(scalar: T) -> Self { Self::fill(scalar) } } // deref 1x1 matrices to a scalar automatically -impl Deref for Matrix { +impl Deref for Matrix { type Target = T; fn deref(&self) -> &Self::Target { @@ -441,14 +463,14 @@ impl Deref for Matrix { } // deref 1x1 matrices to a mutable scalar automatically -impl DerefMut for Matrix { +impl DerefMut for Matrix { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.data[0][0] } } // IntoIter -impl IntoIterator for Matrix { +impl IntoIterator for Matrix { type Item = T; type IntoIter = Flatten>; @@ -458,7 +480,7 @@ impl IntoIterator for Matrix } // FromIterator -impl FromIterator for Matrix +impl FromIterator for Matrix where Self: Default, { @@ -471,7 +493,7 @@ where } } -impl Sum for Matrix +impl Sum for Matrix where Self: Zero + AddAssign, { @@ -486,7 +508,7 @@ where } } -impl Product for Matrix +impl Product for Matrix where Self: One + MulAssign, { diff --git a/tests/ops.rs b/tests/ops.rs index 0afb902..2b450e2 100644 --- a/tests/ops.rs +++ b/tests/ops.rs @@ -1,11 +1,11 @@ use generic_parameterize::parameterize; use std::fmt::Debug; use std::ops; -use vector_victor::{Matrix, Scalar}; +use vector_victor::Matrix; #[parameterize(S = (i32, f32, u32), M = [1,4], N = [1,4])] #[test] -fn test_add + PartialEq + Debug, const M: usize, const N: usize>() +fn test_add + PartialEq + Debug, const M: usize, const N: usize>() where Matrix: ops::Add>, {