Skip to content

Commit

Permalink
- Computational complexity calculated in the same way for all gemm tests
Browse files Browse the repository at this point in the history
- No hard-coded alpha and beta in static gemm tests
  • Loading branch information
mkatliar committed Aug 5, 2024
1 parent 51bf00a commit e6e519b
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 37 deletions.
2 changes: 2 additions & 0 deletions bench/analysis/dgemm_performance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import json

Expand Down
2 changes: 2 additions & 0 deletions bench/analysis/dgemm_performance_ratio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import json

Expand Down
10 changes: 4 additions & 6 deletions bench/blas/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.



#include <benchmark/benchmark.h>
#include <bench/Gemm.hpp>

#include <blaze/Math.h>

Expand All @@ -31,10 +29,10 @@ namespace blast :: benchmark
for (auto _ : state)
gemm(C, trans(A), B, 1.0, 1.0);

state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, m, m));
state.counters["m"] = m;
}

BENCHMARK_TEMPLATE(BM_gemm, double)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm, float)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm, double)->DenseRange(1, BENCHMARK_MAX_GEMM);
BENCHMARK_TEMPLATE(BM_gemm, float)->DenseRange(1, BENCHMARK_MAX_GEMM);
}
3 changes: 1 addition & 2 deletions bench/blasfeo/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <bench/Gemm.hpp>

#include <random>
#include <memory>


namespace blast :: benchmark
Expand Down Expand Up @@ -53,7 +52,7 @@ namespace blast :: benchmark
for (auto _ : state)
gemm_nt(m, n, k, 1., A, 0, 0, B, 0, 0, 1., C, 0, 0, C, 0, 0);

state.counters["flops"] = Counter(2 * m * n * k + 3 * m * n, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, n, k));
state.counters["m"] = m;
}

Expand Down
13 changes: 7 additions & 6 deletions bench/blast/math/dense/DynamicGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

#include <blast/math/dense/Gemm.hpp>

#include <bench/Gemm.hpp>

#include <blaze/math/DynamicMatrix.h>

#include <bench/Gemm.hpp>
#include <test/Randomize.hpp>

#include <random>
#include <memory>


namespace blast :: benchmark
{
Expand All @@ -26,21 +24,24 @@ namespace blast :: benchmark
DynamicMatrix<Real, columnMajor> B(N, K);
DynamicMatrix<Real, columnMajor> C(M, N);
DynamicMatrix<Real, columnMajor> D(M, N);
Real alpha, beta;

randomize(A);
randomize(B);
randomize(C);
randomize(alpha);
randomize(beta);

for (auto _ : state)
{
gemm(1., A, trans(B), 1., C, D);
gemm(alpha, A, B, beta, C, D);
DoNotOptimize(A);
DoNotOptimize(B);
DoNotOptimize(C);
DoNotOptimize(D);
}

state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(M, N, K));
state.counters["m"] = M;
}

Expand Down
5 changes: 3 additions & 2 deletions bench/blast/math/dense/StaticGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

#include <blast/math/dense/Gemm.hpp>

#include <bench/Gemm.hpp>

#include <blaze/math/StaticMatrix.h>

#include <bench/Gemm.hpp>
#include <test/Randomize.hpp>


Expand Down Expand Up @@ -39,7 +40,7 @@ namespace blast :: benchmark
DoNotOptimize(D);
}

state.counters["flops"] = Counter(2 * M * N * K + 3 * M * N, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(M, N, K));
state.counters["m"] = M;
}

Expand Down
4 changes: 2 additions & 2 deletions bench/blast/math/panel/DynamicGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "blast/math/StorageOrder.hpp"
#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/panel/Gemm.hpp>

#include <bench/Gemm.hpp>

#include <test/Randomize.hpp>


Expand Down Expand Up @@ -40,7 +40,7 @@ namespace blast :: benchmark
DoNotOptimize(D);
}

state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(M, N, K));
state.counters["m"] = M;
}

Expand Down
5 changes: 2 additions & 3 deletions bench/blast/math/panel/StaticGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "blast/math/StorageOrder.hpp"
#include <blast/math/StaticPanelMatrix.hpp>
#include <blast/math/panel/Gemm.hpp>

