From a6c7cb0904ea4cb248bfd4f433014e0b23c7b009 Mon Sep 17 00:00:00 2001 From: Jakub Bachurski <108397875+JakubBachurskiQC@users.noreply.github.com> Date: Tue, 20 Dec 2022 18:06:59 +0100 Subject: [PATCH] Upgrade to ONNX 1.13 * Bump ONNX versions * Bump ORT version for improved test coverage * Fix generate script importlib usage on 3.11 This usage was seemingly not using defined behaviour. importlib.resources.path is also deprecated from 3.11, but a replacement is available only from 3.9 which we are yet to be on. * Regenerate ai.onnx@17 Small changes to docstrings * Add rtol capability to testing routine * Run more function tests after ORT fixes There are still very odd build errors on more complex cases, due to colliding (pseudo-random?) identifiers. * Use non-deprecated onnx.mapping alternative * Avoid usage of onnx.mapping in utils As it is becoming deprecated. Brings up the issue of object-str for the string dtype in Spox. * Integrate reference implementation of operators * Remove ORT value prop * Minor membership test fix * Create module for propagated value wrapper * Fix type-hints around PropValue abstraction * Apply wrapper more often and fix ref integration * Patch found broken implementations * Add switch for value propagation runtime * Add more consistency checks * Be more strict with PropValue types * Update src/spox/_standard.py Co-authored-by: Christian Bourjau * Comment on value representation cases Co-authored-by: Jakub Bachurski Co-authored-by: Christian Bourjau Co-authored-by: Jakub Bachurski Co-authored-by: Christian Bourjau --- conda.recipe/meta.yaml | 2 +- environment.yml | 4 +- src/generate.py | 5 +- src/spox/__init__.py | 4 + src/spox/_graph.py | 2 +- src/spox/_node.py | 20 ++++- src/spox/_patch_ref_impl.py | 28 ++++++ src/spox/_standard.py | 112 ++++++++++++----------- src/spox/_type_system.py | 2 +- src/spox/_utils.py | 30 +++---- src/spox/_value_prop.py | 151 ++++++++++++++++++++++++++++++++ src/spox/_var.py | 56 ++++-------- src/spox/opset/ai/onnx/ml/v3.py | 1 + src/spox/opset/ai/onnx/v17.py | 55 ++++++++---- src/templates/class.jinja2 | 2 +- src/templates/preamble.jinja2 | 1 + tests/conftest.py | 8 +- tests/test_custom_operator.py | 4 +- tests/test_function.py | 13 +-- tests/test_value_propagation.py | 84 +++++++++--------- 20 files changed, 390 insertions(+), 194 deletions(-) create mode 100644 src/spox/_patch_ref_impl.py create mode 100644 src/spox/_value_prop.py diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index 05fa3ffa..0b212a5b 100644 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -21,7 +21,7 @@ requirements: run: - python >=3.8 - numpy - - onnx + - onnx >=1.13 test: requires: diff --git a/environment.yml b/environment.yml index edc1688c..27d7c52a 100644 --- a/environment.yml +++ b/environment.yml @@ -10,8 +10,8 @@ dependencies: - matplotlib - numpy - numpydoc - - onnx - - onnxruntime>=1.12.0 + - onnx>=1.13.0 + - onnxruntime>=1.13.1 - pip - pre-commit - pytest>=6 diff --git a/src/generate.py b/src/generate.py index b7e61e10..5e13a5e5 100644 --- a/src/generate.py +++ b/src/generate.py @@ -50,7 +50,8 @@ onnx.AttributeProto.TYPE_PROTOS: "AttrTypes", } -_TEMPLATE_DIR = Path(str(importlib.resources.path("spox", "."))).parent / "templates" +with importlib.resources.path("spox", ".") as path: + _TEMPLATE_DIR = path.parent / "templates" @dataclass @@ -153,7 +154,7 @@ def _get_default_value(attr, attr_type_overrides) -> Optional[str]: attr.name in attr_type_overrides and attr_type_overrides[attr.name][0] == "npt.DTypeLike" ): - return f"np.{onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[default].name}" + return f"np.{onnx.helper.tensor_dtype_to_np_dtype(default).name}" # Strings are bytes at this point and they # need to be wrapped in quotes. diff --git a/src/spox/__init__.py b/src/spox/__init__.py index 3442cb60..5a3090d8 100644 --- a/src/spox/__init__.py +++ b/src/spox/__init__.py @@ -10,10 +10,14 @@ except ModuleNotFoundError: pass +from spox import _patch_ref_impl + # Public interface from spox._type_system import Optional, Sequence, Tensor, Type from spox._var import Var +_patch_ref_impl.patch_reference_implementations() + __all__ = [ "Var", "Type", diff --git a/src/spox/_graph.py b/src/spox/_graph.py index f1b8f993..b60f9a0f 100644 --- a/src/spox/_graph.py +++ b/src/spox/_graph.py @@ -389,7 +389,7 @@ def to_onnx_model( "Consider adding an Identity operator if you are just copying arguments." ) - opset_req: list[tuple[str, int]] = list(opsets.items()) # type: ignore + opset_req: List[tuple[str, int]] = list(opsets.items()) # type: ignore function_protos: Dict[Tuple[str, str], onnx.FunctionProto] = {} for fun in self._get_build_result().functions: proto = fun.to_onnx_function(extra_opset_req=opset_req) diff --git a/src/spox/_node.py b/src/spox/_node.py index 1664a506..b1de0df9 100644 --- a/src/spox/_node.py +++ b/src/spox/_node.py @@ -24,6 +24,7 @@ from ._attributes import AttrGraph from ._exceptions import InferenceWarning from ._type_system import Type +from ._value_prop import PropValue, PropValueType from ._var import Var from ._varfields import VarFields @@ -224,7 +225,7 @@ def pre_init(self, **kwargs): def post_init(self, **kwargs): """Post-initialization hook. Called at the end of ``__init__`` after other default fields are set.""" - def propagate_values(self) -> Dict[str, Any]: + def propagate_values(self) -> Dict[str, PropValueType]: """ Propagate values from inputs, and, if possible, compute values for outputs as well. This method is used to implement ONNX partial data propagation - for example so that @@ -308,7 +309,7 @@ def _list_types(self, source): yield name, value_type def _init_output_vars( - self, types: Dict[str, Type], values: Dict[str, Any] + self, types: Dict[str, Type], values: Dict[str, PropValueType] ) -> VarFields: """ Initialize empty output vars bound to this Node and return the respective Fields object. @@ -317,7 +318,20 @@ def _init_output_vars( """ def arr(name): - return Var(self, types.get(name), values.get(name)) + typ: Optional[Type] = types.get(name) + val: Optional[PropValue] + if typ is not None and name in values: + val = PropValue(typ, values.get(name)) + if not val.check(): + warnings.warn( + InferenceWarning( + f"PropValue of {val.value} does not match the expected type {val.type}, dropping." + ) + ) + val = None + else: + val = None + return Var(self, typ, val) var = self.Outputs.get_variadic_name() outputs: Dict[str, Union[Var, Sequence[Var]]] = { diff --git a/src/spox/_patch_ref_impl.py b/src/spox/_patch_ref_impl.py new file mode 100644 index 00000000..dd1f6cd2 --- /dev/null +++ b/src/spox/_patch_ref_impl.py @@ -0,0 +1,28 @@ +import numpy +import onnx +import onnx.reference.op_run +import onnx.reference.ops._op_list +import onnx.reference.ops.op_cast +from onnx.reference.op_run import OpRun + + +class PatchedOptionalHasElement(OpRun): + def _run(self, x): + return (numpy.array(not ((isinstance(x, list) and x == [None]) or x is None)),) + + +class PatchedCast(OpRun): + def _run(self, x, to=None): # type: ignore + if to == onnx.TensorProto.STRING: + return (x.astype(numpy.str_),) + return (onnx.reference.ops.op_cast.cast_to(x, to),) + + +def patch_reference_implementations(): + """Patch known broken reference implementations in ONNX. + + As the reference implementation module in ONNX is quite new, it still has bugs which are a nuisance in Spox. + This function modifies their implementation by catching out the special cases known to be faulty. + """ + onnx.reference.ops._op_list.OptionalHasElement = PatchedOptionalHasElement + onnx.reference.ops._op_list.Cast = PatchedCast diff --git a/src/spox/_standard.py b/src/spox/_standard.py index 509bf509..d3cb40b8 100644 --- a/src/spox/_standard.py +++ b/src/spox/_standard.py @@ -1,13 +1,16 @@ """Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference.""" +import enum import typing -from typing import Any, Dict, Tuple, Union +from typing import Dict, Tuple import numpy import onnx +import onnx.reference import onnx.shape_inference from onnx.defs import OpSchema +from . import _value_prop from ._exceptions import InferenceError from ._node import Node from ._schemas import SCHEMAS @@ -15,17 +18,20 @@ from ._shape import SimpleShape from ._type_system import Optional, Sequence, Tensor, Type from ._utils import from_array -from ._var import Nothing, _nil +from ._value_prop import PropValue, PropValueType +from ._var import _nil if typing.TYPE_CHECKING: from ._graph import Graph -try: - import onnxruntime -except ImportError: - onnxruntime = None # type: ignore -_USE_ONNXRUNTIME_VALUE_PROP = False +class ValuePropBackend(enum.Enum): + NONE = 0 + REFERENCE = 1 + ONNXRUNTIME = 2 + + +_VALUE_PROP_BACKEND: ValuePropBackend = ValuePropBackend.REFERENCE class StandardNode(Node): @@ -53,9 +59,13 @@ def min_output(self) -> int: return self.schema.min_output def to_singleton_onnx_model( - self, *, dummy_outputs: bool = True, with_subgraphs: bool = True + self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True ) -> Tuple[onnx.ModelProto, Scope]: - """Build a singleton model consisting of just this StandardNode. Used for type inference.""" + """ + Build a singleton model consisting of just this StandardNode. Used for type inference. + Dummy subgraphs are typed, but have no graph body, so that we can avoid the build cost. + They refer to non-existent nodes, but ONNX does not raise an error (for now?). + """ # Prepare names for the values in scope of the node scope = Scope() scope.node[self] = "_this_" @@ -81,7 +91,7 @@ def to_singleton_onnx_model( # Subgraphs are not fully built for possibly significant performance gains. # However, this uses a trick so that they type correctly. # This may throw if we are building ``not with_subgraphs``. - build_subgraph = _make_dummy_subgraph if with_subgraphs else None + build_subgraph = _make_dummy_subgraph if with_dummy_subgraphs else None (node_proto,) = self.to_onnx(scope, build_subgraph=build_subgraph) finally: self.attrs = self_attrs @@ -107,9 +117,9 @@ def out_value_info(curr_key, curr_var): # Initializers, passed in to allow partial data propagation # - used so that operators like Reshape are aware of constant shapes initializers = [ - from_array(var._value, key) + from_array(var._value.value, key) for key, var in self.inputs.as_dict().items() - if isinstance(var._value, numpy.ndarray) + if var._value and isinstance(var._value.value, numpy.ndarray) ] # Graph and model graph = onnx.helper.make_graph( @@ -157,45 +167,52 @@ def infer_output_types_onnx(self) -> Dict[str, Type]: # Strips some unuseful type data (unknown dimensions become global-scoped dimension parameters). return {key: _strip_unk_param(type_) for key, type_ in results.items()} - def propagate_values_onnx(self) -> Dict[str, Any]: - """ - Perform value propagation by evaluating singleton models with ONNX Runtime. + def propagate_values_onnx(self) -> Dict[str, PropValueType]: + """Perform value propagation by evaluating singleton model. - Assumes onnxruntime was imported successfully. Does not support subgraphs. + The backend used for the propagation can be configured with the `spox._standard.ValuePropBackend` variable. """ - if any(var and var._value is None for var in self.inputs.as_dict().values()): - # Cannot do propagation when some inputs were not propagated + # Cannot do propagation when some inputs were not propagated/inferred + if any( + var and (var.type is None or var._value is None) + for var in self.inputs.as_dict().values() + ): return {} if next(iter(self.subgraphs), None) is not None: # Cannot do propagation with subgraphs implicitly for performance - should be reimplemented return {} - # Silence possible warnings during execution (especially constant folding) - options = onnxruntime.SessionOptions() - options.log_severity_level = 3 - # Set everything up for evaluation - model, scope = self.to_singleton_onnx_model(with_subgraphs=False) - session = onnxruntime.InferenceSession(model.SerializeToString(), options) + if _VALUE_PROP_BACKEND == ValuePropBackend.REFERENCE: + wrap_feed = PropValue.to_ref_value + run = _value_prop._run_reference_implementation + unwrap_feed = PropValue.from_ref_value + elif _VALUE_PROP_BACKEND == ValuePropBackend.ONNXRUNTIME: + wrap_feed = PropValue.to_ort_value + run = _value_prop._run_onnxruntime + unwrap_feed = PropValue.from_ort_value + else: + raise RuntimeError( + f"Not a valid value propagation backend: {_VALUE_PROP_BACKEND}." + ) + model, scope = self.to_singleton_onnx_model(with_dummy_subgraphs=False) input_feed = { - scope.var[var]: _value_prop_to_ort(var._value) + scope.var[var]: wrap_feed(var._value) for var in self.inputs.as_dict().values() + if var and var._value } - # Get outputs and give a map from output field names - output_feed = dict(zip(session.get_outputs(), session.run(None, input_feed))) - return { - scope.var[output.name]._which_output: _value_prop_from_ort(result) - for output, result in output_feed.items() + output_feed = run(model, input_feed) + results = { + scope.var[str(name)] + ._which_output: unwrap_feed(scope.var[str(name)].unwrap_type(), result) + .value + for name, result in output_feed.items() } + return {k: v for k, v in results.items() if k is not None} def infer_output_types(self) -> Dict[str, Type]: return self.infer_output_types_onnx() - def propagate_values(self) -> Dict[str, Any]: - if _USE_ONNXRUNTIME_VALUE_PROP: - if onnxruntime is None: - raise RuntimeError( - "Cannot use ONNX Runtime value prop when ONNX Runtime isn't available " - "(ImportError was raised)." - ) + def propagate_values(self) -> Dict[str, PropValueType]: + if _VALUE_PROP_BACKEND != ValuePropBackend.NONE: return self.propagate_values_onnx() return {} @@ -250,24 +267,3 @@ def _make_dummy_subgraph(_node: Node, key: str, graph: "Graph") -> onnx.GraphPro nodes.append(onnx.helper.make_node("Identity", [outer], [out])) return onnx.helper.make_graph(nodes, f"__dummy_{key}", inputs, outputs) - - -def _value_prop_to_ort(value) -> Union[numpy.ndarray, list, None]: - if value is Nothing: - return None - return value - - -def _value_prop_from_ort(value: Union[numpy.ndarray, list, None]): - if value is None: - return Nothing - elif isinstance(value, list): - return [_value_prop_from_ort(elem) for elem in value] - elif isinstance(value, numpy.ndarray): - # This looks ridiculous, but is required to normalise numpy.longlong back into a fixed size type. - # ORT sometimes returns non-sized types (like longlong) and Var's value typecheck will fail because of it. - # - numpy.dtype(longlong).type is longlong, but - # - numpy.dtype(longlong) == numpy.dtype(int64), while - # - longlong != int64 - return value.astype(numpy.dtype(value.dtype.name)) - raise TypeError(f"Cannot handle ORT value: {value}") diff --git a/src/spox/_type_system.py b/src/spox/_type_system.py index aac6b9d0..a88100f3 100644 --- a/src/spox/_type_system.py +++ b/src/spox/_type_system.py @@ -153,7 +153,7 @@ def __le__(self, other: "Type") -> bool: """ if not isinstance(other, Type): return NotImplemented - return self == Type() or other == Type() or self == other + return other == Type() or self == other @dataclass(frozen=True) diff --git a/src/spox/_utils.py b/src/spox/_utils.py index e054f16a..268d56f9 100644 --- a/src/spox/_utils.py +++ b/src/spox/_utils.py @@ -1,25 +1,16 @@ -from typing import Dict, Optional +from typing import Optional import numpy as np import numpy.typing as npt +import onnx from onnx import TensorProto -from onnx.helper import make_tensor, mapping - -_DTYPE_TO_TENSOR_TYPE: Dict[np.dtype, int] = { - **{ - dtype: ttype - for dtype, ttype in mapping.NP_TYPE_TO_TENSOR_TYPE.items() - if dtype != np.object_ - }, - np.dtype(str): TensorProto.STRING, -} - -_TENSOR_TYPE_TO_DTYPE = {ttype: dtype for dtype, ttype in _DTYPE_TO_TENSOR_TYPE.items()} def tensor_type_to_dtype(ttype: int) -> np.dtype: """Convert integer tensor types to ``numpy.dtype`` objects.""" - return _TENSOR_TYPE_TO_DTYPE[ttype] + if ttype == onnx.TensorProto.STRING: + return np.dtype(str) # Spox uses the str datatype for strings, not object + return onnx.helper.tensor_dtype_to_np_dtype(ttype) def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int: @@ -35,8 +26,15 @@ def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int: raise TypeError(err_msg) # normalize in the case of aliases like ``long`` which are missing in the lookup dtype = np.dtype(np.dtype(dtype_like).type) + if dtype == np.dtype(object): + raise TypeError( + "`np.dtype('object')` is not supported as a tensor element type. " + "Hint: Spox uses `np.dtype('str')` for the string datatype." + ) + elif dtype == np.dtype(str): + return onnx.TensorProto.STRING try: - return _DTYPE_TO_TENSOR_TYPE[dtype] + return onnx.helper.np_dtype_to_tensor_dtype(dtype) except KeyError: raise TypeError(err_msg) @@ -53,7 +51,7 @@ def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto: cast_to_bytes = False if arr.dtype.type in [np.str_, np.object_]: cast_to_bytes = True - return make_tensor( + return onnx.helper.make_tensor( name=name or "", data_type=dtype_to_tensor_type(arr.dtype), dims=arr.shape, diff --git a/src/spox/_value_prop.py b/src/spox/_value_prop.py new file mode 100644 index 00000000..c645ca68 --- /dev/null +++ b/src/spox/_value_prop.py @@ -0,0 +1,151 @@ +import warnings +from dataclasses import dataclass +from typing import Dict, List, Union + +import numpy +import onnx +import onnx.reference + +from ._exceptions import InferenceWarning +from ._shape import Shape +from ._type_system import Optional, Sequence, Tensor, Type + +""" +The internal representation for runtime values. + +- numpy.ndarray -> Tensor +- list[PropValue] -> Sequence +- PropValue -> Optional, Some (has value) +- None -> Optional, Nothing (no value) +""" +PropValueType = Union[numpy.ndarray, List["PropValue"], "PropValue", None] +ORTValue = Union[numpy.ndarray, list, None] +RefValue = Union[numpy.ndarray, list, float] + +VALUE_PROP_STRICT_CHECK: bool = False + + +@dataclass(frozen=True) +class PropValue: + """Propagated value given to a Var, which has a run-time value known at compile-time. + + Wrapper for a few Python types which are used to represent values of ONNX types. + + Implements routines for conversion to and from: + + - ONNX Runtime (ORT) + - Reference implementations (Ref). + """ + + type: Type + value: PropValueType + + def __post_init__(self): + if VALUE_PROP_STRICT_CHECK and not self.check(): + raise ValueError( + f"Attempt to construct PropValue of {self.value}, " + f"which does not match the expected type {self.type}." + ) + + def __str__(self): + return f"" + + def check(self) -> bool: + if isinstance(self.type, Tensor): + return ( + isinstance(self.value, numpy.ndarray) + and self.value.dtype.type is self.type.dtype.type + and Shape.from_simple(self.value.shape) <= self.type._shape + ) + elif isinstance(self.type, Sequence): + return isinstance(self.value, list) and all( + elem.type <= self.type.elem_type for elem in self.value + ) + elif isinstance(self.type, Optional): + return self.value is None or isinstance(self.value, PropValue) + warnings.warn( + InferenceWarning( + f"Unknown or unspecified type for propagated value: {self.type!r}" + ) + ) + return True + + @classmethod + def from_ref_value(cls, typ: Type, value: RefValue) -> "PropValue": + # Sometimes non-Sequence values are wrapped in a list. + if ( + not isinstance(typ, Sequence) + and isinstance(value, list) + and len(value) == 1 + ): + (value,) = value + if value is None: # Optional, Nothing + return cls(typ, None) + elif isinstance(typ, Optional): # Optional, Some + return cls(typ, cls.from_ref_value(typ.elem_type, value)) + elif isinstance(value, list): # Sequence + elem_type = typ.unwrap_sequence().elem_type + return cls(typ, [cls.from_ref_value(elem_type, elem) for elem in value]) + else: # otherwise must have Tensor (sometimes this is just a scalar) + return cls(typ, numpy.array(value)) + # No fail branch because representations of Tensor are inconsistent + + @classmethod + def from_ort_value(cls, typ: Type, value: ORTValue) -> "PropValue": + print(f"{typ!r} {value!r} {type(value)!r}") + if value is None: # Optional, Nothing + return cls(typ, None) + elif isinstance(typ, Optional): # Optional, Some + return cls(typ, cls.from_ort_value(typ.elem_type, value)) + elif isinstance(value, list): # Sequence + elem_type = typ.unwrap_sequence().elem_type + return cls(typ, [cls.from_ort_value(elem_type, elem) for elem in value]) + elif isinstance(value, numpy.ndarray): # Tensor + # Normalise the dtype in case we got an alias (like longlong) + value = value.astype(numpy.dtype(value.dtype.name)) + if value.dtype == numpy.dtype(object): + value = value.astype(str) + return cls(typ, value) + raise TypeError(f"No handler for ORT value: {value}") + + def to_ref_value(self) -> RefValue: + if self.value is None: # Optional, Nothing + return [None] # Optionals are wrapped in a singleton list + elif isinstance(self.value, PropValue): # Optional, Some + return [self.value.to_ref_value()] + elif isinstance(self.value, list): # Sequence + return [elem.to_ref_value() for elem in self.value] + else: # Tensor + return self.value + + def to_ort_value(self) -> ORTValue: + if self.value is None: # Optional, Nothing + return None + elif isinstance(self.value, PropValue): # Optional, Some + return self.value.to_ref_value() + elif isinstance(self.value, list): # Sequence + return [elem.to_ref_value() for elem in self.value] + else: # Tensor + return self.value + + +def _run_reference_implementation( + model: onnx.ModelProto, input_feed: Dict[str, RefValue] +) -> Dict[str, RefValue]: + session = onnx.reference.ReferenceEvaluator(model) + output_feed = dict(zip(session.output_names, session.run(None, input_feed))) + return output_feed + + +def _run_onnxruntime( + model: onnx.ModelProto, input_feed: Dict[str, ORTValue] +) -> Dict[str, ORTValue]: + import onnxruntime + + # Silence possible warnings during execution (especially constant folding) + options = onnxruntime.SessionOptions() + options.log_severity_level = 3 + session = onnxruntime.InferenceSession(model.SerializeToString(), options) + output_names = [output.name for output in session.get_outputs()] + output_feed = dict(zip(output_names, session.run(None, input_feed))) + return output_feed diff --git a/src/spox/_var.py b/src/spox/_var.py index 1ee27b67..3f361926 100644 --- a/src/spox/_var.py +++ b/src/spox/_var.py @@ -1,11 +1,10 @@ import typing -from typing import Any, Optional, Union +from typing import Optional, Union import numpy -from . import _type_system +from . import _type_system, _value_prop from ._config import get_default_opset -from ._shape import Shape if typing.TYPE_CHECKING: from ._node import Node @@ -28,7 +27,7 @@ class Var: """ type: Optional[_type_system.Type] - _value: Optional[Any] + _value: Optional[_value_prop.PropValue] _op: "Node" _name: Optional[str] @@ -36,44 +35,31 @@ def __init__( self, op: "Node", type_: Optional[_type_system.Type], - value: Optional[Any] = None, + value: Optional[_value_prop.PropValue] = None, ): + if type_ is not None and not isinstance(type_, _type_system.Type): + raise TypeError("The type field of a Var must be a Spox Type.") + if value is not None and not isinstance(value, _value_prop.PropValue): + raise TypeError("The propagated value field of a Var must be a PropValue.") + if value is not None and value.type != type_: + raise ValueError( + f"The propagated value type ({value.type}) and actual Var type ({type_}) must be the same." + ) + self.type = type_ self._value = value self._op = op self._name = None - if not self._value_matches_type(value, type_): - raise TypeError( - f"Propagated value {value} of type {type(value)} does not match expected type {type_}." - ) def _rename(self, name: Optional[str]): """Mutates the internal state of the Var, overriding its name as given.""" self._name = name - @staticmethod - def _value_matches_type( - value: Optional[Any], type_: Optional[_type_system.Type] - ) -> bool: - if value is None or type_ is None: - return True - if isinstance(type_, _type_system.Tensor): - return ( - isinstance(value, numpy.ndarray) - and value.dtype.type is type_.dtype.type - and Shape.from_simple(value.shape) <= type_._shape - ) - elif isinstance(type_, _type_system.Optional): - return value is Nothing or Var._value_matches_type(value, type_.elem_type) - elif isinstance(type_, _type_system.Sequence): - return isinstance(value, list) and all( - Var._value_matches_type(elem, type_.elem_type) for elem in value - ) - return True - @property def _which_output(self) -> Optional[str]: """Return the name of the output field that this var is stored in under ``self._op``.""" + if self._op is None: + return None op_outs = self._op.outputs.as_dict() candidates = [key for key, var in op_outs.items() if var is self] return candidates[0] if candidates else None @@ -86,7 +72,7 @@ def __repr__(self) -> str: which_repr = "->??" if which is None else (f"->{which}" if is_unary else "") return ( f" str: - return "Nothing" - - -Nothing = _NothingType() diff --git a/src/spox/opset/ai/onnx/ml/v3.py b/src/spox/opset/ai/onnx/ml/v3.py index fcfcf3cd..46dc54ad 100644 --- a/src/spox/opset/ai/onnx/ml/v3.py +++ b/src/spox/opset/ai/onnx/ml/v3.py @@ -35,6 +35,7 @@ from spox._standard import InferenceError, StandardNode # noqa: F401 from spox._type_system import Sequence as SpoxSequence # noqa: F401 from spox._type_system import Tensor, Type +from spox._value_prop import PropValueType from spox._var import Var, _nil, result_type # noqa: F401 from spox._varfields import NoVars, VarFields # noqa: F401 diff --git a/src/spox/opset/ai/onnx/v17.py b/src/spox/opset/ai/onnx/v17.py index 90367385..c81b1753 100644 --- a/src/spox/opset/ai/onnx/v17.py +++ b/src/spox/opset/ai/onnx/v17.py @@ -35,6 +35,7 @@ from spox._standard import InferenceError, StandardNode # noqa: F401 from spox._type_system import Sequence as SpoxSequence # noqa: F401 from spox._type_system import Tensor, Type +from spox._value_prop import PropValueType from spox._var import Var, _nil, result_type # noqa: F401 from spox._varfields import NoVars, VarFields # noqa: F401 @@ -514,7 +515,7 @@ class Attributes: class Outputs(VarFields): output: Var - def propagate_values(self) -> Dict[str, Any]: + def propagate_values(self) -> Dict[str, PropValueType]: ((key, raw),) = ( (k, v.value) for k, v in self.attrs.__dict__.items() if v is not None ) @@ -4384,6 +4385,20 @@ def cast( User must be aware of precision loss and value change caused by range difference between two types. For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type. + In more detail, the conversion among numerical types should follow these rules: + * Casting from floating point to: + * floating point: +/- infinity if OOR (out of range). + * fixed point: undefined if OOR. + * bool: +/- 0.0 to False; all else to True. + * Casting from fixed point to: + * floating point: +/- infinity if OOR. (+ infinity in the case of uint) + * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for + signed types). For example, 200 (int16) -> -56 (int8). + * bool: zero to False; nonzero to True. + * Casting from bool to: + * floating point: `{1.0, 0.0}`. + * fixed point: `{1, 0}`. + * bool: no change. Parameters ========== @@ -6674,11 +6689,14 @@ def grid_sample( padding_mode: str = "zeros", ) -> Var: r""" - Given an `input` and a flow-field `grid`, computes the `output` using `input` values and pixel locations from `grid`. - Currently, only spatial (4-D) inputs are supported. For `input` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2), - the `output` will have shape (N, C, H_out, W_out). - For each output location `output[N, C, H_out, W_out]`, the size-2 vector `grid[N, H_out, W_out]` specifies `input` pixel locations `x` and `y`, - which are used to interpolate the output value `output[N, C, H_out, W_out]`. + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`. + Currently, only spatial (4-D) inputs are supported. For input `X` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2), + the output `Y` will have shape (N, C, H_out, W_out). + The tensor `X` contains values at centers of square pixels in a H by W 2-dimensional image. + The tensor `grid` describes normalized positions where the output `Y` is to be computed + using a specified interpolation method (the mode) and a padding mode (for grid positions falling outside the 2-dimensional image). + Elements in `grid[N, H_out, W_out]` are size-2 vectors specifying positions in the 2-dimensional space of `X`. + They are used to interpolate output values of `Y[N, C, H_out, W_out]`. The GridSample operator is often used in doing grid generator and sampler in the Spatial Transformer Networks (https://arxiv.org/abs/1506.02025). See also in torch.nn.functional.grid_sample (https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample). @@ -6688,7 +6706,7 @@ def grid_sample( Type T1. 4-D tensor of shape (N, C, H, W), where N is the batch size, C is the numbers of channels, H and W are the height and width of the input data. grid - Type T1. + Type T2. Input offset, 4-D tensor of shape (N, H_out, W_out, 2), where H_out and W_out are the height and width of grid and output, Grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. If grid has values outside the range of [-1, 1], the corresponding outputs will be handled as defined by padding_mode. align_corners Attribute. @@ -6703,8 +6721,8 @@ def grid_sample( Returns ======= Y : Var - Type T2. - 4-D tensor of shape (N, C, H_out, W_out). + Type T1. + 4-D tensor of shape (N, C, H_out, W_out) of sampled values. For integer input types, intermediate values are computed as floating point and cast to integer at the end. Notes ===== @@ -7005,7 +7023,7 @@ def if_( ======= outputs : Sequence[Var] Type V. - Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same data type. The `then_branch` and `else_branch` may produce tensors with the same element type and different shapes. If corresponding outputs from the then-branch and the else-branch have static shapes S1 and S2, then the shape of the corresponding output variable of the if-node (if present) must be compatible with both S1 and S2 as it represents the union of both possible shapes.For example, if in a model file, the the first output of `then_branch` is typed float tensor with shape [2] and the first output of `else_branch` is another float tensor with shape [3], If's first output should have (a) no shape set, or (b) a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) a shape of rank 1 with a unique `dim_param`. In contrast, the first output cannot have the shape [2] since [2] and [3] are not compatible. + Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same data type. The `then_branch` and `else_branch` may produce tensors with the same element type and different shapes. If corresponding outputs from the then-branch and the else-branch have static shapes S1 and S2, then the shape of the corresponding output variable of the if-node (if present) must be compatible with both S1 and S2 as it represents the union of both possible shapes.For example, if in a model file, the first output of `then_branch` is typed float tensor with shape [2] and the first output of `else_branch` is another float tensor with shape [3], If's first output should have (a) no shape set, or (b) a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) a shape of rank 1 with a unique `dim_param`. In contrast, the first output cannot have the shape [2] since [2] and [3] are not compatible. Notes ===== @@ -7405,7 +7423,7 @@ def layer_normalization( ``` Mean = ReduceMean(X) D = Sub(X, Mean) - DD = Mul(Diff, Diff) + DD = Mul(D, D) Var = ReduceMean(DD) VarEps = Add(Var, epsilon) StdDev = Sqrt(VarEps) @@ -8159,7 +8177,7 @@ def max_pool( Padding for the beginning and ending along each spatial axis, it can take any value greater than or equal to 0. The value represent the number of pixels added to the beginning and end part of the corresponding axis. `pads` format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels added at the beginning of axis `i` and xi_end, the number of pixels added at the end of axis `i`. This attribute cannot be used simultaneously with auto_pad attribute. If not present, the padding defaults to 0 along start and end of each spatial axis. storage_order Attribute. - The storage order of the tensor. 0 is row major, and 1 is column major. + The storage order of the tensor. 0 is row major, and 1 is column major. This attribute is used only to convert an n-tuple index value into a single integer value for producing the second output. strides Attribute. Stride along each spatial axis. If not present, the stride defaults to 1 along each spatial axis. @@ -8260,7 +8278,7 @@ def max_unpool( ) -> Var: r""" MaxUnpool essentially computes the partial inverse of the MaxPool op. - The input information to this op is typically the the output information from a MaxPool op. The first + The input information to this op is typically the output information from a MaxPool op. The first input tensor X is the tensor that needs to be unpooled, which is typically the pooled tensor (first output) from MaxPool. The second input tensor, I, contains the indices to the (locally maximal) elements corrsponding to the elements in the first input tensor X. Input tensor I is typically the second output of the MaxPool op. @@ -10474,7 +10492,7 @@ def reduce_sum( Keep the reduced dimension or not, default 1 means keep reduced dimension. noop_with_empty_axes Attribute. - Defines behaviour if 'axes' is empty. Default behaviour with 'false' is to reduce all axes. When axes is empty and this attribute is set to true, input tensor will not be reduced,and the output tensor would be equivalent to input tensor. + Defines behavior if 'axes' is empty. Default behavior with 'false' is to reduce all axes. When axes is empty and this attribute is set to true, input tensor will not be reduced,and the output tensor would be equivalent to input tensor. Returns ======= @@ -10601,6 +10619,9 @@ def reshape( dimension will be set explicitly to zero (i.e. not taken from input tensor). Shape (second input) could be an empty shape, which means converting to a scalar. The input tensor's shape and the output tensor's shape are required to have the same number of elements. + If the attribute 'allowzero' is set, it is invalid for the specified shape to + contain both a zero value and -1, as the value of the dimension corresponding + to -1 cannot be determined uniquely. Parameters ========== @@ -11408,13 +11429,13 @@ def scatter_nd( and `updates` tensor of rank q + r - indices.shape[-1] - 1. The output of the operation is produced by creating a copy of the input `data`, and then updating its value to values specified by `updates` at specific index positions specified by `indices`. Its output shape - is the same as the shape of `data`. Note that `indices` should not have duplicate entries. - That is, two or more `updates` for the same index-location is not supported. + is the same as the shape of `data`. `indices` is an integer tensor. Let k denote indices.shape[-1], the last dimension in the shape of `indices`. `indices` is treated as a (q-1)-dimensional tensor of k-tuples, where each k-tuple is a partial-index into `data`. Hence, k can be a value at most the rank of `data`. When k equals rank(data), each update entry specifies an update to a single element of the tensor. When k is less than rank(data) each update entry specifies an - update to a slice of the tensor. + update to a slice of the tensor. Index values are allowed to be negative, as per the usual + convention for counting backwards from the end, but are expected in the valid range. `updates` is treated as a (q-1)-dimensional tensor of replacement-slice-values. Thus, the first (q-1) dimensions of updates.shape must match the first (q-1) dimensions of indices.shape. The remaining dimensions of `updates` correspond to the dimensions of the diff --git a/src/templates/class.jinja2 b/src/templates/class.jinja2 index 04704ae0..3fdc5b66 100644 --- a/src/templates/class.jinja2 +++ b/src/templates/class.jinja2 @@ -52,7 +52,7 @@ class _{{ schema.name }}(StandardNode): {% endif %} {% if value_propagation %} - def propagate_values(self) -> Dict[str, Any]: + def propagate_values(self) -> Dict[str, PropValueType]: {% filter indent(width=8) %} {%+ include value_propagation %} {% endfilter %} diff --git a/src/templates/preamble.jinja2 b/src/templates/preamble.jinja2 index 80d2891b..427fe4f8 100644 --- a/src/templates/preamble.jinja2 +++ b/src/templates/preamble.jinja2 @@ -36,3 +36,4 @@ from spox._internal_op import intro # noqa: F401 from spox._node import OpType # noqa: F401 from spox._standard import InferenceError, StandardNode # noqa: F401 from spox._type_system import Tensor, Type, Sequence as SpoxSequence # noqa: F401 +from spox._value_prop import PropValueType diff --git a/tests/conftest.py b/tests/conftest.py index cb16b01e..81c86e60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,13 @@ import pytest import spox.opset.ai.onnx.v17 +from spox import _value_prop from spox._debug import show_construction_tracebacks from spox._graph import Graph from spox._node import TypeWarningLevel, set_type_warning_level set_type_warning_level(TypeWarningLevel.CRITICAL) +_value_prop.VALUE_PROP_STRICT_CHECK = True class ONNXRuntimeHelper: @@ -52,18 +54,18 @@ def run(self, graph: Graph, unwrap: Optional[str] = None, **kwargs): return result @staticmethod - def assert_close(given, expected): + def assert_close(given, expected, rtol=1e-7): if given is None: assert expected is None else: if isinstance(given, list): for subarray in given: numpy.testing.assert_allclose( - given, numpy.array(expected, dtype=subarray.dtype) + given, numpy.array(expected, dtype=subarray.dtype), rtol=rtol ) else: numpy.testing.assert_allclose( - given, numpy.array(expected, dtype=given.dtype) + given, numpy.array(expected, dtype=given.dtype), rtol=rtol ) diff --git a/tests/test_custom_operator.py b/tests/test_custom_operator.py index f56e4dcc..c0c83c9e 100644 --- a/tests/test_custom_operator.py +++ b/tests/test_custom_operator.py @@ -6,7 +6,7 @@ Of these, ``propagate_values`` is probably least common. """ from dataclasses import dataclass -from typing import Any, Dict +from typing import Dict import numpy @@ -48,7 +48,7 @@ def infer_output_types(self) -> Dict[str, Type]: ) return {"Y": t} - def propagate_values(self) -> Dict[str, Any]: + def propagate_values(self) -> Dict[str, numpy.ndarray]: # This is optional and implements value propagation ('partial data propagation' in ONNX). # In essence constant folding carried through for purposes of type inference. return ( diff --git a/tests/test_function.py b/tests/test_function.py index 2a571920..a725f1cd 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -314,7 +314,6 @@ def test_onnxruntime_nested_function_attr_support(): assert (session.run(None, {"x": x})[0] == y).all() -@pytest.mark.skip("ONNXRuntime does not fully support local functions") def test_minimal_wrapped(onnx_helper, wrapped_linear_graph): a = numpy.random.rand(8).astype(numpy.float32) onnx_helper.assert_close( @@ -322,22 +321,24 @@ def test_minimal_wrapped(onnx_helper, wrapped_linear_graph): ) -@pytest.mark.skip("ONNXRuntime does not fully support local functions") def test_simple_nested_calls_session(onnx_helper, cubic_graph): model = cubic_graph.to_onnx_model() onnxruntime.InferenceSession(model.SerializeToString()) -@pytest.mark.skip("ONNXRuntime does not fully support local functions") def test_simple_nested_calls(onnx_helper, cubic_graph): a = numpy.random.rand(8).astype(numpy.float32) + # increase rtol due to small *non-deterministic* discrepancies onnx_helper.assert_close( onnx_helper.run(cubic_graph, "y", x=a), (1 + a * (2 + a * (3 + a * 5))), + rtol=1e-6, ) -@pytest.mark.skip("ONNXRuntime does not fully support local functions") +@pytest.mark.skip( + "ONNX Runtime generates colliding internal identifiers for nested function nodes." +) def test_nested_calls(onnx_helper, cubic_rational_graph): a = numpy.random.rand(8).astype(numpy.float32) onnx_helper.assert_close( @@ -346,7 +347,9 @@ def test_nested_calls(onnx_helper, cubic_rational_graph): ) -@pytest.mark.skip("ONNXRuntime does not fully support local functions") +@pytest.mark.skip( + "ONNX Runtime generates colliding internal identifiers for function nodes." +) def test_complex_nested_calls(onnx_helper, cubic_rational_graph_2x3): a = numpy.random.rand(8).astype(numpy.float32) onnx_helper.assert_close( diff --git a/tests/test_value_propagation.py b/tests/test_value_propagation.py index 1bccd4f7..1e3290b4 100644 --- a/tests/test_value_propagation.py +++ b/tests/test_value_propagation.py @@ -1,19 +1,10 @@ import numpy -import onnxruntime.capi.onnxruntime_pybind11_state import pytest -from spox import Var, _standard, _type_system, _var +from spox import Var, _type_system from spox._graph import arguments, results from spox._shape import Shape - - -@pytest.fixture(scope="function") -def enable_onnx_value_propagation(): - """Fixture for enabling ONNX Runtime value propagation for tests that use it.""" - prev = _standard._USE_ONNXRUNTIME_VALUE_PROP - _standard._USE_ONNXRUNTIME_VALUE_PROP = True - yield - _standard._USE_ONNXRUNTIME_VALUE_PROP = prev +from spox._value_prop import ORTValue, PropValue def dummy_var(typ=None, value=None): @@ -21,7 +12,7 @@ def dummy_var(typ=None, value=None): return Var(None, typ, value) # type: ignore -def assert_equal_value(arr, expected): +def assert_equal_value(var: Var, expected: ORTValue): """ Convenience function for comparing a ``var``'s propagated value and an expected one. Expected Types vs value types: @@ -30,46 +21,55 @@ def assert_equal_value(arr, expected): - Optional - spox.var.Nothing or the underlying type - Sequence - list of underlying type """ - assert arr._value is not None, "var.value expected to be known" - if isinstance(arr.type, _type_system.Tensor): + assert var._value is not None, "var.value expected to be known" + value = var._value.to_ort_value() + if isinstance(var.type, _type_system.Tensor): expected = numpy.array(expected) - assert arr.type.dtype.type == expected.dtype.type, "element type must match" - assert Shape.from_simple(expected.shape) <= arr.type._shape, "shape must match" - numpy.testing.assert_allclose(arr._value, expected) - elif isinstance(arr.type, _type_system.Optional): + assert var.type.dtype.type == expected.dtype.type, "element type must match" + assert Shape.from_simple(expected.shape) <= var.type._shape, "shape must match" + numpy.testing.assert_allclose(value, expected) + elif isinstance(var.type, _type_system.Optional): if expected is None: - assert ( - arr._value is _var.Nothing - ), "value must be Nothing when optional is empty" + assert value is None, "value must be Nothing when optional is empty" else: - assert_equal_value(dummy_var(arr.type.elem_type, arr._value), expected) - elif isinstance(arr.type, _type_system.Sequence): - assert isinstance(arr._value, list), "value must be list when it is a Sequence" - assert len(arr._value) == len(expected), "sequence length must match" - for a, b in zip(arr._value, expected): - assert_equal_value(dummy_var(arr.type.elem_type, a), b) + assert_equal_value( + dummy_var(var.type.elem_type, var._value.value), expected + ) + elif isinstance(var.type, _type_system.Sequence): + assert isinstance(value, list), "value must be list when it is a Sequence" + assert isinstance( + expected, list + ), "expected value must be list when tested type is a Sequence" + assert len(value) == len(expected), "sequence length must match" + for a, b in zip(value, expected): + assert_equal_value( + dummy_var( + var.type.elem_type, PropValue.from_ort_value(var.type.elem_type, a) + ), + b, + ) else: - raise NotImplementedError(f"Datatype {arr.type}") + raise NotImplementedError(f"Datatype {var.type}") -def test_sanity_no_prop(enable_onnx_value_propagation, op): +def test_sanity_no_prop(op): (x,) = arguments(x=_type_system.Tensor(numpy.int64, ())) op.add(x, x) -def test_sanity_const(enable_onnx_value_propagation, op): +def test_sanity_const(op): assert_equal_value(op.const(2), numpy.int64(2)) -def test_add(enable_onnx_value_propagation, op): +def test_add(op): assert_equal_value(op.add(op.const(2), op.const(2)), numpy.int64(4)) -def test_div(enable_onnx_value_propagation, op): +def test_div(op): assert_equal_value(op.div(op.const(5.0), op.const(2.0)), numpy.float32(2.5)) -def test_identity(enable_onnx_value_propagation, op): +def test_identity(op): for x in [ 5, [1, 2, 3], @@ -79,21 +79,21 @@ def test_identity(enable_onnx_value_propagation, op): assert_equal_value(op.const(x), x) -def test_reshape(enable_onnx_value_propagation, op): +def test_reshape(op): assert_equal_value( op.reshape(op.const([1, 2, 3, 4]), op.const([2, 2])), [[1, 2], [3, 4]] ) -def test_optional(enable_onnx_value_propagation, op): +def test_optional(op): assert_equal_value(op.optional(op.const(2.0)), numpy.float32(2.0)) -def test_empty_optional(enable_onnx_value_propagation, op): +def test_empty_optional(op): assert_equal_value(op.optional(type=_type_system.Tensor(numpy.float32, ())), None) -def test_empty_optional_has_no_element(enable_onnx_value_propagation, op): +def test_empty_optional_has_no_element(op): assert_equal_value( op.optional_has_element( op.optional(type=_type_system.Tensor(numpy.float32, ())) @@ -102,18 +102,18 @@ def test_empty_optional_has_no_element(enable_onnx_value_propagation, op): ) -def test_sequence_empty(enable_onnx_value_propagation, op): +def test_sequence_empty(op): assert_equal_value(op.sequence_empty(dtype=numpy.float32), []) -def test_sequence_append(enable_onnx_value_propagation, op): +def test_sequence_append(op): emp = op.sequence_empty(dtype=numpy.int64) assert_equal_value( op.sequence_insert(op.sequence_insert(emp, op.const(2)), op.const(1)), [2, 1] ) -def test_with_reconstruct(enable_onnx_value_propagation, op): +def test_with_reconstruct(op): a, b = arguments( a=_type_system.Tensor(numpy.int64, ()), b=_type_system.Tensor(numpy.int64, ()), @@ -127,7 +127,7 @@ def test_with_reconstruct(enable_onnx_value_propagation, op): ) -def test_bad_reshape_raises(enable_onnx_value_propagation, op): +def test_bad_reshape_raises(op): op.reshape(op.const([1, 2]), op.const([2])) # sanity - with pytest.raises(onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException): + with pytest.raises(Exception): op.reshape(op.const([1, 2, 3]), op.const([2]))