diff --git a/mmdnn/conversion/pytorch/pytorch_emitter.py b/mmdnn/conversion/pytorch/pytorch_emitter.py index 86413c12..bed71607 100644 --- a/mmdnn/conversion/pytorch/pytorch_emitter.py +++ b/mmdnn/conversion/pytorch/pytorch_emitter.py @@ -254,7 +254,7 @@ def emit_Pool(self, IR_node): ceil_mode = self.is_ceil_mode(IR_node.get_attr('pads')) # input_node = self._defuse_padding(IR_node, exstr) - code = "{:<15} = F.{}({}, kernel_size={}, stride={}, padding={}, ceil_mode={})".format( + code = "{:<15} = F.{}({}, kernel_size={}, stride={}, padding={}, ceil_mode={}, count_include_pad=False)".format( IR_node.variable_name, pool_name, self.parent_variable_name(IR_node), diff --git a/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py b/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py index d89227da..8af486ee 100644 --- a/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py +++ b/mmdnn/conversion/tensorflow/tensorflow_frozenparser.py @@ -8,6 +8,9 @@ from mmdnn.conversion.common.utils import * from mmdnn.conversion.common.DataStructure.parser import Parser from distutils.version import LooseVersion +import tempfile +import os +import shutil class TensorflowParser2(Parser): @@ -120,11 +123,14 @@ def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes): output_node_names = dest_nodes, placeholder_type_enum = dtypes.float32.as_datatype_enum) # Save it to an output file - frozen_model_file = './frozen.pb' + tempdir = tempfile.mkdtemp() + frozen_model_file = os.path.join(tempdir, 'frozen.pb') with gfile.GFile(frozen_model_file, "wb") as f: f.write(original_gdef.SerializeToString()) with open(frozen_model_file, 'rb') as f: serialized = f.read() + shutil.rmtree(tempdir) + tensorflow.reset_default_graph() model = tensorflow.GraphDef() model.ParseFromString(serialized) @@ -149,14 +155,15 @@ def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes): x = tensorflow.placeholder(dtype) input_map[in_nodes[i] + ':0'] = x - + tensorflow.import_graph_def(model, name='', input_map=input_map) with tensorflow.Session(graph = g) as sess: - meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta') + tempdir = tempfile.mkdtemp() + meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta')) model = meta_graph_def.graph_def - + shutil.rmtree((tempdir)) self.tf_graph = TensorflowGraph(model) self.tf_graph.build() @@ -346,7 +353,7 @@ def _add_constant_node(self, source_node): parent_node = self.tf_graph.get_node(s) if parent_node.type == 'Const': self._rename_Const(parent_node) - + def _rename_Const(self, source_node): IR_node = self._convert_identity_operation(source_node, end_idx=0, new_op='Constant') # Constant value = source_node.get_attr('value') @@ -369,13 +376,13 @@ def gen_IR(self): continue node_type = current_node.type - + if hasattr(self, "rename_" + node_type): - + func = getattr(self, "rename_" + node_type) func(current_node) else: - + self.rename_UNKNOWN(current_node) @@ -800,9 +807,9 @@ def rename_Gather(self, source_node): assign_IRnode_values(IR_node, kwargs) return IR_node - + def rename_GatherV2(self, source_node): - + IR_node = self.rename_Gather(source_node) kwargs = {} @@ -1013,7 +1020,7 @@ def rename_Rank(self, source_node): def rename_Transpose(self, source_node): IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'Transpose') - + input_node_perm = self.get_parent(source_node.name, [1]) # input_node_perm = self.check_const(self.get_parent(source_node.name, [1], True)) tensor_content = input_node_perm.get_attr('value') @@ -1142,6 +1149,6 @@ def rename_Tanh(self, source_node): kwargs['shape'] = self.tensor_shape_to_list(input_node.get_attr('_output_shapes'))[0] assign_IRnode_values(IR_node, kwargs) - + def rename_Log(self, source_node): IR_node = self._convert_identity_operation(source_node, new_op = 'Log') \ No newline at end of file diff --git a/mmdnn/conversion/tensorflow/tensorflow_parser.py b/mmdnn/conversion/tensorflow/tensorflow_parser.py index 658b4e89..4db7a4ad 100644 --- a/mmdnn/conversion/tensorflow/tensorflow_parser.py +++ b/mmdnn/conversion/tensorflow/tensorflow_parser.py @@ -14,6 +14,9 @@ from mmdnn.conversion.common.DataStructure.parser import Parser from tensorflow.tools.graph_transforms import TransformGraph from mmdnn.conversion.rewriter.utils import * +import tempfile +import os +import shutil class TensorflowParser(Parser): @@ -308,9 +311,10 @@ def __init__(self, meta_file, checkpoint_file, dest_nodes, inputShape = None, in tensorflow.import_graph_def(transformed_graph_def, name='', input_map=input_map) with tensorflow.Session(graph = g) as sess: - - meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta') + tempdir = tempfile.mkdtemp() + meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta')) model = meta_graph_def.graph_def + shutil.rmtree(tempdir) self.tf_graph = TensorflowGraph(model) self.tf_graph.build() diff --git a/tests/conversion_imagenet.py b/tests/conversion_imagenet.py index 43e68f1c..5a71d0bc 100644 --- a/tests/conversion_imagenet.py +++ b/tests/conversion_imagenet.py @@ -691,12 +691,6 @@ def onnx_emit(original_framework, architecture_name, architecture_path, weight_p predict = tf_rep.run(input_data)[0] - return predict - - except ImportError: - print('Please install Onnx! Or Onnx is not supported in your platform.', file=sys.stderr) - - finally: del prepare del model_converted del tf_rep @@ -705,6 +699,10 @@ def onnx_emit(original_framework, architecture_name, architecture_path, weight_p os.remove(converted_file + '.py') os.remove(converted_file + '.npy') + return predict + + except ImportError: + print('Please install Onnx! Or Onnx is not supported in your platform.', file=sys.stderr) # In case of odd number add the extra padding at the end for SAME_UPPER(eg. pads:[0, 2, 2, 0, 0, 3, 3, 0]) and at the beginning for SAME_LOWER(eg. pads:[0, 3, 3, 0, 0, 2, 2, 0])