#include <bench/Gemm.hpp>
#include <blaze/util/Random.h>

#include <test/Randomize.hpp>


Expand Down Expand Up @@ -40,7 +39,7 @@ namespace blast :: benchmark
DoNotOptimize(D);
}

state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(M, N, K));
state.counters["m"] = M;
}

Expand Down
8 changes: 4 additions & 4 deletions bench/blaze/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace blast :: benchmark
DoNotOptimize(C);
}

state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(M, N, K));
state.counters["m"] = M;
state.counters["n"] = N;
state.counters["k"] = K;
Expand Down Expand Up @@ -63,7 +63,7 @@ namespace blast :: benchmark
DoNotOptimize(C);
}

state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, m, m));
state.counters["m"] = m;
}

Expand All @@ -90,7 +90,7 @@ namespace blast :: benchmark
}
}

state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, m, m));
state.counters["m"] = m;
}

Expand Down Expand Up @@ -123,7 +123,7 @@ namespace blast :: benchmark
}
}

state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, m, m));
state.counters["m"] = m;
}

Expand Down
8 changes: 3 additions & 5 deletions bench/eigen/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include <bench/Gemm.hpp>

#include <vector>


namespace blast :: benchmark
{
Expand Down Expand Up @@ -43,7 +41,7 @@ namespace blast :: benchmark
::benchmark::DoNotOptimize(C);
}

state.counters["flops"] = Counter(2 * M * N * K, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(M, N, K));
state.counters["m"] = M;
state.counters["n"] = N;
state.counters["k"] = K;
Expand Down Expand Up @@ -72,12 +70,12 @@ namespace blast :: benchmark
::benchmark::DoNotOptimize(C);
}

state.counters["flops"] = Counter(2 * m * m * m, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, m, m));
state.counters["m"] = m;
}


BENCHMARK_TEMPLATE(BM_gemm_dynamic, double)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm_dynamic, double)->DenseRange(1, BENCHMARK_MAX_GEMM);


#define BOOST_PP_LOCAL_LIMITS (1, BENCHMARK_MAX_GEMM)
Expand Down
15 changes: 8 additions & 7 deletions bench/libxsmm/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <bench/Benchmark.hpp>
#include <bench/Gemm.hpp>

#include <test/Randomize.hpp>

#include <libxsmm.h>
Expand Down Expand Up @@ -39,7 +40,7 @@ namespace blast :: benchmark
for (auto _ : state)
kernel(a.data(), b.data(), c.data());

state.counters["flops"] = Counter(m * n * k, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, n, k));
state.counters["m"] = m;
}

Expand Down Expand Up @@ -69,14 +70,14 @@ namespace blast :: benchmark
for (auto _ : state)
kernel(a.data(), b.data(), c.data());

state.counters["flops"] = Counter(2 * m * n * k, Counter::kIsIterationInvariantRate);
setCounters(state.counters, complexityGemm(m, n, k));
state.counters["m"] = m;
}


BENCHMARK_TEMPLATE(BM_gemm_nn, double)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm_nt, double)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm_nn, double)->DenseRange(1, BENCHMARK_MAX_GEMM);
BENCHMARK_TEMPLATE(BM_gemm_nt, double)->DenseRange(1, BENCHMARK_MAX_GEMM);

BENCHMARK_TEMPLATE(BM_gemm_nn, float)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm_nt, float)->DenseRange(1, 50);
BENCHMARK_TEMPLATE(BM_gemm_nn, float)->DenseRange(1, BENCHMARK_MAX_GEMM);
BENCHMARK_TEMPLATE(BM_gemm_nt, float)->DenseRange(1, BENCHMARK_MAX_GEMM);
}
14 changes: 14 additions & 0 deletions include/bench/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,19 @@
#pragma once

#include <bench/Benchmark.hpp>
#include <bench/Complexity.hpp>

#define BENCHMARK_MAX_GEMM 50


namespace blast :: benchmark
{
/// @brief Algorithmic complexity of gemm
inline Complexity complexityGemm(std::size_t m, std::size_t n, std::size_t k)
{
return {
{"add", (m * n) * (k + 2)},
{"mul", (m * n) * (k + 1)},
};
}
}

0 comments on commit e6e519b

Please sign in to comment.