Skip to content

Commit

Permalink
Simplify implementations with tile() function (#5)
Browse files Browse the repository at this point in the history
* Added cachePreferredTraversal to matrix pointers

* Removed PanelGemmTest.testNT_submatrix

* Added traversal_order argument to tile()

* gemm_nt_backend() taking matrix pointers

* Removed padding in minor direction of DynamicPanelMatrix

* Removed obsolete functions from DynamicPanelMatrix

* - Include quadratic term in gemm() flops;
- call blast::gemm() with alpha and beta in the benchmark, for fair comparison
  • Loading branch information
mkatliar authored May 23, 2024
1 parent 48cdb5f commit 785914c
Show file tree
Hide file tree
Showing 20 changed files with 225 additions and 279 deletions.
2 changes: 1 addition & 1 deletion bench/blasfeo/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace blast :: benchmark
for (auto _ : state)
gemm_nt(m, n, k, 1., A, 0, 0, B, 0, 0, 1., C, 0, 0, C, 0, 0);

state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate);
state.counters["flops"] = Counter(2 * m * n * k + 3 * m * n, Counter::kIsIterationInvariantRate);
state.counters["m"] = m;
}

Expand Down
5 changes: 2 additions & 3 deletions bench/blast/math/dense/StaticGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ namespace blast :: benchmark

for (auto _ : state)
{
// gemm(1., A, trans(B), 1., C, D);
gemm(A, B, C, D);
gemm(0.5, A, B, 0.1, C, D);
DoNotOptimize(A);
DoNotOptimize(B);
DoNotOptimize(C);
DoNotOptimize(D);
}

state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate);
state.counters["flops"] = Counter(2 * M * N * K + 3 * M * N, Counter::kIsIterationInvariantRate);
state.counters["m"] = M;
}

Expand Down
6 changes: 6 additions & 0 deletions include/blast/math/StorageOrder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,10 @@ namespace blast
rowMajor = blaze::rowMajor,
columnMajor = blaze::columnMajor
};


inline constexpr StorageOrder operator!(StorageOrder so)
{
return so == rowMajor ? columnMajor : rowMajor;
}
}
1 change: 1 addition & 0 deletions include/blast/math/TypeTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <blast/math/typetraits/IsStatic.hpp>
#include <blast/math/typetraits/IsAligned.hpp>
#include <blast/math/typetraits/IsPadded.hpp>
#include <blast/math/typetraits/ElementType.hpp>
#include <blast/math/typetraits/StorageOrder.hpp>
Expand Down
4 changes: 3 additions & 1 deletion include/blast/math/algorithm/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ namespace blast
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPC>>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPD>>, ET);

tile<ET, StorageOrder(StorageOrder_v<MPD>)>(M, N,
tile<ET, StorageOrder(StorageOrder_v<MPD>)>(
D.cachePreferredTraversal,
M, N,
[&] (auto& ker, size_t i, size_t j)
{
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
Expand Down
158 changes: 104 additions & 54 deletions include/blast/math/algorithm/Tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/simd/SimdSize.hpp>
#include <blast/math/StorageOrder.hpp>
#include <blast/math/RegisterMatrix.hpp>

Expand All @@ -22,6 +23,34 @@

namespace blast
{
template <typename ET, size_t KM, size_t KN, StorageOrder SO, typename FF, typename FP>
BLAZE_ALWAYS_INLINE void tile_backend(size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial)
{
RegisterMatrix<ET, KM, KN, SO> ker;

if (i + KM <= m)
{
size_t j = 0;

for (; j + KN <= n; j += KN)
f_full(ker, i, j);

if (j < n)
f_partial(ker, i, j, KM, n - j);
}
else
{
size_t j = 0;

for (; j + KN <= n; j += KN)
f_partial(ker, i, j, m - i, KN);

if (j < n)
f_partial(ker, i, j, m - i, n - j);
}
}


/**
* @brief Cover a matrix with tiles of different sizes in a performance-efficient way.
*
Expand All @@ -48,77 +77,98 @@ namespace blast
* @param f_partial functor to call on partial tiles
*/
template <typename ET, StorageOrder SO, typename FF, typename FP>
BLAST_ALWAYS_INLINE void tile(std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
BLAST_ALWAYS_INLINE void tile(StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
{
size_t constexpr TILE_SIZE = TileSize_v<ET>;
size_t constexpr SS = SimdSize_v<ET>;
size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined

size_t j = 0;
static_assert(SO == columnMajor, "tile() for row-major matrices not implemented");

// Main part
for (; j + TILE_SIZE <= n; j += TILE_SIZE)
if (traversal_order == columnMajor)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
}
size_t j = 0;

for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE)
// Main part
for (; j + TILE_STEP <= n; j += TILE_STEP)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS)
{
RegisterMatrix<ET, 3 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

for (; i + 2 * SS <= m; i += 2 * SS)
{
RegisterMatrix<ET, 2 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

for (; i + 1 * SS <= m; i += 1 * SS)
{
RegisterMatrix<ET, 1 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

// Bottom side
if (i < m)
{
RegisterMatrix<ET, SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, m - i, ker.columns());
}
}

for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
}

// Bottom side
if (i < m)
// Right side
if (j < n)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, m - i, ker.columns());
size_t i = 0;

// i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP:
// it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel.
for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS)
{
RegisterMatrix<ET, 3 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 2 * SS <= m; i += 2 * SS)
{
RegisterMatrix<ET, 2 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 1 * SS <= m; i += 1 * SS)
{
RegisterMatrix<ET, 1 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

// Bottom-right corner
if (i < m)
{
RegisterMatrix<ET, SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, m - i, n - j);
}
}
}


