Skip to content

Commit

Permalink
Merge pull request #1829 from Sherry-XLL/master
Browse files Browse the repository at this point in the history
FIX: fix issue #1825 and outdated parameters in LightGBM
  • Loading branch information
Sherry-XLL authored Aug 12, 2023
2 parents 11f4399 + 35636ef commit 2accad5
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 24 deletions.
2 changes: 1 addition & 1 deletion recbole/data/dataset/kg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _add_auxiliary_relation(self):
reverse_kg_data = {
self.head_entity_field: original_tids,
self.relation_field: reverse_rels,
self.head_entity_field: original_hids,
self.tail_entity_field: original_hids,
}
reverse_kg_feat = pd.DataFrame(reverse_kg_data)
self.kg_feat = pd.concat([self.kg_feat, reverse_kg_feat])
Expand Down
5 changes: 0 additions & 5 deletions recbole/properties/model/lightgbm.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Dataset
convert_token_to_onehot: False # (bool) Whether to convert token type features into one-hot form.
token_num_threhold: 10000 # (int) The threshold of one-hot conversion.
lgb_silent: False # (bool) Whether to print messages during construction.

# Train
lgb_model: ~ # (file name of stored lgb model or 'Booster' instance)
lgb_params: # (dict) Booster params.
boosting: gbdt
num_leaves: 90
Expand All @@ -15,7 +13,4 @@ lgb_params: # (dict) Booster params.
lambda_l1: 0.1
metric: ['auc', 'binary_logloss']
force_row_wise: True
lgb_learning_rates: ~ # (list, callable or None) List of learning rates or a customized function.
lgb_num_boost_round: 300 # (int) Number of boosting iterations.
lgb_early_stopping_rounds: ~ # (int or None) Activates early stopping.
lgb_verbose_eval: 100 # (bool or int) Requires at least one validation data to print evaluation metrics.
20 changes: 2 additions & 18 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,18 +1150,11 @@ def __init__(self, config, model):
super(LightGBMTrainer, self).__init__(config, model)

self.lgb = __import__("lightgbm")
self.boost_model = config["lgb_model"]
self.silent = config["lgb_silent"]

# train params
self.params = config["lgb_params"]
self.num_boost_round = config["lgb_num_boost_round"]
self.evals = ()
self.early_stopping_rounds = config["lgb_early_stopping_rounds"]
self.evals_result = {}
self.verbose_eval = config["lgb_verbose_eval"]
self.learning_rates = config["lgb_learning_rates"]
self.callbacks = None
self.deval_data = self.deval_label = None
self.eval_pred = self.eval_true = None

Expand All @@ -1174,7 +1167,7 @@ def _interaction_to_lib_datatype(self, dataloader):
dataset(lgb.Dataset): Data in the form of 'lgb.Dataset'.
"""
data, label = self._interaction_to_sparse(dataloader)
return self.lgb.Dataset(data=data, label=label, silent=self.silent)
return self.lgb.Dataset(data=data, label=label)

def _train_at_once(self, train_data, valid_data):
r"""
Expand All @@ -1187,16 +1180,7 @@ def _train_at_once(self, train_data, valid_data):
self.dvalid = self._interaction_to_lib_datatype(valid_data)
self.evals = [self.dtrain, self.dvalid]
self.model = self.lgb.train(
self.params,
self.dtrain,
self.num_boost_round,
self.evals,
early_stopping_rounds=self.early_stopping_rounds,
evals_result=self.evals_result,
verbose_eval=self.verbose_eval,
learning_rates=self.learning_rates,
init_model=self.boost_model,
callbacks=self.callbacks,
self.params, self.dtrain, self.num_boost_round, self.evals
)

self.model.save_model(self.temp_file)
Expand Down

0 comments on commit 2accad5

Please sign in to comment.