diff --git a/include/treelite/model_builder.h b/include/treelite/model_builder.h index 09c48c89..705bb5b4 100644 --- a/include/treelite/model_builder.h +++ b/include/treelite/model_builder.h @@ -39,6 +39,13 @@ class PostProcessorFunc; */ class ModelBuilder { public: + /*! + * \brief Set a flag to control validation behavior. + * Currently, we support "check_orphaned_nodes" (defaults to true). + * \param flag Name of the flag + * \param value Value to set the flag + */ + virtual void SetValidationFlag(std::string const& flag, bool value) = 0; /*! * \brief Start a new tree */ diff --git a/src/model_builder/model_builder.cc b/src/model_builder/model_builder.cc index b7c2e6ba..5e77c628 100644 --- a/src/model_builder/model_builder.cc +++ b/src/model_builder/model_builder.cc @@ -66,7 +66,8 @@ class ModelBuilderImpl : public ModelBuilder { current_node_key_{}, current_node_id_{}, current_state_{ModelBuilderState::kExpectTree}, - metadata_initialized_{false} {} + metadata_initialized_{false}, + flag_check_orphaned_nodes_{true} {} ModelBuilderImpl(Metadata const& metadata, TreeAnnotation const& tree_annotation, PostProcessorFunc const& postprocessor, std::vector const& base_scores, @@ -79,10 +80,17 @@ class ModelBuilderImpl : public ModelBuilder { current_node_key_{}, current_node_id_{}, current_state_{ModelBuilderState::kExpectTree}, - metadata_initialized_{false} { + metadata_initialized_{false}, + flag_check_orphaned_nodes_{true} { InitializeMetadataImpl(metadata, tree_annotation, postprocessor, base_scores, attributes); } + void SetValidationFlag(std::string const& flag, bool value) override { + if (flag == "check_orphaned_nodes") { + flag_check_orphaned_nodes_ = value; + } + } + void StartTree() override { CheckStateWithDiagnostic("StartTree()", {ModelBuilderState::kExpectTree}, current_state_); @@ -121,17 +129,19 @@ class ModelBuilderImpl : public ModelBuilder { orphaned[cright] = false; } } - auto itr = std::find(orphaned.begin(), orphaned.end(), true); - if (itr != orphaned.end()) { - auto orphaned_node_id = *itr; - for (auto [k, v] : node_id_map_) { - if (v == orphaned_node_id) { - TREELITE_LOG(FATAL) << "Node with key " << k << " is orphaned -- it cannot be reached " - << "from the root node"; + if (flag_check_orphaned_nodes_) { + auto itr = std::find(orphaned.begin(), orphaned.end(), true); + if (itr != orphaned.end()) { + auto orphaned_node_id = *itr; + for (auto [k, v] : node_id_map_) { + if (v == orphaned_node_id) { + TREELITE_LOG(FATAL) << "Node with key " << k << " is orphaned -- it cannot be reached " + << "from the root node"; + } } + TREELITE_LOG(FATAL) << "Node at index " << orphaned_node_id << " is orphaned " + << "-- it cannot be reached from the root node"; } - TREELITE_LOG(FATAL) << "Node at index " << orphaned_node_id << " is orphaned " - << "-- it cannot be reached from the root node"; } auto& trees = std::get>(model_->variant_).trees; @@ -292,6 +302,7 @@ class ModelBuilderImpl : public ModelBuilder { int current_node_id_; // current node ID (internal) ModelBuilderState current_state_; bool metadata_initialized_{false}; + bool flag_check_orphaned_nodes_{true}; void CheckStateWithDiagnostic(std::string const& func_name, std::vector const& valid_states, ModelBuilderState actual_state) { diff --git a/src/model_loader/xgboost_json.cc b/src/model_loader/xgboost_json.cc index ba8537f5..2ef8879a 100644 --- a/src/model_loader/xgboost_json.cc +++ b/src/model_loader/xgboost_json.cc @@ -377,6 +377,7 @@ bool GBTreeModelHandler::StartArray() { if (this->should_ignore_upcoming_value()) { return push_handler(); } + output.builder->SetValidationFlag("check_orphaned_nodes", false); return (push_key_handler>( "trees", reg_tree_params, *output.builder) || push_key_handler, std::vector>("tree_info", output.tree_info)