// Right side
if (j < n)
else
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}
// i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS:
// it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel.
for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS)
tile_backend<ET, 3 * SS, TILE_STEP, SO>(m, n, i, f_full, f_partial);

for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}
for (; i + 1 * SS < m; i += 2 * SS)
tile_backend<ET, 2 * SS, TILE_STEP, SO>(m, n, i, f_full, f_partial);

// Bottom-right corner
if (i < m)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, m - i, n - j);
}
for (; i + 0 * SS < m; i += 1 * SS)
tile_backend<ET, 1 * SS, TILE_STEP, SO>(m, n, i, f_full, f_partial);
}
}
}
1 change: 1 addition & 0 deletions include/blast/math/dense/DynamicMatrixPointer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace blast
static bool constexpr aligned = AF;
static bool constexpr padded = PF;
static bool constexpr isStatic = false;
static StorageOrder constexpr cachePreferredTraversal = SO == columnMajor ? columnMajor : rowMajor;


/**
Expand Down
2 changes: 2 additions & 0 deletions include/blast/math/dense/StaticMatrixPointer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once

#include <blast/math/StorageOrder.hpp>
#include <blast/math/TransposeFlag.hpp>
#include <blast/math/simd/Simd.hpp>
#include <blast/math/simd/SimdVec.hpp>
Expand Down Expand Up @@ -31,6 +32,7 @@ namespace blast
static bool constexpr aligned = AF;
static bool constexpr padded = PF;
static bool constexpr isStatic = true;
static StorageOrder constexpr cachePreferredTraversal = SO == columnMajor ? columnMajor : rowMajor;


/**
Expand Down
10 changes: 5 additions & 5 deletions include/blast/math/expressions/PanelMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ namespace blast

for (size_t i = 0; i + SS <= m; i += SS)
{
ET2 const * pr = (*rhs).ptr(i, 0);
ET2 const * pr = &(*rhs)(i, 0);
ET1 * pl = data(lhs) + i;

for (size_t j = 0; j < n; ++j)
Expand All @@ -119,7 +119,7 @@ namespace blast
{
MaskType const mask = SIMD::index() < rem;
size_t const i = m - rem;
ET2 const * pr = (*rhs).ptr(i, 0);
ET2 const * pr = &(*rhs)(i, 0);
ET1 * pl = data(lhs) + i;

for (size_t j = 0; j < n; ++j)
Expand Down Expand Up @@ -197,7 +197,7 @@ namespace blast
for (size_t i = 0; i + SS <= m; i += SS)
{
ET2 const * pr = data(rhs) + i;
ET1 * pl = (*lhs).ptr(i, 0);
ET1 * pl = &(*lhs)(i, 0);

for (size_t j = 0; j < n; ++j)
store<aligned>(pl + PANEL_SIZE * j, load<aligned, SS>(pr + s * j));
Expand All @@ -207,7 +207,7 @@ namespace blast
{
size_t const i = m - rem;
ET2 const * pr = data(rhs) + i;
ET1 * pl = (*lhs).ptr(i, 0);
ET1 * pl = &(*lhs)(i, 0);

for (size_t j = 0; j < n; ++j)
for (size_t i1 = 0; i1 < rem; ++i1)
Expand All @@ -230,7 +230,7 @@ namespace blast
for (size_t i = 0; i < m; ++i)
{
ET2 const * pr = data(rhs) + s * i;
ET1 * pl = (*lhs).ptr(i, 0);
ET1 * pl = &(*lhs)(i, 0);

for (size_t j = 0; j < n; ++j)
pl[PANEL_SIZE * j] = pr[j];
Expand Down
Loading

0 comments on commit 785914c

Please sign in to comment.