-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_ensemble.py
70 lines (54 loc) · 2.45 KB
/
eval_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import numpy as np
import os
import sys
import torch
import torch.nn.functional as F
import data
import models
import utils
def main(args):
torch.backends.cudnn.benchmark = True
loaders, num_classes = data.loaders(
args.dataset,
args.data_path,
args.batch_size,
args.num_workers,
args.transform,
args.use_test
)
architecture = getattr(models, args.model)
model = architecture.base(num_classes=num_classes, **architecture.kwargs)
criterion = F.cross_entropy
model.cuda()
ensemble_size = 0
predictions_sum = np.zeros((len(loaders['test'].dataset), num_classes))
for path in args.ckpt:
print(path)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state'])
predictions, targets = utils.predictions(loaders['test'], model)
acc = 100.0 * np.mean(np.argmax(predictions, axis=1) == targets)
predictions_sum += predictions
ens_acc = 100.0 * np.mean(np.argmax(predictions_sum, axis=1) == targets)
print('Model accuracy: %8.4f. Ensemble accuracy: %8.4f' % (acc, ens_acc))
if __name__ == "main":
parser = argparse.ArgumentParser(description='Ensemble evaluation')
parser.add_argument('--dataset', type=str, default='CIFAR10', metavar='DATASET',
help='dataset name (default: CIFAR10)')
parser.add_argument('--use_test', action='store_true',
help='switches between validation and test set (default: validation)')
parser.add_argument('--transform', type=str, default='VGG', metavar='TRANSFORM',
help='transform name (default: VGG)')
parser.add_argument('--data_path', type=str, default=None, metavar='PATH',
help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
help='input batch size (default: 128)')
parser.add_argument('--num-workers', type=int, default=4, metavar='N',
help='number of workers (default: 4)')
parser.add_argument('--model', type=str, default=None, metavar='MODEL',
help='model name (default: None)')
parser.add_argument('--ckpt', type=str, action='append', metavar='CKPT', required=True,
help='checkpoint to eval, pass all the models through this parameter')
args = parser.parse_args()
main(args)