forked from chenxin061/pdarts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_imagenet.py
102 lines (81 loc) · 3.26 KB
/
test_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import sys
import numpy as np
import torch
import utils
import glob
import random
import logging
import argparse
import torch.nn as nn
import genotypes
import torch.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from model import NetworkImageNet as Network
parser = argparse.ArgumentParser("imagenet")
parser.add_argument('--data', type=str, default='../data/imagenet/', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
parser.add_argument('--model_path', type=str, default='../models/imagenet.pth.tar', help='path of pretrained model')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--arch', type=str, default='PDARTS', help='which architecture to use')
args = parser.parse_args()
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
CLASSES = 1000
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
cudnn.enabled=True
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
model = nn.DataParallel(model)
model = model.cuda()
model.load_state_dict(torch.load(args.model_path)['state_dict'])
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
validdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
valid_data = dset.ImageFolder(
validdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=False, num_workers=4)
model.module.drop_path_prob = 0.0
valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
logging.info('Valid_acc_top1 %f', valid_acc_top1)
logging.info('Valid_acc_top5 %f', valid_acc_top5)
def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda()
with torch.no_grad():
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data.item(), n)
top1.update(prec1.data.item(), n)
top5.update(prec5.data.item(), n)
if step % args.report_freq == 0:
logging.info('Valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, top5.avg, objs.avg
if __name__ == '__main__':
main()