diff --git a/include/LightGBM/cuda/cuda_tree.hpp b/include/LightGBM/cuda/cuda_tree.hpp index e2836baa2be5..7ab06190481b 100644 --- a/include/LightGBM/cuda/cuda_tree.hpp +++ b/include/LightGBM/cuda/cuda_tree.hpp @@ -77,7 +77,7 @@ class CUDATree : public Tree { const data_size_t* used_data_indices, data_size_t num_data, double* score) const override; - inline void AsConstantTree(double val) override; + inline void AsConstantTree(double val, int count) override; const int* cuda_leaf_parent() const { return cuda_leaf_parent_; } diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 0c4a41f46a87..c28ddd140c48 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -228,13 +228,14 @@ class Tree { shrinkage_ = 1.0f; } - virtual inline void AsConstantTree(double val) { + virtual inline void AsConstantTree(double val, int count = 0) { num_leaves_ = 1; shrinkage_ = 1.0f; leaf_value_[0] = val; if (is_linear_) { leaf_const_[0] = val; } + leaf_count_[0] = count; } /*! \brief Serialize this object to string*/ @@ -563,7 +564,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature, leaf_parent_[leaf] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx; // save current leaf value to internal node before change - internal_weight_[new_node_idx] = leaf_weight_[leaf]; + internal_weight_[new_node_idx] = left_weight + right_weight; internal_value_[new_node_idx] = leaf_value_[leaf]; internal_count_[new_node_idx] = left_cnt + right_cnt; leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value; diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index af4d757f480b..90b01f25fa80 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3913,7 +3913,7 @@ def _get_split_feature( return feature_name def _is_single_node_tree(tree: Dict[str, Any]) -> bool: - return set(tree.keys()) == {"leaf_value"} + return set(tree.keys()) == {"leaf_value", "leaf_count"} # Create the node record, and populate universal data members node: Dict[str, Union[int, str, None]] = OrderedDict() diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 937b44fcc8aa..f966a275ae8b 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -419,7 +419,10 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { score_updater->AddScore(init_scores[cur_tree_id], cur_tree_id); } } - new_tree->AsConstantTree(init_scores[cur_tree_id]); + new_tree->AsConstantTree(init_scores[cur_tree_id], num_data_); + } else { + // extend init_scores with zeros + new_tree->AsConstantTree(0, num_data_); } } // add model diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp index e6101dc30a39..3eb065b30f2d 100644 --- a/src/boosting/rf.hpp +++ b/src/boosting/rf.hpp @@ -168,7 +168,7 @@ class RF : public GBDT { output = init_scores_[cur_tree_id]; } } - new_tree->AsConstantTree(output); + new_tree->AsConstantTree(output, num_data_); MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_)); UpdateScore(new_tree.get(), cur_tree_id); MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1)); diff --git a/src/io/cuda/cuda_tree.cpp b/src/io/cuda/cuda_tree.cpp index 923e51961e0b..c5dee89ca3af 100644 --- a/src/io/cuda/cuda_tree.cpp +++ b/src/io/cuda/cuda_tree.cpp @@ -330,9 +330,10 @@ void CUDATree::SyncLeafOutputFromCUDAToHost() { CopyFromCUDADeviceToHost(leaf_value_.data(), cuda_leaf_value_, leaf_value_.size(), __FILE__, __LINE__); } -void CUDATree::AsConstantTree(double val) { - Tree::AsConstantTree(val); +void CUDATree::AsConstantTree(double val, int count) { + Tree::AsConstantTree(val, count); CopyFromHostToCUDADevice(cuda_leaf_value_, &val, 1, __FILE__, __LINE__); + CopyFromHostToCUDADevice(cuda_leaf_count_, &count, 1, __FILE__, __LINE__); } } // namespace LightGBM diff --git a/src/io/cuda/cuda_tree.cu b/src/io/cuda/cuda_tree.cu index 62020c3a09ae..87abfc1353b4 100644 --- a/src/io/cuda/cuda_tree.cu +++ b/src/io/cuda/cuda_tree.cu @@ -94,7 +94,7 @@ __global__ void SplitKernel( // split information split_gain[new_node_index] = static_cast(cuda_split_info->gain); } else if (thread_index == 4) { // save current leaf value to internal node before change - internal_weight[new_node_index] = leaf_weight[leaf_index]; + internal_weight[new_node_index] = cuda_split_info->left_sum_hessians + cuda_split_info->right_sum_hessians; leaf_weight[leaf_index] = cuda_split_info->left_sum_hessians; } else if (thread_index == 5) { internal_value[new_node_index] = leaf_value[leaf_index]; @@ -210,7 +210,7 @@ __global__ void SplitCategoricalKernel( // split information split_gain[new_node_index] = static_cast(cuda_split_info->gain); } else if (thread_index == 4) { // save current leaf value to internal node before change - internal_weight[new_node_index] = leaf_weight[leaf_index]; + internal_weight[new_node_index] = cuda_split_info->left_sum_hessians + cuda_split_info->right_sum_hessians; leaf_weight[leaf_index] = cuda_split_info->left_sum_hessians; } else if (thread_index == 5) { internal_value[new_node_index] = leaf_value[leaf_index]; diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 4312b4f65002..975f09d209df 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -416,12 +416,15 @@ std::string Tree::ToJSON() const { str_buf << "\"num_cat\":" << num_cat_ << "," << '\n'; str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n'; if (num_leaves_ == 1) { + str_buf << "\"tree_structure\":{"; + str_buf << "\"leaf_value\":" << leaf_value_[0] << ", " << '\n'; if (is_linear_) { - str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n"; - str_buf << LinearModelToJSON(0) << "}" << "\n"; + str_buf << "\"leaf_count\":" << leaf_count_[0] << ", " << '\n'; + str_buf << LinearModelToJSON(0); } else { - str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n'; + str_buf << "\"leaf_count\":" << leaf_count_[0]; } + str_buf << "}" << '\n'; } else { str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n'; } @@ -731,6 +734,12 @@ Tree::Tree(const char* str, size_t* used_len) { is_linear_ = false; } + if (key_vals.count("leaf_count")) { + leaf_count_ = CommonC::StringToArrayFast(key_vals["leaf_count"], num_leaves_); + } else { + leaf_count_.resize(num_leaves_); + } + #ifdef USE_CUDA is_cuda_tree_ = false; #endif // USE_CUDA @@ -793,12 +802,6 @@ Tree::Tree(const char* str, size_t* used_len) { leaf_weight_.resize(num_leaves_); } - if (key_vals.count("leaf_count")) { - leaf_count_ = CommonC::StringToArrayFast(key_vals["leaf_count"], num_leaves_); - } else { - leaf_count_.resize(num_leaves_); - } - if (key_vals.count("decision_type")) { decision_type_ = CommonC::StringToArrayFast(key_vals["decision_type"], num_leaves_ - 1); } else { diff --git a/src/treelearner/cuda/cuda_leaf_splits.cpp b/src/treelearner/cuda/cuda_leaf_splits.cpp index 57b5b777c142..803d4674ee48 100644 --- a/src/treelearner/cuda/cuda_leaf_splits.cpp +++ b/src/treelearner/cuda/cuda_leaf_splits.cpp @@ -38,12 +38,14 @@ void CUDALeafSplits::InitValues( const double lambda_l1, const double lambda_l2, const score_t* cuda_gradients, const score_t* cuda_hessians, const data_size_t* cuda_bagging_data_indices, const data_size_t* cuda_data_indices_in_leaf, - const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) { + const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, + double* root_sum_gradients, double* root_sum_hessians) { cuda_gradients_ = cuda_gradients; cuda_hessians_ = cuda_hessians; cuda_sum_of_gradients_buffer_.SetValue(0); cuda_sum_of_hessians_buffer_.SetValue(0); LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf); + CopyFromCUDADeviceToHost(root_sum_gradients, cuda_sum_of_gradients_buffer_.RawData(), 1, __FILE__, __LINE__); CopyFromCUDADeviceToHost(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__); } @@ -53,11 +55,12 @@ void CUDALeafSplits::InitValues( const int16_t* cuda_gradients_and_hessians, const data_size_t* cuda_bagging_data_indices, const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, - hist_t* cuda_hist_in_leaf, double* root_sum_hessians, + hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians, const score_t* grad_scale, const score_t* hess_scale) { cuda_gradients_ = reinterpret_cast(cuda_gradients_and_hessians); cuda_hessians_ = nullptr; LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale); + CopyFromCUDADeviceToHost(root_sum_gradients, cuda_sum_of_gradients_buffer_.RawData(), 1, __FILE__, __LINE__); CopyFromCUDADeviceToHost(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__); } diff --git a/src/treelearner/cuda/cuda_leaf_splits.hpp b/src/treelearner/cuda/cuda_leaf_splits.hpp index 33a9ea578a1f..c2635346098b 100644 --- a/src/treelearner/cuda/cuda_leaf_splits.hpp +++ b/src/treelearner/cuda/cuda_leaf_splits.hpp @@ -44,14 +44,14 @@ class CUDALeafSplits { const score_t* cuda_gradients, const score_t* cuda_hessians, const data_size_t* cuda_bagging_data_indices, const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, - hist_t* cuda_hist_in_leaf, double* root_sum_hessians); + hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians); void InitValues( const double lambda_l1, const double lambda_l2, const int16_t* cuda_gradients_and_hessians, const data_size_t* cuda_bagging_data_indices, const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, - hist_t* cuda_hist_in_leaf, double* root_sum_hessians, + hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians, const score_t* grad_scale, const score_t* hess_scale); void InitValues(); diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp index 8f8ff15f0715..952ef52f8023 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp @@ -66,6 +66,7 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_ leaf_best_split_default_left_.resize(config_->num_leaves, 0); leaf_num_data_.resize(config_->num_leaves, 0); leaf_data_start_.resize(config_->num_leaves, 0); + leaf_sum_gradients_.resize(config_->num_leaves, 0.0f); leaf_sum_hessians_.resize(config_->num_leaves, 0.0f); if (!boosting_on_cuda_) { @@ -122,6 +123,7 @@ void CUDASingleGPUTreeLearner::BeforeTrain() { cuda_data_partition_->cuda_data_indices(), root_num_data, cuda_histogram_constructor_->cuda_hist_pointer(), + &leaf_sum_gradients_[0], &leaf_sum_hessians_[0], cuda_gradient_discretizer_->grad_scale_ptr(), cuda_gradient_discretizer_->hess_scale_ptr()); @@ -137,6 +139,7 @@ void CUDASingleGPUTreeLearner::BeforeTrain() { cuda_data_partition_->cuda_data_indices(), root_num_data, cuda_histogram_constructor_->cuda_hist_pointer(), + &leaf_sum_gradients_[0], &leaf_sum_hessians_[0]); } leaf_num_data_[0] = root_num_data; @@ -162,6 +165,12 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, const bool track_branch_features = !(config_->interaction_constraints_vector.empty()); std::unique_ptr tree(new CUDATree(config_->num_leaves, track_branch_features, config_->linear_tree, config_->gpu_device_id, has_categorical_feature_)); + // set the root value by hand, as it is not handled by splits + tree->SetLeafOutput(0, CUDALeafSplits::CalculateSplittedLeafOutput( + leaf_sum_gradients_[smaller_leaf_index_], leaf_sum_hessians_[smaller_leaf_index_], + config_->lambda_l1, config_->lambda_l2, config_->path_smooth, + static_cast(num_data_), 0)); + tree->SyncLeafOutputFromHostToCUDA(); for (int i = 0; i < config_->num_leaves - 1; ++i) { global_timer.Start("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf"); const data_size_t num_data_in_smaller_leaf = leaf_num_data_[smaller_leaf_index_]; @@ -293,8 +302,6 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, best_split_info); } - double sum_left_gradients = 0.0f; - double sum_right_gradients = 0.0f; cuda_data_partition_->Split(best_split_info, best_leaf_index_, right_leaf_index, @@ -313,10 +320,10 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, &leaf_data_start_[right_leaf_index], &leaf_sum_hessians_[best_leaf_index_], &leaf_sum_hessians_[right_leaf_index], - &sum_left_gradients, - &sum_right_gradients); + &leaf_sum_gradients_[best_leaf_index_], + &leaf_sum_gradients_[right_leaf_index]); #ifdef DEBUG - CheckSplitValid(best_leaf_index_, right_leaf_index, sum_left_gradients, sum_right_gradients); + CheckSplitValid(best_leaf_index_, right_leaf_index); #endif // DEBUG smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index); larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_); @@ -374,6 +381,7 @@ void CUDASingleGPUTreeLearner::ResetConfig(const Config* config) { leaf_best_split_default_left_.resize(config_->num_leaves, 0); leaf_num_data_.resize(config_->num_leaves, 0); leaf_data_start_.resize(config_->num_leaves, 0); + leaf_sum_gradients_.resize(config_->num_leaves, 0.0f); leaf_sum_hessians_.resize(config_->num_leaves, 0.0f); } cuda_histogram_constructor_->ResetConfig(config); @@ -562,9 +570,7 @@ void CUDASingleGPUTreeLearner::SelectFeatureByNode(const Tree* tree) { #ifdef DEBUG void CUDASingleGPUTreeLearner::CheckSplitValid( const int left_leaf, - const int right_leaf, - const double split_sum_left_gradients, - const double split_sum_right_gradients) { + const int right_leaf) { std::vector left_data_indices(leaf_num_data_[left_leaf]); std::vector right_data_indices(leaf_num_data_[right_leaf]); CopyFromCUDADeviceToHost(left_data_indices.data(), @@ -585,9 +591,9 @@ void CUDASingleGPUTreeLearner::CheckSplitValid( sum_right_gradients += host_gradients_[index]; sum_right_hessians += host_hessians_[index]; } - CHECK_LE(std::fabs(sum_left_gradients - split_sum_left_gradients), 1e-6f); + CHECK_LE(std::fabs(sum_left_gradients - leaf_sum_gradients_[left_leaf]), 1e-6f); CHECK_LE(std::fabs(sum_left_hessians - leaf_sum_hessians_[left_leaf]), 1e-6f); - CHECK_LE(std::fabs(sum_right_gradients - split_sum_right_gradients), 1e-6f); + CHECK_LE(std::fabs(sum_right_gradients - leaf_sum_gradients_[right_leaf]), 1e-6f); CHECK_LE(std::fabs(sum_right_hessians - leaf_sum_hessians_[right_leaf]), 1e-6f); } #endif // DEBUG diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp index a1ea79efa1a1..b50d3a8884ca 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp @@ -71,8 +71,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { #ifdef DEBUG void CheckSplitValid( - const int left_leaf, const int right_leaf, - const double sum_left_gradients, const double sum_right_gradients); + const int left_leaf, const int right_leaf); #endif // DEBUG void RenewDiscretizedTreeLeaves(CUDATree* cuda_tree); @@ -103,6 +102,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { std::vector leaf_best_split_default_left_; std::vector leaf_num_data_; std::vector leaf_data_start_; + std::vector leaf_sum_gradients_; std::vector leaf_sum_hessians_; int smaller_leaf_index_; int larger_leaf_index_; diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index f3a88bd18679..14ede072dc9e 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -201,6 +201,12 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians auto tree_ptr = tree.get(); constraints_->ShareTreePointer(tree_ptr); + // set the root value by hand, as it is not handled by splits + tree->SetLeafOutput(0, FeatureHistogram::CalculateSplittedLeafOutput( + smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(), + config_->lambda_l1, config_->lambda_l2, config_->max_delta_step, + BasicConstraint(), config_->path_smooth, static_cast(num_data_), 0)); + // root leaf int left_leaf = 0; int cur_depth = 1; diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 247f2eb10a54..1929d4ba53c5 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -2,7 +2,6 @@ """Tests for lightgbm.dask module""" import inspect -import random import socket from itertools import groupby from os import getenv @@ -1451,21 +1450,27 @@ def test_init_score(task, output, cluster): _, _, _, _, dX, dy, dw, dg = _create_data(objective=task, output=output, group=None) model_factory = task_to_dask_factory[task] + rnd = np.random.RandomState(42) params = {"n_estimators": 1, "num_leaves": 2, "time_out": 5} - init_score = random.random() - size_factor = 1 - if task == "multiclass-classification": - size_factor = 3 # number of classes + num_classes = 3 if task == "multiclass-classification" else 1 if output.startswith("dataframe"): - init_scores = dy.map_partitions(lambda x: pd.DataFrame([[init_score] * size_factor] * x.size)) + init_scores = dy.map_partitions(lambda x: pd.DataFrame(rnd.uniform(size=(x.size, num_classes)))) else: - init_scores = dy.map_blocks(lambda x: np.full((x.size, size_factor), init_score)) + init_scores = dy.map_blocks(lambda x: rnd.uniform(size=(x.size, num_classes))) + model = model_factory(client=client, **params) - model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg) - # value of the root node is 0 when init_score is set - assert model.booster_.trees_to_dataframe()["value"][0] == 0 + model.fit(dX, dy, sample_weight=dw, group=dg) + pred = model.predict(dX, raw_score=True) + + model_init_score = model_factory(client=client, **params) + model_init_score.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg) + pred_init_score = model_init_score.predict(dX, raw_score=True) + + # check if init score changes predictions + with pytest.raises(AssertionError): + assert_eq(pred, pred_init_score) def sklearn_checks_to_run(): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 74f6939c8371..d3883a74fe89 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -24,6 +24,7 @@ from .utils import ( SERIALIZERS, + assert_all_trees_valid, assert_silent, dummy_obj, load_breast_cancer, @@ -3855,21 +3856,65 @@ def test_reset_params_works_with_metric_num_class_and_boosting(): assert new_bst.params == expected_params -def test_dump_model(): +@pytest.mark.parametrize("linear_tree", [False, True]) +def test_dump_model_stump(linear_tree): X, y = load_breast_cancer(return_X_y=True) - train_data = lgb.Dataset(X, label=y) - params = {"objective": "binary", "verbose": -1} + # intentionally create a stump (tree with only a root-node) + # using restricted # samples + subidx = random.sample(range(len(y)), 30) + + train_data = lgb.Dataset(X[subidx], label=y[subidx]) + params = { + "objective": "binary", + "verbose": -1, + "linear_tree": linear_tree, + } bst = lgb.train(params, train_data, num_boost_round=5) - dumped_model_str = str(bst.dump_model(5, 0)) + dumped_model = bst.dump_model(5, 0) + tree_structure = dumped_model["tree_info"][0]["tree_structure"] + assert len(dumped_model["tree_info"]) == 1 + assert "leaf_value" in tree_structure + assert tree_structure["leaf_count"] == 30 + + +def test_dump_model(): + offset = 100 + X, y = make_synthetic_regression() + train_data = lgb.Dataset(X, label=y + offset) + + params = { + "objective": "regression", + "verbose": -1, + "boost_from_average": True, + } + bst = lgb.train(params, train_data, num_boost_round=5) + dumped_model = bst.dump_model(5, 0) + dumped_model_str = str(dumped_model) assert "leaf_features" not in dumped_model_str assert "leaf_coeff" not in dumped_model_str assert "leaf_const" not in dumped_model_str assert "leaf_value" in dumped_model_str assert "leaf_count" in dumped_model_str - params["linear_tree"] = True + + for tree in dumped_model["tree_info"]: + assert not np.all(tree["tree_structure"]["internal_value"] == 0) + + np.testing.assert_allclose(dumped_model["tree_info"][0]["tree_structure"]["internal_value"], offset, atol=1) + assert_all_trees_valid(dumped_model) + + +def test_dump_model_linear(): + X, y = load_breast_cancer(return_X_y=True) + params = { + "objective": "binary", + "verbose": -1, + "linear_tree": True, + } train_data = lgb.Dataset(X, label=y) bst = lgb.train(params, train_data, num_boost_round=5) - dumped_model_str = str(bst.dump_model(5, 0)) + dumped_model = bst.dump_model(5, 0) + assert_all_trees_valid(dumped_model) + dumped_model_str = str(dumped_model) assert "leaf_features" in dumped_model_str assert "leaf_coeff" in dumped_model_str assert "leaf_const" in dumped_model_str diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index 8aacef13f7b8..c8ef7a9fa139 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -225,3 +225,38 @@ def np_assert_array_equal(*args, **kwargs): if not _numpy_testing_supports_strict_kwarg: kwargs.pop("strict") np.testing.assert_array_equal(*args, **kwargs) + + +def assert_subtree_valid(root): + """Recursively checks the validity of a subtree rooted at `root`. + + Currently it only checks whether weights and counts are consistent between + all parent nodes and their children. + + Parameters + ---------- + root : dict + A dictionary representing the root of the subtree. + It should be produced by dump_model() + + Returns + ------- + tuple + A tuple containing the weight and count of the subtree rooted at `root`. + """ + if "leaf_count" in root: + return (root["leaf_weight"], root["leaf_count"]) + + left_child = root["left_child"] + right_child = root["right_child"] + (l_w, l_c) = assert_subtree_valid(left_child) + (r_w, r_c) = assert_subtree_valid(right_child) + assert np.allclose(root["internal_weight"], l_w + r_w) + assert np.allclose(root["internal_count"], l_c + r_c) + return (root["internal_weight"], root["internal_count"]) + + +def assert_all_trees_valid(model_dump): + for idx, tree in enumerate(model_dump["tree_info"]): + assert tree["tree_index"] == idx + assert_subtree_valid(tree["tree_structure"])