forked from qraleq/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
oid_hierarchical_labels_expansion.py
233 lines (197 loc) · 8.33 KB
/
oid_hierarchical_labels_expansion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""An executable to expand image-level labels, boxes and segments.
The expansion is performed using class hierarchy, provided in JSON file.
The expected file formats are the following:
- for box and segment files: CSV file is expected to have LabelName field
- for image-level labels: CSV file is expected to have LabelName and Confidence
fields
Note, that LabelName is the only field used for expansion.
Example usage:
python models/research/object_detection/dataset_tools/\
oid_hierarchical_labels_expansion.py \
--json_hierarchy_file=<path to JSON hierarchy> \
--input_annotations=<input csv file> \
--output_annotations=<output csv file> \
--annotation_type=<1 (for boxes and segments) or 2 (for image-level labels)>
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
from absl import app
from absl import flags
import six
flags.DEFINE_string(
'json_hierarchy_file', None,
'Path to the file containing label hierarchy in JSON format.')
flags.DEFINE_string(
'input_annotations', None, 'Path to Open Images annotations file'
'(either bounding boxes, segments or image-level labels).')
flags.DEFINE_string('output_annotations', None, 'Path to the output file.')
flags.DEFINE_integer(
'annotation_type', None,
'Type of the input annotations: 1 - boxes or segments,'
'2 - image-level labels.'
)
FLAGS = flags.FLAGS
def _update_dict(initial_dict, update):
"""Updates dictionary with update content.
Args:
initial_dict: initial dictionary.
update: updated dictionary.
"""
for key, value_list in update.items():
if key in initial_dict:
initial_dict[key].update(value_list)
else:
initial_dict[key] = set(value_list)
def _build_plain_hierarchy(hierarchy, skip_root=False):
"""Expands tree hierarchy representation to parent-child dictionary.
Args:
hierarchy: labels hierarchy as JSON file.
skip_root: if true skips root from the processing (done for the case when all
classes under hierarchy are collected under virtual node).
Returns:
keyed_parent - dictionary of parent - all its children nodes.
keyed_child - dictionary of children - all its parent nodes
children - all children of the current node.
"""
all_children = set([])
all_keyed_parent = {}
all_keyed_child = {}
if 'Subcategory' in hierarchy:
for node in hierarchy['Subcategory']:
keyed_parent, keyed_child, children = _build_plain_hierarchy(node)
# Update is not done through dict.update() since some children have multi-
# ple parents in the hiearchy.
_update_dict(all_keyed_parent, keyed_parent)
_update_dict(all_keyed_child, keyed_child)
all_children.update(children)
if not skip_root:
all_keyed_parent[hierarchy['LabelName']] = copy.deepcopy(all_children)
all_children.add(hierarchy['LabelName'])
for child, _ in all_keyed_child.items():
all_keyed_child[child].add(hierarchy['LabelName'])
all_keyed_child[hierarchy['LabelName']] = set([])
return all_keyed_parent, all_keyed_child, all_children
class OIDHierarchicalLabelsExpansion(object):
""" Main class to perform labels hierachical expansion."""
def __init__(self, hierarchy):
"""Constructor.
Args:
hierarchy: labels hierarchy as JSON object.
"""
self._hierarchy_keyed_parent, self._hierarchy_keyed_child, _ = (
_build_plain_hierarchy(hierarchy, skip_root=True))
def expand_boxes_or_segments_from_csv(self, csv_row,
labelname_column_index=1):
"""Expands a row containing bounding boxes/segments from CSV file.
Args:
csv_row: a single row of Open Images released groundtruth file.
labelname_column_index: 0-based index of LabelName column in CSV file.
Returns:
a list of strings (including the initial row) corresponding to the ground
truth expanded to multiple annotation for evaluation with Open Images
Challenge 2018/2019 metrics.
"""
# Row header is expected to be the following for boxes:
# ImageID,LabelName,Confidence,XMin,XMax,YMin,YMax,IsGroupOf
# Row header is expected to be the following for segments:
# ImageID,LabelName,ImageWidth,ImageHeight,XMin,XMax,YMin,YMax,
# IsGroupOf,Mask
split_csv_row = six.ensure_str(csv_row).split(',')
result = [csv_row]
assert split_csv_row[
labelname_column_index] in self._hierarchy_keyed_child
parent_nodes = self._hierarchy_keyed_child[
split_csv_row[labelname_column_index]]
for parent_node in parent_nodes:
split_csv_row[labelname_column_index] = parent_node
result.append(','.join(split_csv_row))
return result
def expand_labels_from_csv(self,
csv_row,
labelname_column_index=1,
confidence_column_index=2):
"""Expands a row containing labels from CSV file.
Args:
csv_row: a single row of Open Images released groundtruth file.
labelname_column_index: 0-based index of LabelName column in CSV file.
confidence_column_index: 0-based index of Confidence column in CSV file.
Returns:
a list of strings (including the initial row) corresponding to the ground
truth expanded to multiple annotation for evaluation with Open Images
Challenge 2018/2019 metrics.
"""
# Row header is expected to be exactly:
# ImageID,Source,LabelName,Confidence
split_csv_row = six.ensure_str(csv_row).split(',')
result = [csv_row]
if int(split_csv_row[confidence_column_index]) == 1:
assert split_csv_row[
labelname_column_index] in self._hierarchy_keyed_child
parent_nodes = self._hierarchy_keyed_child[
split_csv_row[labelname_column_index]]
for parent_node in parent_nodes:
split_csv_row[labelname_column_index] = parent_node
result.append(','.join(split_csv_row))
else:
assert split_csv_row[
labelname_column_index] in self._hierarchy_keyed_parent
child_nodes = self._hierarchy_keyed_parent[
split_csv_row[labelname_column_index]]
for child_node in child_nodes:
split_csv_row[labelname_column_index] = child_node
result.append(','.join(split_csv_row))
return result
def main(unused_args):
del unused_args
with open(FLAGS.json_hierarchy_file) as f:
hierarchy = json.load(f)
expansion_generator = OIDHierarchicalLabelsExpansion(hierarchy)
labels_file = False
if FLAGS.annotation_type == 2:
labels_file = True
elif FLAGS.annotation_type != 1:
print('--annotation_type expected value is 1 or 2.')
return -1
confidence_column_index = -1
labelname_column_index = -1
with open(FLAGS.input_annotations, 'r') as source:
with open(FLAGS.output_annotations, 'w') as target:
header = source.readline()
target.writelines([header])
column_names = header.strip().split(',')
labelname_column_index = column_names.index('LabelName')
if labels_file:
confidence_column_index = column_names.index('Confidence')
for line in source:
if labels_file:
expanded_lines = expansion_generator.expand_labels_from_csv(
line, labelname_column_index, confidence_column_index)
else:
expanded_lines = (
expansion_generator.expand_boxes_or_segments_from_csv(
line, labelname_column_index))
target.writelines(expanded_lines)
if __name__ == '__main__':
flags.mark_flag_as_required('json_hierarchy_file')
flags.mark_flag_as_required('input_annotations')
flags.mark_flag_as_required('output_annotations')
flags.mark_flag_as_required('annotation_type')
app.run(main)