2022-08-01 07:33:43 +00:00
|
|
|
use crate::impl_matrix_op;
|
|
|
|
use crate::index::Index2D;
|
|
|
|
use std::iter::{zip, Enumerate, Flatten};
|
|
|
|
use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Range};
|
|
|
|
use std::option::IntoIter;
|
2022-07-29 01:49:48 +00:00
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
pub trait Get2D {
|
|
|
|
type Scalar: Sized + Copy;
|
|
|
|
const HEIGHT: usize;
|
|
|
|
const WIDTH: usize;
|
2022-07-29 01:49:48 +00:00
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
fn get<I: Index2D>(&self, i: I) -> Option<&Self::Scalar>;
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
pub trait Get2DMut: Get2D {
|
|
|
|
fn get_mut<I: Index2D>(&mut self, i: I) -> Option<&mut Self::Scalar>;
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
trait Scalar: Copy + 'static {}
|
|
|
|
macro_rules! multi_impl { ($name:ident for $($t:ty),*) => ($( impl $name for $t {} )*) }
|
|
|
|
multi_impl!(Scalar for i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64);
|
|
|
|
impl<T> Scalar for &'static T where T: Scalar {}
|
|
|
|
|
|
|
|
#[derive(Debug, Copy, Clone)]
|
|
|
|
struct Matrix<T, const M: usize, const N: usize>
|
2022-07-29 01:49:48 +00:00
|
|
|
where
|
2022-08-01 07:33:43 +00:00
|
|
|
T: Scalar,
|
2022-07-29 01:49:48 +00:00
|
|
|
{
|
2022-08-01 07:33:43 +00:00
|
|
|
data: [[T; N]; M],
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
type Vector<T, const N: usize> = Matrix<T, N, 1>;
|
|
|
|
|
|
|
|
impl<T: Scalar, const M: usize, const N: usize> Matrix<T, M, N> {
|
|
|
|
fn new(data: [[T; N]; M]) -> Self {
|
|
|
|
return Matrix::<T, M, N> { data };
|
|
|
|
}
|
|
|
|
|
|
|
|
fn from_rows<I>(iter: &I) -> Self
|
|
|
|
where
|
|
|
|
Self: Default,
|
|
|
|
I: Iterator<Item = Vector<T, N>> + Copy,
|
|
|
|
{
|
|
|
|
let mut result = Self::default();
|
|
|
|
for (m, row) in iter.enumerate().filter(|(m, _)| *m <= M) {
|
|
|
|
result.set_row(m, &row)
|
|
|
|
}
|
|
|
|
result
|
|
|
|
}
|
|
|
|
|
|
|
|
fn from_cols<I>(iter: &I) -> Self
|
|
|
|
where
|
|
|
|
Self: Default,
|
|
|
|
I: Iterator<Item = Vector<T, M>> + Copy,
|
|
|
|
{
|
|
|
|
let mut result = Self::default();
|
|
|
|
for (n, col) in iter.enumerate().filter(|(n, _)| *n <= N) {
|
|
|
|
result.set_col(n, &col)
|
|
|
|
}
|
|
|
|
result
|
|
|
|
}
|
|
|
|
|
|
|
|
fn elements<'a>(&'a self) -> impl Iterator<Item = &T> + 'a {
|
|
|
|
self.data.iter().flatten()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn elements_mut<'a>(&'a mut self) -> impl Iterator<Item = &mut T> + 'a {
|
|
|
|
self.data.iter_mut().flatten()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get(&self, index: impl Index2D) -> Option<&T> {
|
|
|
|
let (m, n) = index.to_2d(M, N)?;
|
|
|
|
Some(&self.data[m][n])
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_mut(&mut self, index: impl Index2D) -> Option<&mut T> {
|
|
|
|
let (m, n) = index.to_2d(M, N)?;
|
|
|
|
Some(&mut self.data[m][n])
|
|
|
|
}
|
|
|
|
|
|
|
|
fn row(&self, m: usize) -> Option<Vector<T, N>> {
|
|
|
|
if m < M {
|
|
|
|
Some(Vector::<T, N>::new_vector(self.data[m]))
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn set_row(&mut self, m: usize, val: &Vector<T, N>) {
|
|
|
|
assert!(
|
|
|
|
m < M,
|
|
|
|
"Row index {} out of bounds for {}x{} matrix",
|
|
|
|
m,
|
|
|
|
M,
|
|
|
|
N
|
|
|
|
);
|
|
|
|
for (n, v) in val.elements().enumerate() {
|
|
|
|
self.data[m][n] = *v;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn col(&self, n: usize) -> Option<Vector<T, M>> {
|
|
|
|
if n < N {
|
|
|
|
Some(Vector::<T, M>::new_vector(self.data.map(|r| r[n])))
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn set_col(&mut self, n: usize, val: &Vector<T, M>) {
|
|
|
|
assert!(
|
|
|
|
n < N,
|
|
|
|
"Column index {} out of bounds for {}x{} matrix",
|
|
|
|
n,
|
|
|
|
M,
|
|
|
|
N
|
|
|
|
);
|
2022-07-29 01:49:48 +00:00
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
for (m, v) in val.elements().enumerate() {
|
|
|
|
self.data[m][n] = *v;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn rows<'a>(&'a self) -> impl Iterator<Item = Vector<T, N>> + 'a {
|
|
|
|
(0..M).map(|m| self.row(m).expect("invalid row reached while iterating"))
|
|
|
|
}
|
|
|
|
|
|
|
|
fn cols<'a>(&'a self) -> impl Iterator<Item = Vector<T, M>> + 'a {
|
|
|
|
(0..N).map(|n| self.col(n).expect("invalid column reached while iterating"))
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
// constructor for column vectors
|
|
|
|
impl<T: Scalar, const N: usize> Vector<T, N> {
|
|
|
|
fn new_vector(data: [T; N]) -> Self {
|
|
|
|
return Vector::<T, N> {
|
|
|
|
data: data.map(|e| [e]),
|
|
|
|
};
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
// default constructor
|
|
|
|
impl<T, const M: usize, const N: usize> Default for Matrix<T, M, N>
|
|
|
|
where
|
|
|
|
[[T; N]; M]: Default,
|
|
|
|
T: Scalar,
|
2022-07-29 01:49:48 +00:00
|
|
|
{
|
2022-08-01 07:33:43 +00:00
|
|
|
fn default() -> Self {
|
|
|
|
Matrix {
|
|
|
|
data: Default::default(),
|
|
|
|
}
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
// deref 1x1 matrices to a scalar automatically
|
|
|
|
impl<T: Scalar> Deref for Matrix<T, 1, 1> {
|
|
|
|
type Target = T;
|
|
|
|
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
|
|
&self.data[0][0]
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
// deref 1x1 matrices to a mutable scalar automatically
|
|
|
|
impl<T: Scalar> DerefMut for Matrix<T, 1, 1> {
|
|
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
|
|
&mut self.data[0][0]
|
|
|
|
}
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
|
|
|
|
2022-08-01 07:33:43 +00:00
|
|
|
impl<T: Scalar, const M: usize, const N: usize> IntoIterator for Matrix<T, M, N> {
|
|
|
|
type Item = T;
|
|
|
|
type IntoIter = Flatten<std::array::IntoIter<[T; N], M>>;
|
|
|
|
|
|
|
|
fn into_iter(self) -> Self::IntoIter {
|
|
|
|
self.data.into_iter().flatten()
|
|
|
|
}
|
2022-07-29 01:49:48 +00:00
|
|
|
}
|
2022-08-01 07:33:43 +00:00
|
|
|
|
|
|
|
impl_matrix_op!(neg, |l: L| { -l });
|
|
|
|
impl_matrix_op!(!, |l: L| { !l });
|
|
|
|
impl_matrix_op!(+, |l,r| {l + r});
|
|
|
|
impl_matrix_op!(-, |l,r| {l - r});
|
|
|
|
impl_matrix_op!(*, |l,r| {l * r});
|
|
|
|
impl_matrix_op!(/, |l,r| {l / r});
|
|
|
|
impl_matrix_op!(%, |l,r| {l % r});
|
|
|
|
impl_matrix_op!(&, |l,r| {l & r});
|
|
|
|
impl_matrix_op!(|, |l,r| {l | r});
|
|
|
|
impl_matrix_op!(^, |l,r| {l ^ r});
|
|
|
|
impl_matrix_op!(<<, |l,r| {l << r});
|
|
|
|
impl_matrix_op!(>>, |l,r| {l >> r});
|