From 4ffbbf63382ac59fa3aa8d7c1d18f42feae62cbe Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Tue, 1 Oct 2024 13:52:20 +0200 Subject: [PATCH 1/4] Added RegisterMatrixTest.testAxpy --- include/blast/math/reference/Axpy.hpp | 75 +++++++++++++++++++++ test/blast/math/simd/RegisterMatrixTest.cpp | 44 +++++++++--- 2 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 include/blast/math/reference/Axpy.hpp diff --git a/include/blast/math/reference/Axpy.hpp b/include/blast/math/reference/Axpy.hpp new file mode 100644 index 0000000..8c6b375 --- /dev/null +++ b/include/blast/math/reference/Axpy.hpp @@ -0,0 +1,75 @@ +// Copyright (c) 2019-2020 Mikhail Katliar All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#pragma once + +#include +#include +#include +#include +#include + + +namespace blast :: reference +{ + /** + * @brief Constant times matrix plus matrix, matrix pointer arguments + * + * C := alpha*A + B + * + * where alpha is a scalar, and A, B and C are M by N matrices + * + * @tparam ST scalar type for @a alpha + * @tparam MPA matrix pointer type for @a A + * @tparam MPB matrix pointer type for @a B + * @tparam MPC matrix pointer type for @a C + * + * @param M the number of rows of the matrices A, B, and C. + * @param N the number of columns of the matrices A, B and C. + * @param alpha the scalar alpha + * @param A pointer to the top left element of matrix A + * @param B pointer to the top left element of matrix B + * @param C pointer to the top left element of matrix C + */ + template + inline void axpy(size_t M, size_t N, ST alpha, MPA A, MPB B, MPC C) + { + for (size_t i = 0; i < M; ++i) + for (size_t j = 0; j < N; ++j) + *(~C)(i, j) = alpha * *(~A)(i, j) + *(~B)(i, j); + } + + + /** + * @brief Constant times matrix plus matrix, matrix arguments + * + * C := alpha*A + B + * + * where alpha is a scalar, and A, B and C are M by N matrices + * + * @tparam ST scalar type for @a alpha + * @tparam MTA matrix type for @a A + * @tparam MTB matrix type for @a B + * @tparam MTC matrix type for @a C + * + * @param alpha the scalar alpha + * @param A pointer to the top left element of matrix A + * @param B pointer to the top left element of matrix B + * @param C pointer to the top left element of matrix C + * + * @throw @a std::invalid_argument if matrix sizes are not consistent + */ + template + inline void axpy(ST alpha, MTA const& A, MTB const& B, MTC& C) + { + size_t const M = rows(C); + size_t const N = columns(C); + + if (rows(A) != M || columns(A) != N || + rows(B) != M || columns(B) != N) + BLAST_THROW_EXCEPTION(std::invalid_argument {"Inconsistent matrix sizes"}); + + reference::axpy(M, N, alpha, ptr(A), ptr(B), ptr(C)); + } +} diff --git a/test/blast/math/simd/RegisterMatrixTest.cpp b/test/blast/math/simd/RegisterMatrixTest.cpp index ab1ef32..3fd1be5 100644 --- a/test/blast/math/simd/RegisterMatrixTest.cpp +++ b/test/blast/math/simd/RegisterMatrixTest.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -28,14 +29,18 @@ namespace blast :: testing using MyTypes = Types< - RegisterMatrix, - RegisterMatrix, - RegisterMatrix, - RegisterMatrix, - RegisterMatrix, - RegisterMatrix, - RegisterMatrix, - RegisterMatrix + RegisterMatrix, 4, columnMajor>, + RegisterMatrix, 2, columnMajor>, + RegisterMatrix, 1, columnMajor>, + RegisterMatrix, 4, columnMajor>, + RegisterMatrix, 2, columnMajor>, + RegisterMatrix, 1, columnMajor>, + RegisterMatrix, 4, columnMajor>, + RegisterMatrix, 2, columnMajor>, + RegisterMatrix, 1, columnMajor>, + RegisterMatrix, 4, columnMajor>, + RegisterMatrix, 4, columnMajor>, + RegisterMatrix, 4, columnMajor> >; @@ -629,4 +634,27 @@ namespace blast :: testing // TODO: should be strictly equal? BLAST_ASSERT_APPROX_EQ(ker, alpha * B * A, absTol(), relTol()); } + + + TYPED_TEST(RegisterMatrixTest, testAxpy) + { + using RM = TypeParam; + using ET = ElementType_t; + + StaticMatrix A, B; + randomize(A); + randomize(B); + + ET alpha {}; + randomize(alpha); + + RM ker; + ker.load(ptr(B)); + ker.axpy(alpha, ptr(A), ker.rows(), ker.columns()); + + StaticMatrix C; + reference::axpy(alpha, A, B, C); + + EXPECT_EQ(ker, C); + } } From 94d903c1eb2637cc994cca916e3b8c9fbd96fcda Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Tue, 1 Oct 2024 14:26:13 +0200 Subject: [PATCH 2/4] Fixed bug in RegisterMatrix::axpy() --- include/blast/math/reference/Axpy.hpp | 33 +++++++++++++-- .../math/register_matrix/RegisterMatrix.hpp | 2 +- test/blast/math/simd/RegisterMatrixTest.cpp | 40 ++++++++++++++++++- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/include/blast/math/reference/Axpy.hpp b/include/blast/math/reference/Axpy.hpp index 8c6b375..90661c0 100644 --- a/include/blast/math/reference/Axpy.hpp +++ b/include/blast/math/reference/Axpy.hpp @@ -54,9 +54,9 @@ namespace blast :: reference * @tparam MTC matrix type for @a C * * @param alpha the scalar alpha - * @param A pointer to the top left element of matrix A - * @param B pointer to the top left element of matrix B - * @param C pointer to the top left element of matrix C + * @param A matrix A + * @param B matrix B + * @param C matrix C * * @throw @a std::invalid_argument if matrix sizes are not consistent */ @@ -72,4 +72,31 @@ namespace blast :: reference reference::axpy(M, N, alpha, ptr(A), ptr(B), ptr(C)); } + + + /** + * @brief Constant times matrix plus matrix, matrix arguments, rvalue reference output argument + * + * C := alpha*A + B + * + * where alpha is a scalar, and A, B and C are M by N matrices + * + * @tparam ST scalar type for @a alpha + * @tparam MTA matrix type for @a A + * @tparam MTB matrix type for @a B + * @tparam MTC matrix view type for @a C + * + * @param alpha the scalar alpha + * @param A matrix A + * @param B matrix B + * @param C matrix view C + * + * @throw @a std::invalid_argument if matrix sizes are not consistent + */ + template + requires IsView_v + inline void axpy(ST alpha, MTA const& A, MTB const& B, MTC&& C) + { + reference::axpy(alpha, A, B, C); + } } diff --git a/include/blast/math/register_matrix/RegisterMatrix.hpp b/include/blast/math/register_matrix/RegisterMatrix.hpp index 4eb935d..ae1a6da 100644 --- a/include/blast/math/register_matrix/RegisterMatrix.hpp +++ b/include/blast/math/register_matrix/RegisterMatrix.hpp @@ -154,7 +154,7 @@ namespace blast #pragma unroll for (size_t j = 0; j < N; ++j) if (j < n) #pragma unroll - for (size_t i = 0; i < RM; ++i) if (i * RM < m) + for (size_t i = 0; i < RM; ++i) if (SS * i < m) v_[i][j] = fmadd(beta_simd, a(SS * i, j).load(), v_[i][j]); } diff --git a/test/blast/math/simd/RegisterMatrixTest.cpp b/test/blast/math/simd/RegisterMatrixTest.cpp index 3fd1be5..19aee26 100644 --- a/test/blast/math/simd/RegisterMatrixTest.cpp +++ b/test/blast/math/simd/RegisterMatrixTest.cpp @@ -650,11 +650,49 @@ namespace blast :: testing RM ker; ker.load(ptr(B)); - ker.axpy(alpha, ptr(A), ker.rows(), ker.columns()); + ker.axpy(alpha, ptr(A)); StaticMatrix C; reference::axpy(alpha, A, B, C); EXPECT_EQ(ker, C); } + + + TYPED_TEST(RegisterMatrixTest, testAxpyWithSize) + { + using RM = TypeParam; + using ET = ElementType_t; + + StaticMatrix A, B; + randomize(A); + randomize(B); + + ET alpha {}; + randomize(alpha); + + StaticMatrix C; + + for (size_t m = 0; m < RM::rows(); ++m) + { + for (size_t n = 0; n < RM::columns(); ++n) + { + RM ker; + ker.load(ptr(B)); + ker.axpy(alpha, ptr(A), ker.rows(), ker.columns()); + + reference::axpy(alpha, + submatrix(A, 0, 0, m, n), + submatrix(B, 0, 0, m, n), + submatrix(C, 0, 0, m, n) + ); + + for (size_t i = 0; i < m; ++i) + for (size_t j = 0; j < n; ++j) + ASSERT_EQ(ker(i, j), C(i, j)) + << "element mismatch at (" << i << ", " << j << "), " + << "axpy() size = " << m << "x" << n; + } + } + } } From eb05f23db7aa98a0f16d20a87453978fcd77a7ca Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Tue, 1 Oct 2024 14:56:59 +0200 Subject: [PATCH 3/4] gemm() tests passing on ARM --- include/blast/math/algorithm/Gemm.hpp | 2 +- include/blast/math/reference/Gemm.hpp | 2 +- test/blast/math/dense/GemmTest.cpp | 49 +++++++++++++++------------ 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/include/blast/math/algorithm/Gemm.hpp b/include/blast/math/algorithm/Gemm.hpp index 77bbe58..108e327 100644 --- a/include/blast/math/algorithm/Gemm.hpp +++ b/include/blast/math/algorithm/Gemm.hpp @@ -93,7 +93,7 @@ namespace blast if (rows(D) != M || columns(D) != N) BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - gemm(M, N, K, alpha, ptr(*A), ptr(*B), beta, ptr(*C), ptr(*D)); + gemm(M, N, K, alpha, ptr(A), ptr(B), beta, ptr(C), ptr(D)); } diff --git a/include/blast/math/reference/Gemm.hpp b/include/blast/math/reference/Gemm.hpp index 671656a..329469c 100644 --- a/include/blast/math/reference/Gemm.hpp +++ b/include/blast/math/reference/Gemm.hpp @@ -88,6 +88,6 @@ namespace blast :: reference if (rows(D) != M || columns(D) != N) BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - gemm(M, N, K, alpha, ptr(A), ptr(B), beta, ptr(C), ptr(D)); + reference::gemm(M, N, K, alpha, ptr(A), ptr(B), beta, ptr(C), ptr(D)); } } diff --git a/test/blast/math/dense/GemmTest.cpp b/test/blast/math/dense/GemmTest.cpp index d0310eb..36ebc17 100644 --- a/test/blast/math/dense/GemmTest.cpp +++ b/test/blast/math/dense/GemmTest.cpp @@ -5,11 +5,12 @@ #define BLAST_USER_ASSERTION 1 #include - -#include #include +#include +#include +#include -#include +#include namespace blast :: testing @@ -29,22 +30,25 @@ namespace blast :: testing for (size_t n = 1; n <= 20; n += 1) for (size_t k = 1; k <= 20; ++k) { - // Init Blaze matrices + // Init matrices // - blaze::DynamicMatrix A(m, k), C(m, n), D(m, n); - blaze::DynamicMatrix B(k, n); + DynamicMatrix A(m, k), C(m, n), D(m, n); + DynamicMatrix B(k, n); randomize(A); randomize(B); randomize(C); Real alpha {}, beta {}; - blaze::randomize(alpha); - blaze::randomize(beta); + randomize(alpha); + randomize(beta); - /// Do gemm + // Do gemm gemm(alpha, A, B, beta, C, D); - BLAST_ASSERT_APPROX_EQ(D, evaluate(beta * C + alpha * A * B), 1e-10, 1e-10) + DynamicMatrix D_ref(m, n); + reference::gemm(alpha, A, B, beta, C, D_ref); + + BLAST_ASSERT_APPROX_EQ(D, D_ref, 1e-10, 1e-10) << "gemm error at size m,n,k=" << m << "," << n << "," << k; } } @@ -54,31 +58,34 @@ namespace blast :: testing void testUnalignedImpl() { size_t constexpr S_MAX = 20; - blaze::DynamicMatrix AA(S_MAX, S_MAX), CC(S_MAX, S_MAX), DD(S_MAX, S_MAX); - blaze::DynamicMatrix BB(S_MAX, S_MAX); + DynamicMatrix AA(S_MAX, S_MAX), CC(S_MAX, S_MAX), DD(S_MAX, S_MAX); + DynamicMatrix BB(S_MAX, S_MAX); for (size_t m = 1; m <= S_MAX; m += 1) for (size_t n = 1; n <= S_MAX; n += 1) for (size_t k = 1; k <= S_MAX; ++k) { - // Init Blaze matrices + // Init matrices // - auto A = submatrix(AA, rows(AA) - m, columns(AA) - k, m, k); - auto C = submatrix(CC, rows(CC) - m, columns(CC) - n, m, n); - auto D = submatrix(DD, rows(DD) - m, columns(DD) - n, m, n); - auto B = submatrix(BB, rows(BB) - k, columns(BB) - n, k, n); + auto A = submatrix(AA, rows(AA) - m, columns(AA) - k, m, k); + auto C = submatrix(CC, rows(CC) - m, columns(CC) - n, m, n); + auto D = submatrix(DD, rows(DD) - m, columns(DD) - n, m, n); + auto B = submatrix(BB, rows(BB) - k, columns(BB) - n, k, n); randomize(A); randomize(B); randomize(C); Real alpha {}, beta {}; - blaze::randomize(alpha); - blaze::randomize(beta); + randomize(alpha); + randomize(beta); - /// Do gemm + // Do gemm gemm(alpha, A, B, beta, C, D); - BLAST_ASSERT_APPROX_EQ(D, evaluate(beta * C + alpha * A * B), 1e-10, 1e-10) + DynamicMatrix D_ref(m, n); + reference::gemm(alpha, A, B, beta, C, D_ref); + + BLAST_ASSERT_APPROX_EQ(D, D_ref, 1e-10, 1e-10) << "gemm error at size m,n,k=" << m << "," << n << "," << k; } } From 0822b47a3893f400f5b602edb1620c32a1cfddea Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Wed, 2 Oct 2024 09:21:16 +0200 Subject: [PATCH 4/4] Removed superfluous comments --- test/blast/math/dense/GemmTest.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/blast/math/dense/GemmTest.cpp b/test/blast/math/dense/GemmTest.cpp index 36ebc17..926fe28 100644 --- a/test/blast/math/dense/GemmTest.cpp +++ b/test/blast/math/dense/GemmTest.cpp @@ -30,8 +30,6 @@ namespace blast :: testing for (size_t n = 1; n <= 20; n += 1) for (size_t k = 1; k <= 20; ++k) { - // Init matrices - // DynamicMatrix A(m, k), C(m, n), D(m, n); DynamicMatrix B(k, n); randomize(A); @@ -65,8 +63,6 @@ namespace blast :: testing for (size_t n = 1; n <= S_MAX; n += 1) for (size_t k = 1; k <= S_MAX; ++k) { - // Init matrices - // auto A = submatrix(AA, rows(AA) - m, columns(AA) - k, m, k); auto C = submatrix(CC, rows(CC) - m, columns(CC) - n, m, n); auto D = submatrix(DD, rows(DD) - m, columns(DD) - n, m, n);