-
Notifications
You must be signed in to change notification settings - Fork 1
/
hpo.py
61 lines (48 loc) · 2.09 KB
/
hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import optuna
from utils import (
compute_nonconformity_score,
get_conformal_set,
coverage_size
)
from average_k import get_average_k_sets
def run_hpo(objective_str, return_obj_fn, direction, init_h_params, cfg):
hpo_cfg = init_h_params.copy()
def objective(trial):
h_params = {}
for param, value in hpo_cfg.items():
if isinstance(value, int):
new_val = trial.suggest_int(param, value-1, value+1)
elif isinstance(value, float):
new_val = trial.suggest_float(param, value/10, value*10, log=True)
else:
raise ValueError(f"Hyper to tune is type {type(value)}, expected int or float")
h_params[param] = new_val
return return_obj_fn(h_params)
if not cfg["verbose"]:
optuna.logging.set_verbosity(optuna.logging.WARNING)
sampler = optuna.samplers.TPESampler(seed=cfg["seed"])
study = optuna.create_study(direction=direction, sampler=sampler)
study.optimize(objective, n_trials=cfg["hpo_iterations"])
best_value = study.best_trial.values
best_h_params = study.best_params
print(f"Result of HPO tuning for {objective_str} is {best_value}")
return best_h_params
def run_hpo_conformal(logits_calib, targets_calib, logits_val, targets_val, used_labels, init_h_params, cfg, conformal_category, bins_calib=None, bins_test=None):
# Define tuning objective
def return_size(h_params):
non_conf_scores_calib, non_conf_scores_val = compute_nonconformity_score(
logits_calib, targets_calib, logits_val, h_params, cfg
)
conformal_preds_val = get_conformal_set(
non_conf_scores_calib,
non_conf_scores_val,
labels=used_labels,
confidence=1-cfg["alpha"],
conformal_category=conformal_category,
bins_cal=bins_calib,
bins_test=bins_test
)
_, size = coverage_size(conformal_preds_val, targets_val)
return size
best_h_params = run_hpo("set size", return_size, "minimize", init_h_params, cfg)
return best_h_params