Skip to content

Commit

Permalink
Upgrade to ONNX 1.13
Browse files Browse the repository at this point in the history
* Bump ONNX versions

* Bump ORT version for improved test coverage

* Fix generate script importlib usage on 3.11

This usage was seemingly not using defined behaviour. importlib.resources.path is also deprecated from 3.11, but a replacement is available only from 3.9 which we are yet to be on.

* Regenerate ai.onnx@17

Small changes to docstrings

* Add rtol capability to testing routine

* Run more function tests after ORT fixes

There are still very odd build errors on more complex cases, due to colliding (pseudo-random?) identifiers.

* Use non-deprecated onnx.mapping alternative

* Avoid usage of onnx.mapping in utils

As it is becoming deprecated.
Brings up the issue of object-str for the string dtype in Spox.

* Integrate reference implementation of operators

* Remove ORT value prop

* Minor membership test fix

* Create module for propagated value wrapper

* Fix type-hints around PropValue abstraction

* Apply wrapper more often and fix ref integration

* Patch found broken implementations

* Add switch for value propagation runtime

* Add more consistency checks

* Be more strict with PropValue types

* Update src/spox/_standard.py

Co-authored-by: Christian Bourjau <[email protected]>

* Comment on value representation cases

Co-authored-by: Jakub Bachurski <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>

Co-authored-by: Jakub Bachurski <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2022
1 parent 7647b62 commit a6c7cb0
Show file tree
Hide file tree
Showing 20 changed files with 390 additions and 194 deletions.
2 changes: 1 addition & 1 deletion conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ requirements:
run:
- python >=3.8
- numpy
- onnx
- onnx >=1.13

test:
requires:
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ dependencies:
- matplotlib
- numpy
- numpydoc
- onnx
- onnxruntime>=1.12.0
- onnx>=1.13.0
- onnxruntime>=1.13.1
- pip
- pre-commit
- pytest>=6
Expand Down
5 changes: 3 additions & 2 deletions src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
onnx.AttributeProto.TYPE_PROTOS: "AttrTypes",
}

_TEMPLATE_DIR = Path(str(importlib.resources.path("spox", "."))).parent / "templates"
with importlib.resources.path("spox", ".") as path:
_TEMPLATE_DIR = path.parent / "templates"


@dataclass
Expand Down Expand Up @@ -153,7 +154,7 @@ def _get_default_value(attr, attr_type_overrides) -> Optional[str]:
attr.name in attr_type_overrides
and attr_type_overrides[attr.name][0] == "npt.DTypeLike"
):
return f"np.{onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[default].name}"
return f"np.{onnx.helper.tensor_dtype_to_np_dtype(default).name}"

# Strings are bytes at this point and they
# need to be wrapped in quotes.
Expand Down
4 changes: 4 additions & 0 deletions src/spox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
except ModuleNotFoundError:
pass

from spox import _patch_ref_impl

# Public interface
from spox._type_system import Optional, Sequence, Tensor, Type
from spox._var import Var

_patch_ref_impl.patch_reference_implementations()

