This repository has been archived by the owner on Aug 16, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
benchmark.py
78 lines (57 loc) · 2.48 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import itertools
import math
import time
import optuna
from optuna.importance import FanovaImportanceEvaluator
from optuna_fast_fanova import FanovaImportanceEvaluator as FastFanovaImportanceEvaluator
optuna.logging.set_verbosity(optuna.logging.ERROR)
def run_optimize(storage, trials, params):
study_name = f"study-trial{trials}-params{params}"
study = optuna.create_study(storage=storage, study_name=study_name, load_if_exists=True)
if len(study.trials) >= trials:
return study
def objective(trial: optuna.Trial):
val = 0
for i in range(params):
xi = trial.suggest_float(str(i), -4, 4)
val += (xi - 2) ** 2
return val
study.optimize(objective, n_trials=trials - len(study.trials))
return study
def print_markdown_table(results):
print("| n_trials | n_params | n_trees | fANOVA (Optuna) | fast-fanova |")
print("| -------- | -------- | ------- | --------------- | ----------- |")
for n_trials, n_params, n_trees, s1, s2 in results:
print(
f"| {n_trials} | {n_params} | {n_trees} | {s1:.3f}s | {s2:.3f}s (-{(s1-s2)/s1*100:.1f}%) |"
)
def is_importance_close(a, b):
assert len(a) == len(b)
for k1, k2 in zip(a, b):
assert k1 == k2
assert math.isclose(a[k1], b[k2])
def main():
storage = "sqlite:///benchmark-fanova.db"
results = []
for n_trials, n_params in itertools.product([100, 1000], [2, 8, 32]):
study = run_optimize(storage, n_trials, n_params)
for n_trees in [32, 64]:
start = time.time()
importances_before = optuna.importance.get_param_importances(
study, evaluator=FanovaImportanceEvaluator(n_trees=n_trees, seed=0)
)
elapsed_before = time.time() - start
start = time.time()
importances_after = optuna.importance.get_param_importances(
study, evaluator=FastFanovaImportanceEvaluator(n_trees=n_trees, seed=0)
)
elapsed_after = time.time() - start
print(
f"Before: n_trees={n_trees} elapsed={elapsed_before:.3f}\t{dict(importances_before)}"
)
print(f"After: n_trees={n_trees} elapsed={elapsed_after:.3f}\t{importances_after}")
is_importance_close(importances_before, importances_after)
results.append((n_trials, n_params, n_trees, elapsed_before, elapsed_after))
print_markdown_table(results)
if __name__ == "__main__":
main()