From 4c5ac0523e46e0b828ad241a4478eab76a063225 Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Thu, 5 Oct 2023 17:53:12 +0200 Subject: [PATCH] Forego version adaption of inlined models if no nodes are from the default domain (#105) --- src/spox/_adapt.py | 4 +++ tests/test_adapt.py | 84 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index 5b20885c..ce65c2a9 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -73,6 +73,10 @@ def adapt_inline( source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")}) target_version = target_opsets[""] + # convert_version fails if the inlined model does not import the default domain + seen_domains = {prot.domain for prot in protos} + if not seen_domains & {"", "ai.onnx"}: + return protos if source_version != target_version: target_model = onnx.version_converter.convert_version( node.model, target_version diff --git a/tests/test_adapt.py b/tests/test_adapt.py index edd2b10c..a91f8c06 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -149,3 +149,87 @@ def test_adapt_node_with_repeating_input_names(): c = op19.identity(a) build({"a": a}, {"b": b, "c": c}) + + +def test_inline_model_custom_node_only(): + """Inline a model which only consists of a custom node. + + Such models do not import from the default domain. + """ + domain = "foo.ai" + node = onnx.helper.make_node("FooOp", ["a"], ["b"], domain=domain) + value_infos_input = [ + onnx.helper.make_value_info( + "a", onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ("N",)) + ), + ] + value_infos_output = [ + onnx.helper.make_value_info( + "b", onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ("N",)) + ) + ] + + model = onnx.helper.make_model( + onnx.helper.make_graph( + [node], + "graph", + value_infos_input, + value_infos_output, + ), + opset_imports=[onnx.helper.make_opsetid(domain, 1)], + ) + + # Ensure that our model is valid + onnx.checker.check_model(model, full_check=True) + + (a,) = arguments(data=Tensor(numpy.str_, ("N",))) + (b,) = inline(model)(a).values() + + # Add another node to the model to trigger the adaption logic + c = op18.identity(b) + build({"a": a}, {"c": c}) + + +@pytest.mark.skip( + reason="Adapting custom nodes (including their subgraphs) is currently not supported" +) +def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto): + """A singleton custom node with a old standard node in its attribute.""" + domain = "foo.ai" + + node = onnx.helper.make_node( + "FooOp", ["a"], ["b"], domain=domain, **{"nested_graph": old_squeeze.graph} + ) + value_infos_input = [ + onnx.helper.make_value_info( + "a", onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, ("N",)) + ), + ] + value_infos_output = [ + onnx.helper.make_value_info( + "b", onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, ("N",)) + ) + ] + + model = onnx.helper.make_model( + onnx.helper.make_graph( + [node], + "graph", + value_infos_input, + value_infos_output, + ), + opset_imports=[ + onnx.helper.make_opsetid(domain, 1), + onnx.helper.make_opsetid("", 12), + ], + ) + + # Ensure that our model is valid + onnx.checker.check_model(model, full_check=True) + + (a,) = arguments(data=Tensor(numpy.float32, ("N",))) + (b,) = inline(model)(a).values() + + # Add another node to the model to trigger the adaption logic + c = op18.identity(b) + build({"a": a}, {"c": c})