diff --git a/quicktex/Matrix.h b/quicktex/Matrix.h index dd7daa8..908eb99 100644 --- a/quicktex/Matrix.h +++ b/quicktex/Matrix.h @@ -361,46 +361,74 @@ class Matrix : public VecBase>, M> { row_type sqr_mag() const { return dot(*this); } Matrix abs() const { - Matrix ret; - for (unsigned i = 0; i < N * M; i++) { ret.element(i) = quicktex::abs(element(i)); } - return ret; - } - - Matrix clamp(T low, T high) { - Matrix ret; - for (unsigned i = 0; i < N * M; i++) { ret.element(i) = quicktex::clamp(element(i), low, high); } - return ret; - } - - Matrix clamp(const Matrix &low, const Matrix &high) { - Matrix ret; - for (unsigned i = 0; i < N * M; i++) { - ret.element(i) = quicktex::clamp(element(i), low.element(i), high.element(i)); + Matrix res; + if constexpr (_batched) { + auto lb = _batch_type::load_unaligned(&this->at(0)); + lb = xsimd::abs(lb); + lb.store_unaligned(&res[0]); + } else { + for (unsigned i = 0; i < N * M; i++) { res.element(i) = quicktex::abs(element(i)); } } - return ret; + return res; + } + + Matrix clamp(T low, T high) { return clamp(Matrix(low), Matrix(high)); } + Matrix clamp(const Matrix &low, const Matrix &high) { + Matrix res; + if constexpr (_batched) { + auto vb = _batch_type::load_unaligned(&this->at(0)); + auto lb = _batch_type::load_unaligned(&low[0]); + auto hb = _batch_type::load_unaligned(&high[0]); + vb = quicktex::clamp(vb, lb, hb); + vb.store_unaligned(&res[0]); + } else { + for (unsigned m = 0; m < M; m++) { + res[m] = quicktex::clamp(get_row(m), low.get_row(m), high.get_row(m)); + } + } + return res; } protected: - template static inline Matrix map(Matrix &lhs, Op f) { - Matrix ret; - for (unsigned i = 0; i < lhs.size(); i++) { ret[i] = f(lhs[i]); } - return ret; + template static inline Matrix map(const Matrix &lhs, Op f) { + Matrix res; + if constexpr (_batched) { + auto lb = _batch_type::load_unaligned(&lhs[0]); + auto resb = f(lb); + resb.store_unaligned(&res[0]); + } else { + for (unsigned i = 0; i < lhs.size(); i++) { res[i] = f(lhs[i]); } + } + return res; } template requires operable static inline Matrix map(const Matrix &lhs, const R &rhs, Op f) { - Matrix r; - for (unsigned i = 0; i < lhs.size(); i++) { r[i] = f(lhs[i], rhs); } - return r; + Matrix res; + if constexpr (_batched && operable<_batch_type, R, Op>) { + auto lb = _batch_type::load_unaligned(&lhs[0]); + auto resb = f(lb, rhs); + resb.store_unaligned(&res[0]); + } else { + for (unsigned i = 0; i < lhs.size(); i++) { res[i] = f(lhs[i], rhs); } + } + return res; } template requires operable static inline Matrix map(const Matrix &lhs, const Matrix &rhs, Op f) { - Matrix r; - for (unsigned i = 0; i < lhs.size(); i++) { r[i] = f(lhs[i], rhs[i]); } - return r; + Matrix res; + if constexpr (_batched && operable<_batch_type, _batch_type, Op>) { + auto lb = _batch_type::load_unaligned(&lhs[0]); + auto rb = xsimd::load_as(&rhs[0], xsimd::unaligned_mode{}); + auto resb = f(lb, rb); + resb.store_unaligned(&res[0]); + } else { + for (unsigned i = 0; i < lhs.size(); i++) { res[i] = f(lhs[i], rhs[i]); } + } + return res; } class column_iterator : public index_iterator_base { @@ -438,26 +466,32 @@ class Matrix : public VecBase>, M> { private: V *_matrix; }; -}; -template class BatchVec : Vec, M> { - template - static BatchVec load_columns(const Matrix &matrix, size_t column) { - const size_t batch_size = xsimd::batch::size; - assert(column + batch_size <= N); + private: + using _batch_type = std::conditional_t::type, void>; + static constexpr bool _batched = !std::is_void_v<_batch_type>; - BatchVec ret; - for (unsigned i = 0; i < M; i++) { ret[i] = xsimd::load(&(matrix[column][i]), U{}); } - return ret; + // right now batched types are always the whole vector but that might change + template using _chunk_type = std::conditional_t; + + template static constexpr size_t _chunk_count = b && _batched ? 1 : M; + + template inline _chunk_type get_chunk(size_t i) const { + assert(i < _chunk_count); + if constexpr (b && _batched) { + return _chunk_type::load_unaligned(&(this->at(0))); + } else { + return get_row(i); + } } - template - void store_columns(Matrix &matrix, size_t column) { - const size_t batch_size = xsimd::batch::size; - assert(column + batch_size <= N); - - for (unsigned i = 0; i < M; i++) { this->at(i).store((&(matrix[column][i]), U{})); } + template inline void set_chunk(size_t i, _chunk_type &value) const { + assert(i < _chunk_count); + if constexpr (b && _batched) { + xsimd::store_unaligned(&(this->at(0)), value); + } else { + set_row(i, value); + } } }; - } // namespace quicktex \ No newline at end of file diff --git a/quicktex/util/math.h b/quicktex/util/math.h index 172c1b4..113b227 100644 --- a/quicktex/util/math.h +++ b/quicktex/util/math.h @@ -27,17 +27,49 @@ #include #include +#include "util/ranges.h" #include "xsimd/xsimd.hpp" namespace quicktex { +namespace detail { using std::abs; // abs overload for builtin types using xsimd::abs; // abs overload for xsimd buffers +} // namespace detail -template constexpr S clamp(S value, S low, S high) { +template + requires requires(S &s) { s.abs(); } +constexpr S abs(S value) { + return value.abs(); +} + +template + requires requires(S &s) { detail::abs(s); } +constexpr S abs(S value) { + return detail::abs(value); +} + +template + requires requires(S &s) { s.clamp(s, s); } +constexpr S clamp(S value, S low, S high) { + assert(low <= high); + return value.clamp(low, high); +} + +template + requires std::is_scalar_v +constexpr S clamp(S value, S low, S high) { assert(low <= high); if (value < low) return low; if (value > high) return high; return value; } + +template +constexpr S clamp(xsimd::batch value, const xsimd::batch &low, const xsimd::batch &high) { + value = xsimd::select(xsimd::lt(low), low, value); + value = xsimd::select(xsimd::gt(high), high, value); + return value; +} + } // namespace quicktex \ No newline at end of file diff --git a/quicktex/util/ranges.h b/quicktex/util/ranges.h index f53a88c..c9cc6c3 100644 --- a/quicktex/util/ranges.h +++ b/quicktex/util/ranges.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include #include #include diff --git a/tests/ctest/TestMatrix.cpp b/tests/ctest/TestMatrix.cpp index 3ee866e..2002cb2 100644 --- a/tests/ctest/TestMatrix.cpp +++ b/tests/ctest/TestMatrix.cpp @@ -147,5 +147,31 @@ TEST(Vec_int, copy) { EXPECT_EQ(out, arr); } + +TEST(Vec_int, neg) { + auto a = Vec{1, 2, 3, 4}; + + expect_matrix_eq(-a, {-1, -2, -3, -4}); +} + +TEST(Vec_int, add) { + auto a = Vec{1, 2, 3, 4}; + auto b = Vec{5, 6, 7, 8}; + + expect_matrix_eq(a + b, {6, 8, 10, 12}); +} + +TEST(Vec_int, sub) { + auto b = Vec{1, 2, 3, 4}; + auto a = Vec{5, 6, 7, 8}; + + expect_matrix_eq(a - b, {4, 4, 4, 4}); +} + +TEST(Vec_int, abs) { + auto a = Vec{1, -5, -1, 0}; + + expect_matrix_eq(a.abs(), {1, 5, 1, 0}); +} // endregion } // namespace quicktex::tests \ No newline at end of file