From 9395632036906ce0a4c1d8c3231f6abaa47ab0fd Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Mon, 11 Mar 2024 02:22:05 +0000 Subject: [PATCH] Update the condition to check output node's consumers. Signed-off-by: Jay Zhang --- tf2onnx/optimizer/transpose_optimizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index cc16e5f05..73f6adc85 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -509,14 +509,15 @@ def _add_handler(self, trans, node): return True return self._handle_node_having_branches(trans, node) - def _output_node_has_single_consumer_node(self, node): + def _output_has_no_multiple_consumers(self, node): output_node = self._g.get_node_by_name(node.output[0]) - return output_node and output_node.output and self._nodes_has_single_consumer_node([output_node]) + return True if output_node is None \ + else (output_node.output and self._nodes_has_single_consumer_node([output_node])) def _transpose_handler(self, trans, node): perm = trans.get_attr_value("perm") perm_inv = invert_perm(perm) - if is_tranpose_of_type(node, perm_inv) and self._output_node_has_single_consumer_node(node): + if is_tranpose_of_type(node, perm_inv) and self._output_has_no_multiple_consumers(node): for g in {self._g, node.graph}: g.replace_all_inputs(node.output[0], trans.input[0]) # ops=g.get_nodes()