diff --git a/src/generate.py b/src/generate.py index 3aba2a8a..133f6767 100644 --- a/src/generate.py +++ b/src/generate.py @@ -552,7 +552,7 @@ def main( "ai.onnx", 17, extras=["const"], - type_inference={"Compress": "compress11"}, + type_inference={"Compress": "compress11", "OneHot": "onehot11"}, value_propagation={"Constant": "constant13"}, out_variadic_solutions=V16_OUT_VARIADIC_SOLUTIONS, subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS, diff --git a/src/spox/opset/ai/onnx/v17.py b/src/spox/opset/ai/onnx/v17.py index 93067ceb..66e4ab1a 100644 --- a/src/spox/opset/ai/onnx/v17.py +++ b/src/spox/opset/ai/onnx/v17.py @@ -2041,6 +2041,35 @@ class Inputs(VarFields): class Outputs(VarFields): output: Var + def infer_output_types(self) -> Dict[str, Type]: + self.infer_output_types_onnx() + if not ( + self.inputs.indices.type + and self.inputs.depth.type + and self.inputs.values.type + ): + return {} + indices = self.inputs.indices.unwrap_tensor() + depth = self.inputs.depth.unwrap_tensor() + values = self.inputs.values.unwrap_tensor() + if depth.shape is not None and len(depth.shape) != 0: + raise InferenceError("Number of classes must be a scalar.") + if values.shape is not None and len(values.shape) != 1: + raise InferenceError("Number of values must be a vector (of length 2).") + if indices.shape is not None: + axis = self.attrs.axis.value + if not (-len(indices.shape) - 1 <= axis <= len(indices.shape)): + raise InferenceError( + f"Attribute axis={axis} out of range [-r-1, r] for indices rank r={len(indices.shape)}." + ) + if axis < 0: + # + 1 because slices on negatives are still right-open + axis += len(indices.shape) + 1 + shape = indices.shape[:axis] + (None,) + indices.shape[axis:] + else: + shape = None + return {"output": Tensor(values.dtype, shape)} + op_type = OpType("OneHot", "", 11) attrs: Attributes diff --git a/src/templates/type_inference/onehot11.jinja2 b/src/templates/type_inference/onehot11.jinja2 new file mode 100644 index 00000000..7fd911c0 --- /dev/null +++ b/src/templates/type_inference/onehot11.jinja2 @@ -0,0 +1,21 @@ +self.infer_output_types_onnx() +if not (self.inputs.indices.type and self.inputs.depth.type and self.inputs.values.type): + return {} +indices = self.inputs.indices.unwrap_tensor() +depth = self.inputs.depth.unwrap_tensor() +values = self.inputs.values.unwrap_tensor() +if depth.shape is not None and len(depth.shape) != 0: + raise InferenceError("Number of classes must be a scalar.") +if values.shape is not None and len(values.shape) != 1: + raise InferenceError("Number of values must be a vector (of length 2).") +if indices.shape is not None: + axis = self.attrs.axis.value + if not (-len(indices.shape) - 1 <= axis <= len(indices.shape)): + raise InferenceError(f"Attribute axis={axis} out of range [-r-1, r] for indices rank r={len(indices.shape)}.") + if axis < 0: + # + 1 because slices on negatives are still right-open + axis += len(indices.shape) + 1 + shape = indices.shape[:axis] + (None,) + indices.shape[axis:] +else: + shape = None +return {'output': Tensor(values.dtype, shape)} \ No newline at end of file diff --git a/tests/type_inference/test_one_hot.py b/tests/type_inference/test_one_hot.py new file mode 100644 index 00000000..6eb6524e --- /dev/null +++ b/tests/type_inference/test_one_hot.py @@ -0,0 +1,46 @@ +import numpy +import pytest + +import spox.opset.ai.onnx.v17 as op +from spox._graph import arguments +from spox._standard import InferenceError +from spox._type_system import Tensor + + +def test_one_hot_inference(): + x, y, z = arguments( + x=Tensor(int, ("N", "M")), y=Tensor(int, ()), z=Tensor(float, (2,)) + ) + assert op.one_hot(x, y, z).unwrap_tensor() == Tensor(float, ("N", "M", None)) + assert op.one_hot(x, y, z, axis=0).unwrap_tensor() == Tensor( + float, (None, "N", "M") + ) + assert op.one_hot(x, y, z, axis=1).unwrap_tensor() == Tensor( + float, ("N", None, "M") + ) + assert op.one_hot(x, y, z, axis=-1).unwrap_tensor() == Tensor( + float, ("N", "M", None) + ) + assert op.one_hot(x, y, z, axis=-2).unwrap_tensor() == Tensor( + float, ("N", None, "M") + ) + + +def test_one_hot_inference_checks_depth_scalar(): + with pytest.raises(InferenceError): + op.one_hot(op.const([]), op.const([1]), op.const([0, 1])) + + +def test_one_hot_inference_checks_values_vector(): + with pytest.raises(InferenceError): + op.one_hot(op.const([]), op.const(1), op.const(numpy.array([[0], [1]]))) + + +def test_one_hot_inference_checks_axis_in_range(): + x, y, z = arguments( + x=Tensor(int, ("N", "M")), y=Tensor(int, ()), z=Tensor(float, (2,)) + ) + with pytest.raises(InferenceError): + assert op.one_hot(x, y, z, axis=-4) + with pytest.raises(InferenceError): + assert op.one_hot(x, y, z, axis=3)