diff --git a/include/blast/math/dense/Trmm.hpp b/include/blast/math/dense/Trmm.hpp index c82bd67..f3bce70 100644 --- a/include/blast/math/dense/Trmm.hpp +++ b/include/blast/math/dense/Trmm.hpp @@ -4,40 +4,34 @@ #pragma once +#include #include #include -#include -#include -#include - -#include - namespace blast { /// @brief C = alpha * A * B + C; A upper-triangular /// - template + template + requires Matrix && Matrix && Matrix + && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) inline void trmmLeftUpper( ST alpha, - DenseMatrix const& A, DenseMatrix const& B, - DenseMatrix& C) + MT1 const& A, MT2 const& B, + MT3& C) { - using ET = ElementType_t; + using ET = ST; size_t constexpr TILE_SIZE = TileSize_v; - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - size_t const M = rows(B); size_t const N = columns(B); if (rows(A) != M || columns(A) != M) - BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match"); + throw std::invalid_argument {"Matrix sizes do not match"}; if (rows(C) != M || columns(C) != N) - BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match"); + throw std::invalid_argument {"Matrix sizes do not match"}; size_t i = 0; @@ -45,32 +39,30 @@ namespace blast // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. for (; i + 2 * TILE_SIZE < M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) trmmLeftUpper_backend<3 * TILE_SIZE, TILE_SIZE>( - M - i, N, alpha, ptr(*A, i, i), ptr(*B, i, 0), ptr(*C, i, 0)); + M - i, N, alpha, ptr(A, i, i), ptr(B, i, 0), ptr(C, i, 0)); for (; i + 1 * TILE_SIZE < M; i += 2 * TILE_SIZE) trmmLeftUpper_backend<2 * TILE_SIZE, TILE_SIZE>( - M - i, N, alpha, ptr(*A, i, i), ptr(*B, i, 0), ptr(*C, i, 0)); + M - i, N, alpha, ptr(A, i, i), ptr(B, i, 0), ptr(C, i, 0)); for (; i + 0 * TILE_SIZE < M; i += 1 * TILE_SIZE) trmmLeftUpper_backend<1 * TILE_SIZE, TILE_SIZE>( - M - i, N, alpha, ptr(*A, i, i), ptr(*B, i, 0), ptr(*C, i, 0)); + M - i, N, alpha, ptr(A, i, i), ptr(B, i, 0), ptr(C, i, 0)); } /// @brief C = alpha * B * A + C; A lower-triangular /// - template + template + requires Matrix && Matrix && Matrix + && (StorageOrder_v == columnMajor) && (StorageOrder_v == columnMajor) inline void trmmRightLower( ET alpha, - DenseMatrix const& B, DenseMatrix const& A, - DenseMatrix& C) + MTB const& B, MTA const& A, + MTC& C) { size_t constexpr TILE_SIZE = TileSize_v; - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t, ET); - size_t const M = rows(B); size_t const N = columns(B); @@ -93,46 +85,46 @@ namespace blast for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j)); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j)); /* ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); */ - ker.store(ptr(*C, i, j)); + ker.store(ptr(C, i, j)); } for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j)); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j)); /* ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); */ - ker.store(ptr(*C, i, j)); + ker.store(ptr(C, i, j)); } for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j)); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j)); /* ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j)); */ - ker.store(ptr(*C, i, j)); + ker.store(ptr(C, i, j)); } // Bottom side if (i < M) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j), M - i, ker.columns()); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), M - i, ker.columns()); /* ker.trmmRightLower(alpha, ptr(B, i, j), ptr(A, j, j)); ker.gemm(K, alpha, ptr(B, i, j + TILE_SIZE), ptr(A, j + TILE_SIZE, j), M - i, ker.columns()); */ - ker.store(ptr(*C, i, j), M - i, ker.columns()); + ker.store(ptr(C, i, j), M - i, ker.columns()); } } @@ -147,30 +139,30 @@ namespace blast for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j), ker.rows(), N - j); - ker.store(ptr(*C, i, j), ker.rows(), N - j); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), ker.rows(), N - j); + ker.store(ptr(C, i, j), ker.rows(), N - j); } for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j), ker.rows(), N - j); - ker.store(ptr(*C, i, j), ker.rows(), N - j); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), ker.rows(), N - j); + ker.store(ptr(C, i, j), ker.rows(), N - j); } for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j), ker.rows(), N - j); - ker.store(ptr(*C, i, j), ker.rows(), N - j); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), ker.rows(), N - j); + ker.store(ptr(C, i, j), ker.rows(), N - j); } // Bottom-right corner if (i < M) { RegisterMatrix ker; - gemm(ker, N - j, alpha, ptr(*B, i, j), ptr(*A, j, j), M - i, N - j); - ker.store(ptr(*C, i, j), M - i, N - j); + gemm(ker, N - j, alpha, ptr(B, i, j), ptr(A, j, j), M - i, N - j); + ker.store(ptr(C, i, j), M - i, N - j); } } } diff --git a/include/blast/math/dense/TrmmBackend.hpp b/include/blast/math/dense/TrmmBackend.hpp index 07d6b82..00ed745 100644 --- a/include/blast/math/dense/TrmmBackend.hpp +++ b/include/blast/math/dense/TrmmBackend.hpp @@ -17,7 +17,7 @@ namespace blast BLAZE_ALWAYS_INLINE void trmmLeftUpper_backend(size_t M, size_t N, T alpha, P1 a, P2 b, P3 c) { size_t constexpr TILE_SIZE = TileSize_v; - BLAZE_STATIC_ASSERT(KM % TILE_SIZE == 0); + static_assert(KM % TILE_SIZE == 0); RegisterMatrix ker; diff --git a/include/blast/math/reference/Trmm.hpp b/include/blast/math/reference/Trmm.hpp new file mode 100644 index 0000000..0bc9926 --- /dev/null +++ b/include/blast/math/reference/Trmm.hpp @@ -0,0 +1,221 @@ +// Copyright (c) 2019-2024 Mikhail Katliar All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#pragma once + +#include +#include +#include + + +namespace blast :: reference +{ + /** + * @brief Reference implementation of left triangular matrix multiplication. + * + * Performs the matrix-matrix operation + * + * C := alpha * A * B + * + * where alpha is a scalar, B is an m by n matrix, A is a unit, or + * non-unit, upper or lower triangular matrix. + * + * LAPACK reference: https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + * + * @tparam Real real number type + * @tparam MPA matrix pointer type for the matrix @a A + * @tparam MPB matrix pointer type for the matrix @a B + * @tparam MPC matrix pointer type for the matrix @a C + * + * @param m number of rows in @a B and @a C + * @param n number of columns in @a B and @a C + * @param alpha scalar multiplier + * @param A pointer to a matrix of dimension ( @a m, @a m ). Depending on the value of @a uplo, the + * upper (lower) triangular part of @a A must contain the upper (lower) triangular matrix + * and the strictly lower (upper) triangular part of @a A is not referenced. When @a diag == true, the diagonal elements of + * @a A are not referenced either, but are assumed to be unity. + * @param uplo specifies whether the matrix @a A is an upper or lower triangular matrix + * @param diag specifies whether or not @a A is unit triangular + * @param B pointer to a matrix of dimension ( @a m, @a n ). + * @param C pointer to a matrix of dimension ( @a m, @a n ) for the result. Can be equal to @a B. + */ + template + requires MatrixPointer && MatrixPointer && MatrixPointer + inline void trmm(size_t m, size_t n, Real alpha, MPA A, UpLo uplo, bool diag, MPB B, MPC C) + { + for (size_t j = 0; j < n; ++j) + { + if (uplo == UpLo::Upper) + { + for (size_t i = 0; i < m; ++i) + { + Real v {}; + for (size_t k = i; k < m; ++k) + v += (diag && k == i) ? *(~B)(k, j) : *(~A)(i, k) * *(~B)(k, j); + + *(~C)(i, j) = alpha * v; + } + } + else + { + for (size_t i = m; i-- > 0; ) + { + Real v {}; + for (size_t k = 0; k <= i; ++k) + v += (diag && k == i) ? *(~B)(k, j) : *(~A)(i, k) * *(~B)(k, j); + + *(~C)(i, j) = alpha * v; + } + } + } + } + + + /** + * @brief Reference implementation of right triangular matrix multiplication. + * + * Performs the matrix-matrix operation + * + * C := alpha * B * A + * + * where alpha is a scalar, B is an m by n matrix, A is a unit, or + * non-unit, upper or lower triangular matrix. + * + * LAPACK reference: https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + * + * @tparam Real real number type + * @tparam MPB matrix pointer type for the matrix @a B + * @tparam MPA matrix pointer type for the matrix @a A + * @tparam MPC matrix pointer type for the matrix @a C + * + * @param m number of rows in @a B and @a C + * @param n number of columns in @a B and @a C + * @param alpha scalar multiplier + * @param B pointer to a matrix of dimension ( @a m, @a n ). + * @param A pointer to a matrix of dimension ( @a n, @a n ). Depending on the value of @a uplo, the + * upper (lower) triangular part of @a A must contain the upper (lower) triangular matrix + * and the strictly lower (upper) triangular part of @a A is not referenced. When @a diag == true, the diagonal elements of + * @a A are not referenced either, but are assumed to be unity. + * @param uplo specifies whether the matrix @a A is an upper or lower triangular matrix + * @param diag specifies whether or not @a A is unit triangular + * @param C pointer to a matrix of dimension ( @a m, @a n ) for the result. Can be equal to @a B. + */ + template + requires MatrixPointer && MatrixPointer && MatrixPointer + inline void trmm(size_t m, size_t n, Real alpha, MPB B, MPA A, UpLo uplo, bool diag, MPC C) + { + for (size_t i = 0; i < m; ++i) + { + if (uplo == UpLo::Lower) + { + for (size_t j = 0; j < n; ++j) + { + Real v {}; + for (size_t k = j; k < n; ++k) + v += (diag && k == j) ? *(~B)(i, k) : *(~B)(i, k) * *(~A)(k, j); + + *(~C)(i, j) = alpha * v; + } + } + else + { + for (size_t j = n; j-- > 0; ) + { + Real v {}; + for (size_t k = 0; k <= j; ++k) + v += (diag && k == j) ? *(~B)(i, k) : *(~B)(i, k) * *(~A)(k, j); + + *(~C)(i, j) = alpha * v; + } + } + } + } + + + /** + * @brief Reference implementation of left triangular matrix multiplication. + * + * Performs the matrix-matrix operation + * + * C := alpha * A * B + * + * where alpha is a scalar, B is an m by n matrix, A is a unit, or + * non-unit, upper or lower triangular matrix. + * + * LAPACK reference: https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + * + * @tparam Real real number type + * @tparam MTA matrix type for the matrix @a A + * @tparam MTB matrix type for the matrix @a B + * @tparam MTC matrix type for the matrix @a C + * + * @param alpha scalar multiplier + * @param A a matrix of dimension (m, m). Depending on the value of @a uplo, the + * upper (lower) triangular part of @a A must contain the upper (lower) triangular matrix + * and the strictly lower (upper) triangular part of @a A is not referenced. When @a diag == true, the diagonal elements of + * @a A are not referenced either, but are assumed to be unity. + * @param uplo specifies whether the matrix @a A is an upper or lower triangular matrix + * @param diag specifies whether or not @a A is unit triangular + * @param B a matrix of dimension (m, n). + * @param C a matrix of dimension (m, n) for the result. Can be the same matrix as @a B. + * + * @throw @a std::invalid_argument if matrix sizes are inconsistent + */ + template + requires Matrix && Matrix && Matrix + inline void trmm(Real alpha, MTA const& A, UpLo uplo, bool diag, MTB const& B, MTC& C) + { + size_t const m = rows(B); + size_t const n = columns(B); + + if (rows(A) != m || columns(A) != m || + rows(C) != m || columns(C) != n) + throw std::invalid_argument {"Inconsistent matrix sizes"}; + + trmm(m, n, alpha, ptr(A), uplo, diag, ptr(B), ptr(C)); + } + + + /** + * @brief Reference implementation of right triangular matrix multiplication. + * + * Performs the matrix-matrix operation + * + * C := alpha * B * A + * + * where alpha is a scalar, B is an m by n matrix, A is a unit, or + * non-unit, upper or lower triangular matrix. + * + * LAPACK reference: https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html + * + * @tparam Real real number type + * @tparam MTB matrix type for the matrix @a B + * @tparam MTA matrix type for the matrix @a A + * @tparam MTC matrix type for the matrix @a C + * + * @param alpha scalar multiplier + * @param B a matrix of dimension (m, n). + * @param A a matrix of dimension (m, m). Depending on the value of @a uplo, the + * upper (lower) triangular part of @a A must contain the upper (lower) triangular matrix + * and the strictly lower (upper) triangular part of @a A is not referenced. When @a diag == true, the diagonal elements of + * @a A are not referenced either, but are assumed to be unity. + * @param uplo specifies whether the matrix @a A is an upper or lower triangular matrix + * @param diag specifies whether or not @a A is unit triangular + * @param C a matrix of dimension (m, n) for the result. Can be the same matrix as @a B. + * + * @throw @a std::invalid_argument if matrix sizes are inconsistent + */ + template + requires Matrix && Matrix && Matrix + inline void trmm(Real alpha, MTB const& B, MTA const& A, UpLo uplo, bool diag, MTC& C) + { + size_t const m = rows(B); + size_t const n = columns(B); + + if (rows(A) != n || columns(A) != n || + rows(C) != m || columns(C) != n) + throw std::invalid_argument {"Inconsistent matrix sizes"}; + + trmm(m, n, alpha, ptr(B), ptr(A), uplo, diag, ptr(C)); + } +} diff --git a/test/blast/math/dense/TrmmTest.cpp b/test/blast/math/dense/TrmmTest.cpp index 04005df..cd5fc30 100644 --- a/test/blast/math/dense/TrmmTest.cpp +++ b/test/blast/math/dense/TrmmTest.cpp @@ -4,7 +4,8 @@ #include #include -#include +#include +#include #include #include @@ -35,7 +36,9 @@ namespace blast :: testing // Do trmm trmmLeftUpper(alpha, A, B, C); - BLAST_ASSERT_APPROX_EQ(C, evaluate(alpha * A * B), 1e-10, 1e-10) + DynamicMatrix C_ref(m, n); + reference::trmm(alpha, A, UpLo::Upper, false, B, C_ref); + BLAST_ASSERT_APPROX_EQ(C, C_ref, 1e-10, 1e-10) << "trmm error at size m,n=" << m << "," << n; } } @@ -65,7 +68,9 @@ namespace blast :: testing // Do trmm trmmRightLower(alpha, B, A, C); - BLAST_ASSERT_APPROX_EQ(C, evaluate(alpha * B * A), 1e-10, 1e-10) + DynamicMatrix C_ref(m, n); + reference::trmm(alpha, B, A, UpLo::Lower, false, C_ref); + BLAST_ASSERT_APPROX_EQ(C, C_ref, 1e-10, 1e-10) << "trmm error at size m,n=" << m << "," << n; } }