Skip to content

Commit

Permalink
Do not prefix unset inputs when inlining (#160)
Browse files Browse the repository at this point in the history

Co-authored-by: Christian Bourjau <[email protected]>
  • Loading branch information
ReadyShowShow and cbourjau authored Jun 18, 2024
1 parent 65e91aa commit 42e0aa3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
Change log
==========

0.12.1 (unreleased)
0.12.1 (2024-06-18)
-------------------

**Bug fix**

- Unset optional inputs are no longer erroneously prefixed by :func:`~spox.inline`.


**Other changes**

- The node-naming algorithm now has constant rather than quadratic time complexity.
Expand Down
2 changes: 2 additions & 0 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def to_onnx(
inner_node_renames: Dict[str, str] = {}

def reserve_prefixed(name: str) -> str:
if not name:
return name
return scope.var.reserve(
scope.var.maybe_enum(f"{scope.node[self]}__{name}")
)
Expand Down
22 changes: 19 additions & 3 deletions tests/test_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
import pytest

import spox.opset.ai.onnx.v17 as op
from spox import Tensor, Var, argument, build, inline
from spox._graph import arguments, results
from spox._inline import rename_in_graph
from spox._public import inline
from spox._type_system import Tensor
from spox._utils import from_array
from spox._var import Var


@pytest.fixture
Expand Down Expand Up @@ -342,3 +340,21 @@ def example_rename(n: str) -> str:
_duplicate_subgraphs_to_list(relu_proto.graph), example_rename
)
assert rename_then_duplicate.node == duplicate_then_rename.node


def test_subgraph_with_nodes_with_optional_inputs():
"""Unset optional inputs must not be prefixed by `inline`."""

def inline_model() -> onnx.ModelProto:
a = argument(Tensor(numpy.float64, ("N",)))
return build({"a": a}, {"b": op.clip(a, None, op.const(1.0, numpy.float64))})

foo = argument(Tensor(numpy.float64, ("N",)))
(bar,) = inline(inline_model())(foo).values()

model_proto = build({"foo": foo}, {"bar": bar})

(clip_node,) = (n for n in model_proto.graph.node if n.op_type == "Clip")
assert len(clip_node.input) == 3
assert clip_node.input[1] == ""
assert clip_node.input[2] != ""

0 comments on commit 42e0aa3

Please sign in to comment.