forked from mynlp/cst_captioning
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstandalize_format.py
195 lines (162 loc) · 5.36 KB
/
standalize_format.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Convert MSRVTT format to standard JSON
"""
import os
import json
import argparse
import string
import itertools
import logging
from datetime import datetime
logger = logging.getLogger(__name__)
def standalize_yt2t(input_file):
"""
Use data splits provided by the NAACL2015 paper
Ref:
"""
logger.info('Reading file: %s', input_file)
lines = [line.rstrip('\n') for line in open(input_file)]
lines = [line.split('\t') for line in lines]
logger.info('Building caption dictionary for each video key')
video_ids = []
capdict = {}
for line in lines:
video_id = line[0]
if video_id in capdict:
capdict[video_id].append(line[1])
else:
capdict[video_id] = [line[1]]
video_ids.append(video_id)
# create the json blob
videos = []
captions = []
counter = itertools.count()
for video_id in video_ids:
vid = int(video_id[3:])
jvid = {}
jvid['category'] = 'unknown'
jvid['video_id'] = video_id
jvid['id'] = vid
jvid['start_time'] = -1
jvid['end_time'] = -1
jvid['url'] = ''
videos.append(jvid)
for caption in capdict[video_id]:
jcap = {}
jcap['id'] = next(counter)
jcap['video_id'] = vid
jcap['caption'] = unicode(caption, errors='ignore')
captions.append(jcap)
out = {}
out['info'] = {}
out['videos'] = videos
out['captions'] = captions
return out
def standalize_msrvtt(
input_file,
dataset='msrvtt2016',
split='train',
val2016_json=None):
"""
Supports both msrvtt2016 and msrvtt2017
There is no official train/val set in MSRVTT2017:
-> train2017 = train2016 + test2016
-> val2017 = val2016
"""
info = json.load(open(args.input_file))
if split == 'val':
split = 'validate'
out = {}
out['info'] = info['info']
if args.dataset == 'msrvtt2017' and split == 'train':
# loading all training videos and removing those that are in the
# val2016 set
logger.info('Loading val2016 info: %s', val2016_json)
info2016 = json.load(open(val2016_json))
val2016_videos = [v for v in info2016[
'videos'] if v['split'] == 'validate']
val2016_video_dict = {v['video_id']: v['id'] for v in val2016_videos}
out['videos'] = [v for v in info['videos']
if v['video_id'] not in val2016_video_dict]
else:
out['videos'] = [v for v in info['videos'] if v['split'] == split]
tmp_dict = {v['video_id']: v['id'] for v in out['videos']}
out['captions'] = [{'id': c['sen_id'], 'video_id': tmp_dict[c['video_id']], 'caption': c[
'caption']} for c in info['sentences'] if c['video_id'] in tmp_dict]
return out
def standalize_tvvtt(input_file, split='train2016'):
"""
Standalize TRECVID V2T task
Read from a metadata file generated in the v2t2017 project
Basically there is no split in the v2t dataset,
Just consider each provided set as an independent dataset
"""
split_mapping = {
'train': 'train2016',
'val': 'test2016',
'test': 'test2017'
}
split = split_mapping[split]
logger.info('Loading file: %s, split: %s', input_file, split)
info = json.load(open(input_file))[split]
out = {}
out['info'] = {}
videos = []
for v in info['videos']:
jvid = {}
jvid['category'] = 'unknown'
jvid['video_id'] = str(v)
jvid['id'] = v
jvid['start_time'] = -1
jvid['end_time'] = -1
jvid['url'] = ''
videos.append(jvid)
out['videos'] = videos
out['captions'] = info['captions']
return out
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s:%(levelname)s: %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('input_file', type=str, help='')
parser.add_argument('output_json', type=str, help='')
parser.add_argument('--split', type=str, help='')
parser.add_argument(
'--dataset',
type=str,
default='yt2t',
choices=[
'yt2t',
'msrvtt2016',
'msrvtt2017',
'tvvtt'],
help='Choose dataset')
parser.add_argument(
'--val2016_json',
type=str,
help='use valset from msrvtt2016 contest')
args = parser.parse_args()
logger.info('Input arguments: %s', args)
start = datetime.now()
if args.dataset == 'msrvtt2016':
out = standalize_msrvtt(
args.input_file,
dataset=args.dataset,
split=args.split)
elif args.dataset == 'msrvtt2017':
out = standalize_msrvtt(
args.input_file,
dataset=args.dataset,
split=args.split,
val2016_json=args.val2016_json)
elif args.dataset == 'yt2t':
out = standalize_yt2t(args.input_file)
elif args.dataset == 'tvvtt':
out = standalize_tvvtt(args.input_file, split=args.split)
else:
raise ValueError('Unknow dataset: %s', args.dataset)
if not os.path.exists(os.path.dirname(args.output_json)):
os.makedirs(os.path.dirname(args.output_json))
with open(args.output_json, 'w') as of:
json.dump(out, of)
logger.info('Time: %s', datetime.now() - start)