Skip to content

Commit

Permalink
don't use dmlc param
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 18, 2025
1 parent 91c663f commit cc9bd8a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/global_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace xgboost {
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
std::int32_t verbosity{1};
bool use_rmm{false};
// This is not a dmlc parameter to avoid conflict with the context class.
std::int32_t nthread{0};
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
DMLC_DECLARE_FIELD(verbosity)
Expand All @@ -24,7 +25,6 @@ struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
.describe("Flag to print out detailed breakdown of runtime.");
DMLC_DECLARE_FIELD(use_rmm).set_default(false).describe(
"Whether to use RAPIDS Memory Manager to allocate GPU memory in XGBoost");
DMLC_DECLARE_FIELD(nthread).set_lower_bound(-1).set_default(0);
}
};

Expand Down
20 changes: 15 additions & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,19 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
xgboost_CHECK_C_ARG_PTR(json_str);
Json config{Json::Load(StringView{json_str})};

for (auto& items : get<Object>(config)) {
// handle nthread, it's not a dmlc parameter.
auto& obj = get<Object>(config);
auto it = obj.find("nthread");
if (it != obj.cend()) {
auto nthread = OptionalArg<Integer>(config, "nthread", Integer::Int{0});
if (nthread > 0) {
omp_set_num_threads(nthread);
GlobalConfigThreadLocalStore::Get()->nthread = nthread;
}
get<Object>(config).erase("nthread");
}

for (auto &items : obj) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kInteger: {
items.second = String{std::to_string(get<Integer const>(items.second))};
Expand Down Expand Up @@ -183,10 +195,7 @@ XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
}
LOG(FATAL) << ss.str() << " }";
}
// The default is 0, we call omp rt only if the default is modified.
if (GlobalConfigThreadLocalStore::Get()->nthread > 0) {
omp_set_num_threads(GlobalConfigThreadLocalStore::Get()->nthread);
}

API_END();
}

Expand Down Expand Up @@ -220,6 +229,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
}
}

config["nthread"] = GlobalConfigThreadLocalStore::Get()->nthread;
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
Json::Dump(config, &local.ret_str);

Expand Down

0 comments on commit cc9bd8a

Please sign in to comment.