diff --git a/quicktex/Matrix.h b/quicktex/Matrix.h index dbffff5..8988274 100644 --- a/quicktex/Matrix.h +++ b/quicktex/Matrix.h @@ -342,26 +342,48 @@ class Matrix : public VecBase>, M> { } // sum up all columns - column_type hsum() const { return std::accumulate(column_begin(), column_end(), column_type{}); } + column_type hsum() const { + if constexpr (N == 1) { return *this; } + if constexpr (M == 1) { return sum(); } + for (unsigned i = 0; i < M; i++) {} + return _map([](auto row) { return quicktex::sum(row); }, *this); + } // sum up all rows - row_type vsum() const { return std::accumulate(begin(), end(), row_type{}); } + row_type vsum() const { + if constexpr (N == 1) { return sum(); } + if constexpr (M == 1) { return *this; } + return std::accumulate(begin(), end(), row_type{}); + } // sum up all values - T sum() const { return std::accumulate(all_begin(), all_end(), T{}); } + T sum() const { + constexpr bool b = _batched && operable<_chunk_type<>, _chunk_type<>, std::plus<>>; + _chunk_type s = _get_chunk(0); + + // add up all chunks in parallel + // chunk is scalar -> this is doing the actual sum + // chunk is vector -> this is doing a vadd + // chunk is simd -> this is doing an add across the simd + for (unsigned i = 1; i < _chunk_count; i++) { + auto c = _get_chunk(i); + s += c; + } + + // now hadd the sum + // if chunk is scalar, this is a noop + return quicktex::sum(s); + } template requires operable> Matrix mult(const Matrix &rhs) const { - auto rt = rhs.transpose(); Matrix res(0); for (unsigned p = 0; p < P; p++) { // for each column of the RHS/Result for (unsigned m = 0; m < M; m++) { // for each row of the LHS/Result - for (unsigned n = 0; n < N; n++) { - res.element(m, p) += element(m, n) * rhs.element(n, p); - } + for (unsigned n = 0; n < N; n++) { res.element(m, p) += element(m, n) * rhs.element(n, p); } } } return res; @@ -378,9 +400,7 @@ class Matrix : public VecBase>, M> { Matrix mirror() const { Matrix result = *this; for (unsigned n = 0; n < N - 1; n++) { - for (unsigned m = (n + 1); m < M; m++) { - result.element(m,n) = result.element(n,m); - } + for (unsigned m = (n + 1); m < M; m++) { result.element(m, n) = result.element(n, m); } } return result; } @@ -388,19 +408,23 @@ class Matrix : public VecBase>, M> { // dot product of two compatible matrices template requires(N == 1) && operable> && operable> - row_type dot(const Matrix &rhs) const { + inline row_type dot(const Matrix &rhs) const { // technically this is Lt * R, but the vsum method is probably faster/more readable // than allocationg a new transpose matrix Matrix product = *this * rhs; return product.vsum(); } - row_type sqr_mag() const { return dot(*this); } + inline row_type sqr_mag() const { return dot(*this); } - Matrix abs() const { return _map(&quicktex::abs<_chunk_type<>>, *this); } + inline Matrix abs() const { + return _map([](auto c) { return quicktex::abs(c); }, *this); + } - Matrix clamp(T low, T high) { return _map(&quicktex::clamp<_chunk_type<>>, *this, low, high); } - Matrix clamp(const Matrix &low, const Matrix &high) { + inline Matrix clamp(T low, T high) { + return _map([low, high](_chunk_type<> c) { return quicktex::clamp(c, low, high); }, *this); + } + inline Matrix clamp(const Matrix &low, const Matrix &high) { return _map(&quicktex::clamp<_chunk_type<>>, *this, low, high); } @@ -476,13 +500,14 @@ class Matrix : public VecBase>, M> { * @param args additional scalar arguments * @return vector mapped with f(lhs[i], args) */ - template inline static Matrix _map(Op f, const Matrix &lhs, Args... args) { - Matrix result; - constexpr bool b = _batched && std::is_invocable_r_v<_chunk_type<>, Op, _chunk_type<>, Args...>; + template + inline static Result _map(Op f, const Matrix &lhs, Args... args) { + Result result; + constexpr bool b = _batched && std::is_invocable_v<_chunk_type<>, Op, _chunk_type<>, Args...>; for (unsigned i = 0; i < _chunk_count; i++) { auto c = lhs._get_chunk(i); auto resultc = f(c, args...); - result._set_chunk(i, resultc); + result.template _set_chunk(i, resultc); } return result; } @@ -497,15 +522,15 @@ class Matrix : public VecBase>, M> { * @param args additional scalar arguments * @return vector mapped with f(lhs[i], rhs[i], args) */ - template - inline static Matrix _map(Op f, const Matrix &lhs, const Matrix &rhs, Args... args) { - Matrix result; + template + inline static Result _map(Op f, const Matrix &lhs, const Matrix &rhs, Args... args) { + Result result; constexpr bool b = _batched && std::is_invocable_r_v<_chunk_type<>, Op, _chunk_type<>, _chunk_type<>, Args...>; for (unsigned i = 0; i < _chunk_count; i++) { auto lc = lhs._get_chunk(i); auto rc = rhs._get_chunk(i); auto resultc = f(lc, rc, args...); - result._set_chunk(i, resultc); + result.template _set_chunk(i, resultc); } return result; } @@ -521,9 +546,9 @@ class Matrix : public VecBase>, M> { * @param args additional scalar arguments * @return vector mapped with f(lhs[i], rhs1[i], rhs2[i], args) */ - template - inline static Matrix _map(Op f, const Matrix &lhs, const Matrix &rhs1, const Matrix &rhs2, Args... args) { - Matrix result; + template + inline static Result _map(Op f, const Matrix &lhs, const Matrix &rhs1, const Matrix &rhs2, Args... args) { + Result result; constexpr bool b = _batched && std::is_invocable_r_v<_chunk_type<>, Op, _chunk_type<>, _chunk_type<>, _chunk_type<>, Args...>; for (unsigned i = 0; i < _chunk_count; i++) { @@ -531,7 +556,7 @@ class Matrix : public VecBase>, M> { auto r1c = rhs1._get_chunk(i); auto r2c = rhs2._get_chunk(i); auto resultc = f(lc, r1c, r2c, args...); - result._set_chunk(i, resultc); + result.template _set_chunk(i, resultc); } return result; } diff --git a/quicktex/util/math.h b/quicktex/util/math.h index 113b227..474e620 100644 --- a/quicktex/util/math.h +++ b/quicktex/util/math.h @@ -32,10 +32,8 @@ namespace quicktex { -namespace detail { using std::abs; // abs overload for builtin types using xsimd::abs; // abs overload for xsimd buffers -} // namespace detail template requires requires(S &s) { s.abs(); } @@ -43,19 +41,6 @@ constexpr S abs(S value) { return value.abs(); } -template - requires requires(S &s) { detail::abs(s); } -constexpr S abs(S value) { - return detail::abs(value); -} - -template - requires requires(S &s) { s.clamp(s, s); } -constexpr S clamp(S value, S low, S high) { - assert(low <= high); - return value.clamp(low, high); -} - template requires std::is_scalar_v constexpr S clamp(S value, S low, S high) { @@ -66,10 +51,28 @@ constexpr S clamp(S value, S low, S high) { } template -constexpr S clamp(xsimd::batch value, const xsimd::batch &low, const xsimd::batch &high) { - value = xsimd::select(xsimd::lt(low), low, value); - value = xsimd::select(xsimd::gt(high), high, value); - return value; +constexpr xsimd::batch clamp(xsimd::batch value, const xsimd::batch &low, + const xsimd::batch &high) { + return xsimd::clip(value, low, high); } +template +constexpr xsimd::batch clamp(xsimd::batch value, const S &low, const S &high) { + return clamp(value, xsimd::broadcast(low), xsimd::broadcast(high)); +} + +template + requires requires(S &s) { s.sum(); } +constexpr auto sum(S value) { + return value.sum(); +} + +template + requires std::is_scalar_v +constexpr auto sum(S value) { + return value; + // horizontally adding a scalar is a noop +} + +template constexpr auto sum(xsimd::batch value) { return xsimd::hadd(value); } } // namespace quicktex \ No newline at end of file