From 785914cf3f05860ef6d6b7ae69514c7fd6b2588b Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Thu, 23 May 2024 14:48:42 +0200 Subject: [PATCH] Simplify implementations with tile() function (#5) * Added cachePreferredTraversal to matrix pointers * Removed PanelGemmTest.testNT_submatrix * Added traversal_order argument to tile() * gemm_nt_backend() taking matrix pointers * Removed padding in minor direction of DynamicPanelMatrix * Removed obsolete functions from DynamicPanelMatrix * - Include quadratic term in gemm() flops; - call blast::gemm() with alpha and beta in the benchmark, for fair comparison --- bench/blasfeo/Gemm.cpp | 2 +- bench/blast/math/dense/StaticGemm.cpp | 5 +- include/blast/math/StorageOrder.hpp | 6 + include/blast/math/TypeTraits.hpp | 1 + include/blast/math/algorithm/Gemm.hpp | 4 +- include/blast/math/algorithm/Tile.hpp | 158 ++++++++++++------ .../blast/math/dense/DynamicMatrixPointer.hpp | 1 + .../blast/math/dense/StaticMatrixPointer.hpp | 2 + .../blast/math/expressions/PanelMatrix.hpp | 10 +- .../blast/math/panel/DynamicPanelMatrix.hpp | 95 ++--------- .../math/panel/DynamicPanelMatrixPointer.hpp | 10 +- include/blast/math/panel/Gemm.hpp | 72 +++----- .../math/panel/StaticPanelMatrixPointer.hpp | 2 + include/blast/math/typetraits/IsAligned.hpp | 28 ++++ .../blast/math/typetraits/MatrixPointer.hpp | 1 + include/blast/math/views/submatrix/Panel.hpp | 27 +-- test/blast/math/dense/MatrixPointerTest.cpp | 18 ++ .../math/panel/DynamicPanelMatrixTest.cpp | 4 +- test/blast/math/panel/GemmTest.cpp | 46 +---- test/blast/math/views/SubmatrixTest.cpp | 12 +- 20 files changed, 225 insertions(+), 279 deletions(-) create mode 100644 include/blast/math/typetraits/IsAligned.hpp diff --git a/bench/blasfeo/Gemm.cpp b/bench/blasfeo/Gemm.cpp index 3801eabb..f3635cb0 100644 --- a/bench/blasfeo/Gemm.cpp +++ b/bench/blasfeo/Gemm.cpp @@ -53,7 +53,7 @@ namespace blast :: benchmark for (auto _ : state) gemm_nt(m, n, k, 1., A, 0, 0, B, 0, 0, 1., C, 0, 0, C, 0, 0); - state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate); + state.counters["flops"] = Counter(2 * m * n * k + 3 * m * n, Counter::kIsIterationInvariantRate); state.counters["m"] = m; } diff --git a/bench/blast/math/dense/StaticGemm.cpp b/bench/blast/math/dense/StaticGemm.cpp index 6504ff65..1e332b13 100644 --- a/bench/blast/math/dense/StaticGemm.cpp +++ b/bench/blast/math/dense/StaticGemm.cpp @@ -32,15 +32,14 @@ namespace blast :: benchmark for (auto _ : state) { - // gemm(1., A, trans(B), 1., C, D); - gemm(A, B, C, D); + gemm(0.5, A, B, 0.1, C, D); DoNotOptimize(A); DoNotOptimize(B); DoNotOptimize(C); DoNotOptimize(D); } - state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate); + state.counters["flops"] = Counter(2 * M * N * K + 3 * M * N, Counter::kIsIterationInvariantRate); state.counters["m"] = M; } diff --git a/include/blast/math/StorageOrder.hpp b/include/blast/math/StorageOrder.hpp index 3aef823d..c40f8169 100644 --- a/include/blast/math/StorageOrder.hpp +++ b/include/blast/math/StorageOrder.hpp @@ -24,4 +24,10 @@ namespace blast rowMajor = blaze::rowMajor, columnMajor = blaze::columnMajor }; + + + inline constexpr StorageOrder operator!(StorageOrder so) + { + return so == rowMajor ? columnMajor : rowMajor; + } } \ No newline at end of file diff --git a/include/blast/math/TypeTraits.hpp b/include/blast/math/TypeTraits.hpp index 9d79a2c7..a65b0ac3 100644 --- a/include/blast/math/TypeTraits.hpp +++ b/include/blast/math/TypeTraits.hpp @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include diff --git a/include/blast/math/algorithm/Gemm.hpp b/include/blast/math/algorithm/Gemm.hpp index 2991ccee..211c3acb 100644 --- a/include/blast/math/algorithm/Gemm.hpp +++ b/include/blast/math/algorithm/Gemm.hpp @@ -61,7 +61,9 @@ namespace blast BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t>, ET); BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t>, ET); - tile)>(M, N, + tile)>( + D.cachePreferredTraversal, + M, N, [&] (auto& ker, size_t i, size_t j) { gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j)); diff --git a/include/blast/math/algorithm/Tile.hpp b/include/blast/math/algorithm/Tile.hpp index 015cca0f..41d7aadb 100644 --- a/include/blast/math/algorithm/Tile.hpp +++ b/include/blast/math/algorithm/Tile.hpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -22,6 +23,34 @@ namespace blast { + template + BLAZE_ALWAYS_INLINE void tile_backend(size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial) + { + RegisterMatrix ker; + + if (i + KM <= m) + { + size_t j = 0; + + for (; j + KN <= n; j += KN) + f_full(ker, i, j); + + if (j < n) + f_partial(ker, i, j, KM, n - j); + } + else + { + size_t j = 0; + + for (; j + KN <= n; j += KN) + f_partial(ker, i, j, m - i, KN); + + if (j < n) + f_partial(ker, i, j, m - i, n - j); + } + } + + /** * @brief Cover a matrix with tiles of different sizes in a performance-efficient way. * @@ -48,77 +77,98 @@ namespace blast * @param f_partial functor to call on partial tiles */ template - BLAST_ALWAYS_INLINE void tile(std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) + BLAST_ALWAYS_INLINE void tile(StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) { - size_t constexpr TILE_SIZE = TileSize_v; + size_t constexpr SS = SimdSize_v; + size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined - size_t j = 0; + static_assert(SO == columnMajor, "tile() for row-major matrices not implemented"); - // Main part - for (; j + TILE_SIZE <= n; j += TILE_SIZE) + if (traversal_order == columnMajor) { - size_t i = 0; - - // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: - // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. - for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE) - { - RegisterMatrix ker; - f_full(ker, i, j); - } + size_t j = 0; - for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE) + // Main part + for (; j + TILE_STEP <= n; j += TILE_STEP) { - RegisterMatrix ker; - f_full(ker, i, j); + size_t i = 0; + + // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: + // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. + for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + for (; i + 2 * SS <= m; i += 2 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + for (; i + 1 * SS <= m; i += 1 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + // Bottom side + if (i < m) + { + RegisterMatrix ker; + f_partial(ker, i, j, m - i, ker.columns()); + } } - for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE) - { - RegisterMatrix ker; - f_full(ker, i, j); - } - // Bottom side - if (i < m) + // Right side + if (j < n) { - RegisterMatrix ker; - f_partial(ker, i, j, m - i, ker.columns()); + size_t i = 0; + + // i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP: + // it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel. + for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + for (; i + 2 * SS <= m; i += 2 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + for (; i + 1 * SS <= m; i += 1 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + // Bottom-right corner + if (i < m) + { + RegisterMatrix ker; + f_partial(ker, i, j, m - i, n - j); + } } } - - - // Right side - if (j < n) + else { size_t i = 0; - // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: - // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. - for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE) - { - RegisterMatrix ker; - f_partial(ker, i, j, ker.rows(), n - j); - } - - for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE) - { - RegisterMatrix ker; - f_partial(ker, i, j, ker.rows(), n - j); - } + // i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS: + // it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel. + for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS) + tile_backend(m, n, i, f_full, f_partial); - for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE) - { - RegisterMatrix ker; - f_partial(ker, i, j, ker.rows(), n - j); - } + for (; i + 1 * SS < m; i += 2 * SS) + tile_backend(m, n, i, f_full, f_partial); - // Bottom-right corner - if (i < m) - { - RegisterMatrix ker; - f_partial(ker, i, j, m - i, n - j); - } + for (; i + 0 * SS < m; i += 1 * SS) + tile_backend(m, n, i, f_full, f_partial); } } } \ No newline at end of file diff --git a/include/blast/math/dense/DynamicMatrixPointer.hpp b/include/blast/math/dense/DynamicMatrixPointer.hpp index 61716bb3..e6540576 100644 --- a/include/blast/math/dense/DynamicMatrixPointer.hpp +++ b/include/blast/math/dense/DynamicMatrixPointer.hpp @@ -31,6 +31,7 @@ namespace blast static bool constexpr aligned = AF; static bool constexpr padded = PF; static bool constexpr isStatic = false; + static StorageOrder constexpr cachePreferredTraversal = SO == columnMajor ? columnMajor : rowMajor; /** diff --git a/include/blast/math/dense/StaticMatrixPointer.hpp b/include/blast/math/dense/StaticMatrixPointer.hpp index 901da1e7..2db833b8 100644 --- a/include/blast/math/dense/StaticMatrixPointer.hpp +++ b/include/blast/math/dense/StaticMatrixPointer.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -31,6 +32,7 @@ namespace blast static bool constexpr aligned = AF; static bool constexpr padded = PF; static bool constexpr isStatic = true; + static StorageOrder constexpr cachePreferredTraversal = SO == columnMajor ? columnMajor : rowMajor; /** diff --git a/include/blast/math/expressions/PanelMatrix.hpp b/include/blast/math/expressions/PanelMatrix.hpp index 64d3edbf..4a556a33 100644 --- a/include/blast/math/expressions/PanelMatrix.hpp +++ b/include/blast/math/expressions/PanelMatrix.hpp @@ -108,7 +108,7 @@ namespace blast for (size_t i = 0; i + SS <= m; i += SS) { - ET2 const * pr = (*rhs).ptr(i, 0); + ET2 const * pr = &(*rhs)(i, 0); ET1 * pl = data(lhs) + i; for (size_t j = 0; j < n; ++j) @@ -119,7 +119,7 @@ namespace blast { MaskType const mask = SIMD::index() < rem; size_t const i = m - rem; - ET2 const * pr = (*rhs).ptr(i, 0); + ET2 const * pr = &(*rhs)(i, 0); ET1 * pl = data(lhs) + i; for (size_t j = 0; j < n; ++j) @@ -197,7 +197,7 @@ namespace blast for (size_t i = 0; i + SS <= m; i += SS) { ET2 const * pr = data(rhs) + i; - ET1 * pl = (*lhs).ptr(i, 0); + ET1 * pl = &(*lhs)(i, 0); for (size_t j = 0; j < n; ++j) store(pl + PANEL_SIZE * j, load(pr + s * j)); @@ -207,7 +207,7 @@ namespace blast { size_t const i = m - rem; ET2 const * pr = data(rhs) + i; - ET1 * pl = (*lhs).ptr(i, 0); + ET1 * pl = &(*lhs)(i, 0); for (size_t j = 0; j < n; ++j) for (size_t i1 = 0; i1 < rem; ++i1) @@ -230,7 +230,7 @@ namespace blast for (size_t i = 0; i < m; ++i) { ET2 const * pr = data(rhs) + s * i; - ET1 * pl = (*lhs).ptr(i, 0); + ET1 * pl = &(*lhs)(i, 0); for (size_t j = 0; j < n; ++j) pl[PANEL_SIZE * j] = pr[j]; diff --git a/include/blast/math/panel/DynamicPanelMatrix.hpp b/include/blast/math/panel/DynamicPanelMatrix.hpp index f2a05c29..f82e51e6 100644 --- a/include/blast/math/panel/DynamicPanelMatrix.hpp +++ b/include/blast/math/panel/DynamicPanelMatrix.hpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include @@ -53,17 +53,14 @@ namespace blast explicit DynamicPanelMatrix(size_t m, size_t n) - : m_(m) - , n_(n) - , spacing_( - SO == columnMajor - ? panelSize_ * nextMultiple(n, panelSize_) - : nextMultiple(m, panelSize_) * panelSize_ - ) - , capacity_(nextMultiple(m, panelSize_) * nextMultiple(n, panelSize_)) + : m_ {m} + , n_ {n} + , spacing_ {SS * (SO == columnMajor ? n : m)} + , capacity_ {spacing_ * nextMultiple(SO == columnMajor ? m : n, SS)} + // Initialize padding elements to 0 to prevent denorms in calculations. // Initialize padding elements to 0 to prevent denorms in calculations. // Denorms can significantly impair performance, see https://github.com/giaf/blasfeo/issues/103 - , v_(new(std::align_val_t {alignment_}) Type[capacity_] {}) + , v_ {new(std::align_val_t {alignment_}) Type[capacity_] {}} { } @@ -114,36 +111,7 @@ namespace blast , bool SO2 > // Storage order of the right-hand side matrix DynamicPanelMatrix& operator=(Matrix const& rhs) { - // using blaze::assign; - - // using TT = decltype( trans( *this ) ); - // using CT = decltype( ctrans( *this ) ); - // using IT = decltype( inv( *this ) ); - - // if( (*rhs).rows() != M || (*rhs).columns() != N ) { - // BLAZE_THROW_INVALID_ARGUMENT( "Invalid assignment to static matrix" ); - // } - - // if( IsSame_v && (*rhs).isAliased( this ) ) { - // transpose( typename IsSquare::Type() ); - // } - // else if( IsSame_v && (*rhs).isAliased( this ) ) { - // ctranspose( typename IsSquare::Type() ); - // } - // else if( !IsSame_v && (*rhs).canAlias( this ) ) { - // StaticPanelMatrix tmp( *rhs ); - // assign( *this, tmp ); - // } - // else { - // if( IsSparseMatrix_v ) - // reset(); - // assign( *this, *rhs ); - // } - - // BLAZE_INTERNAL_ASSERT( isIntact(), "Invariant violation detected" ); - assign(*this, *rhs); - return *this; } @@ -178,23 +146,6 @@ namespace blast } - /// @brief Offset of the first matrix element from the start of the panel. - /// - /// In rows for column-major matrices, in columns for row-major matrices. - size_t constexpr offset() const - { - return 0; - } - - - void unpackLower(Type * data, size_t lda) const - { - for (size_t i = 0; i < m_; ++i) - for (size_t j = 0; j <= i; ++j) - data[i + lda * j] = (*this)(i, j); - } - - Type * data() noexcept { return v_; @@ -207,35 +158,9 @@ namespace blast } - Type * ptr(size_t i, size_t j) - { - // BLAST_USER_ASSERT(i % panelSize_ == 0, "Row index not aligned to panel boundary"); - return v_ + elementIndex(i, j); - } - - - Type const * ptr(size_t i, size_t j) const - { - // BLAST_USER_ASSERT(i % panelSize_ == 0, "Row index not aligned to panel boundary"); - return v_ + elementIndex(i, j); - } - - - template - auto load(size_t i, size_t j) const - { - BLAZE_INTERNAL_ASSERT(i < m_, "Invalid row access index"); - BLAZE_INTERNAL_ASSERT(j < n_, "Invalid column access index"); - BLAZE_INTERNAL_ASSERT(i % panelSize_ == 0 || SO == rowMajor, "Row index not aligned to panel boundary"); - BLAZE_INTERNAL_ASSERT(j % panelSize_ == 0 || SO == columnMajor, "Column index not aligned to panel boundary"); - - return blast::load(v_ + elementIndex(i, j)); - } - - private: static size_t constexpr alignment_ = CACHE_LINE_SIZE; - static size_t constexpr panelSize_ = PanelSize_v; + static size_t constexpr SS = SimdSize_v; size_t m_; size_t n_; @@ -248,8 +173,8 @@ namespace blast size_t elementIndex(size_t i, size_t j) const noexcept { return SO == columnMajor - ? i / panelSize_ * spacing_ + i % panelSize_ + j * panelSize_ - : j / panelSize_ * spacing_ + j % panelSize_ + i * panelSize_; + ? i / SS * spacing_ + i % SS + j * SS + : j / SS * spacing_ + j % SS + i * SS; } }; } diff --git a/include/blast/math/panel/DynamicPanelMatrixPointer.hpp b/include/blast/math/panel/DynamicPanelMatrixPointer.hpp index 105e16c2..f965f4e1 100644 --- a/include/blast/math/panel/DynamicPanelMatrixPointer.hpp +++ b/include/blast/math/panel/DynamicPanelMatrixPointer.hpp @@ -13,6 +13,8 @@ #include #include +#include + namespace blast { @@ -29,6 +31,7 @@ namespace blast static bool constexpr aligned = AF; static bool constexpr padded = PF; static bool constexpr isStatic = false; + static StorageOrder constexpr cachePreferredTraversal = SO == columnMajor ? rowMajor : columnMajor; /** @@ -59,7 +62,7 @@ namespace blast else { // NOTE: non-optimized! - ElementType tmp[SS]; + std::remove_cv_t tmp[SS]; for (size_t i = 0; i < SS; ++i) tmp[i] = storageOrder == columnMajor ? *(~*this)(i, 0) : *(~*this)(0, i); return SimdVecType {tmp, false}; @@ -76,7 +79,10 @@ namespace blast SimdVecType load(TransposeFlag orientation) const { if (orientation == majorOrientation) - return SimdVecType {ptr_, AF}; + if constexpr (AF) + return SimdVecType {ptr_, AF}; + else + static_assert(AF, "load() crossing panel boundary not implemented"); else BLAZE_THROW_LOGIC_ERROR("Cross-load not implemented"); } diff --git a/include/blast/math/panel/Gemm.hpp b/include/blast/math/panel/Gemm.hpp index 59df3d46..a8991e84 100644 --- a/include/blast/math/panel/Gemm.hpp +++ b/include/blast/math/panel/Gemm.hpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include @@ -31,14 +31,6 @@ namespace blast } - /// Returns the index of first unprocessed row. - template - void gemm_nt_backend( - size_t i, ST1 alpha, ST2 beta, - PanelMatrix const& A, PanelMatrix const& B, - PanelMatrix const& C, PanelMatrix& D); - - template BLAZE_ALWAYS_INLINE void gemm_nt( ST1 alpha, ST2 beta, @@ -46,7 +38,8 @@ namespace blast PanelMatrix const& C, PanelMatrix& D) { using ET = ElementType_t; - size_t constexpr PANEL_SIZE = PanelSize_v; + size_t constexpr SS = SimdSize_v; + size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); @@ -69,56 +62,37 @@ namespace blast // i + 4 * PANEL_SIZE != M is to improve performance in case when the remaining number of rows is 4 * PANEL_SIZE: // it is more efficient to apply 2 * PANEL_SIZE kernel 2 times than 3 * PANEL_SIZE + 1 * PANEL_SIZE kernel. - for (; i + 2 * PANEL_SIZE < M && i + 4 * PANEL_SIZE != M; i += 3 * PANEL_SIZE) - gemm_nt_backend<3 * PANEL_SIZE, 4>(i, alpha, beta, *A, *B, *C, *D); + for (; i + 2 * SS < M && i + 4 * SS != M; i += 3 * SS) + gemm_nt_backend<3 * SS, TILE_STEP>(M, N, K, i, alpha, ptr(*A), ~trans(ptr(*B)), beta, ptr(*C), ptr(*D)); - for (; i + 1 * PANEL_SIZE < M; i += 2 * PANEL_SIZE) - gemm_nt_backend<2 * PANEL_SIZE, 4>(i, alpha, beta, *A, *B, *C, *D); + for (; i + 1 * SS < M; i += 2 * SS) + gemm_nt_backend<2 * SS, TILE_STEP>(M, N, K, i, alpha, ptr(*A), ~trans(ptr(*B)), beta, ptr(*C), ptr(*D)); - for (; i + 0 * PANEL_SIZE < M; i += 1 * PANEL_SIZE) - gemm_nt_backend<1 * PANEL_SIZE, 4>(i, alpha, beta, *A, *B, *C, *D); + for (; i + 0 * SS < M; i += 1 * SS) + gemm_nt_backend<1 * SS, TILE_STEP>(M, N, K, i, alpha, ptr(*A), ~trans(ptr(*B)), beta, ptr(*C), ptr(*D)); } - template - BLAZE_ALWAYS_INLINE void gemm_nt_backend( - size_t i, ST1 alpha, ST2 beta, - PanelMatrix const& A, PanelMatrix const& B, - PanelMatrix const& C, PanelMatrix& D) + template < + size_t KM, size_t KN, typename T, + typename MPA, typename MPB, typename MPC, typename MPD + > + requires MatrixPointer && MatrixPointer && MatrixPointer && MatrixPointer + BLAZE_ALWAYS_INLINE void gemm_nt_backend(size_t M, size_t N, size_t K, size_t i, T alpha, MPA A, MPB B, T beta, MPC C, MPD D) { - using ET = ElementType_t; - size_t constexpr PANEL_SIZE = PanelSize_v; - - BLAZE_STATIC_ASSERT(KM % PANEL_SIZE == 0); - - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - - size_t const M = rows(A); - size_t const N = rows(B); - size_t const K = columns(A); - - BLAST_USER_ASSERT(columns(B) == K, "Matrix sizes do not match"); - BLAST_USER_ASSERT(rows(C) == M && columns(C) == N, "Matrix sizes do not match"); - BLAST_USER_ASSERT(rows(D) == M && columns(D) == N, "Matrix sizes do not match"); - + using ET = ElementType_t; RegisterMatrix ker; if (i + KM <= M) { size_t j = 0; - auto a = ptr(A, i, 0); + auto a = A(i, 0); for (; j + KN <= N; j += KN) - gemm(ker, K, alpha, - a, trans(ptr(B, j, 0)), - beta, ptr(C, i, j), ptr(D, i, j)); + gemm(ker, K, alpha, a, B(0, j), beta, C(i, j), D(i, j)); if (j < N) - gemm(ker, K, alpha, - a, trans(ptr(B, j, 0)), - beta, ptr(C, i, j), ptr(D, i, j), KM, N - j); + gemm(ker, K, alpha, a, B(0, j), beta, C(i, j), D(i, j), KM, N - j); } else { @@ -126,14 +100,10 @@ namespace blast size_t j = 0; for (; j + KN <= N; j += KN) - gemm(ker, K, alpha, - ptr(A, i, 0), trans(ptr(B, j, 0)), - beta, ptr(C, i, j), ptr(D, i, j), M - i, KN); + gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, KN); if (j < N) - gemm(ker, K, alpha, - ptr(A, i, 0), trans(ptr(B, j, 0)), - beta, ptr(C, i, j), ptr(D, i, j), M - i, N - j); + gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, N - j); } } } \ No newline at end of file diff --git a/include/blast/math/panel/StaticPanelMatrixPointer.hpp b/include/blast/math/panel/StaticPanelMatrixPointer.hpp index 6fcabe8c..50544d02 100644 --- a/include/blast/math/panel/StaticPanelMatrixPointer.hpp +++ b/include/blast/math/panel/StaticPanelMatrixPointer.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include #include @@ -28,6 +29,7 @@ namespace blast static bool constexpr aligned = AF; static bool constexpr padded = PF; static bool constexpr isStatic = true; + static StorageOrder constexpr cachePreferredTraversal = SO == columnMajor ? rowMajor : columnMajor; /** diff --git a/include/blast/math/typetraits/IsAligned.hpp b/include/blast/math/typetraits/IsAligned.hpp new file mode 100644 index 00000000..5b3e7d8b --- /dev/null +++ b/include/blast/math/typetraits/IsAligned.hpp @@ -0,0 +1,28 @@ +// Copyright 2023 Mikhail Katliar +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + + +namespace blast +{ + template + struct IsAligned : blaze::IsAligned {}; + + + template + bool constexpr IsAligned_v = IsAligned::value; +} \ No newline at end of file diff --git a/include/blast/math/typetraits/MatrixPointer.hpp b/include/blast/math/typetraits/MatrixPointer.hpp index 6d961d68..50800749 100644 --- a/include/blast/math/typetraits/MatrixPointer.hpp +++ b/include/blast/math/typetraits/MatrixPointer.hpp @@ -24,6 +24,7 @@ namespace blast trans(p); ~p; *p; + p.cachePreferredTraversal; // {p.get()} -> std::same_as; }; diff --git a/include/blast/math/views/submatrix/Panel.hpp b/include/blast/math/views/submatrix/Panel.hpp index c1de07d6..c922bba7 100644 --- a/include/blast/math/views/submatrix/Panel.hpp +++ b/include/blast/math/views/submatrix/Panel.hpp @@ -91,7 +91,7 @@ namespace blast , j_(j) , m_(m) , n_(n) - , data_(matrix_.ptr(row(), column())) + , data_(&matrix_(row(), column())) { if( !Contains_v< TypeList, Unchecked > ) { @@ -204,19 +204,6 @@ namespace blast } - template - auto load(size_t i, size_t j) const - { - if( i >= rows() ) { - BLAZE_THROW_OUT_OF_RANGE( "Invalid row access index" ); - } - if( j >= columns() ) { - BLAZE_THROW_OUT_OF_RANGE( "Invalid column access index" ); - } - return matrix_.template load(row() + i, column() + j); - } - - Reference at( size_t i, size_t j ) { if( i >= rows() ) { @@ -253,18 +240,6 @@ namespace blast } - Pointer ptr(size_t i, size_t j) - { - return matrix_.ptr(i + row(), j + column()); - } - - - ConstPointer ptr(size_t i, size_t j) const - { - return matrix_.ptr(i + row(), j + column()); - } - - private: static size_t constexpr panelSize_ = PanelSize_v; diff --git a/test/blast/math/dense/MatrixPointerTest.cpp b/test/blast/math/dense/MatrixPointerTest.cpp index d02bdff1..5bc45c2f 100644 --- a/test/blast/math/dense/MatrixPointerTest.cpp +++ b/test/blast/math/dense/MatrixPointerTest.cpp @@ -162,6 +162,24 @@ namespace blast :: testing } + TYPED_TEST(MatrixPointerTest, testIsAlignedByDefault) + { + auto const p = ptr(this->m_); + EXPECT_EQ(p.aligned, IsAligned_vm_)>); + } + + + TYPED_TEST(MatrixPointerTest, testIsAligned) + { + for (size_t i = 0; i < rows(this->m_); i += this->incI_) + for (size_t j = 0; j < columns(this->m_); j += this->incJ_) + { + auto const p = ptr(this->m_, i, j); + EXPECT_EQ(p.aligned, TestFixture::isAligned); + } + } + + TYPED_TEST(MatrixPointerTest, testGet) { for (size_t i = 0; i < rows(this->m_); i += this->incI_) diff --git a/test/blast/math/panel/DynamicPanelMatrixTest.cpp b/test/blast/math/panel/DynamicPanelMatrixTest.cpp index c0c4c6b8..03ad02e2 100644 --- a/test/blast/math/panel/DynamicPanelMatrixTest.cpp +++ b/test/blast/math/panel/DynamicPanelMatrixTest.cpp @@ -37,7 +37,7 @@ namespace blast :: testing { { DynamicPanelMatrix m(5, 2); - EXPECT_EQ(m.spacing(), 4 * 4); + EXPECT_EQ(m.spacing(), 4 * 2); } { @@ -47,7 +47,7 @@ namespace blast :: testing { DynamicPanelMatrix m(5, 7); - EXPECT_EQ(m.spacing(), 4 * 8); + EXPECT_EQ(m.spacing(), 4 * 7); } } diff --git a/test/blast/math/panel/GemmTest.cpp b/test/blast/math/panel/GemmTest.cpp index 2ba580ca..82311d22 100644 --- a/test/blast/math/panel/GemmTest.cpp +++ b/test/blast/math/panel/GemmTest.cpp @@ -57,19 +57,12 @@ namespace blast :: testing B = blaze_B; C = blaze_C; - // std::cout << "A=\n" << A << std::endl; - // std::cout << "B=\n" << B << std::endl; - // std::cout << "C=\n" << C << std::endl; - // Do gemm with BLAST gemm_nt(A, B, C, D); // Copy the resulting D matrix from BLAST to Blaze blaze_D = D; - // Print the result from BLAST - // std::cout << "blaze_D=\n" << blaze_blasfeo_D; - BLAST_ASSERT_APPROX_EQ(blaze_D, evaluate(blaze_C + blaze_A * trans(blaze_B)), absTol(), relTol()) << "gemm error at size m,n,k=" << M << "," << N << "," << K; } @@ -78,46 +71,9 @@ namespace blast :: testing } - TYPED_TEST_P(PanelGemmTest, testNT_submatrix) - { - using Real = TypeParam; - size_t const M = 8, N = 8, K = 3 * 8; - - // Init Blaze matrices - // - blaze::DynamicMatrix blaze_A(M, K), blaze_B(N, K), blaze_C(M, N), blaze_D(M, N); - randomize(blaze_A); - randomize(blaze_B); - randomize(blaze_C); - - // Init BLAST matrices - // - StaticPanelMatrix A; - StaticPanelMatrix B; - StaticPanelMatrix C; - StaticPanelMatrix D; - - A = blaze_A; - B = blaze_B; - C = blaze_C; - - // Do gemm with BLAST - auto D1 = submatrix(D, 0, 0, M, N); - gemm_nt(submatrix(A, 0, 0, M, K), submatrix(B, 0, 0, N, K), submatrix(C, 0, 0, M, N), D1); - - // Copy the resulting D matrix from BLAST to Blaze - blaze_D = D; - - // Print the result from BLAST - // std::cout << "blaze_D=\n" << blaze_blasfeo_D; - - BLAST_EXPECT_APPROX_EQ(blaze_D, evaluate(blaze_C + blaze_A * trans(blaze_B)), absTol(), relTol()); - } - REGISTER_TYPED_TEST_SUITE_P(PanelGemmTest, - testNT, - testNT_submatrix + testNT ); diff --git a/test/blast/math/views/SubmatrixTest.cpp b/test/blast/math/views/SubmatrixTest.cpp index 6bb7e9bc..801537ed 100644 --- a/test/blast/math/views/SubmatrixTest.cpp +++ b/test/blast/math/views/SubmatrixTest.cpp @@ -16,6 +16,7 @@ namespace blast :: testing { StaticPanelMatrix A; auto B = submatrix(A, 4, 0, 8, 8); + EXPECT_EQ(IsAligned_v, false); } @@ -23,6 +24,7 @@ namespace blast :: testing { StaticPanelMatrix const A; auto B = submatrix(A, 4, 0, 8, 8); + EXPECT_EQ(IsAligned_v, false); } @@ -31,8 +33,8 @@ namespace blast :: testing DynamicPanelMatrix A(12, 12); auto B = submatrix(A, 4, 0, 8, 8); - static_assert(std::is_same_v); - B.ptr(0, 0); + static_assert(std::is_same_v); + EXPECT_EQ(IsAligned_v, false); } @@ -41,8 +43,8 @@ namespace blast :: testing DynamicPanelMatrix const A(12, 12); auto B = submatrix(A, 4, 0, 8, 8); - static_assert(std::is_same_v); - B.ptr(0, 0); + static_assert(std::is_same_v); + EXPECT_EQ(IsAligned_v, false); } @@ -74,5 +76,7 @@ namespace blast :: testing ASSERT_EQ(B1(i, j), val); ASSERT_EQ(A(i + B1.row(), j + B1.column()), val); } + + EXPECT_EQ(IsAligned_v, false); } } \ No newline at end of file