From 191b219997927681ad0d41d78e5597da58fcb5be Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 15 Jan 2025 01:26:48 +0800 Subject: [PATCH] Fix prediction error. (#11167) --- include/xgboost/predictor.h | 4 ++-- src/learner.cc | 34 ++++++++++++++++------------------ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index f40abdf4faa6..ad89e54891c6 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -51,12 +51,12 @@ class PredictionContainer : public DMatrixCache { public: PredictionContainer() : DMatrixCache{DefaultSize()} {} - PredictionCacheEntry& Cache(std::shared_ptr m, DeviceOrd device) { + std::shared_ptr Cache(std::shared_ptr m, DeviceOrd device) { auto p_cache = this->CacheItem(m); if (!device.IsCPU()) { p_cache->predictions.SetDevice(device); } - return *p_cache; + return p_cache; } }; diff --git a/src/learner.cc b/src/learner.cc index 1dcd0fcfc7eb..34f395beb34b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -23,12 +23,10 @@ #include // for numeric_limits #include // for allocator, unique_ptr, shared_ptr, operator== #include // for mutex, lock_guard -#include // for set #include // for operator<<, basic_ostream, basic_ostream::opera... #include // for stack #include // for basic_string, char_traits, operator<, string #include // for errc -#include // for get #include // for operator!=, unordered_map #include // for pair, as_const, move, swap #include // for vector @@ -1299,19 +1297,19 @@ class LearnerImpl : public LearnerIO { this->ValidateDMatrix(train.get(), true); - auto& predt = prediction_container_.Cache(train, ctx_.Device()); + auto predt = prediction_container_.Cache(train, ctx_.Device()); monitor_.Start("PredictRaw"); - this->PredictRaw(train.get(), &predt, true, 0, 0); - TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); + this->PredictRaw(train.get(), predt.get(), true, 0, 0); + TrainingObserver::Instance().Observe(predt->predictions, "Predictions"); monitor_.Stop("PredictRaw"); monitor_.Start("GetGradient"); - GetGradient(predt.predictions, train->Info(), iter, &gpair_); + GetGradient(predt->predictions, train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients"); - gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); + gbm_->DoBoost(train.get(), &gpair_, predt.get(), obj_.get()); monitor_.Stop("UpdateOneIter"); } @@ -1329,8 +1327,8 @@ class LearnerImpl : public LearnerIO { CHECK_EQ(this->learner_model_param_.OutputLength(), in_gpair->Shape(1)) << "The number of columns in gradient should be equal to the number of targets/classes in " "the model."; - auto& predt = prediction_container_.Cache(train, ctx_.Device()); - gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get()); + auto predt = prediction_container_.Cache(train, ctx_.Device()); + gbm_->DoBoost(train.get(), in_gpair, predt.get(), obj_.get()); monitor_.Stop("BoostOneIter"); } @@ -1355,13 +1353,13 @@ class LearnerImpl : public LearnerIO { for (size_t i = 0; i < data_sets.size(); ++i) { std::shared_ptr m = data_sets[i]; - auto &predt = prediction_container_.Cache(m, ctx_.Device()); + auto predt = prediction_container_.Cache(m, ctx_.Device()); this->ValidateDMatrix(m.get(), false); - this->PredictRaw(m.get(), &predt, false, 0, 0); + this->PredictRaw(m.get(), predt.get(), false, 0, 0); - auto &out = output_predictions_.Cache(m, ctx_.Device()).predictions; - out.Resize(predt.predictions.Size()); - out.Copy(predt.predictions); + auto &out = output_predictions_.Cache(m, ctx_.Device())->predictions; + out.Resize(predt->predictions.Size()); + out.Copy(predt->predictions); obj_->EvalTransform(&out); for (auto& ev : metrics_) { @@ -1395,12 +1393,12 @@ class LearnerImpl : public LearnerIO { } else if (pred_leaf) { gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end); } else { - auto& prediction = prediction_container_.Cache(data, ctx_.Device()); - this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end); + auto predt = prediction_container_.Cache(data, ctx_.Device()); + this->PredictRaw(data.get(), predt.get(), training, layer_begin, layer_end); // Copy the prediction cache to output prediction. out_preds comes from C API out_preds->SetDevice(ctx_.Device()); - out_preds->Resize(prediction.predictions.Size()); - out_preds->Copy(prediction.predictions); + out_preds->Resize(predt->predictions.Size()); + out_preds->Copy(predt->predictions); if (!output_margin) { obj_->PredTransform(out_preds); }