Skip to content

Commit

Permalink
Optuna to optuna-integration change (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
maciekmalachowski committed Jul 8, 2024
1 parent 95eda00 commit 4585419
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ shap>=0.42.1
seaborn>=0.11.1
wordcloud>=1.8.1
category_encoders>=2.2.2
optuna>=2.7.0
optuna-integration=>3.6.0
mljar-scikit-plot>=0.3.11
markdown
typing-extensions
Expand Down
3 changes: 2 additions & 1 deletion supervised/tuner/optuna/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import lightgbm as lgb
import numpy as np
import optuna
import optuna_integration
import pandas as pd

from supervised.algorithms.lightgbm import lightgbm_eval_metric, lightgbm_objective
Expand Down Expand Up @@ -134,7 +135,7 @@ def __call__(self, trial):
metric_name = self.eval_metric_name
if metric_name == "custom":
metric_name = self.custom_eval_metric_name
pruning_callback = optuna.integration.LightGBMPruningCallback(
pruning_callback = optuna_integration.LightGBMPruningCallback(
trial, metric_name, "validation"
)
early_stopping_callback = lgb.early_stopping(
Expand Down
3 changes: 2 additions & 1 deletion supervised/tuner/optuna/xgboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import optuna
import optuna_integration
import xgboost as xgb

from supervised.algorithms.registry import (
Expand Down Expand Up @@ -101,7 +102,7 @@ def __call__(self, trial):
if self.num_class is not None:
param["num_class"] = self.num_class
try:
pruning_callback = optuna.integration.XGBoostPruningCallback(
pruning_callback = optuna_integration.XGBoostPruningCallback(
trial, f"validation-{self.eval_metric_name}"
)
bst = xgb.train(
Expand Down

0 comments on commit 4585419

Please sign in to comment.