Skip to content

Commit

Permalink
[Relax] Update ONNX frontend for unique, nonzero and compress
Browse files Browse the repository at this point in the history
This PR updates the ONNX frontend:

- Add match cast for unique and nonzero operators, enabling further import of ONNX models.
- Add support for compress operator.
- Fix the shape of the output tensor for nonzero operator.
  • Loading branch information
Hzfengsy committed Nov 12, 2024
1 parent d5b9f5c commit 166a482
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 9 deletions.
52 changes: 49 additions & 3 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,32 @@ def _impl_v18(cls, bb, inputs, attr, params):
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)


class Compress(OnnxOpConverter):
"""Convert an onnx Compress node into an equivalent Relax expression."""

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
tensor, condition = inputs
axis = attr.get("axis", None)

# Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4]
if condition.struct_info.dtype != "bool":
raise ValueError("Condition tensor is expected to be a boolean tensor")
if condition.struct_info.ndim != 1:
raise ValueError("Condition tensor is expected to be a 1D boolean tensor")
indices = relax.op.nonzero(condition)
num_nonzero = tir.Var("num_nonzero", "int64")
indices = bb.match_cast(indices, relax.TensorStructInfo([1, num_nonzero], "int64"))
indices = relax.op.reshape(indices, [-1])

if axis is not None:
return relax.op.take(tensor, indices, axis=axis)

# if axis is None, flatten input tensor before selection
tensor = relax.op.reshape(tensor, (-1,))
return relax.op.take(tensor, indices, axis=0)


class Size(OnnxOpConverter):
"""Convert an onnx Size node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2726,15 +2752,35 @@ def _impl_v11(cls, bb, inputs, attr, params):
axis = attr.get("axis", None)
sorted = bool(attr.get("sorted", 1))
# TODO(tvm-team): Add support for return_index, return_inverse, return_counts
return relax.op.unique(data, sorted=sorted, axis=axis)
unique = relax.op.unique(data, sorted=sorted, axis=axis)
unique_numbers = tir.Var("unique_numbers", "int64")
input_shape = data.struct_info.shape
dtype = data.struct_info.dtype

if axis is None:
# flatten the input tensor
return bb.match_cast(unique, relax.TensorStructInfo((unique_numbers,), dtype))

axis = axis if axis >= 0 else len(input_shape) + axis
if axis < 0 or axis >= len(input_shape):
raise ValueError(f"Axis {axis} is out of bounds")
output_shape = [
input_shape[i] if i != axis else unique_numbers for i in range(len(input_shape))
]
return bb.match_cast(unique, relax.TensorStructInfo(output_shape, dtype))


class NonZero(OnnxOpConverter):
"""Converts an onnx NonZero node into an equivalent Relax expression."""

@classmethod
def _impl_v9(cls, bb, inputs, attr, params):
return relax.op.nonzero(inputs[0])
ndim = inputs[0].struct_info.ndim
ndim = 1 if ndim == 0 else ndim
nonzero_numbers = tir.Var("nonzero_numbers", "int64")
return bb.match_cast(
relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim, nonzero_numbers), "int64")
)


class HardSigmoid(OnnxOpConverter):
Expand Down Expand Up @@ -3075,7 +3121,7 @@ def _get_convert_map():
"Scatter": Scatter,
"ScatterElements": ScatterElements,
"ScatterND": ScatterND,
# "Compress": Compress,
"Compress": Compress,
"Size": Size,
"EyeLike": EyeLike,
# Normalization
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def nonzero(x: Expr) -> Expr:
Returns
-------
result : relax.Expr
A (n+1)-D tensor containing indices of non-zero elements.
A 2-D tensor containing indices of non-zero elements.
Note
----
Expand Down
4 changes: 1 addition & 3 deletions src/relax/op/tensor/set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero);

StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
// Cheat zero dim scalar as 1-dim.
int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1;
return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice);
return TensorStructInfo(DataType::Int(64), 2, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.nonzero")
Expand Down
30 changes: 29 additions & 1 deletion tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,34 @@ def verify_scatter_nd(data_shape, indices_shape, updates_shape):
verify_scatter_nd([10], [5, 1], [5])


@pytest.mark.parametrize("tensor_shape", [[32, 32]])
@pytest.mark.parametrize("condition_shape", [None, [8], [16]])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_compress(
tensor_shape: List[int],
condition_shape: Optional[List[int]],
axis: Optional[int],
):
if condition_shape is None and axis is None:
pytest.skip("Either condition_shape or axis must be specified")
if condition_shape is None:
condition_shape = [tensor_shape[axis]]
compress_node = helper.make_node("Compress", ["tensor", "condition"], ["output"], axis=axis)
graph = helper.make_graph(
[compress_node],
"compress_test",
inputs=[
helper.make_tensor_value_info("tensor", TensorProto.FLOAT, tensor_shape),
helper.make_tensor_value_info("condition", TensorProto.BOOL, condition_shape),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [])
], # shape is unknown
)
model = helper.make_model(graph, producer_name="compress_test")
check_correctness(model, opset=11)


def test_size():
test_node = helper.make_node("Size", ["x"], ["y"])
graph = helper.make_graph(
Expand Down Expand Up @@ -2478,7 +2506,7 @@ def test_unique(axis: Optional[int], sorted: int):
check_correctness(model)


@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)])
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6), (7, 8, 9, 10)])
def test_nonzero(shape):
verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_op_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def test_nonzero_infer_struct_info(shape):
_check_inference(
bb,
relax.op.nonzero(x0),
relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"),
relax.TensorStructInfo(ndim=2, dtype="int64"),
)


Expand Down

0 comments on commit 166a482

Please sign in to comment.