diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 13609704ccb7..9934d4f13269 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4510,14 +4510,19 @@ def _impl_v1(cls, inputs, attr, params): # Add constants from both branches to parent graph. graph_scope._params.update(then_graph._params) graph_scope._nodes.update(then_graph._nodes) + graph_scope._params.update(else_graph._params) + graph_scope._nodes.update(else_graph._nodes) + then_free_vars = analysis.free_vars(then_expr) for var in then_free_vars: graph_scope._nodes.update({var.name_hint: var}) - graph_scope._params.update(else_graph._params) - graph_scope._nodes.update(else_graph._nodes) + if var.name_hint in graph_scope._inputs: + graph_scope._inputs.update({var.name_hint: var}) else_free_vars = analysis.free_vars(else_expr) for var in else_free_vars: graph_scope._nodes.update({var.name_hint: var}) + if var.name_hint in graph_scope._inputs: + graph_scope._inputs.update({var.name_hint: var}) # Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks. # Often these dont contribute anything. If we see a dynamic rank output, try to unify diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 216732343028..b9f2d14b7888 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5147,6 +5147,101 @@ def append_constant_nodes(nodes, outputs, expected, name): verify_if(cond_array=True, num_outputs=2) +@tvm.testing.parametrize_targets +def test_graph_input_use_in_if(target, dev): + """test_graph_input_use_in_if""" + + def verify_if(num_nested, cond): + # return "graph input" if cond is True, else return constant(-1). + + input_tensor = helper.make_tensor_value_info("graph_input", TensorProto.FLOAT, [1]) + output_tensor = helper.make_tensor_value_info("graph_output", TensorProto.FLOAT, [1]) + constant_node = make_constant_node("const_val", TensorProto.FLOAT, [1], [-1]) + cond_tensor = helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]) + inner_if_node = None + for i in range(num_nested): + identity_node = helper.make_node( + "Identity", + inputs=["const_val"], + outputs=[f"const{i}"], + name=f"depth{i}'th else identity", + ) + else_branch = helper.make_graph( + [identity_node], + f"else{i}_body", + inputs=[], + outputs=[helper.make_tensor_value_info(f"const{i}", TensorProto.FLOAT, [1])], + ) + out_name = f"if_output{i}" if i != (num_nested - 1) else "graph_output" + + if i == 0: + identity_node = helper.make_node( + "Identity", + inputs=["graph_input"], + outputs=[f"input_identity{i}"], + name=f"depth{i}'th then identity", + ) + then_branch = helper.make_graph( + [identity_node], + f"then{i}_body", + inputs=[], + outputs=[ + helper.make_tensor_value_info(f"input_identity{i}", TensorProto.FLOAT, [1]) + ], + ) + if_node = helper.make_node( + "If", + inputs=["cond"], + outputs=[out_name], + then_branch=then_branch, + else_branch=else_branch, + name=f"depth{i}'s If node", + ) + inner_if_node = if_node + else: + then_branch = helper.make_graph( + [inner_if_node], + f"then{i}_body", + inputs=[], + outputs=[ + helper.make_tensor_value_info(f"if_output{i-1}", TensorProto.FLOAT, [1]) + ], + ) + if_node = helper.make_node( + "If", + inputs=["cond"], + outputs=[out_name], + then_branch=then_branch, + else_branch=else_branch, + name=f"depth{i}'s If node", + ) + inner_if_node = if_node + graph_nodes = [constant_node, inner_if_node] + graph = helper.make_graph( + graph_nodes, + "input_use_in_if_test", + inputs=[input_tensor, cond_tensor], + outputs=[output_tensor], + ) + model = helper.make_model(graph, producer_name="input_use_in_if_test") + + verify_with_ort_with_inputs( + model, + [np.array([3.0], dtype="float32"), np.array([cond])], + dtype="float32", + use_vm=True, + opset=14, + target=target, + dev=dev, + ) + + # Confirm that if works with cond as an array or scalar. + verify_if(num_nested=1, cond=True) + verify_if(num_nested=1, cond=False) + verify_if(num_nested=2, cond=True) + verify_if(num_nested=2, cond=False) + + @tvm.testing.parametrize_targets def test_size(target, dev): """test_size"""