matrix multiplication and transposition

This commit is contained in:
Andrew Cassidy 2022-06-13 22:55:41 -07:00
parent 2c59419bf0
commit 3756f31e20

View File

@ -56,9 +56,9 @@ template <typename V> struct vector_stats {
template <typename V>
requires is_matrix<V>
struct vector_stats<V> {
static constexpr size_t width = V::width();
static constexpr size_t height = V::height();
static constexpr size_t dims = ((width > 1) ? 1 : 0) + ((height > 1) ? 1 : 0);
static constexpr size_t width = V::width;
static constexpr size_t height = V::height;
static constexpr size_t dims = V::dims;
};
template <typename V> constexpr size_t vector_width = vector_stats<V>::width;
@ -99,8 +99,8 @@ class Matrix : public VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M> {
using base = VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M>;
using value_type = T;
using row_type = std::conditional_t<N == 1, T, Vec<T, N>>;
using column_type = std::conditional_t<M == 1, T, Vec<T, M>>;
using row_type = matrix_row_type<T, N, M>;
using column_type = matrix_column_type<T, N, M>;
using base::base;
using base::begin;
@ -113,7 +113,7 @@ class Matrix : public VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M> {
* Create a vector from an intializer list
* @param il values to populate with
*/
Matrix(std::initializer_list<T> il) {
Matrix(std::initializer_list<row_type> il) {
assert(il.size() == M); // ensure il is of the right size
std::copy_n(il.begin(), M, this->begin());
}
@ -132,7 +132,7 @@ class Matrix : public VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M> {
template <typename II>
Matrix(const II input_iterator)
requires std::input_iterator<II> && std::convertible_to<std::iter_value_t<II>,
T> {
row_type> {
std::copy_n(input_iterator, M, this->begin());
}
@ -143,7 +143,7 @@ class Matrix : public VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M> {
*/
template <typename R>
Matrix(const R &input_range)
requires range<R> && std::convertible_to<typename R::value_type, T>
requires range<R> && std::convertible_to<typename R::value_type, row_type>
: Matrix(input_range.begin()) {
assert(std::distance(input_range.begin(), input_range.end()) == M);
}
@ -151,8 +151,9 @@ class Matrix : public VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M> {
// region iterators and accessors
static constexpr size_t size() { return M; }
static constexpr size_t width() { return N; }
static constexpr size_t height() { return M; }
static constexpr size_t width = N;
static constexpr size_t height = M;
static constexpr size_t dims = ((width > 1) ? 1 : 0) + ((height > 1) ? 1 : 0);
auto row_begin() { return this->begin(); }
auto row_begin() const { return this->begin(); }
@ -319,15 +320,38 @@ class Matrix : public VecBase<std::conditional_t<N == 1, T, VecBase<T, N>>, M> {
// sum up all rows
row_type vsum() const { return std::accumulate(row_begin(), row_end(), row_type{}); }
template <typename R, size_t P>
requires operable<R, T, std::multiplies<>>
Matrix<T, P, M> mult(const Matrix<R, P, N> &rhs) {
auto rt = rhs.transpose();
Matrix<T, P, M> res(0);
for (unsigned i = 0; i < P; i++) {
// for each column of the RHS/Result
for (unsigned j = 0; j < M; j++) {
// for each row of the LHS/Result
res.element(i, j) = get_row(i).dot(rt.get_row(j));
}
}
return res;
}
Matrix<T, M, N> transpose() {
Matrix<T, M, N> res;
for (unsigned m = 0; m < M; m++) { res.set_column(m, get_row(m)); }
return res;
}
// dot product of two compatible matrices
template <typename R>
requires operable<T, R, std::multiplies<>> && operable<T, T, std::plus<>>
requires(N == 1) && operable<T, R, std::multiplies<>> && operable<T, T, std::plus<>>
row_type dot(const Matrix<R, N, M> &rhs) const {
// technically this is Lt * R, but the vsum method is probably faster/more readable
// than allocationg a new transpose matrix
Matrix product = *this * rhs;
return product.vsum();
}
row_type sqr_mag() const { return this->dot(*this); }
row_type sqr_mag() const { return dot(*this); }
Matrix abs() const {
Matrix ret;