-
Notifications
You must be signed in to change notification settings - Fork 6
/
tuning.py
38 lines (31 loc) · 1.52 KB
/
tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import time
import argparse
import numpy as np
import pickle as pkl
import os
from math import log
from citation import train_regression
from models import get_model
from utils import sgc_precompute, load_citation, set_seed
from args import get_citation_args
import torch
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
# Arguments
args = get_citation_args()
# setting random seeds
set_seed(args.seed, args.cuda)
# Hyperparameter optimization
space = {'weight_decay' : hp.loguniform('weight_decay', log(1e-10), log(1e-4))}
adj, features, labels, idx_train, idx_val, idx_test = load_citation(args.dataset, args.normalization, args.cuda)
if args.model == "SGC": features, precompute_time = sgc_precompute(features, adj, args.degree)
def sgc_objective(space):
model = get_model(args.model, features.size(1), labels.max().item()+1, args.hidden, args.dropout, args.cuda)
model, acc_val, _ = train_regression(model, features[idx_train], labels[idx_train], features[idx_val], labels[idx_val],
args.epochs, space['weight_decay'], args.lr, args.dropout)
print('weight decay: {:.2e} '.format(space['weight_decay']) + 'accuracy: {:.4f}'.format(acc_val))
return {'loss': -acc_val, 'status': STATUS_OK}
best = fmin(sgc_objective, space=space, algo=tpe.suggest, max_evals=60)
print("Best weight decay: {:.2e}".format(best["weight_decay"]))
os.makedirs("./{}-tuning".format(args.model), exist_ok=True)
path = '{}-tuning/{}.txt'.format(args.model, args.dataset)
with open(path, 'wb') as f: pkl.dump(best, f)