From a6122642788f82c58389d7786b6bedaf5684524c Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Wed, 15 Jan 2025 15:56:02 +0200 Subject: [PATCH] Add tests --- CHANGELOG.rst | 4 +++ src/spox/_future.py | 40 ++++++++++++++++++---------- src/spox/_value_prop_backend.py | 1 + tests/test_value_propagation.py | 47 ++++++++++++++++++++++++++++++++- 4 files changed, 77 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 63b0595..1365438 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,10 @@ Change log - Adds missing shape inference logic for :func:`spox.opsets.ai.v19.loop` and :func:`spox.opsets.ai.v21.loop`. +**Deprecation** + +- Using :func:`spox._future.set_value_prop_backend`` now triggers a ``DeprecationWarning``. The function will be removed in a future release. + **Other changes** - Propagated values may now be garbage collected if their associated `Var` object goes out of scope. diff --git a/src/spox/_future.py b/src/spox/_future.py index b5486fb..8ebc5d7 100644 --- a/src/spox/_future.py +++ b/src/spox/_future.py @@ -18,6 +18,8 @@ import spox._value_prop from spox._graph import initializer as _initializer from spox._type_system import Tensor +from spox._value_prop import ValuePropBackend +from spox._value_prop_backend import BaseValuePropBackend from spox._var import Var TypeWarningLevel = spox._node.TypeWarningLevel @@ -35,35 +37,44 @@ def type_warning_level(level: TypeWarningLevel) -> Iterator[None]: set_type_warning_level(prev_level) -ValuePropBackend = spox._value_prop.ValuePropBackend - - -def set_value_prop_backend(backend: ValuePropBackend) -> None: - warnings.warn( - "using '_future.set_value_prop_backend' with a '_future.ValuePropBackend' is deprecated and will be removed in the future", - DeprecationWarning, - ) - new_backend: spox._value_prop_backend.BaseValuePropBackend | None = None +def set_value_prop_backend( + backend: ValuePropBackend | BaseValuePropBackend | None, +) -> None: + if not isinstance(backend, BaseValuePropBackend): + warnings.warn( + "using '_future.set_value_prop_backend' with a '_future.ValuePropBackend' is deprecated and will be removed in the future", + DeprecationWarning, + ) + new_backend: BaseValuePropBackend | None = None if backend == spox._value_prop.ValuePropBackend.REFERENCE: new_backend = spox._value_prop_backend.ReferenceValuePropBackend() elif backend == spox._value_prop.ValuePropBackend.ONNXRUNTIME: new_backend = spox._value_prop_backend.OnnxruntimeValuePropBackend() + elif isinstance(backend, BaseValuePropBackend): + new_backend = backend spox._value_prop_backend.set_value_prop_backend(new_backend) @contextmanager -def value_prop_backend(backend: ValuePropBackend) -> Iterator[None]: - warnings.warn( - "using '_future.value_prop_backend' with a 'spox._future.ValuePropBackend' is deprecated and will be removed in the future", - DeprecationWarning, - ) +def value_prop_backend( + backend: ValuePropBackend | BaseValuePropBackend | None, +) -> Iterator[None]: + if not isinstance(backend, BaseValuePropBackend): + warnings.warn( + "using '_future.value_prop_backend' with a 'spox._future.ValuePropBackend' is deprecated and will be removed in the future", + DeprecationWarning, + ) + new_backend: spox._value_prop_backend.BaseValuePropBackend | None = None if backend == spox._value_prop.ValuePropBackend.REFERENCE: new_backend = spox._value_prop_backend.ReferenceValuePropBackend() elif backend == spox._value_prop.ValuePropBackend.ONNXRUNTIME: new_backend = spox._value_prop_backend.OnnxruntimeValuePropBackend() + elif isinstance(backend, BaseValuePropBackend): + new_backend = backend + with spox._value_prop_backend.value_prop_backend(new_backend): yield @@ -255,6 +266,7 @@ def __getattr__(name: str) -> Any: "initializer", # Value propagation backend "ValuePropBackend", + "BaseValuePropBackend", "set_value_prop_backend", "value_prop_backend", ] diff --git a/src/spox/_value_prop_backend.py b/src/spox/_value_prop_backend.py index 2cfd617..276c8e7 100644 --- a/src/spox/_value_prop_backend.py +++ b/src/spox/_value_prop_backend.py @@ -111,6 +111,7 @@ def get_value_prop_backend() -> BaseValuePropBackend | None: def set_value_prop_backend(backend: BaseValuePropBackend | None) -> None: + global _VALUE_PROP_BACKEND _VALUE_PROP_BACKEND = backend diff --git a/tests/test_value_propagation.py b/tests/test_value_propagation.py index e33cb58..2a20314 100644 --- a/tests/test_value_propagation.py +++ b/tests/test_value_propagation.py @@ -1,6 +1,7 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause + import numpy as np import onnx import pytest @@ -9,7 +10,7 @@ import spox._future import spox.opset.ai.onnx.ml.v3 as ml import spox.opset.ai.onnx.v20 as op -from spox import Var, _type_system +from spox import Type, Var, _type_system from spox._graph import arguments, results from spox._shape import Shape from spox._value_prop import ORTValue, PropValue @@ -233,3 +234,47 @@ def test_strings(value_prop_backend): x, y = op.const(["foo"]), op.const(["bar"]) np.testing.assert_equal(op.string_concat(x, y)._value.value, np.array(["foobar"])) # type: ignore + + +def test_value_prop_backend_class(): + class CustomValuePropBackend(spox._future.BaseValuePropBackend[ORTValue]): + def __init__(self) -> None: + import onnxruntime + + self.session_options = onnxruntime.SessionOptions() + self.session_options.log_severity_level = 3 + self.session_counter = 0 + + def wrap_feed(self, value: PropValue) -> ORTValue: + return value.to_ort_value() + + def run( + self, model: onnx.ModelProto, input_feed: dict[str, ORTValue] + ) -> dict[str, ORTValue]: + import onnxruntime + + self.session_counter += 1 + session = onnxruntime.InferenceSession( + model.SerializeToString(), self.session_options + ) + output_names = [output.name for output in session.get_outputs()] + output_feed = dict(zip(output_names, session.run(None, input_feed))) + return output_feed + + def unwrap_feed(self, typ: Type, value: ORTValue) -> PropValue: + return PropValue.from_ref_value(typ, value) + + custom_backend = CustomValuePropBackend() + spox._future.set_value_prop_backend(custom_backend) + + MAX_ITER = 10 + for iter in range(1, MAX_ITER + 1): + x = op.add(op.const(1), op.const(1)) + assert x._value.value == 2 # type: ignore + assert custom_backend.session_counter == iter + + other_custom_backend = CustomValuePropBackend() + with spox._future.value_prop_backend(other_custom_backend): + x = op.add(op.const(1), op.const(1)) + assert x._value.value == 2 # type: ignore + assert other_custom_backend.session_counter == 1