diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a1b198f8..0fe17120 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,9 +7,14 @@ Change log ========== -0.12.1 (unreleased) +0.12.1 (2024-06-18) ------------------- +**Bug fix** + +- Unset optional inputs are no longer erroneously prefixed by :func:`~spox.inline`. + + **Other changes** - The node-naming algorithm now has constant rather than quadratic time complexity. diff --git a/src/spox/_inline.py b/src/spox/_inline.py index 275a33cb..4619701d 100644 --- a/src/spox/_inline.py +++ b/src/spox/_inline.py @@ -154,6 +154,8 @@ def to_onnx( inner_node_renames: Dict[str, str] = {} def reserve_prefixed(name: str) -> str: + if not name: + return name return scope.var.reserve( scope.var.maybe_enum(f"{scope.node[self]}__{name}") ) diff --git a/tests/test_inline.py b/tests/test_inline.py index eb2df6fc..732b4c9d 100644 --- a/tests/test_inline.py +++ b/tests/test_inline.py @@ -6,12 +6,10 @@ import pytest import spox.opset.ai.onnx.v17 as op +from spox import Tensor, Var, argument, build, inline from spox._graph import arguments, results from spox._inline import rename_in_graph -from spox._public import inline -from spox._type_system import Tensor from spox._utils import from_array -from spox._var import Var @pytest.fixture @@ -342,3 +340,21 @@ def example_rename(n: str) -> str: _duplicate_subgraphs_to_list(relu_proto.graph), example_rename ) assert rename_then_duplicate.node == duplicate_then_rename.node + + +def test_subgraph_with_nodes_with_optional_inputs(): + """Unset optional inputs must not be prefixed by `inline`.""" + + def inline_model() -> onnx.ModelProto: + a = argument(Tensor(numpy.float64, ("N",))) + return build({"a": a}, {"b": op.clip(a, None, op.const(1.0, numpy.float64))}) + + foo = argument(Tensor(numpy.float64, ("N",))) + (bar,) = inline(inline_model())(foo).values() + + model_proto = build({"foo": foo}, {"bar": bar}) + + (clip_node,) = (n for n in model_proto.graph.node if n.op_type == "Clip") + assert len(clip_node.input) == 3 + assert clip_node.input[1] == "" + assert clip_node.input[2] != ""