Skip to content

Commit

Permalink
Merge pull request #12792 from hkthorn/develop
Browse files Browse the repository at this point in the history
Belos: Fix Belos solver behavior to handle NaNs
  • Loading branch information
hkthorn authored Mar 2, 2024
2 parents a469209 + 9977ba0 commit 2c85546
Show file tree
Hide file tree
Showing 21 changed files with 527 additions and 16 deletions.
9 changes: 9 additions & 0 deletions packages/belos/src/BelosBiCGStabSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,15 @@ ReturnType BiCGStabSolMgr<ScalarType,MV,OP>::solve ()
"Belos::BiCGStabSolMgr::solve(): Invalid return from BiCGStabIter::iterate().");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::BiCGStabSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in BiCGStabIter::iterate() at iteration "
<< bicgstab_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosBlockCGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,15 @@ ReturnType BlockCGSolMgr<ScalarType,MV,OP,true>::solve() {
"to the Belos developers.");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::BlockCGSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
std::ostream& err = printer_->stream (Errors);
err << "Error! Caught std::exception in CGIteration::iterate() at "
Expand Down
13 changes: 11 additions & 2 deletions packages/belos/src/BelosBlockGCRODRSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1964,7 +1964,7 @@ ReturnType BlockGCRODRSolMgr<ScalarType,MV,OP>::solve() {
tmpNumBlocks = dim / blockSize_; // Allow for a good breakdown.
else{
tmpNumBlocks = ( dim - blockSize_) / blockSize_; // Allow for restarting.
printer_->stream(Warnings) << "Belos::BlockGmresSolMgr::solve(): Warning! Requested Krylov subspace dimension is larger than operator dimension!"
printer_->stream(Warnings) << "Belos::BlockGCRODRSolMgr::solve(): Warning! Requested Krylov subspace dimension is larger than operator dimension!"
<< std::endl << "The maximum number of blocks allowed for the Krylov subspace will be adjusted to " << tmpNumBlocks << std::endl;
primeList.set("Num Blocks",Teuchos::as<int>(tmpNumBlocks));
}
Expand Down Expand Up @@ -2019,7 +2019,7 @@ ReturnType BlockGCRODRSolMgr<ScalarType,MV,OP>::solve() {
if ( expConvTest_->getLOADetected() ) {
// we don't have convergence
loaDetected_ = true;
printer_->stream(Warnings) << "Belos::BlockGmresSolMgr::solve(): Warning! Solver has experienced a loss of accuracy!" << std::endl;
printer_->stream(Warnings) << "Belos::BlockGCRODRSolMgr::solve(): Warning! Solver has experienced a loss of accuracy!" << std::endl;
}
}
// *******************************************
Expand Down Expand Up @@ -2054,6 +2054,15 @@ ReturnType BlockGCRODRSolMgr<ScalarType,MV,OP>::solve() {
isConverged = false;
}
} // end catch (const GmresIterationOrthoFailure &e)
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::BlockGCRODRSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in BlockGmresIter::iterate() at iteration "
<< block_gmres_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosBlockGmresSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,15 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
break;
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::BlockGmresSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in BlockGmresIter::iterate() at iteration "
<< block_gmres_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosFixedPointSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,15 @@ ReturnType FixedPointSolMgr<ScalarType,MV,OP>::solve() {
"to the Belos developers.");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::FixedPointSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
std::ostream& err = printer_->stream (Errors);
err << "Error! Caught std::exception in FixedPointIteration::iterate() at "
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosGCRODRSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,15 @@ ReturnType GCRODRSolMgr<ScalarType,MV,OP,true>::solve() {
if (convTest_->getStatus() == Passed)
primeConverged = true;
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::GCRODRSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught exception in GCRODRIter::iterate() at iteration "
<< gcrodr_prime_iter->getNumIters() << std::endl
Expand Down
12 changes: 11 additions & 1 deletion packages/belos/src/BelosMinresSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,17 @@ namespace Belos {
"nor reached the maximum number of iterations " << maxIters_
<< ". That means something went wrong.");
}
} catch (const std::exception &e) {
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MST::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::MinresSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream (Errors)
<< "Error! Caught std::exception in MinresIter::iterate() at "
<< "iteration " << minres_iter->getNumIters() << endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosPCPGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,15 @@ ReturnType PCPGSolMgr<ScalarType,MV,OP,true>::solve() {
"Belos::PCPGSolMgr::solve(): Invalid return from PCPGIter::iterate().");
} // end if
} // end try
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::PCPG::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught exception in PCPGIter::iterate() at iteration "
<< pcpg_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosPseudoBlockCGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,15 @@ ReturnType PseudoBlockCGSolMgr<ScalarType,MV,OP,true>::solve ()
"Belos::PseudoBlockCGSolMgr::solve(): Invalid return from PseudoBlockCGIter::iterate().");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::PseudoBlockCGSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in PseudoBlockCGIter::iterate() at iteration "
<< block_cg_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosPseudoBlockGmresSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,15 @@ ReturnType PseudoBlockGmresSolMgr<ScalarType,MV,OP>::solve() {
isConverged = false;
break;
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::PseudoBlockGmresSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in PseudoBlockGmresIter::iterate() at iteration "
<< block_gmres_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosPseudoBlockTFQMRSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,15 @@ ReturnType PseudoBlockTFQMRSolMgr<ScalarType,MV,OP>::solve() {
"Belos::PseudoBlockTFQMRSolMgr::solve(): Invalid return from PseudoBlockTFQMRIter::iterate().");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::PseudoBlockTFQMRSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in PseudoBlockTFQMRIter::iterate() at iteration "
<< block_tfqmr_iter->getNumIters() << std::endl
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosRCGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,15 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,true>::solve() {
"Belos::RCGSolMgr::solve(): Invalid return from RCGIter::iterate().");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::RCGSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in RCGIter::iterate() at iteration "
<< rcg_iter->getNumIters() << std::endl
Expand Down
3 changes: 3 additions & 0 deletions packages/belos/src/BelosStatusTest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ namespace Belos {
class StatusTestError : public BelosError
{public: StatusTestError(const std::string& what_arg) : BelosError(what_arg) {}};

class StatusTestNaNError : public StatusTestError
{public: StatusTestNaNError(const std::string& what_arg) : StatusTestError(what_arg) {}};

//@}

template <class ScalarType, class MV, class OP>
Expand Down
2 changes: 1 addition & 1 deletion packages/belos/src/BelosStatusTestGenResNorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ StatusType StatusTestGenResNorm<ScalarType,MV,OP>::checkStatus( Iteration<Scalar
} else {
// Throw an std::exception if a NaN is found.
status_ = Failed;
TEUCHOS_TEST_FOR_EXCEPTION(true,StatusTestError,"StatusTestGenResNorm::checkStatus(): NaN has been detected.");
TEUCHOS_TEST_FOR_EXCEPTION(true,StatusTestNaNError,"StatusTestGenResNorm::checkStatus(): NaN has been detected.");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/belos/src/BelosStatusTestGenResSubNorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ class StatusTestGenResSubNorm<ScalarType,Thyra::MultiVectorBase<ScalarType>,Thyr
} else {
// Throw an std::exception if a NaN is found.
status_ = Failed;
TEUCHOS_TEST_FOR_EXCEPTION(true,StatusTestError,"StatusTestGenResSubNorm::checkStatus(): NaN has been detected.");
TEUCHOS_TEST_FOR_EXCEPTION(true,StatusTestNaNError,"StatusTestGenResSubNorm::checkStatus(): NaN has been detected.");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/belos/src/BelosStatusTestImpResNorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ checkStatus (Iteration<ScalarType,MV,OP>* iSolver)
// tolerance is NaN; we assume the former. We also mark the
// test as failed, in case you want to catch the exception.
status_ = Failed;
TEUCHOS_TEST_FOR_EXCEPTION(true, StatusTestError, "Belos::"
TEUCHOS_TEST_FOR_EXCEPTION(true, StatusTestNaNError, "Belos::"
"StatusTestImpResNorm::checkStatus(): One or more of the current "
"implicit residual norms is NaN.");
}
Expand Down
9 changes: 9 additions & 0 deletions packages/belos/src/BelosTFQMRSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,15 @@ ReturnType TFQMRSolMgr<ScalarType,MV,OP>::solve() {
"Belos::TFQMRSolMgr::solve(): Invalid return from TFQMRIter::iterate().");
}
}
catch (const StatusTestNaNError& e) {
// A NaN was detected in the solver. Set the solution to zero and return unconverged.
achievedTol_ = MT::one();
Teuchos::RCP<MV> X = problem_->getLHS();
MVT::MvInit( *X, SCT::zero() );
printer_->stream(Warnings) << "Belos::TFQMRSolMgr::solve(): Warning! NaN has been detected!"
<< std::endl;
return Unconverged;
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in TFQMRIter::iterate() at iteration "
<< tfqmr_iter->getNumIters() << std::endl
Expand Down
7 changes: 7 additions & 0 deletions packages/belos/tpetra/test/LinearSolverFactory/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ TRIBITS_ADD_EXECUTABLE_AND_TEST(
COMM serial mpi
)

TRIBITS_ADD_EXECUTABLE_AND_TEST(
SolverFactoryNaN
SOURCES SolverFactoryNaN.cpp ${TEUCHOS_STD_UNIT_TEST_MAIN}
ARGS
COMM serial mpi
)

TRIBITS_ADD_EXECUTABLE_AND_TEST(
CustomSolverFactory
SOURCES CustomSolverFactory.cpp ${TEUCHOS_STD_UNIT_TEST_MAIN}
Expand Down
Loading

0 comments on commit 2c85546

Please sign in to comment.