-
Notifications
You must be signed in to change notification settings - Fork 8
/
test.py
87 lines (64 loc) · 2.41 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
from tqdm import tqdm
import numpy as np
import cv2
from config import cfg
import torch
from base import Tester
from utils.vis import vis_keypoints
import torch.backends.cudnn as cudnn
from utils.transforms import flip
import time
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='6,7', dest='gpu_ids')
parser.add_argument('--test_set', type=str, default='test', dest='test_set')
parser.add_argument('--test_epoch', type=str, default='0', dest='test_epoch')
args = parser.parse_args()
if not args.gpu_ids:
assert 0, "Please set propoer gpu ids"
if '-' in args.gpu_ids:
gpus = args.gpu_ids.split('-')
gpus[0] = int(gpus[0])
gpus[1] = int(gpus[1]) + 1
args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus))))
return args
def test():
args = parse_args()
cfg.set_args(args.gpu_ids)
cudnn.benchmark = True
if cfg.dataset == 'InterHand2.6M':
assert args.test_set, 'Test set is required. Select one of test/val'
else:
args.test_set = 'test'
tester = Tester(args.test_epoch)
tester._make_batch_generator(args.test_set)
tester._make_model()
preds = {'joint_coord': [], 'inv_trans': [], 'joint_valid': [] }
timer = []
with torch.no_grad():
for itr, (inputs, targets, meta_info) in enumerate(tqdm(tester.batch_generator,ncols=150)):
# forward
start = time.time()
out = tester.model(inputs, targets, meta_info, 'test')
end = time.time()
joint_coord_out = out['joint_coord'].cpu().numpy()
inv_trans = out['inv_trans'].cpu().numpy()
joint_vaild = out['joint_valid'].cpu().numpy()
preds['joint_coord'].append(joint_coord_out)
preds['inv_trans'].append(inv_trans)
preds['joint_valid'].append(joint_vaild)
timer.append(end-start)
# evaluate
preds = {k: np.concatenate(v) for k,v in preds.items()}
mpjpe_dict, hand_accuracy, mrrpe = tester._evaluate(preds)
print(mpjpe_dict)
print('time per batch is',np.mean(timer))
if __name__ == "__main__":
test()