Skip to content

Commit

Permalink
Add ONNX test
Browse files Browse the repository at this point in the history
(cherry picked from commit 5bf3be2)
  • Loading branch information
KodiaqQ committed Nov 5, 2024
1 parent 640efc8 commit 5dc2b42
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,18 @@ class ONNXErfMetatype(ONNXOpMetatype):
op_names = ["Erf"]


@ONNX_OPERATION_METATYPES.register()
class ONNXCosMetatype(ONNXOpMetatype):
name = "CosOp"
op_names = ["Cos"]


@ONNX_OPERATION_METATYPES.register()
class ONNXSinMetatype(ONNXOpMetatype):
name = "SinOp"
op_names = ["Sin"]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 22 additions & 0 deletions nncf/onnx/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,25 @@ def create_se_block() -> GraphPattern:
multiply_node = node_id
pattern.add_edge(non_pattern_node, multiply_node)
return pattern


@ONNX_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.ONNXMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.ONNXTransposeMetatype}
)
concat_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.ONNXConcatMetatype}
)
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.ONNXCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.ONNXSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
strict digraph {
"0 unsqueeze" [id=0, type=Unsqueeze];
"1 cast" [id=1, type=Cast];
"2 reshape" [id=2, type=Reshape];
"3 matmul" [id=3, type=MatMul];
"4 transpose" [id=4, type=Transpose];
"5 concat" [id=5, type=Concat];
"6 sin" [id=6, type=Sin];
"7 cos" [id=7, type=Cos];
"8 nncf_model_input_0" [id=8, type=nncf_model_input];
"9 nncf_model_output_0" [id=9, type=nncf_model_output];
"10 nncf_model_output_1" [id=10, type=nncf_model_output];
"0 unsqueeze" -> "1 cast" [label="[1, 10, 1]", style=dashed];
"1 cast" -> "3 matmul" [label="[1, 10, 1]", style=solid];
"2 reshape" -> "3 matmul" [label="[1, 5]", style=solid];
"3 matmul" -> "4 transpose" [label="[1, 10, 5]", style=solid];
"4 transpose" -> "5 concat" [label="[1, 5, 10]", style=solid];
"5 concat" -> "6 sin" [label="[1, 5, 10]", style=solid];
"5 concat" -> "7 cos" [label="[1, 5, 10]", style=solid];
"6 sin" -> "9 nncf_model_output_0" [label="[1, 5, 10]", style=solid];
"7 cos" -> "10 nncf_model_output_1" [label="[1, 5, 10]", style=solid];
"8 nncf_model_input_0" -> "0 unsqueeze" [label="[1, 10]", style=dashed];
}
123 changes: 123 additions & 0 deletions tests/onnx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,3 +1757,126 @@ def __init__(self):
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "unified_embedding_model.dot")


class RoPEModel(ONNXReferenceModel):
def __init__(self):
rng = np.random.default_rng(seed=0)

input_shape = [1, 10]
input_name = "model_in"
model_in = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT64, input_shape)

output_shape = [1, 5, 10]
cos_out_name = "cos_out"
cos_out = onnx.helper.make_tensor_value_info(cos_out_name, onnx.TensorProto.FLOAT, output_shape)
sin_out_name = "sin_out"
sin_out = onnx.helper.make_tensor_value_info(sin_out_name, onnx.TensorProto.FLOAT, output_shape)

unsqueeze_out_name = "un_out"
unsqueeze_tensor_name = "un_tensor"
unsqueeze_tensor = create_initializer_tensor(
name=unsqueeze_tensor_name, tensor_array=np.int64([2]), data_type=onnx.TensorProto.INT64
)
unsqueeze_node = onnx.helper.make_node(
name="unsqueeze",
op_type="Unsqueeze",
inputs=[input_name, unsqueeze_tensor_name],
outputs=[unsqueeze_out_name],
)

cast_out_name = "cast_out"
cast_node = onnx.helper.make_node(
name="cast",
op_type="Cast",
to=onnx.TensorProto.FLOAT,
inputs=[unsqueeze_out_name],
outputs=[cast_out_name],
)

reshape_shape_name = "re_shape"
reshape_shape = create_initializer_tensor(
name=reshape_shape_name,
tensor_array=np.array([1, 5]).astype(np.int64),
data_type=onnx.TensorProto.INT64,
)
reshape_tensor_name = "re_tensor"
reshape_tensor = create_initializer_tensor(
name=reshape_tensor_name,
tensor_array=rng.uniform(0, 1, (5)).astype(np.float32),
data_type=onnx.TensorProto.FLOAT,
)
reshape_out_name = "re_out"
reshape_node = onnx.helper.make_node(
name="reshape",
op_type="Reshape",
inputs=[reshape_tensor_name, reshape_shape_name],
outputs=[reshape_out_name],
)

matmul_out_name = "mm_out"
matmul_node = onnx.helper.make_node(
name="matmul",
op_type="MatMul",
inputs=[cast_out_name, reshape_out_name],
outputs=[matmul_out_name],
)

transpose_out_name = "trans_out"
transpose_node = onnx.helper.make_node(
name="transpose",
op_type="Transpose",
inputs=[matmul_out_name],
outputs=[transpose_out_name],
perm=[0, 2, 1],
)

concat_out_name = "concat_out"
concat_node = onnx.helper.make_node(
name="concat",
op_type="Concat",
inputs=[transpose_out_name],
outputs=[concat_out_name],
axis=-1,
)

sin_node = onnx.helper.make_node(
name="sin",
op_type="Sin",
inputs=[concat_out_name],
outputs=[sin_out_name],
)

cos_node = onnx.helper.make_node(
name="cos",
op_type="Cos",
inputs=[concat_out_name],
outputs=[cos_out_name],
)

graph_def = onnx.helper.make_graph(
nodes=[
unsqueeze_node,
cast_node,
reshape_node,
matmul_node,
transpose_node,
concat_node,
sin_node,
cos_node,
],
name="RoPEModel",
inputs=[model_in],
outputs=[sin_out, cos_out],
initializer=[
unsqueeze_tensor,
reshape_tensor,
reshape_shape,
],
)

op = onnx.OperatorSetIdProto()
op.version = OPSET_VERSION
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "rope_model.dot")
12 changes: 12 additions & 0 deletions tests/onnx/quantization/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import pytest

from nncf.parameters import ModelType
from tests.onnx.models import ALL_SYNTHETIC_MODELS
from tests.onnx.models import RoPEModel
from tests.onnx.quantization.common import compare_nncf_graph
from tests.onnx.quantization.common import min_max_quantize_model
from tests.onnx.quantization.common import mock_collect_statistics
Expand All @@ -23,3 +25,13 @@ def test_synthetic_models_graph(model_cls_to_test, mocker):
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(model_to_test.onnx_model)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)


@pytest.mark.parametrize("model_cls_to_test", [RoPEModel])
def test_synthetic_models_graph_transformer(model_cls_to_test, mocker):
mock_collect_statistics(mocker)
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(
model_to_test.onnx_model, quantization_params={"model_type": ModelType.TRANSFORMER}
)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)

0 comments on commit 5dc2b42

Please sign in to comment.