forked from XieBinghui/INSET
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
30 lines (25 loc) · 983 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from model.EquiVSet import EquiVSet
from utils.config import MOONS_CONFIG, GAUSSIAN_CONFIG, AMAZON_CONFIG, CELEBA_CONFIG
if __name__ == "__main__":
argparser = EquiVSet.get_model_specific_argparser()
hparams = argparser.parse_args()
data_name = hparams.data_name
if data_name == 'moons':
hparams.__dict__.update(MOONS_CONFIG)
elif data_name == 'gaussian':
hparams.__dict__.update(GAUSSIAN_CONFIG)
elif data_name == 'amazon':
hparams.__dict__.update(AMAZON_CONFIG)
elif data_name == 'celeba':
hparams.__dict__.update(CELEBA_CONFIG)
else:
raise ValueError('invalid dataset...')
model = EquiVSet(hparams)
if hparams.train:
model.run_training_sessions()
else:
model.load()
print('Loaded model with: %s' % model.flag_hparams())
val_perf, test_perf = model.run_test()
print('Val: {:8.2f}'.format(val_perf))
print('Test: {:8.2f}'.format(test_perf))