Skip to content

Commit

Permalink
Merge pull request #108 from Alnusjaponica/suppress-warnings
Browse files Browse the repository at this point in the history
Suppress `ExperimentalWarning`s
  • Loading branch information
eukaryo authored Apr 7, 2024
2 parents f4007ca + e69e515 commit 59ce1ff
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 250 deletions.
76 changes: 51 additions & 25 deletions tests/importance_tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

import numpy as np
import optuna
from optuna.importance import get_param_importances
Expand Down Expand Up @@ -36,7 +38,8 @@ def objective(trial: Trial) -> tuple[float, float]:
study = create_study(directions=["minimize", "minimize"], storage=storage)
study.optimize(objective, n_trials=3)

with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator())


Expand All @@ -63,9 +66,11 @@ def objective(trial: Trial) -> float:
study = create_study(storage=storage, sampler=RandomSampler())
study.optimize(objective, n_trials=3)

param_importance = get_param_importances(
study, evaluator=ShapleyImportanceEvaluator(), normalize=normalize
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
param_importance = get_param_importances(
study, evaluator=ShapleyImportanceEvaluator(), normalize=normalize
)

assert isinstance(param_importance, dict)
assert len(param_importance) == 6
Expand Down Expand Up @@ -107,9 +112,11 @@ def objective(trial: Trial) -> float:
study = create_study(storage=storage)
study.optimize(objective, n_trials=10)

param_importance = get_param_importances(
study, evaluator=ShapleyImportanceEvaluator(), params=params, normalize=normalize
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
param_importance = get_param_importances(
study, evaluator=ShapleyImportanceEvaluator(), params=params, normalize=normalize
)

assert isinstance(param_importance, dict)
assert len(param_importance) == len(params)
Expand Down Expand Up @@ -143,12 +150,14 @@ def objective(trial: Trial) -> float:
study = create_study(storage=storage)
study.optimize(objective, n_trials=3)

param_importance = get_param_importances(
study,
evaluator=ShapleyImportanceEvaluator(),
target=lambda t: t.params["x1"] + t.params["x2"],
normalize=normalize,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
param_importance = get_param_importances(
study,
evaluator=ShapleyImportanceEvaluator(),
target=lambda t: t.params["x1"] + t.params["x2"],
normalize=normalize,
)

assert isinstance(param_importance, dict)
assert len(param_importance) == 3
Expand All @@ -169,12 +178,14 @@ def objective(trial: Trial) -> float:
def test_get_param_importances_invalid_empty_study() -> None:
study = create_study()

with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator())

study.optimize(pruned_objective, n_trials=3)

with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator())


Expand All @@ -186,7 +197,8 @@ def objective(trial: Trial) -> float:
study = create_study()
study.optimize(objective, n_trials=1)

with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator())


Expand All @@ -202,15 +214,18 @@ def objective(trial: Trial) -> float:
study.optimize(objective, n_trials=3)

# None of the trials with `x2` are completed.
with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator(), params=["x2"])

# None of the trials with `x2` are completed. Adding "x1" should not matter.
with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator(), params=["x1", "x2"])

# None of the trials contain `x3`.
with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator(), params=["x3"])


Expand All @@ -222,7 +237,8 @@ def objective(trial: Trial) -> float:
study = create_study()
study.optimize(objective, n_trials=3)

with pytest.raises(ValueError):
with pytest.raises(ValueError), warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
get_param_importances(study, evaluator=ShapleyImportanceEvaluator(), params=["x1"])


Expand All @@ -235,7 +251,9 @@ def objective(trial: Trial) -> float:
study = create_study()
study.optimize(objective, n_trials=3)

param_importance = get_param_importances(study, evaluator=ShapleyImportanceEvaluator())
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
param_importance = get_param_importances(study, evaluator=ShapleyImportanceEvaluator())

assert len(param_importance) == 2
assert all([param in param_importance for param in ["x", "y"]])
Expand All @@ -253,14 +271,20 @@ def objective(trial: Trial) -> float:
study = create_study(sampler=RandomSampler(seed=0))
study.optimize(objective, n_trials=3)

evaluator = ShapleyImportanceEvaluator(seed=2)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
evaluator = ShapleyImportanceEvaluator(seed=2)
param_importance = evaluator.evaluate(study)

evaluator = ShapleyImportanceEvaluator(seed=2)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
evaluator = ShapleyImportanceEvaluator(seed=2)
param_importance_same_seed = evaluator.evaluate(study)
assert param_importance == param_importance_same_seed

evaluator = ShapleyImportanceEvaluator(seed=3)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
evaluator = ShapleyImportanceEvaluator(seed=3)
param_importance_different_seed = evaluator.evaluate(study)
assert param_importance != param_importance_different_seed

Expand All @@ -276,7 +300,9 @@ def objective(trial: Trial) -> float:
study = create_study(sampler=RandomSampler(seed=0))
study.optimize(objective, n_trials=3)

evaluator = ShapleyImportanceEvaluator(seed=0)
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
evaluator = ShapleyImportanceEvaluator(seed=0)
param_importance = evaluator.evaluate(study)
param_importance_with_target = evaluator.evaluate(
study,
Expand Down
Loading

0 comments on commit 59ce1ff

Please sign in to comment.