Skip to content

Commit

Permalink
[RELAX][ONNX][FIX] add a parser to handle expression in the shape dim…
Browse files Browse the repository at this point in the history
… names (#17505)

* added a simple parser that can handle onnx variable names containing
expressions such as past_sequence_length + sequence_length where each
variable becomes a tvm.tir.SizeVar

* added doc strings

updated binary base to completely unpack relax.PrimValue if it contains
tir.IntImm or tir.FloatImm

added regression tests

* updated formatting
  • Loading branch information
PatrikPerssonInceptron authored Nov 10, 2024
1 parent 3e386fd commit d5b9f5c
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 16 deletions.
109 changes: 93 additions & 16 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
github.com/apache/tvm/issues if you hit an error with dynamic kernels.
"""
import math
import operator
import re
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -101,6 +103,83 @@ def get_constant(
return var


def get_value(token, value_dict: Dict[str, tvm.tir.SizeVar]) -> Union[int, tvm.tir.SizeVar]:
"""Converts to token to an integer value if it a constant, otherwise it generates a SizeVar
Parameters
----------
token: str
current token to decode.
value_dict: Dict
The Dictionary mapping from the name of ValueInfoProto to SizeVar.
Returns
-------
Union[int, tvm.tir.SizeVar]
The decoded token
"""

try:
return int(token)
except ValueError:
if token not in value_dict or token == "?":
value_dict[token] = tvm.tir.SizeVar(token, "int64")
value = value_dict[token]
return value


def parse_shape_name(
name: str, value_dict: Dict[str, tvm.tir.SizeVar]
) -> Union[tir.PrimExpr, tvm.tir.SizeVar]:
"""Converts expressions in the shape dimension name to prim expressions.
Parameters
----------
name: str
name of shape dimension.
value_dict: Dict
The Dictionary mapping from the name of ValueInfoProto to SizeVar.
Returns
-------
Union[tir.PrimExpr, tvm.tir.SizeVar]
The expression of the shape dimension.
"""

tokens = re.split(r"(\+|\-|\*|\/\/|\/)", name.replace(" ", ""))

operators = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.floordiv, # is floordiv since the operands are always int
"//": operator.floordiv,
}

value_stack = []
operator_stack = []

for token in tokens:
if token in operators:
operator_stack.append(token)
else:
value = get_value(token, value_dict)
if value_stack and operator_stack:
prev_value = value_stack.pop()
op = operator_stack.pop()
result = operators[op](prev_value, value)
value_stack.append(result)
else:
value_stack.append(value)

if value_stack:
return value_stack[0]
else:
raise Exception("Shape dimension could not be inferred")


def get_info(
info_proto: onnx.onnx_ml_pb2.ValueInfoProto, value_dict: Dict[str, tvm.tir.SizeVar]
) -> Tuple[str, List, str, List, Dict]:
Expand All @@ -126,9 +205,7 @@ def get_info(
name = dim.dim_param
value = dim.dim_value
if value is None or value == 0:
if name not in value_dict or name == "?":
value_dict[name] = tvm.tir.SizeVar(name, "int64")
value = value_dict[name]
value = parse_shape_name(name, value_dict)
shape_name.append(name)
else:
shape_name.append(value)
Expand All @@ -145,9 +222,7 @@ def get_info(
def get_numpy(tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> _np.ndarray:
"""Grab data in TensorProto and convert to numpy array."""
try:
from onnx.numpy_helper import ( # pylint: disable=import-outside-toplevel
to_array,
)
from onnx.numpy_helper import to_array # pylint: disable=import-outside-toplevel
except ImportError as exception:
raise ImportError("Unable to import onnx which is required {}".format(exception))
return to_array(tensor_proto)
Expand Down Expand Up @@ -237,6 +312,16 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.matmul(inputs[0], inputs[1])


def _to_numpy(x):
if isinstance(x, relax.PrimValue):
x = x.value
if isinstance(x, (tir.IntImm, tir.FloatImm)):
x = x.value
return _np.array(x)
else:
return x.data.numpy()


class BinaryBase(OnnxOpConverter):
"""Converts an onnx BinaryBase node into an equivalent Relax expression."""

Expand All @@ -254,16 +339,8 @@ def base_impl(cls, bb, inputs, attr, params):
)
return relax.const(output, inputs[0].struct_info.dtype)
if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
x = (
_np.array(inputs[0].value)
if isinstance(inputs[0], relax.PrimValue)
else inputs[0].data.numpy()
)
y = (
_np.array(inputs[1].value)
if isinstance(inputs[1], relax.PrimValue)
else inputs[1].data.numpy()
)
x = _to_numpy(inputs[0])
y = _to_numpy(inputs[1])
return relax.PrimValue(cls.numpy_op(x, y)) # pylint: disable=not-callable

return cls.relax_op(inputs[0], inputs[1]) # pylint: disable=not-callable
Expand Down
203 changes: 203 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tvm.relax.frontend.onnx import from_onnx
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I

bg = np.random.MT19937(0)
rg = np.random.Generator(bg)
Expand Down Expand Up @@ -2752,5 +2753,207 @@ def test_params_names_start_with_onnx():
check_correctness(model)


def test_shape_dim_string_expression():
def _verify(x_shape, example_shape):

identity_node = helper.make_node("Identity", ["x"], ["y"])

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(
graph, producer_name="test_var_shape_dim_containing_expressions_onnx"
)

inputs = {"x": generate_random_value(example_shape, TensorProto.FLOAT)}
check_correctness(model, inputs)

_verify(["A", "B", "A + B"], [3, 9, 12])
_verify(["A", "B", "A - B"], [9, 3, 6])
_verify(["A", "B", "A * B"], [9, 3, 27])
_verify(["A", "B", "A // B"], [9, 3, 3])


def test_shape_dim_string_expression_graph_add():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A + B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A + B"), dtype="float32")) -> R.Tensor(("A", "B", "A + B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A + B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_subtract():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A - B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A - B"), dtype="float32")) -> R.Tensor(("A", "B", "A - B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A - B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_mul():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A * B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A * B"), dtype="float32")) -> R.Tensor(("A", "B", "A * B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A * B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_div_1():

identity_node = helper.make_node("Identity", ["x"], ["y"])

# this will result in a floordiv despite not using // since the operands are always int
x_shape = ["A", "B", "A / B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", "B", "A // B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A // B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


def test_shape_dim_string_expression_graph_div_2():

identity_node = helper.make_node("Identity", ["x"], ["y"])

x_shape = ["A", "B", "A // B"]

graph = helper.make_graph(
[identity_node],
"test_var_shape_dim_containing_expressions_onnx",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, x_shape)],
)
model = helper.make_model(graph, producer_name="test_var_shape_dim_containing_expressions_onnx")

tvm_model = from_onnx(model, opset=14, keep_params_in_input=True)

# fmt: off
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", "B", "A // B"), dtype="float32"):
A = T.int64(is_size_var=True)
B = T.int64(is_size_var=True)
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((A, B, A // B), dtype="float32") = x
R.output(gv)
return gv
# fmt: on

tvm.ir.assert_structural_equal(tvm_model, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit d5b9f5c

Please sign in to comment.