Skip to content

Commit

Permalink
MueLu CoalesceDropFactory_kokkos: Correctly handle "filtered matrix: …
Browse files Browse the repository at this point in the history
…Dirichlet threshold"

Signed-off-by: Christian Glusa <[email protected]>
  • Loading branch information
cgcgcg committed Jan 14, 2025
1 parent 114d081 commit 6247bb7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

const bool useRootStencil = pL.get<bool>("filtered matrix: use root stencil");
const bool useSpreadLumping = pL.get<bool>("filtered matrix: use spread lumping");

const MT filteringDirichletThreshold = as<MT>(pL.get<double>("filtered matrix: Dirichlet threshold"));
TEUCHOS_ASSERT(!useRootStencil);
TEUCHOS_ASSERT(!useSpreadLumping);

Expand Down Expand Up @@ -692,18 +694,18 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

if (lumping) {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, true>(lclA, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, true>(lclA, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, true>(lclA, results, lclFilteredA);
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, true>(lclA, results, lclFilteredA, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_noreuse", range, fillFunctor);
}
} else {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, false>(lclA, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, false>(lclA, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, false>(lclA, results, lclFilteredA);
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, false>(lclA, results, lclFilteredA, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_noreuse", range, fillFunctor);
}
}
Expand Down Expand Up @@ -854,6 +856,9 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

const bool useRootStencil = pL.get<bool>("filtered matrix: use root stencil");
const bool useSpreadLumping = pL.get<bool>("filtered matrix: use spread lumping");

const MT filteringDirichletThreshold = as<MT>(pL.get<double>("filtered matrix: Dirichlet threshold"));

TEUCHOS_ASSERT(!useRootStencil);
TEUCHOS_ASSERT(!useSpreadLumping);

Expand Down Expand Up @@ -1095,18 +1100,18 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

if (lumping) {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, true>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, true>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, false>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, false>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_noreuse", range, fillFunctor);
}
} else {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, true>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, true>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, false>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, false>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_noreuse", range, fillFunctor);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,23 @@ class PointwiseFillReuseFunctor {
using memory_space = typename local_matrix_type::memory_space;
using results_view = Kokkos::View<DecisionType*, memory_space>;
using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;

local_matrix_type A;
results_view results;
local_matrix_type filteredA;
local_graph_type graph;
magnitudeType dirichletThreshold;
const scalar_type zero = ATS::zero();
const scalar_type one = ATS::one();

public:
PointwiseFillReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_)
PointwiseFillReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_, magnitudeType dirichletThreshold_)
: A(A_)
, results(results_)
, filteredA(filteredA_)
, graph(graph_) {}
, graph(graph_)
, dirichletThreshold(dirichletThreshold_) {}

KOKKOS_INLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
Expand Down Expand Up @@ -300,6 +304,8 @@ class PointwiseFillReuseFunctor {
}
if constexpr (lumping) {
rowFilteredA.value(diagOffset) += diagCorrection;
if ((dirichletThreshold >= 0.0) && (ATS::real(rowFilteredA.value(diagOffset)) <= dirichletThreshold))
rowFilteredA.value(diagOffset) = one;
}
}
};
Expand All @@ -319,17 +325,21 @@ class PointwiseFillNoReuseFunctor {
using memory_space = typename local_matrix_type::memory_space;
using results_view = Kokkos::View<DecisionType*, memory_space>;
using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;

local_matrix_type A;
results_view results;
local_matrix_type filteredA;
magnitudeType dirichletThreshold;
const scalar_type zero = ATS::zero();
const scalar_type one = ATS::one();

public:
PointwiseFillNoReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_)
PointwiseFillNoReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_, magnitudeType dirichletThreshold_)
: A(A_)
, results(results_)
, filteredA(filteredA_) {}
, filteredA(filteredA_)
, dirichletThreshold(dirichletThreshold_) {}

KOKKOS_INLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
Expand All @@ -356,6 +366,8 @@ class PointwiseFillNoReuseFunctor {
}
if constexpr (lumping) {
rowFilteredA.value(diagOffset) += diagCorrection;
if ((dirichletThreshold >= 0.0) && (ATS::real(rowFilteredA.value(diagOffset)) <= dirichletThreshold))
rowFilteredA.value(diagOffset) = one;
}
}
};
Expand Down Expand Up @@ -785,26 +797,30 @@ class VectorFillFunctor {
using OTS = Kokkos::ArithTraits<local_ordinal_type>;
using block_indices_view_type = Kokkos::View<local_ordinal_type*, memory_space>;
using permutation_type = Kokkos::View<local_ordinal_type*, memory_space>;
using magnitudeType = typename ATS::magnitudeType;

local_matrix_type A;
local_ordinal_type blockSize;
block_indices_view_type ghosted_point_to_block;
results_view results;
local_matrix_type filteredA;
local_graph_type graph;
magnitudeType dirichletThreshold;
const scalar_type zero = ATS::zero();
const scalar_type one = ATS::one();

BlockRowComparison<local_matrix_type> comparison;
permutation_type permutation;

public:
VectorFillFunctor(local_matrix_type& A_, local_ordinal_type blockSize_, block_indices_view_type ghosted_point_to_block_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_)
VectorFillFunctor(local_matrix_type& A_, local_ordinal_type blockSize_, block_indices_view_type ghosted_point_to_block_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_, magnitudeType dirichletThreshold_)
: A(A_)
, blockSize(blockSize_)
, ghosted_point_to_block(ghosted_point_to_block_)
, results(results_)
, filteredA(filteredA_)
, graph(graph_)
, dirichletThreshold(dirichletThreshold_)
, comparison(BlockRowComparison(A, blockSize_, ghosted_point_to_block)) {
permutation = permutation_type("permutation", A.nnz());
}
Expand Down Expand Up @@ -844,6 +860,8 @@ class VectorFillFunctor {
}
if constexpr (lumping) {
rowFilteredA.value(diagOffset) += diagCorrection;
if ((dirichletThreshold >= 0.0) && (ATS::real(rowFilteredA.value(diagOffset)) <= dirichletThreshold))
rowFilteredA.value(diagOffset) = one;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,11 @@ void ParameterListInterpreter<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: use lumping", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: reuse graph", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: reuse eigenvalue", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: use root stencil", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: Dirichlet threshold", double, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: use spread lumping", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: spread lumping diag dom growth factor", double, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: spread lumping diag dom cap", double, dropParams);
}

if (!amalgFact.is_null())
Expand Down

0 comments on commit 6247bb7

Please sign in to comment.