-
Notifications
You must be signed in to change notification settings - Fork 24
/
eval_utils.py
117 lines (97 loc) · 4.44 KB
/
eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import json
import os
import misc.utils as utils
def language_eval(preds, model_id, split):
import sys
sys.path.append("coco-caption")
annFile = 'coco-caption/annotations/captions_val2014.json'
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
if not os.path.isdir('eval_results'):
os.mkdir('eval_results')
cache_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
coco = COCO(annFile)
valids = coco.getImgIds()
# filter results to only those in MSCOCO validation set (will be about a third)
preds_filt = [p for p in preds if p['image_id'] in valids]
print('using %d/%d predictions' % (len(preds_filt), len(preds)))
json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
cocoRes = coco.loadRes(cache_path)
cocoEval = COCOEvalCap(coco, cocoRes)
cocoEval.params['image_id'] = cocoRes.getImgIds()
cocoEval.evaluate()
# create output dictionary
out = {}
for metric, score in cocoEval.eval.items():
out[metric] = score
imgToEval = cocoEval.imgToEval
for p in preds_filt:
image_id, caption = p['image_id'], p['caption']
imgToEval[image_id]['caption'] = caption
with open(cache_path, 'w') as outfile:
json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
return out
def eval_split(model, loader, eval_kwargs={}):
verbose = eval_kwargs.get('verbose', True)
verbose_beam = eval_kwargs.get('verbose_beam', 1)
num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
split = eval_kwargs.get('split', 'val')
lang_eval = eval_kwargs.get('language_eval', 0)
beam_size = eval_kwargs.get('beam_size', 1)
# Make sure in the evaluation mode
model.eval()
loader.reset_iterator(split)
n = 0
predictions = []
while True:
data = loader.get_batch(split)
n = n + loader.batch_size
# forward the model to get generated samples for each image
tmp = [data['fc_feats'], data['att_feats'], data['att_masks']]
tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
fc_feats, att_feats, att_masks = tmp
sg_data = {key: data['sg_data'][key] if data['sg_data'][key] is None else torch.from_numpy(data['sg_data'][key]).cuda() for key in data['sg_data']}
# forward the model to also get generated samples for each image
with torch.no_grad():
seq = model(sg_data, fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data
# Print beam search
if beam_size > 1 and verbose_beam:
for i in range(loader.batch_size):
print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
print('--' * 10)
sents = utils.decode_sequence(loader.get_vocab(), seq)
for k, sent in enumerate(sents):
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
if eval_kwargs.get('dump_path', 0) == 1:
entry['file_name'] = data['infos'][k]['file_path']
predictions.append(entry)
if eval_kwargs.get('dump_images', 0) == 1:
# dump the raw image to vis/ folder
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
print(cmd)
os.system(cmd)
if verbose and k % 20 == 0:
print('image %s: %s' %(entry['image_id'], entry['caption']))
# if we wrapped around the split or used up val imgs budget then bail
ix0 = data['bounds']['it_pos_now']
ix1 = data['bounds']['it_max']
if num_images != -1:
ix1 = min(ix1, num_images)
for i in range(n - ix1):
predictions.pop()
if verbose:
print('evaluating validation preformance... %d/%d' %(len(predictions), ix1))
if data['bounds']['wrapped']:
break
if num_images >= 0 and n >= num_images:
break
lang_stats = None
if lang_eval == 1:
lang_stats = language_eval(predictions, eval_kwargs['id'], split)
# Switch back to training mode
model.train()
return predictions, lang_stats