Skip to content

Commit

Permalink
WIP: get rid of Blaze in RegisterMatrixTest
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Katliar committed Aug 8, 2024
1 parent 9ce75c6 commit b5d18dd
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 58 deletions.
176 changes: 176 additions & 0 deletions include/blast/math/dense/StaticMatrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (c) 2019-2020 Mikhail Katliar All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#pragma once

#include <blast/math/Forward.hpp>
#include <blast/math/StorageOrder.hpp>
#include <blast/math/Simd.hpp>
#include <blast/math/TypeTraits.hpp>
#include <blast/system/CacheLine.hpp>
#include <blast/util/NextMultiple.hpp>
#include <blast/util/Types.hpp>

#include <initializer_list>


namespace blast
{
/// @brief Matrix with statically defined size.
///
/// @tparam T element type of the matrix
/// @tparam M number of rows
/// @tparam N number of columns
/// @tparam SO storage order
template <typename T, size_t M, size_t N, bool SO>
class StaticMatrix
{
public:
using ElementType = T;
static bool constexpr storageOrder = SO;


StaticMatrix() noexcept
{
// Initialize padding elements to 0 to prevent denorms in calculations.
// Denorms can significantly impair performance, see https://github.com/giaf/blasfeo/issues/103
std::fill_n(v_, capacity_, T {});
}


StaticMatrix(T const& v) noexcept
{
std::fill_n(v_, capacity_, v);
}


constexpr StaticMatrix(std::initializer_list<std::initializer_list<T>> list)
{
std::fill_n(v_, capacity_, T {});

if (list.size() != M || determineColumns(list) > N)
throw std::invalid_argument {"Invalid setup of static matrix"};

size_t i = 0;

for (auto const& row : list)
{
size_t j = 0;

for (const auto& element : row)
{
v_[elementIndex(i, j)] = element;
++j;
}

++i;
}
}


StaticMatrix& operator=(T val) noexcept
{
for (size_t i = 0; i < M; ++i)
for (size_t j = 0; j < N; ++j)
(*this)(i, j) = val;

return *this;
}


constexpr T const& operator()(size_t i, size_t j) const noexcept
{
return v_[elementIndex(i, j)];
}


constexpr T& operator()(size_t i, size_t j)
{
return v_[elementIndex(i, j)];
}


static size_t constexpr rows() noexcept
{
return M;
}


static size_t constexpr columns() noexcept
{
return N;
}


static size_t constexpr spacing() noexcept
{
return spacing_;
}


T * data() noexcept
{
return v_;
}


T const * data() const noexcept
{
return v_;
}


private:
static size_t constexpr spacing_ = nextMultiple(SO == columnMajor ? M : N, SimdSize_v<T>);
static size_t constexpr capacity_ = spacing_ * (SO == columnMajor ? N : M);

// Alignment of the data elements.
static size_t constexpr alignment_ = CACHE_LINE_SIZE;

// Aligned element storage.
alignas(alignment_) T v_[capacity_];


size_t elementIndex(size_t i, size_t j) const
{
return SO == columnMajor ? i + spacing_ * j : spacing_ * i + j;
}
};


template <typename T, size_t M, size_t N, bool SO>
inline size_t constexpr rows(StaticMatrix<T, M, N, SO> const& m) noexcept
{
return m.rows();
}


template <typename T, size_t M, size_t N, bool SO>
inline size_t constexpr columns(StaticMatrix<T, M, N, SO> const& m) noexcept
{
return m.columns();
}


template <typename T, size_t M, size_t N, bool SO>
inline constexpr T * data(StaticMatrix<T, M, N, SO>& m) noexcept
{
return m.data();
}


template <typename T, size_t M, size_t N, bool SO>
inline constexpr T const * data(StaticMatrix<T, M, N, SO> const& m) noexcept
{
return m.data();
}


template <typename T, size_t M, size_t N, bool SO>
struct IsDenseMatrix<StaticMatrix<T, M, N, SO>> : std::true_type {};


template <typename T, size_t M, size_t N, bool SO>
struct IsStatic<StaticMatrix<T, M, N, SO>> : std::true_type {};
}
12 changes: 5 additions & 7 deletions include/blast/math/expressions/PanelMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <blast/math/typetraits/IsPanelMatrix.hpp>
#include <blast/math/typetraits/ElementType.hpp>
#include <blast/math/simd/SimdSize.hpp>
#include <blast/math/simd/SimdMask.hpp>
#include <blast/math/simd/SimdIndex.hpp>
Expand All @@ -21,12 +22,9 @@

namespace blast
{
using namespace blaze;


template <typename Derived, bool SO>
struct PanelMatrix
: public Matrix<Derived, SO>
: public blaze::Matrix<Derived, SO>
{
public:
using TagType = Group0;
Expand Down Expand Up @@ -80,14 +78,14 @@ namespace blast
template <typename MT1, typename MT2, typename MT3, bool SO>
inline auto assign(PanelMatrix<MT1, SO>& lhs,
blaze::DMatTDMatMultExpr<MT2, MT3, false, false, false, false> const& rhs)
-> blaze::EnableIf_t<IsPanelMatrix_v<MT2> && IsRowMajorMatrix_v<MT2> && IsPanelMatrix_v<MT3> && IsRowMajorMatrix_v<MT3>>
-> blaze::EnableIf_t<IsPanelMatrix_v<MT2> && blaze::IsRowMajorMatrix_v<MT2> && IsPanelMatrix_v<MT3> && blaze::IsRowMajorMatrix_v<MT3>>
{
BLAZE_THROW_LOGIC_ERROR("Not implemented 2");
}


template <typename MT1, bool SO1, typename MT2, bool SO2>
inline void assign(DenseMatrix<MT1, SO1>& lhs, PanelMatrix<MT2, SO2> const& rhs)
inline void assign(blaze::DenseMatrix<MT1, SO1>& lhs, PanelMatrix<MT2, SO2> const& rhs)
{
BLAZE_INTERNAL_ASSERT( (*lhs).rows() == (*rhs).rows() , "Invalid number of rows" );
BLAZE_INTERNAL_ASSERT( (*lhs).columns() == (*rhs).columns(), "Invalid number of columns" );
Expand Down Expand Up @@ -176,7 +174,7 @@ namespace blast


template <typename MT1, bool SO1, typename MT2, bool SO2>
inline void assign(PanelMatrix<MT1, SO1>& lhs, DenseMatrix<MT2, SO2> const& rhs)
inline void assign(PanelMatrix<MT1, SO1>& lhs, blaze::DenseMatrix<MT2, SO2> const& rhs)
{
size_t const m = (*rhs).rows();
size_t const n = (*rhs).columns();
Expand Down
21 changes: 21 additions & 0 deletions include/blast/math/reference/Ger.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2019-2024 Mikhail Katliar All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#pragma once

#include <blast/math/typetraits/MatrixPointer.hpp>
#include <blast/math/typetraits/VectorPointer.hpp>
#include <blast/util/Types.hpp>


namespace blast :: reference
{
template <typename Real, typename VPX, typename VPY, typename MPA>
requires VectorPointer<VPX, Real> && VectorPointer<VPY, Real> && MatrixPointer<MPA, Real>
inline void ger(size_t m, size_t n, Real alpha, VPX x, VPY y, MPA a)
{
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j)
*a(i, j) += alpha * *x(i) * *y(j);
}
}
Loading

0 comments on commit b5d18dd

Please sign in to comment.