Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Unify the batch size dim in the model conversion. (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl authored Jul 26, 2019
1 parent be1c504 commit ffba640
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 33 deletions.
2 changes: 1 addition & 1 deletion keras2onnx/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################
from .utils import with_variable
from .utils import with_variable, get_default_batch_size
from .utils import k2o_logger, set_logger_level
from .cvtfunc import cvtfunc
from .intop import Operator
Expand Down
3 changes: 2 additions & 1 deletion keras2onnx/common/interim.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def bind_all_ops(self):
for op_ in oplist:
setattr(self, op_, functools.partial(self.add_node, op_))

def _make_value_info(self, variable):
@staticmethod
def _make_value_info(variable):
value_info = helper.ValueInfoProto()
value_info.name = variable.full_name
value_info.type.CopyFrom(variable.type.to_onnx_type())
Expand Down
5 changes: 5 additions & 0 deletions keras2onnx/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def set_logger_level(lvl):
logger.setLevel(lvl)


@with_variable('batch_size')
def get_default_batch_size():
return 'N'


def get_producer():
"""
Internal helper function to return the producer
Expand Down
5 changes: 2 additions & 3 deletions keras2onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .topology import convert_topology
from .common import with_variable
from .ke2onnx import static_set_ke2onnx_converters
from .parser import parse_graph, DEFAULT_BATCH_SIZE, tsname_to_node
from .parser import parse_graph, tsname_to_node
from .topology import Topology
from .common.utils import set_logger_level
from ._builtin import set_converter, tf2onnx_builtin_conversion
Expand Down Expand Up @@ -94,7 +94,6 @@ def convert_keras(model, name=None, doc_string='', target_opset=None, channel_fi
get_tensorboard_writer().add_graph(sess.graph)
raw_model_container = KerasTfModelContainer(sess.graph, model)
topology = Topology(raw_model_container,
default_batch_size=DEFAULT_BATCH_SIZE,
target_opset=target_opset,
custom_op_dict=custom_op_conversions)
topology.debug_mode = debug_mode
Expand Down Expand Up @@ -165,7 +164,7 @@ def convert_tensorflow(frozen_graph_def,

custom_op_handlers = tf2onnx_builtin_conversion(target_opset)
if custom_op_conversions:
custom_op_handlers += custom_op_conversions
custom_op_handlers.update(custom_op_conversions)
with tf.Session(graph=tf_graph):
g = tf2onnx.tfonnx.process_tf_graph(tf_graph,
continue_on_error=debug_mode,
Expand Down
36 changes: 14 additions & 22 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from six.moves import queue
from collections.abc import Iterable
from .proto import keras
from .common import k2o_logger
from .common import k2o_logger, get_default_batch_size
from .ke2onnx import extract_inbound_nodes, list_input_tensors, list_output_tensors, build_opdict_from_keras
from .common.data_types import Int32TensorType, Int64TensorType, FloatTensorType, DoubleTensorType, BooleanTensorType
from .topology import Topology
Expand All @@ -17,19 +17,13 @@
from .wrapper import tf2onnx_wrap, TFNODES


DEFAULT_BATCH_SIZE = None


def _infer_variable_type(tensor, default_batch_size=DEFAULT_BATCH_SIZE):
def _infer_variable_type(tensor):
if tensor.shape == tf.TensorShape(None):
tensor_shape = []
elif tensor.shape == tf.TensorShape([]):
tensor_shape = []
else:
tensor_shape = [d.value for d in tensor.shape]
# Adjust batch size if needed
if tensor_shape[0] is None:
tensor_shape[0] = default_batch_size

# Determine the tensor's element type
tensor_type = tensor.dtype
Expand Down Expand Up @@ -99,15 +93,6 @@ def _is_relevant_keras_node(model, node):
return False


def _get_tensor_safe(graph, name):
try:
ts = graph.get_tensor_by_name(name)
except KeyError:
ts = None

return ts


def _convert_keras_timedistributed(graph, node_list, layer, model, varset):
"""
This conversion supports timedistributed wrapper partially where the layer itself can be converted by onnx.
Expand Down Expand Up @@ -172,6 +157,12 @@ def _convert_keras_timedistributed(graph, node_list, layer, model, varset):
return operator


def _adjust_input_batch_size(var_type):
if len(var_type.shape) > 0 and var_type.shape[0] is None:
var_type.shape = [get_default_batch_size()] + var_type.shape[1:]
return var_type


def _convert_keras_scope(graph, node_list, layer, model, varset, prefix=None):
operator = varset.declare_local_operator(type(layer), raw_model=layer, op_name=layer.name)
operator.nodelist = node_list
Expand All @@ -193,7 +184,8 @@ def _convert_keras_scope(graph, node_list, layer, model, varset, prefix=None):
for i_ in inputs:
iname = prefix + i_.name
k2o_logger().debug('input : ' + iname)
i0 = varset.get_local_variable_or_declare_one(iname, _infer_variable_type(i_))
var_type = _adjust_input_batch_size(_infer_variable_type(i_))
i0 = varset.get_local_variable_or_declare_one(iname, var_type)
operator.add_input(i0)

if hasattr(layer, 'input_mask') and layer.input_mask is not None:
Expand Down Expand Up @@ -623,9 +615,9 @@ def parse_graph(topo, graph, target_opset, output_names):
"""
Build the node-layer mapper and parse the whole TF graph of Keras Model.
"""
keras_op_table = {}
keras_layer_ts_map = {}
if topo.raw_model.model is not None:
keras_op_table = \
keras_layer_ts_map = \
{tsname_to_node(nm_): x for (nm_, x) in
six.iteritems(build_opdict_from_keras(topo.raw_model.model))}

Expand All @@ -636,7 +628,7 @@ def parse_graph(topo, graph, target_opset, output_names):
for idx_, ts_ in enumerate(topo.raw_model.model.inputs):
op = top_level.declare_local_operator('identity')
input_ts = topo.raw_model.model.inputs[idx_]
var_type = _infer_variable_type(input_ts)
var_type = _adjust_input_batch_size(_infer_variable_type(input_ts))
str_value = input_ts.name
var0 = None
if hasattr(topo.raw_model.model, 'input_names'):
Expand All @@ -654,4 +646,4 @@ def parse_graph(topo, graph, target_opset, output_names):
op.add_output(var1)
topo.raw_model.add_input_name(str_value)

return _parse_graph_scope(graph, keras_op_table, topo, top_level, output_names)
return _parse_graph_scope(graph, keras_layer_ts_map, topo, top_level, output_names)
11 changes: 5 additions & 6 deletions keras2onnx/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class Topology:

def __init__(self, model, default_batch_size, target_opset=None, custom_op_dict=None,
def __init__(self, model, target_opset=None, custom_op_dict=None,
reserved_variable_names=None, reserved_operator_names=None):
"""
Initialize a Topology object, which is an intermediate representation of a computational graph.
Expand All @@ -30,15 +30,14 @@ def __init__(self, model, default_batch_size, target_opset=None, custom_op_dict=
self.scope_names = set()
self.variable_name_set = reserved_variable_names if reserved_variable_names is not None else set()
self.operator_name_set = reserved_operator_names if reserved_operator_names is not None else set()
self.default_batch_size = default_batch_size
self.target_opset = target_opset
self.debug_mode = False
self.custom_op_dict = {} if custom_op_dict is None else custom_op_dict

# This attribute is used in optimizing the graph structure. If root_names is not empty, only the variables
# specified will be treated as the roots (i.e., set is_fed to True in the beginning of a graph evaluation) of
# the graph. Specifying all root variables in this list and leaving it empty are equivalent. This attribute
# directly affects _initialize_graph_status_for_traversing function and indirectly affects _infer_all_shapes and
# directly affects initialize_graph_status_for_traversing function and indirectly affects _infer_all_shapes and
# _prune functions.
self.root_names = list()

Expand Down Expand Up @@ -67,7 +66,7 @@ def topological_operator_iterator(self):
If you want to simply go though all operators without considering their topological structure, please use
another function, unordered_operator_iterator.
"""
self._initialize_graph_status_for_traversing()
self.initialize_graph_status_for_traversing()
while not all(operator.is_evaluated for scope in self.scopes for operator in scope.operators.values()):
is_evaluation_happened = False
for operator in self.unordered_operator_iterator():
Expand Down Expand Up @@ -120,7 +119,7 @@ def _check_structure(self):
# if len(unused_operators) > 0:
# raise RuntimeError('Isolated operators exist: %s' % unused_operators)

def _initialize_graph_status_for_traversing(self):
def initialize_graph_status_for_traversing(self):
"""
Initialize the status of all variables and operators for traversing the underline graph
"""
Expand Down Expand Up @@ -171,7 +170,7 @@ def convert_topology(topology, model_name, doc_string, target_opset, channel_fir
:param channel_first_inputs: A list of channel first input.
:return: a ONNX ModelProto
"""
topology._initialize_graph_status_for_traversing()
topology.initialize_graph_status_for_traversing()

container = OnnxObjectContainer(target_opset)

Expand Down

0 comments on commit ffba640

Please sign in to comment.