From 682b3345c2736b635fb27af46bb3e32cda51561c Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Mon, 30 Sep 2024 11:58:00 +0200 Subject: [PATCH] trmm() refactoring --- bench/blast/math/dense/StaticTrmm.cpp | 4 +- include/blast/math/algorithm/Trmm.hpp | 341 +++++++++++++++++--------- include/blast/util/Exception.hpp | 3 - test/blast/math/dense/TrmmTest.cpp | 5 +- 4 files changed, 226 insertions(+), 127 deletions(-) diff --git a/bench/blast/math/dense/StaticTrmm.cpp b/bench/blast/math/dense/StaticTrmm.cpp index ec709322..b0594805 100644 --- a/bench/blast/math/dense/StaticTrmm.cpp +++ b/bench/blast/math/dense/StaticTrmm.cpp @@ -24,7 +24,7 @@ namespace blast :: benchmark for (auto _ : state) { - trmmLeftUpper(1., A, B, C); + trmm(1., A, UpLo::Upper, false, B, C); DoNotOptimize(A); DoNotOptimize(B); DoNotOptimize(C); @@ -48,7 +48,7 @@ namespace blast :: benchmark for (auto _ : state) { - trmmRightLower(1., B, A, C); + trmm(1., B, A, UpLo::Lower, false, C); DoNotOptimize(A); DoNotOptimize(B); DoNotOptimize(C); diff --git a/include/blast/math/algorithm/Trmm.hpp b/include/blast/math/algorithm/Trmm.hpp index 01eafb67..f1d7ccb9 100644 --- a/include/blast/math/algorithm/Trmm.hpp +++ b/include/blast/math/algorithm/Trmm.hpp @@ -7,8 +7,12 @@ #include #include #include +#include #include #include +#include + +#include namespace blast @@ -67,159 +71,256 @@ namespace blast } } } - /// @brief C = alpha * A * B + C; A upper-triangular + + + /// @brief C = alpha * A * B; A upper- or lower-triangular. Matrix pointer arguments. /// - template - requires Matrix && Matrix && Matrix - && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) - inline void trmmLeftUpper( - ST alpha, - MT1 const& A, MT2 const& B, - MT3& C) + /// See https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + /// + /// @tparam MPA matrix pointer type for matrix A + /// @tparam MPB matrix pointer type for matrix B + /// @tparam MPC matrix pointer type for matrix C + /// + /// @param M the number of rows of B + /// @param N the number of columns of B + /// @param alpha the scalar alpha + /// @param A pointer to top left element of matrix A + /// @param uplo specifies whether the matrix A is an upper or lower triangular + /// @param diag specifies whether or not A is unit triangular + /// @param B pointer to top left element of matrix B + /// @param C pointer to top left element of matrix C + /// + template + requires MatrixPointer && MatrixPointer && MatrixPointer + && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) + inline void trmm(size_t M, size_t N, ST alpha, MPA A, UpLo uplo, bool diag, MPB B, MPC C) { using ET = ST; size_t constexpr TILE_SIZE = TileSize_v; - size_t const M = rows(B); - size_t const N = columns(B); - - if (rows(A) != M || columns(A) != M) - throw std::invalid_argument {"Matrix sizes do not match"}; + if (diag) + BLAST_THROW_EXCEPTION(std::logic_error {"Unit-triangular matrices support not implemented in trmm()"}); - if (rows(C) != M || columns(C) != N) - throw std::invalid_argument {"Matrix sizes do not match"}; - - size_t i = 0; + if (uplo == UpLo::Upper) + { + size_t i = 0; - // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: - // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. - for (; i + 2 * TILE_SIZE < M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) - detail::trmmLeftUpper_backend<3 * TILE_SIZE, TILE_SIZE>( - M - i, N, alpha, ptr(A, i, i), ptr(B, i, 0), ptr(C, i, 0)); + // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: + // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. + for (; i + 2 * TILE_SIZE < M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) + detail::trmmLeftUpper_backend<3 * TILE_SIZE, TILE_SIZE>( + M - i, N, alpha, A(i, i), B(i, 0), C(i, 0)); - for (; i + 1 * TILE_SIZE < M; i += 2 * TILE_SIZE) - detail::trmmLeftUpper_backend<2 * TILE_SIZE, TILE_SIZE>( - M - i, N, alpha, ptr(A, i, i), ptr(B, i, 0), ptr(C, i, 0)); + for (; i + 1 * TILE_SIZE < M; i += 2 * TILE_SIZE) + detail::trmmLeftUpper_backend<2 * TILE_SIZE, TILE_SIZE>( + M - i, N, alpha, A(i, i), B(i, 0), C(i, 0)); - for (; i + 0 * TILE_SIZE < M; i += 1 * TILE_SIZE) - detail::trmmLeftUpper_backend<1 * TILE_SIZE, TILE_SIZE>( - M - i, N, alpha, ptr(A, i, i), ptr(B, i, 0), ptr(C, i, 0)); + for (; i + 0 * TILE_SIZE < M; i += 1 * TILE_SIZE) + detail::trmmLeftUpper_backend<1 * TILE_SIZE, TILE_SIZE>( + M - i, N, alpha, A(i, i), B(i, 0), C(i, 0)); + } + else + { + BLAST_THROW_EXCEPTION(std::logic_error {"Left product with lower-triangular matrices not implemented in trmm()"}); + } } - /// @brief C = alpha * B * A + C; A lower-triangular + /// @brief C = alpha * B * A; A upper- or lower-triangular. Matrix pointer arguments. /// - template - requires Matrix && Matrix && Matrix - && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) - inline void trmmRightLower( - ET alpha, - MTB const& B, MTA const& A, - MTC& C) + /// See https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + /// + /// @tparam MPB matrix pointer type for matrix B + /// @tparam MPA matrix pointer type for matrix A + /// @tparam MPC matrix pointer type for matrix C + /// + /// @param M the number of rows of B + /// @param N the number of columns of B + /// @param alpha the scalar alpha + /// @param B pointer to top left element of matrix B + /// @param A pointer to top left element of matrix A + /// @param uplo specifies whether the matrix A is an upper or lower triangular + /// @param diag specifies whether or not A is unit triangular + /// @param C pointer to top left element of matrix C + /// + template + requires MatrixPointer && MatrixPointer && MatrixPointer + && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) + inline void trmm(size_t M, size_t N, ST alpha, MPB B, MPA A, UpLo uplo, bool diag, MPC C) { + using ET = ST; size_t constexpr TILE_SIZE = TileSize_v; - size_t const M = rows(B); - size_t const N = columns(B); + if (diag) + BLAST_THROW_EXCEPTION(std::logic_error {"Unit-triangular matrices support not implemented in trmm()"}); - if (rows(A) != N || columns(A) != N) - BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match"); + if (uplo == UpLo::Lower) + { + size_t j = 0; - if (rows(C) != M || columns(C) != N) - BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match"); + // Main part + for (; j + TILE_SIZE <= N; j += TILE_SIZE) + { + // size_t const K = N - j - TILE_SIZE; + size_t i = 0; - size_t j = 0; + // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: + // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. + for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j)); + /* + ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); + ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); + */ + ker.store(C(i, j)); + } - // Main part - for (; j + TILE_SIZE <= N; j += TILE_SIZE) - { - // size_t const K = N - j - TILE_SIZE; - size_t i = 0; + for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j)); + /* + ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); + ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); + */ + ker.store(C(i, j)); + } - // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: - // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. - for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j)); - /* - ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); - ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); - */ - ker.store(ptr(C, i, j)); - } + for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j)); + /* + ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); + ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); + */ + ker.store(C(i, j)); + } - for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j)); - /* - ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); - ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); - */ - ker.store(ptr(C, i, j)); + // Bottom side + if (i < M) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j), M - i, ker.columns()); + /* + ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); + ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j), M - i, ker.columns()); + */ + ker.store(C(i, j), M - i, ker.columns()); + } } - for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j)); - /* - ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); - ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); - */ - ker.store(ptr(C, i, j)); - } - // Bottom side - if (i < M) + // Right side + if (j < N) { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), M - i, ker.columns()); - /* - ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); - ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j), M - i, ker.columns()); - */ - ker.store(ptr(C, i, j), M - i, ker.columns()); + size_t i = 0; + + // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: + // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. + for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j), ker.rows(), N - j); + ker.store(C(i, j), ker.rows(), N - j); + } + + for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j), ker.rows(), N - j); + ker.store(C(i, j), ker.rows(), N - j); + } + + for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j), ker.rows(), N - j); + ker.store(C(i, j), ker.rows(), N - j); + } + + // Bottom-right corner + if (i < M) + { + RegisterMatrix ker; + gemm(ker, N - j, alpha, B(i, j), A(j, j), M - i, N - j); + ker.store(C(i, j), M - i, N - j); + } } } + else + { + BLAST_THROW_EXCEPTION(std::logic_error {"Right product with upper-triangular matrices not implemented in trmm()"}); + } + } - // Right side - if (j < N) - { - size_t i = 0; + /// @brief C = alpha * A * B; A upper- or lower-triangular. Matrix arguments. + /// + /// See https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + /// + /// @tparam MT1 matrix type for matrix A + /// @tparam MT2 matrix type for matrix B + /// @tparam MT3 matrix type for matrix C + /// + /// @param alpha the scalar alpha + /// @param A matrix A + /// @param uplo specifies whether the matrix A is an upper or lower triangular + /// @param diag specifies whether or not A is unit triangular + /// @param B matrix B + /// @param C matrix C + /// + template + requires Matrix && Matrix && Matrix + inline void trmm(ST alpha, MT1 const& A, UpLo uplo, bool diag, MT2 const& B, MT3& C) + { + using ET = ST; + size_t constexpr TILE_SIZE = TileSize_v; - // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: - // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. - for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), ker.rows(), N - j); - ker.store(ptr(C, i, j), ker.rows(), N - j); - } + size_t const M = rows(B); + size_t const N = columns(B); - for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), ker.rows(), N - j); - ker.store(ptr(C, i, j), ker.rows(), N - j); - } + if (rows(A) != M || columns(A) != M) + BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), ker.rows(), N - j); - ker.store(ptr(C, i, j), ker.rows(), N - j); - } + if (rows(C) != M || columns(C) != N) + BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - // Bottom-right corner - if (i < M) - { - RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), M - i, N - j); - ker.store(ptr(C, i, j), M - i, N - j); - } - } + trmm(M, N, alpha, ptr(A), uplo, diag, ptr(B), ptr(C)); + } + + + /// @brief C = alpha * B * A + C; A lower-triangular. Matrix arguments. + /// + /// See https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + /// + /// @tparam MTB matrix type for matrix B + /// @tparam MTA matrix type for matrix A + /// @tparam MTC matrix type for matrix C + /// + /// @param alpha the scalar alpha + /// @param B matrix B + /// @param A matrix A + /// @param uplo specifies whether the matrix A is an upper or lower triangular + /// @param diag specifies whether or not A is unit triangular + /// @param C matrix C + /// + template + requires Matrix && Matrix && Matrix + && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) + inline void trmm(ET alpha, MTB const& B, MTA const& A, UpLo uplo, bool diag, MTC& C) + { + size_t const M = rows(B); + size_t const N = columns(B); + + if (rows(A) != N || columns(A) != N) + BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match"); + + if (rows(C) != M || columns(C) != N) + BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match"); + + trmm(M, N, alpha, ptr(B), ptr(A), uplo, diag, ptr(C)); } } diff --git a/include/blast/util/Exception.hpp b/include/blast/util/Exception.hpp index 77143e1f..03608f0c 100644 --- a/include/blast/util/Exception.hpp +++ b/include/blast/util/Exception.hpp @@ -7,7 +7,4 @@ #include -#include - - #define BLAST_THROW_EXCEPTION BOOST_THROW_EXCEPTION diff --git a/test/blast/math/dense/TrmmTest.cpp b/test/blast/math/dense/TrmmTest.cpp index 7d755b03..e7e9b816 100644 --- a/test/blast/math/dense/TrmmTest.cpp +++ b/test/blast/math/dense/TrmmTest.cpp @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +#include "blast/math/UpLo.hpp" #include #include #include @@ -34,7 +35,7 @@ namespace blast :: testing randomize(alpha); // Do trmm - trmmLeftUpper(alpha, A, B, C); + trmm(alpha, A, UpLo::Upper, false, B, C); DynamicMatrix C_ref(m, n); reference::trmm(alpha, A, UpLo::Upper, false, B, C_ref); @@ -66,7 +67,7 @@ namespace blast :: testing randomize(alpha); // Do trmm - trmmRightLower(alpha, B, A, C); + trmm(alpha, B, A, UpLo::Lower, false, C); DynamicMatrix C_ref(m, n); reference::trmm(alpha, B, A, UpLo::Lower, false, C_ref);