__all__ = [
"Var",
"Type",
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def to_onnx_model(
"Consider adding an Identity operator if you are just copying arguments."
)

opset_req: list[tuple[str, int]] = list(opsets.items()) # type: ignore
opset_req: List[tuple[str, int]] = list(opsets.items()) # type: ignore
function_protos: Dict[Tuple[str, str], onnx.FunctionProto] = {}
for fun in self._get_build_result().functions:
proto = fun.to_onnx_function(extra_opset_req=opset_req)
Expand Down
20 changes: 17 additions & 3 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ._attributes import AttrGraph
from ._exceptions import InferenceWarning
from ._type_system import Type
from ._value_prop import PropValue, PropValueType
from ._var import Var
from ._varfields import VarFields

Expand Down Expand Up @@ -224,7 +225,7 @@ def pre_init(self, **kwargs):
def post_init(self, **kwargs):
"""Post-initialization hook. Called at the end of ``__init__`` after other default fields are set."""

def propagate_values(self) -> Dict[str, Any]:
def propagate_values(self) -> Dict[str, PropValueType]:
"""
Propagate values from inputs, and, if possible, compute values for outputs as well.
This method is used to implement ONNX partial data propagation - for example so that
Expand Down Expand Up @@ -308,7 +309,7 @@ def _list_types(self, source):
yield name, value_type

def _init_output_vars(
self, types: Dict[str, Type], values: Dict[str, Any]
self, types: Dict[str, Type], values: Dict[str, PropValueType]
) -> VarFields:
"""
Initialize empty output vars bound to this Node and return the respective Fields object.
Expand All @@ -317,7 +318,20 @@ def _init_output_vars(
"""

def arr(name):
return Var(self, types.get(name), values.get(name))
typ: Optional[Type] = types.get(name)
val: Optional[PropValue]
if typ is not None and name in values:
val = PropValue(typ, values.get(name))
if not val.check():
warnings.warn(
InferenceWarning(
f"PropValue of {val.value} does not match the expected type {val.type}, dropping."
)
)
val = None
else:
val = None
return Var(self, typ, val)

var = self.Outputs.get_variadic_name()
outputs: Dict[str, Union[Var, Sequence[Var]]] = {
Expand Down
28 changes: 28 additions & 0 deletions src/spox/_patch_ref_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy
import onnx
import onnx.reference.op_run
import onnx.reference.ops._op_list
import onnx.reference.ops.op_cast
from onnx.reference.op_run import OpRun


class PatchedOptionalHasElement(OpRun):
def _run(self, x):
return (numpy.array(not ((isinstance(x, list) and x == [None]) or x is None)),)


class PatchedCast(OpRun):
def _run(self, x, to=None): # type: ignore
if to == onnx.TensorProto.STRING:
return (x.astype(numpy.str_),)
return (onnx.reference.ops.op_cast.cast_to(x, to),)


def patch_reference_implementations():
"""Patch known broken reference implementations in ONNX.
As the reference implementation module in ONNX is quite new, it still has bugs which are a nuisance in Spox.
This function modifies their implementation by catching out the special cases known to be faulty.
"""
onnx.reference.ops._op_list.OptionalHasElement = PatchedOptionalHasElement
onnx.reference.ops._op_list.Cast = PatchedCast
112 changes: 54 additions & 58 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
"""Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference."""

import enum
import typing
from typing import Any, Dict, Tuple, Union
from typing import Dict, Tuple

import numpy
import onnx
import onnx.reference
import onnx.shape_inference
from onnx.defs import OpSchema

from . import _value_prop
from ._exceptions import InferenceError
from ._node import Node
from ._schemas import SCHEMAS
from ._scope import Scope
from ._shape import SimpleShape
from ._type_system import Optional, Sequence, Tensor, Type
from ._utils import from_array
from ._var import Nothing, _nil
from ._value_prop import PropValue, PropValueType
from ._var import _nil

if typing.TYPE_CHECKING:
from ._graph import Graph

try:
import onnxruntime
except ImportError:
onnxruntime = None # type: ignore

_USE_ONNXRUNTIME_VALUE_PROP = False
class ValuePropBackend(enum.Enum):
NONE = 0
REFERENCE = 1
ONNXRUNTIME = 2


_VALUE_PROP_BACKEND: ValuePropBackend = ValuePropBackend.REFERENCE


class StandardNode(Node):
Expand Down Expand Up @@ -53,9 +59,13 @@ def min_output(self) -> int:
return self.schema.min_output

def to_singleton_onnx_model(
self, *, dummy_outputs: bool = True, with_subgraphs: bool = True
self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True
) -> Tuple[onnx.ModelProto, Scope]:
"""Build a singleton model consisting of just this StandardNode. Used for type inference."""
"""
Build a singleton model consisting of just this StandardNode. Used for type inference.
Dummy subgraphs are typed, but have no graph body, so that we can avoid the build cost.
They refer to non-existent nodes, but ONNX does not raise an error (for now?).
"""
# Prepare names for the values in scope of the node
scope = Scope()
scope.node[self] = "_this_"
Expand All @@ -81,7 +91,7 @@ def to_singleton_onnx_model(
# Subgraphs are not fully built for possibly significant performance gains.
# However, this uses a trick so that they type correctly.
# This may throw if we are building ``not with_subgraphs``.
build_subgraph = _make_dummy_subgraph if with_subgraphs else None
build_subgraph = _make_dummy_subgraph if with_dummy_subgraphs else None
(node_proto,) = self.to_onnx(scope, build_subgraph=build_subgraph)
finally:
self.attrs = self_attrs
Expand All @@ -107,9 +117,9 @@ def out_value_info(curr_key, curr_var):
# Initializers, passed in to allow partial data propagation
# - used so that operators like Reshape are aware of constant shapes
initializers = [
from_array(var._value, key)
from_array(var._value.value, key)
for key, var in self.inputs.as_dict().items()
if isinstance(var._value, numpy.ndarray)
if var._value and isinstance(var._value.value, numpy.ndarray)
]
# Graph and model
graph = onnx.helper.make_graph(
Expand Down Expand Up @@ -157,45 +167,52 @@ def infer_output_types_onnx(self) -> Dict[str, Type]:
# Strips some unuseful type data (unknown dimensions become global-scoped dimension parameters).
return {key: _strip_unk_param(type_) for key, type_ in results.items()}

def propagate_values_onnx(self) -> Dict[str, Any]:
"""
Perform value propagation by evaluating singleton models with ONNX Runtime.
def propagate_values_onnx(self) -> Dict[str, PropValueType]:
"""Perform value propagation by evaluating singleton model.
Assumes onnxruntime was imported successfully. Does not support subgraphs.
The backend used for the propagation can be configured with the `spox._standard.ValuePropBackend` variable.
"""
if any(var and var._value is None for var in self.inputs.as_dict().values()):
# Cannot do propagation when some inputs were not propagated
# Cannot do propagation when some inputs were not propagated/inferred
if any(
var and (var.type is None or var._value is None)
for var in self.inputs.as_dict().values()
):
return {}
if next(iter(self.subgraphs), None) is not None:
# Cannot do propagation with subgraphs implicitly for performance - should be reimplemented
return {}
# Silence possible warnings during execution (especially constant folding)
options = onnxruntime.SessionOptions()
options.log_severity_level = 3
# Set everything up for evaluation
model, scope = self.to_singleton_onnx_model(with_subgraphs=False)
session = onnxruntime.InferenceSession(model.SerializeToString(), options)
if _VALUE_PROP_BACKEND == ValuePropBackend.REFERENCE:
wrap_feed = PropValue.to_ref_value
run = _value_prop._run_reference_implementation
unwrap_feed = PropValue.from_ref_value
elif _VALUE_PROP_BACKEND == ValuePropBackend.ONNXRUNTIME:
wrap_feed = PropValue.to_ort_value
run = _value_prop._run_onnxruntime
unwrap_feed = PropValue.from_ort_value
else:
raise RuntimeError(
f"Not a valid value propagation backend: {_VALUE_PROP_BACKEND}."
)
model, scope = self.to_singleton_onnx_model(with_dummy_subgraphs=False)
input_feed = {
scope.var[var]: _value_prop_to_ort(var._value)
scope.var[var]: wrap_feed(var._value)
for var in self.inputs.as_dict().values()
if var and var._value
}
# Get outputs and give a map from output field names
output_feed = dict(zip(session.get_outputs(), session.run(None, input_feed)))
return {
scope.var[output.name]._which_output: _value_prop_from_ort(result)
for output, result in output_feed.items()
output_feed = run(model, input_feed)
results = {
scope.var[str(name)]
._which_output: unwrap_feed(scope.var[str(name)].unwrap_type(), result)
.value
for name, result in output_feed.items()
}
return {k: v for k, v in results.items() if k is not None}

def infer_output_types(self) -> Dict[str, Type]:
return self.infer_output_types_onnx()

def propagate_values(self) -> Dict[str, Any]:
if _USE_ONNXRUNTIME_VALUE_PROP:
if onnxruntime is None:
raise RuntimeError(
"Cannot use ONNX Runtime value prop when ONNX Runtime isn't available "
"(ImportError was raised)."
)
def propagate_values(self) -> Dict[str, PropValueType]:
if _VALUE_PROP_BACKEND != ValuePropBackend.NONE:
return self.propagate_values_onnx()
return {}

Expand Down Expand Up @@ -250,24 +267,3 @@ def _make_dummy_subgraph(_node: Node, key: str, graph: "Graph") -> onnx.GraphPro
nodes.append(onnx.helper.make_node("Identity", [outer], [out]))

return onnx.helper.make_graph(nodes, f"__dummy_{key}", inputs, outputs)


def _value_prop_to_ort(value) -> Union[numpy.ndarray, list, None]:
if value is Nothing:
return None
return value


def _value_prop_from_ort(value: Union[numpy.ndarray, list, None]):
if value is None:
return Nothing
elif isinstance(value, list):
return [_value_prop_from_ort(elem) for elem in value]
elif isinstance(value, numpy.ndarray):
# This looks ridiculous, but is required to normalise numpy.longlong back into a fixed size type.
# ORT sometimes returns non-sized types (like longlong) and Var's value typecheck will fail because of it.
# - numpy.dtype(longlong).type is longlong, but
# - numpy.dtype(longlong) == numpy.dtype(int64), while
# - longlong != int64
return value.astype(numpy.dtype(value.dtype.name))
raise TypeError(f"Cannot handle ORT value: {value}")
2 changes: 1 addition & 1 deletion src/spox/_type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __le__(self, other: "Type") -> bool:
"""
if not isinstance(other, Type):
return NotImplemented
return self == Type() or other == Type() or self == other
return other == Type() or self == other


@dataclass(frozen=True)
Expand Down
30 changes: 14 additions & 16 deletions src/spox/_utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
from typing import Dict, Optional
from typing import Optional

import numpy as np
import numpy.typing as npt
import onnx
from onnx import TensorProto
from onnx.helper import make_tensor, mapping

_DTYPE_TO_TENSOR_TYPE: Dict[np.dtype, int] = {
**{
dtype: ttype
for dtype, ttype in mapping.NP_TYPE_TO_TENSOR_TYPE.items()
if dtype != np.object_
},
np.dtype(str): TensorProto.STRING,
}

_TENSOR_TYPE_TO_DTYPE = {ttype: dtype for dtype, ttype in _DTYPE_TO_TENSOR_TYPE.items()}


def tensor_type_to_dtype(ttype: int) -> np.dtype:
"""Convert integer tensor types to ``numpy.dtype`` objects."""
return _TENSOR_TYPE_TO_DTYPE[ttype]
if ttype == onnx.TensorProto.STRING:
return np.dtype(str) # Spox uses the str datatype for strings, not object
return onnx.helper.tensor_dtype_to_np_dtype(ttype)


def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int:
Expand All @@ -35,8 +26,15 @@ def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int:
raise TypeError(err_msg)
# normalize in the case of aliases like ``long`` which are missing in the lookup
dtype = np.dtype(np.dtype(dtype_like).type)
if dtype == np.dtype(object):
raise TypeError(
"`np.dtype('object')` is not supported as a tensor element type. "
"Hint: Spox uses `np.dtype('str')` for the string datatype."
)
elif dtype == np.dtype(str):
return onnx.TensorProto.STRING
try:
return _DTYPE_TO_TENSOR_TYPE[dtype]
return onnx.helper.np_dtype_to_tensor_dtype(dtype)
except KeyError:
raise TypeError(err_msg)

Expand All @@ -53,7 +51,7 @@ def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto:
cast_to_bytes = False
if arr.dtype.type in [np.str_, np.object_]:
cast_to_bytes = True
return make_tensor(
return onnx.helper.make_tensor(
name=name or "",
data_type=dtype_to_tensor_type(arr.dtype),
dims=arr.shape,
Expand Down
Loading

0 comments on commit a6c7cb0

Please sign in to comment.