Skip to content

Commit

Permalink
Tpetra: enable CrsMatrix apply TPL for float/double & int
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed May 20, 2024
1 parent 0a0903b commit 2b4141b
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 23 deletions.
14 changes: 14 additions & 0 deletions packages/tpetra/core/src/Tpetra_Details_KokkosCounter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ KokkosRegionCounter::get_count_region_contains(const std::string &needle) {
}
return count;
}

void KokkosRegionCounter::dump_regions(Teuchos::FancyOStream &os) {
for (const auto &region : KokkosRegionCounterDetails::regions) {
os << region << "\n";
}
}

void KokkosRegionCounter::dump_regions(std::ostream &os) {
for (const auto &region : KokkosRegionCounterDetails::regions) {
os << region << "\n";
}
}


// clang-format off


Expand Down
8 changes: 7 additions & 1 deletion packages/tpetra/core/src/Tpetra_Details_KokkosCounter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
/// types using the Kokkos Profiling Library

#include <string>
#include <Teuchos_FancyOStream.hpp>

namespace Tpetra {
namespace Details {
Expand Down Expand Up @@ -89,6 +90,7 @@ namespace FenceCounter {
}

// clang-format on

/// \brief Counter for Kokkos regions representing third-party library usage
namespace KokkosRegionCounter {
/// \brief Start the counter
Expand All @@ -102,9 +104,13 @@ void stop();

/// \brief How many regions containing `substr` have been seen
size_t get_count_region_contains(const std::string &substr);

/// \brief Print all observed region labels, separated by newline
void dump_regions(std::ostream &os);
void dump_regions(Teuchos::FancyOStream &os);
} // namespace KokkosRegionCounter
// clang-format off

// clang-format off



Expand Down
30 changes: 21 additions & 9 deletions packages/tpetra/core/test/CrsMatrix/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,28 @@ TRIBITS_ADD_EXECUTABLE_AND_TEST(
STANDARD_PASS_OUTPUT
)

if ((Tpetra_ENABLE_CUDA AND TPL_ENABLE_CUSPARSE ) OR
(Tpetra_ENABLE_HIP AND TPL_ENABLE_ROCSPARSE))
TRIBITS_ADD_EXECUTABLE_AND_TEST(
CrsMatrix_ApplyUsesTPLs
SOURCES
CrsMatrix_ApplyUsesTPLs.cpp
${TEUCHOS_STD_UNIT_TEST_MAIN}
COMM serial mpi
STANDARD_PASS_OUTPUT
if (
# supported TPLs
(
(Tpetra_ENABLE_CUDA AND TPL_ENABLE_CUSPARSE ) OR
(Tpetra_ENABLE_HIP AND TPL_ENABLE_ROCSPARSE)
)

AND

# supported type combos
(
(Tpetra_INST_DOUBLE OR Tpetra_INST_FLOAT)
)
)
TRIBITS_ADD_EXECUTABLE_AND_TEST(
CrsMatrix_ApplyUsesTPLs
SOURCES
CrsMatrix_ApplyUsesTPLs.cpp
${TEUCHOS_STD_UNIT_TEST_MAIN}
COMM serial mpi
STANDARD_PASS_OUTPUT
)
endif()

SET(TIMING_INSTALLS "")
Expand Down
29 changes: 16 additions & 13 deletions packages/tpetra/core/test/CrsMatrix/CrsMatrix_ApplyUsesTPLs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ namespace {
////
TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL( CrsMatrix, NonSquare, LO, GO, Scalar, Node )
{

// skip test if Scalar is not (float or double)
if constexpr (!(std::is_same_v<Scalar, float> || std::is_same_v<Scalar, double>)) {
TEST_EQUALITY_CONST(1,1); // SKIP
return;
}
// skip test if LO != int
if constexpr (!std::is_same_v<LO, int>) {
TEST_EQUALITY_CONST(1,1); // SKIP
return;
}

typedef CrsMatrix<Scalar,LO,GO,Node> MAT;
typedef MultiVector<Scalar,LO,GO,Node> MV;
typedef Map<LO,GO,Node> map_type;
Expand Down Expand Up @@ -221,21 +233,16 @@ namespace {
X.replaceGlobalValue(i,j,static_cast<Scalar>(i+j*P));
}
}
// build the expected output multivector B
MV Bexp(rowmap,numVecs), Bout(rowmap,numVecs);
for (GO i=static_cast<GO>(myImageID*M); i<static_cast<GO>(myImageID*M+M); ++i) {
for (GO j=0; j<static_cast<GO>(numVecs); ++j) {
Bexp.replaceGlobalValue(i,j,static_cast<Scalar>(j*i*P*P + (i+j*M*N*P)*(P*P-P)/2 + M*N*P*(P-1)*(2*P-1)/6));
}
}
// allocate output multivec
MV Bout(rowmap,numVecs);
// test the action
Bout.randomize();
Tpetra::Details::KokkosRegionCounter::reset();
Tpetra::Details::KokkosRegionCounter::start();
A.apply(X,Bout);
Tpetra::Details::KokkosRegionCounter::stop();

TEST_COMPARE(Tpetra::Details::KokkosRegionCounter::get_count_region_contains("TPL_"), ==, 1);
TEST_COMPARE(Tpetra::Details::KokkosRegionCounter::get_count_region_contains("spmv[TPL_"), ==, 1);

using Teuchos::outArg;
using Teuchos::REDUCE_MIN;
Expand All @@ -249,17 +256,13 @@ namespace {
}
}



//
// INSTANTIATIONS
//

#define UNIT_TEST_GROUP( SCALAR, LO, GO, NODE ) \
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( CrsMatrix, NonSquare, LO, GO, SCALAR, NODE )
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( CrsMatrix, NonSquare, LO, GO, SCALAR, NODE )

TPETRA_ETI_MANGLING_TYPEDEFS()

TPETRA_INSTANTIATE_SLGN( UNIT_TEST_GROUP )

}

0 comments on commit 2b4141b

Please sign in to comment.