Skip to content

Commit

Permalink
Patch type inference for OneHot
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubBachurskiQC authored Jan 12, 2023
1 parent f729274 commit 1121c6f
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def main(
"ai.onnx",
17,
extras=["const"],
type_inference={"Compress": "compress11"},
type_inference={"Compress": "compress11", "OneHot": "onehot11"},
value_propagation={"Constant": "constant13"},
out_variadic_solutions=V16_OUT_VARIADIC_SOLUTIONS,
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
Expand Down
29 changes: 29 additions & 0 deletions src/spox/opset/ai/onnx/v17.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,35 @@ class Inputs(VarFields):
class Outputs(VarFields):
output: Var

def infer_output_types(self) -> Dict[str, Type]:
self.infer_output_types_onnx()
if not (
self.inputs.indices.type
and self.inputs.depth.type
and self.inputs.values.type
):
return {}
indices = self.inputs.indices.unwrap_tensor()
depth = self.inputs.depth.unwrap_tensor()
values = self.inputs.values.unwrap_tensor()
if depth.shape is not None and len(depth.shape) != 0:
raise InferenceError("Number of classes must be a scalar.")
if values.shape is not None and len(values.shape) != 1:
raise InferenceError("Number of values must be a vector (of length 2).")
if indices.shape is not None:
axis = self.attrs.axis.value
if not (-len(indices.shape) - 1 <= axis <= len(indices.shape)):
raise InferenceError(
f"Attribute axis={axis} out of range [-r-1, r] for indices rank r={len(indices.shape)}."
)
if axis < 0:
# + 1 because slices on negatives are still right-open
axis += len(indices.shape) + 1
shape = indices.shape[:axis] + (None,) + indices.shape[axis:]
else:
shape = None
return {"output": Tensor(values.dtype, shape)}

op_type = OpType("OneHot", "", 11)

attrs: Attributes
Expand Down
21 changes: 21 additions & 0 deletions src/templates/type_inference/onehot11.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
self.infer_output_types_onnx()
if not (self.inputs.indices.type and self.inputs.depth.type and self.inputs.values.type):
return {}
indices = self.inputs.indices.unwrap_tensor()
depth = self.inputs.depth.unwrap_tensor()
values = self.inputs.values.unwrap_tensor()
if depth.shape is not None and len(depth.shape) != 0:
raise InferenceError("Number of classes must be a scalar.")
if values.shape is not None and len(values.shape) != 1:
raise InferenceError("Number of values must be a vector (of length 2).")
if indices.shape is not None:
axis = self.attrs.axis.value
if not (-len(indices.shape) - 1 <= axis <= len(indices.shape)):
raise InferenceError(f"Attribute axis={axis} out of range [-r-1, r] for indices rank r={len(indices.shape)}.")
if axis < 0:
# + 1 because slices on negatives are still right-open
axis += len(indices.shape) + 1
shape = indices.shape[:axis] + (None,) + indices.shape[axis:]
else:
shape = None
return {'output': Tensor(values.dtype, shape)}
46 changes: 46 additions & 0 deletions tests/type_inference/test_one_hot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy
import pytest

import spox.opset.ai.onnx.v17 as op
from spox._graph import arguments
from spox._standard import InferenceError
from spox._type_system import Tensor


def test_one_hot_inference():
x, y, z = arguments(
x=Tensor(int, ("N", "M")), y=Tensor(int, ()), z=Tensor(float, (2,))
)
assert op.one_hot(x, y, z).unwrap_tensor() == Tensor(float, ("N", "M", None))
assert op.one_hot(x, y, z, axis=0).unwrap_tensor() == Tensor(
float, (None, "N", "M")
)
assert op.one_hot(x, y, z, axis=1).unwrap_tensor() == Tensor(
float, ("N", None, "M")
)
assert op.one_hot(x, y, z, axis=-1).unwrap_tensor() == Tensor(
float, ("N", "M", None)
)
assert op.one_hot(x, y, z, axis=-2).unwrap_tensor() == Tensor(
float, ("N", None, "M")
)


def test_one_hot_inference_checks_depth_scalar():
with pytest.raises(InferenceError):
op.one_hot(op.const([]), op.const([1]), op.const([0, 1]))


def test_one_hot_inference_checks_values_vector():
with pytest.raises(InferenceError):
op.one_hot(op.const([]), op.const(1), op.const(numpy.array([[0], [1]])))


def test_one_hot_inference_checks_axis_in_range():
x, y, z = arguments(
x=Tensor(int, ("N", "M")), y=Tensor(int, ()), z=Tensor(float, (2,))
)
with pytest.raises(InferenceError):
assert op.one_hot(x, y, z, axis=-4)
with pytest.raises(InferenceError):
assert op.one_hot(x, y, z, axis=3)

0 comments on commit 1121c6f

Please sign in to comment.