Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Jan 15, 2025
1 parent b11d761 commit a612264
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Change log

- Adds missing shape inference logic for :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop`.

**Deprecation**

- Using :func:`spox._future.set_value_prop_backend`` now triggers a ``DeprecationWarning``. The function will be removed in a future release.

**Other changes**

- Propagated values may now be garbage collected if their associated `Var` object goes out of scope.
Expand Down
40 changes: 26 additions & 14 deletions src/spox/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import spox._value_prop
from spox._graph import initializer as _initializer
from spox._type_system import Tensor
from spox._value_prop import ValuePropBackend
from spox._value_prop_backend import BaseValuePropBackend
from spox._var import Var

TypeWarningLevel = spox._node.TypeWarningLevel
Expand All @@ -35,35 +37,44 @@ def type_warning_level(level: TypeWarningLevel) -> Iterator[None]:
set_type_warning_level(prev_level)


ValuePropBackend = spox._value_prop.ValuePropBackend


def set_value_prop_backend(backend: ValuePropBackend) -> None:
warnings.warn(
"using '_future.set_value_prop_backend' with a '_future.ValuePropBackend' is deprecated and will be removed in the future",
DeprecationWarning,
)
new_backend: spox._value_prop_backend.BaseValuePropBackend | None = None
def set_value_prop_backend(
backend: ValuePropBackend | BaseValuePropBackend | None,
) -> None:
if not isinstance(backend, BaseValuePropBackend):
warnings.warn(
"using '_future.set_value_prop_backend' with a '_future.ValuePropBackend' is deprecated and will be removed in the future",
DeprecationWarning,
)

new_backend: BaseValuePropBackend | None = None
if backend == spox._value_prop.ValuePropBackend.REFERENCE:
new_backend = spox._value_prop_backend.ReferenceValuePropBackend()
elif backend == spox._value_prop.ValuePropBackend.ONNXRUNTIME:
new_backend = spox._value_prop_backend.OnnxruntimeValuePropBackend()
elif isinstance(backend, BaseValuePropBackend):
new_backend = backend

spox._value_prop_backend.set_value_prop_backend(new_backend)


@contextmanager
def value_prop_backend(backend: ValuePropBackend) -> Iterator[None]:
warnings.warn(
"using '_future.value_prop_backend' with a 'spox._future.ValuePropBackend' is deprecated and will be removed in the future",
DeprecationWarning,
)
def value_prop_backend(
backend: ValuePropBackend | BaseValuePropBackend | None,
) -> Iterator[None]:
if not isinstance(backend, BaseValuePropBackend):
warnings.warn(
"using '_future.value_prop_backend' with a 'spox._future.ValuePropBackend' is deprecated and will be removed in the future",
DeprecationWarning,
)

new_backend: spox._value_prop_backend.BaseValuePropBackend | None = None
if backend == spox._value_prop.ValuePropBackend.REFERENCE:
new_backend = spox._value_prop_backend.ReferenceValuePropBackend()
elif backend == spox._value_prop.ValuePropBackend.ONNXRUNTIME:
new_backend = spox._value_prop_backend.OnnxruntimeValuePropBackend()
elif isinstance(backend, BaseValuePropBackend):
new_backend = backend

with spox._value_prop_backend.value_prop_backend(new_backend):
yield

Expand Down Expand Up @@ -255,6 +266,7 @@ def __getattr__(name: str) -> Any:
"initializer",
# Value propagation backend
"ValuePropBackend",
"BaseValuePropBackend",
"set_value_prop_backend",
"value_prop_backend",
]
1 change: 1 addition & 0 deletions src/spox/_value_prop_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_value_prop_backend() -> BaseValuePropBackend | None:


def set_value_prop_backend(backend: BaseValuePropBackend | None) -> None:
global _VALUE_PROP_BACKEND
_VALUE_PROP_BACKEND = backend


Expand Down
47 changes: 46 additions & 1 deletion tests/test_value_propagation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause


import numpy as np
import onnx
import pytest
Expand All @@ -9,7 +10,7 @@
import spox._future
import spox.opset.ai.onnx.ml.v3 as ml
import spox.opset.ai.onnx.v20 as op
from spox import Var, _type_system
from spox import Type, Var, _type_system
from spox._graph import arguments, results
from spox._shape import Shape
from spox._value_prop import ORTValue, PropValue
Expand Down Expand Up @@ -233,3 +234,47 @@ def test_strings(value_prop_backend):

x, y = op.const(["foo"]), op.const(["bar"])
np.testing.assert_equal(op.string_concat(x, y)._value.value, np.array(["foobar"])) # type: ignore


def test_value_prop_backend_class():
class CustomValuePropBackend(spox._future.BaseValuePropBackend[ORTValue]):
def __init__(self) -> None:
import onnxruntime

self.session_options = onnxruntime.SessionOptions()
self.session_options.log_severity_level = 3
self.session_counter = 0

def wrap_feed(self, value: PropValue) -> ORTValue:
return value.to_ort_value()

def run(
self, model: onnx.ModelProto, input_feed: dict[str, ORTValue]
) -> dict[str, ORTValue]:
import onnxruntime

self.session_counter += 1
session = onnxruntime.InferenceSession(
model.SerializeToString(), self.session_options
)
output_names = [output.name for output in session.get_outputs()]
output_feed = dict(zip(output_names, session.run(None, input_feed)))
return output_feed

def unwrap_feed(self, typ: Type, value: ORTValue) -> PropValue:
return PropValue.from_ref_value(typ, value)

custom_backend = CustomValuePropBackend()
spox._future.set_value_prop_backend(custom_backend)

MAX_ITER = 10
for iter in range(1, MAX_ITER + 1):
x = op.add(op.const(1), op.const(1))
assert x._value.value == 2 # type: ignore
assert custom_backend.session_counter == iter

other_custom_backend = CustomValuePropBackend()
with spox._future.value_prop_backend(other_custom_backend):
x = op.add(op.const(1), op.const(1))
assert x._value.value == 2 # type: ignore
assert other_custom_backend.session_counter == 1

0 comments on commit a612264

Please sign in to comment.