forked from qraleq/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
exporter_lib_v2.py
182 lines (145 loc) · 6.49 KB
/
exporter_lib_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to export object detection inference graph."""
import os
import tensorflow.compat.v2 as tf
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
from object_detection.utils import config_util
def _decode_image(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
image_tensor.set_shape((None, None, 3))
return image_tensor
def _decode_tf_example(tf_example_string_tensor):
tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
tf_example_string_tensor)
image_tensor = tensor_dict[fields.InputDataFields.image]
return image_tensor
class DetectionInferenceModule(tf.Module):
"""Detection Inference Module."""
def __init__(self, detection_model):
"""Initializes a module for detection.
Args:
detection_model: The detection model to use for inference.
"""
self._model = detection_model
def _run_inference_on_images(self, image):
"""Cast image to float and run inference.
Args:
image: uint8 Tensor of shape [1, None, None, 3]
Returns:
Tensor dictionary holding detections.
"""
label_id_offset = 1
image = tf.cast(image, tf.float32)
image, shapes = self._model.preprocess(image)
prediction_dict = self._model.predict(image, shapes)
detections = self._model.postprocess(prediction_dict, shapes)
classes_field = fields.DetectionResultFields.detection_classes
detections[classes_field] = (
tf.cast(detections[classes_field], tf.float32) + label_id_offset)
for key, val in detections.items():
detections[key] = tf.cast(val, tf.float32)
return detections
class DetectionFromImageModule(DetectionInferenceModule):
"""Detection Inference Module for image inputs."""
@tf.function(
input_signature=[
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)])
def __call__(self, input_tensor):
return self._run_inference_on_images(input_tensor)
class DetectionFromFloatImageModule(DetectionInferenceModule):
"""Detection Inference Module for float image inputs."""
@tf.function(
input_signature=[
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.float32)])
def __call__(self, input_tensor):
return self._run_inference_on_images(input_tensor)
class DetectionFromEncodedImageModule(DetectionInferenceModule):
"""Detection Inference Module for encoded image string inputs."""
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)])
def __call__(self, input_tensor):
with tf.device('cpu:0'):
image = tf.map_fn(
_decode_image,
elems=input_tensor,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False)
return self._run_inference_on_images(image)
class DetectionFromTFExampleModule(DetectionInferenceModule):
"""Detection Inference Module for TF.Example inputs."""
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)])
def __call__(self, input_tensor):
with tf.device('cpu:0'):
image = tf.map_fn(
_decode_tf_example,
elems=input_tensor,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False)
return self._run_inference_on_images(image)
DETECTION_MODULE_MAP = {
'image_tensor': DetectionFromImageModule,
'encoded_image_string_tensor':
DetectionFromEncodedImageModule,
'tf_example': DetectionFromTFExampleModule,
'float_image_tensor': DetectionFromFloatImageModule
}
def export_inference_graph(input_type,
pipeline_config,
trained_checkpoint_dir,
output_directory):
"""Exports inference graph for the model specified in the pipeline config.
This function creates `output_directory` if it does not already exist,
which will hold a copy of the pipeline config with filename `pipeline.config`,
and two subdirectories named `checkpoint` and `saved_model`
(containing the exported checkpoint and SavedModel respectively).
Args:
input_type: Type of input for the graph. Can be one of ['image_tensor',
'encoded_image_string_tensor', 'tf_example'].
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
Raises:
ValueError: if input_type is invalid.
"""
output_checkpoint_directory = os.path.join(output_directory, 'checkpoint')
output_saved_model_directory = os.path.join(output_directory, 'saved_model')
detection_model = model_builder.build(pipeline_config.model,
is_training=False)
ckpt = tf.train.Checkpoint(
model=detection_model)
manager = tf.train.CheckpointManager(
ckpt, trained_checkpoint_dir, max_to_keep=1)
status = ckpt.restore(manager.latest_checkpoint).expect_partial()
if input_type not in DETECTION_MODULE_MAP:
raise ValueError('Unrecognized `input_type`')
detection_module = DETECTION_MODULE_MAP[input_type](detection_model)
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# saved model.
concrete_function = detection_module.__call__.get_concrete_function()
status.assert_existing_objects_matched()
exported_checkpoint_manager = tf.train.CheckpointManager(
ckpt, output_checkpoint_directory, max_to_keep=1)
exported_checkpoint_manager.save(checkpoint_number=0)
tf.saved_model.save(detection_module,
output_saved_model_directory,
signatures=concrete_function)
config_util.save_pipeline_config(pipeline_config, output_directory)