Skip to content

Commit

Permalink
Merge with microsoft#5964
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Aug 3, 2024
1 parent 48e6b96 commit 26b9859
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
22 changes: 13 additions & 9 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,16 @@ std::string Tree::ToJSON() const {
str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":{";
if (is_linear_) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n";
str_buf << LinearModelToJSON(0) << "}" << "\n";
str_buf << "\"leaf_value\":" << leaf_value_[0] << ", " << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[0] << ", " << '\n';
str_buf << LinearModelToJSON(0);
} else {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[0] << ", " << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[0];
}
str_buf << "}" << '\n';
} else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
}
Expand Down Expand Up @@ -731,6 +735,12 @@ Tree::Tree(const char* str, size_t* used_len) {
is_linear_ = false;
}

if (key_vals.count("leaf_count")) {
leaf_count_ = CommonC::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
leaf_count_.resize(num_leaves_);
}

#ifdef USE_CUDA
is_cuda_tree_ = false;
#endif // USE_CUDA
Expand Down Expand Up @@ -793,12 +803,6 @@ Tree::Tree(const char* str, size_t* used_len) {
leaf_weight_.resize(num_leaves_);
}

if (key_vals.count("leaf_count")) {
leaf_count_ = CommonC::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
leaf_count_.resize(num_leaves_);
}

if (key_vals.count("decision_type")) {
decision_type_ = CommonC::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
} else {
Expand Down
29 changes: 28 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,10 +3853,37 @@ def test_reset_params_works_with_metric_num_class_and_boosting():
assert new_bst.params == expected_params


def test_dump_model_stump():
X, y = load_breast_cancer(return_X_y=True)
# intentionally create a stump (tree with only a root-node)
# using restricted # samples
subidx = random.sample(range(len(y)), 30)
X = X[subidx]
y = y[subidx]

train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1,
"n_jobs": 1,
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model = bst.dump_model(5, 0)
print(dumped_model)
assert len(dumped_model["tree_info"]) == 1
tree_structure = dumped_model["tree_info"][0]["tree_structure"]
assert "leaf_value" in tree_structure
assert "leaf_count" in tree_structure
assert tree_structure["leaf_count"] == 30


def test_dump_model():
X, y = load_breast_cancer(return_X_y=True)
train_data = lgb.Dataset(X, label=y)
params = {"objective": "binary", "verbose": -1}
params = {
"objective": "binary",
"verbose": -1,
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model = bst.dump_model(5, 0)
dumped_model_str = str(dumped_model)
Expand Down

0 comments on commit 26b9859

Please sign in to comment.