Skip to content

Commit

Permalink
Fix prediction error. (#11167)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 14, 2025
1 parent 461d27c commit 191b219
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
4 changes: 2 additions & 2 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {

public:
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, DeviceOrd device) {
std::shared_ptr<PredictionCacheEntry> Cache(std::shared_ptr<DMatrix> m, DeviceOrd device) {
auto p_cache = this->CacheItem(m);
if (!device.IsCPU()) {
p_cache->predictions.SetDevice(device);
}
return *p_cache;
return p_cache;
}
};

Expand Down
34 changes: 16 additions & 18 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
#include <limits> // for numeric_limits
#include <memory> // for allocator, unique_ptr, shared_ptr, operator==
#include <mutex> // for mutex, lock_guard
#include <set> // for set
#include <sstream> // for operator<<, basic_ostream, basic_ostream::opera...
#include <stack> // for stack
#include <string> // for basic_string, char_traits, operator<, string
#include <system_error> // for errc
#include <tuple> // for get
#include <unordered_map> // for operator!=, unordered_map
#include <utility> // for pair, as_const, move, swap
#include <vector> // for vector
Expand Down Expand Up @@ -1299,19 +1297,19 @@ class LearnerImpl : public LearnerIO {

this->ValidateDMatrix(train.get(), true);

auto& predt = prediction_container_.Cache(train, ctx_.Device());
auto predt = prediction_container_.Cache(train, ctx_.Device());

monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true, 0, 0);
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
this->PredictRaw(train.get(), predt.get(), true, 0, 0);
TrainingObserver::Instance().Observe(predt->predictions, "Predictions");
monitor_.Stop("PredictRaw");

monitor_.Start("GetGradient");
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
GetGradient(predt->predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");

gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
gbm_->DoBoost(train.get(), &gpair_, predt.get(), obj_.get());
monitor_.Stop("UpdateOneIter");
}

Expand All @@ -1329,8 +1327,8 @@ class LearnerImpl : public LearnerIO {
CHECK_EQ(this->learner_model_param_.OutputLength(), in_gpair->Shape(1))
<< "The number of columns in gradient should be equal to the number of targets/classes in "
"the model.";
auto& predt = prediction_container_.Cache(train, ctx_.Device());
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
auto predt = prediction_container_.Cache(train, ctx_.Device());
gbm_->DoBoost(train.get(), in_gpair, predt.get(), obj_.get());
monitor_.Stop("BoostOneIter");
}

Expand All @@ -1355,13 +1353,13 @@ class LearnerImpl : public LearnerIO {

for (size_t i = 0; i < data_sets.size(); ++i) {
std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = prediction_container_.Cache(m, ctx_.Device());
auto predt = prediction_container_.Cache(m, ctx_.Device());
this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false, 0, 0);
this->PredictRaw(m.get(), predt.get(), false, 0, 0);

auto &out = output_predictions_.Cache(m, ctx_.Device()).predictions;
out.Resize(predt.predictions.Size());
out.Copy(predt.predictions);
auto &out = output_predictions_.Cache(m, ctx_.Device())->predictions;
out.Resize(predt->predictions.Size());
out.Copy(predt->predictions);

obj_->EvalTransform(&out);
for (auto& ev : metrics_) {
Expand Down Expand Up @@ -1395,12 +1393,12 @@ class LearnerImpl : public LearnerIO {
} else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
} else {
auto& prediction = prediction_container_.Cache(data, ctx_.Device());
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
auto predt = prediction_container_.Cache(data, ctx_.Device());
this->PredictRaw(data.get(), predt.get(), training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(ctx_.Device());
out_preds->Resize(prediction.predictions.Size());
out_preds->Copy(prediction.predictions);
out_preds->Resize(predt->predictions.Size());
out_preds->Copy(predt->predictions);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
Expand Down

0 comments on commit 191b219

Please sign in to comment.