diff --git a/src/io/tree.cpp b/src/io/tree.cpp index ce45d20cf454..fd33335c8260 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -416,12 +416,16 @@ std::string Tree::ToJSON() const { str_buf << "\"num_cat\":" << num_cat_ << "," << '\n'; str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n'; if (num_leaves_ == 1) { + str_buf << "\"tree_structure\":{"; if (is_linear_) { - str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n"; - str_buf << LinearModelToJSON(0) << "}" << "\n"; + str_buf << "\"leaf_value\":" << leaf_value_[0] << ", " << '\n'; + str_buf << "\"leaf_count\":" << leaf_count_[0] << ", " << '\n'; + str_buf << LinearModelToJSON(0); } else { - str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n'; + str_buf << "\"leaf_value\":" << leaf_value_[0] << ", " << '\n'; + str_buf << "\"leaf_count\":" << leaf_count_[0]; } + str_buf << "}" << '\n'; } else { str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n'; } @@ -731,6 +735,12 @@ Tree::Tree(const char* str, size_t* used_len) { is_linear_ = false; } + if (key_vals.count("leaf_count")) { + leaf_count_ = CommonC::StringToArrayFast(key_vals["leaf_count"], num_leaves_); + } else { + leaf_count_.resize(num_leaves_); + } + #ifdef USE_CUDA is_cuda_tree_ = false; #endif // USE_CUDA @@ -793,12 +803,6 @@ Tree::Tree(const char* str, size_t* used_len) { leaf_weight_.resize(num_leaves_); } - if (key_vals.count("leaf_count")) { - leaf_count_ = CommonC::StringToArrayFast(key_vals["leaf_count"], num_leaves_); - } else { - leaf_count_.resize(num_leaves_); - } - if (key_vals.count("decision_type")) { decision_type_ = CommonC::StringToArrayFast(key_vals["decision_type"], num_leaves_ - 1); } else { diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index b46526bcfaf6..4ef72888e767 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3986,12 +3986,36 @@ def test_reset_params_works_with_metric_num_class_and_boosting(): assert new_bst.params == expected_params +def test_dump_model_stump(): + X, y = load_breast_cancer(return_X_y=True) + # intentionally create a stump (tree with only a root-node) + # using restricted # samples + subidx = random.sample(range(len(y)), 30) + X = X[subidx] + y = y[subidx] + + train_data = lgb.Dataset(X, label=y) + params = { + "objective": "binary", + "verbose": -1, + "n_jobs": 1, + } + bst = lgb.train(params, train_data, num_boost_round=5) + dumped_model = bst.dump_model(5, 0) + print(dumped_model) + assert len(dumped_model["tree_info"]) == 1 + tree_structure = dumped_model["tree_info"][0]["tree_structure"] + assert "leaf_value" in tree_structure + assert "leaf_count" in tree_structure + assert tree_structure["leaf_count"] == 30 + + def test_dump_model(): X, y = load_breast_cancer(return_X_y=True) train_data = lgb.Dataset(X, label=y) params = { "objective": "binary", - "verbose": -1 + "verbose": -1, } bst = lgb.train(params, train_data, num_boost_round=5) dumped_model_str = str(bst.dump_model(5, 0)) @@ -4000,6 +4024,7 @@ def test_dump_model(): assert "leaf_const" not in dumped_model_str assert "leaf_value" in dumped_model_str assert "leaf_count" in dumped_model_str + params['linear_tree'] = True train_data = lgb.Dataset(X, label=y) bst = lgb.train(params, train_data, num_boost_round=5)