-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_script.py
47 lines (39 loc) · 1.27 KB
/
run_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import pandas as pd
import os
import argparse
from bnn_ensemble import EnsembleMLPCV
from bnn_vi import BayesianMLPCV
from bnn_sghmc import ExactBayesianMLPCV
from sklearn.model_selection import RepeatedKFold
from data_utils import load_dataset
#%%
parser = argparse.ArgumentParser('Testing UQ in simple neural networks')
parser.add_argument('--dataset',type=str,required=True)
parser.add_argument('--save_dir',type=str,required=True)
parser.add_argument('--n_jobs',type=int,required=True)
args = parser.parse_args()
save_path = os.path.join(args.save_dir,args.dataset)
if not os.path.exists(save_path):
os.makedirs(save_path)
#%%
X,y = load_dataset(args.dataset)
cv = RepeatedKFold(n_splits=10,n_repeats=2,random_state=1)
#%%
tests = [
BayesianMLPCV(X,y,'independent',n_jobs=args.n_jobs), # independent VI
BayesianMLPCV(X,y,'lowrank',n_jobs=args.n_jobs),
EnsembleMLPCV(X,y,n_models=10,n_jobs=args.n_jobs),
ExactBayesianMLPCV(X,y,n_jobs=args.n_jobs)
]
names = ['vi_ind','vi_lowrank','ensemble','sghmc']
for i,test in enumerate(tests):
try:
scores = test.cvloss(cv)
except Exception as e:
print(e)
continue
scores = pd.DataFrame(scores)
scores.to_csv(
os.path.join(save_path,names[i]+'.csv'),index=False
)