diff --git a/include/xgboost/global_config.h b/include/xgboost/global_config.h index bb6c211f1089..1a3d1b711d78 100644 --- a/include/xgboost/global_config.h +++ b/include/xgboost/global_config.h @@ -16,6 +16,7 @@ namespace xgboost { struct GlobalConfiguration : public XGBoostParameter { std::int32_t verbosity{1}; bool use_rmm{false}; + // This is not a dmlc parameter to avoid conflict with the context class. std::int32_t nthread{0}; DMLC_DECLARE_PARAMETER(GlobalConfiguration) { DMLC_DECLARE_FIELD(verbosity) @@ -24,7 +25,6 @@ struct GlobalConfiguration : public XGBoostParameter { .describe("Flag to print out detailed breakdown of runtime."); DMLC_DECLARE_FIELD(use_rmm).set_default(false).describe( "Whether to use RAPIDS Memory Manager to allocate GPU memory in XGBoost"); - DMLC_DECLARE_FIELD(nthread).set_lower_bound(-1).set_default(0); } }; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index bb2324fec91b..008c2e4aff67 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -143,7 +143,19 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) { xgboost_CHECK_C_ARG_PTR(json_str); Json config{Json::Load(StringView{json_str})}; - for (auto& items : get(config)) { + // handle nthread, it's not a dmlc parameter. + auto& obj = get(config); + auto it = obj.find("nthread"); + if (it != obj.cend()) { + auto nthread = OptionalArg(config, "nthread", Integer::Int{0}); + if (nthread > 0) { + omp_set_num_threads(nthread); + GlobalConfigThreadLocalStore::Get()->nthread = nthread; + } + get(config).erase("nthread"); + } + + for (auto &items : obj) { switch (items.second.GetValue().Type()) { case xgboost::Value::ValueKind::kInteger: { items.second = String{std::to_string(get(items.second))}; @@ -183,10 +195,7 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) { } LOG(FATAL) << ss.str() << " }"; } - // The default is 0, we call omp rt only if the default is modified. - if (GlobalConfigThreadLocalStore::Get()->nthread > 0) { - omp_set_num_threads(GlobalConfigThreadLocalStore::Get()->nthread); - } + API_END(); } @@ -220,6 +229,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) { } } + config["nthread"] = GlobalConfigThreadLocalStore::Get()->nthread; auto& local = *GlobalConfigAPIThreadLocalStore::Get(); Json::Dump(config, &local.ret_str);