From 80df08b19d5ba1f0a8b44362865297604aad96a8 Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Thu, 7 Nov 2024 16:13:45 +0100 Subject: [PATCH] Added ignored RoPE for ONNX/PT (#3059) On top of https://github.com/openvinotoolkit/nncf/pull/3049 ### Changes - As stated in the title. - Added Sin and Cos ONNX/PT operations. ### Reason for changes - Accuracy improvement for models with ModelType.TRANSFORMERS type. ### Related tickets - 155511 ### Tests - Added graph test for ONNX/PT. --- nncf/onnx/graph/metatypes/onnx_metatypes.py | 12 ++ nncf/onnx/quantization/ignored_patterns.py | 22 ++++ nncf/torch/graph/operator_metatypes.py | 12 ++ nncf/torch/quantization/ignored_patterns.py | 20 +++ tests/cross_fw/test_templates/helpers.py | 19 +++ .../quantization/synthetic/rope_model.dot | 23 ++++ tests/onnx/models.py | 123 ++++++++++++++++++ tests/onnx/quantization/test_graphs.py | 12 ++ tests/onnx/test_pattern_manager.py | 1 - .../quantized/ptq/symmetric/rope_model.dot | 21 +++ tests/torch/ptq/test_graphs.py | 3 + tests/torch/test_pattern_manager.py | 1 - 12 files changed, 267 insertions(+), 2 deletions(-) create mode 100644 tests/onnx/data/reference_graphs/quantization/synthetic/rope_model.dot create mode 100644 tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot diff --git a/nncf/onnx/graph/metatypes/onnx_metatypes.py b/nncf/onnx/graph/metatypes/onnx_metatypes.py index 9e45490cd92..075b2597643 100644 --- a/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -677,6 +677,18 @@ class ONNXErfMetatype(ONNXOpMetatype): op_names = ["Erf"] +@ONNX_OPERATION_METATYPES.register() +class ONNXCosMetatype(ONNXOpMetatype): + name = "CosOp" + op_names = ["Cos"] + + +@ONNX_OPERATION_METATYPES.register() +class ONNXSinMetatype(ONNXOpMetatype): + name = "SinOp" + op_names = ["Sin"] + + def get_operator_metatypes() -> List[Type[OperatorMetatype]]: """ Returns a list of the operator metatypes. diff --git a/nncf/onnx/quantization/ignored_patterns.py b/nncf/onnx/quantization/ignored_patterns.py index 782ecdf7c3e..8e1255f3ca1 100644 --- a/nncf/onnx/quantization/ignored_patterns.py +++ b/nncf/onnx/quantization/ignored_patterns.py @@ -157,3 +157,25 @@ def create_se_block() -> GraphPattern: multiply_node = node_id pattern.add_edge(non_pattern_node, multiply_node) return pattern + + +@ONNX_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE) +def create_rope() -> GraphPattern: + pattern = GraphPattern() + matmul_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.ONNXMatMulMetatype} + ) + transpose_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.ONNXTransposeMetatype} + ) + concat_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.ONNXConcatMetatype} + ) + cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.ONNXCosMetatype}) + sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.ONNXSinMetatype}) + + pattern.add_edge(matmul_node, transpose_node) + pattern.add_edge(transpose_node, concat_node) + pattern.add_edge(concat_node, cos_node) + pattern.add_edge(concat_node, sin_node) + return pattern diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 8eda5611049..86c1b8f689c 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -1110,6 +1110,18 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype): target_input_ports = [0, 1] +@PT_OPERATOR_METATYPES.register() +class PTCosMetatype(PTOperatorMetatype): + name = "CosOp" + module_to_function_names = {NamespaceTarget.TORCH: ["cos"]} + + +@PT_OPERATOR_METATYPES.register() +class PTSinMetatype(PTOperatorMetatype): + name = "SinOp" + module_to_function_names = {NamespaceTarget.TORCH: ["sin"]} + + def get_operator_metatypes() -> List[Type[OperatorMetatype]]: """ Returns a list of the operator metatypes. diff --git a/nncf/torch/quantization/ignored_patterns.py b/nncf/torch/quantization/ignored_patterns.py index 895849c244e..e5bd1d93e16 100644 --- a/nncf/torch/quantization/ignored_patterns.py +++ b/nncf/torch/quantization/ignored_patterns.py @@ -227,3 +227,23 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern: main_pattern.add_pattern_alternative(get_se_block_with_reshape()) main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape()) return main_pattern + + +@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE) +def create_rope() -> GraphPattern: + pattern = GraphPattern() + matmul_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype} + ) + transpose_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.PTTransposeMetatype} + ) + concat_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype}) + cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype}) + sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype}) + + pattern.add_edge(matmul_node, transpose_node) + pattern.add_edge(transpose_node, concat_node) + pattern.add_edge(concat_node, cos_node) + pattern.add_edge(concat_node, sin_node) + return pattern diff --git a/tests/cross_fw/test_templates/helpers.py b/tests/cross_fw/test_templates/helpers.py index d578428b912..2bb21dac6d9 100644 --- a/tests/cross_fw/test_templates/helpers.py +++ b/tests/cross_fw/test_templates/helpers.py @@ -466,3 +466,22 @@ def __init__(self): def forward(self, x): return self.conv(x) + + +class RoPEModel(nn.Module): + INPUT_SIZE = [1, 10] + + def __init__(self): + super().__init__() + with set_torch_seed(): + self.data = torch.randn([5]) + + def forward(self, x): + x = torch.unsqueeze(x, dim=0) + reshape = torch.reshape(self.data, [1, 5, 1]) + x = torch.matmul(reshape, x) + x = torch.transpose(x, 2, 1) + x = torch.cat([x], dim=2) + x1 = x.sin() + x2 = x.cos() + return x1, x2 diff --git a/tests/onnx/data/reference_graphs/quantization/synthetic/rope_model.dot b/tests/onnx/data/reference_graphs/quantization/synthetic/rope_model.dot new file mode 100644 index 00000000000..84ddf95e3de --- /dev/null +++ b/tests/onnx/data/reference_graphs/quantization/synthetic/rope_model.dot @@ -0,0 +1,23 @@ +strict digraph { +"0 unsqueeze" [id=0, type=Unsqueeze]; +"1 cast" [id=1, type=Cast]; +"2 reshape" [id=2, type=Reshape]; +"3 matmul" [id=3, type=MatMul]; +"4 transpose" [id=4, type=Transpose]; +"5 concat" [id=5, type=Concat]; +"6 sin" [id=6, type=Sin]; +"7 cos" [id=7, type=Cos]; +"8 nncf_model_input_0" [id=8, type=nncf_model_input]; +"9 nncf_model_output_0" [id=9, type=nncf_model_output]; +"10 nncf_model_output_1" [id=10, type=nncf_model_output]; +"0 unsqueeze" -> "1 cast" [label="[1, 10, 1]", style=dashed]; +"1 cast" -> "3 matmul" [label="[1, 10, 1]", style=solid]; +"2 reshape" -> "3 matmul" [label="[1, 5]", style=solid]; +"3 matmul" -> "4 transpose" [label="[1, 10, 5]", style=solid]; +"4 transpose" -> "5 concat" [label="[1, 5, 10]", style=solid]; +"5 concat" -> "6 sin" [label="[1, 5, 10]", style=solid]; +"5 concat" -> "7 cos" [label="[1, 5, 10]", style=solid]; +"6 sin" -> "9 nncf_model_output_0" [label="[1, 5, 10]", style=solid]; +"7 cos" -> "10 nncf_model_output_1" [label="[1, 5, 10]", style=solid]; +"8 nncf_model_input_0" -> "0 unsqueeze" [label="[1, 10]", style=dashed]; +} diff --git a/tests/onnx/models.py b/tests/onnx/models.py index 7cdad6ba6ea..6ff99dced56 100644 --- a/tests/onnx/models.py +++ b/tests/onnx/models.py @@ -1757,3 +1757,126 @@ def __init__(self): model = onnx.helper.make_model(graph_def, opset_imports=[op]) onnx.checker.check_model(model) super().__init__(model, [input_shape], "unified_embedding_model.dot") + + +class RoPEModel(ONNXReferenceModel): + def __init__(self): + rng = np.random.default_rng(seed=0) + + input_shape = [1, 10] + input_name = "model_in" + model_in = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT64, input_shape) + + output_shape = [1, 5, 10] + cos_out_name = "cos_out" + cos_out = onnx.helper.make_tensor_value_info(cos_out_name, onnx.TensorProto.FLOAT, output_shape) + sin_out_name = "sin_out" + sin_out = onnx.helper.make_tensor_value_info(sin_out_name, onnx.TensorProto.FLOAT, output_shape) + + unsqueeze_out_name = "un_out" + unsqueeze_tensor_name = "un_tensor" + unsqueeze_tensor = create_initializer_tensor( + name=unsqueeze_tensor_name, tensor_array=np.int64([2]), data_type=onnx.TensorProto.INT64 + ) + unsqueeze_node = onnx.helper.make_node( + name="unsqueeze", + op_type="Unsqueeze", + inputs=[input_name, unsqueeze_tensor_name], + outputs=[unsqueeze_out_name], + ) + + cast_out_name = "cast_out" + cast_node = onnx.helper.make_node( + name="cast", + op_type="Cast", + to=onnx.TensorProto.FLOAT, + inputs=[unsqueeze_out_name], + outputs=[cast_out_name], + ) + + reshape_shape_name = "re_shape" + reshape_shape = create_initializer_tensor( + name=reshape_shape_name, + tensor_array=np.array([1, 5]).astype(np.int64), + data_type=onnx.TensorProto.INT64, + ) + reshape_tensor_name = "re_tensor" + reshape_tensor = create_initializer_tensor( + name=reshape_tensor_name, + tensor_array=rng.uniform(0, 1, (5)).astype(np.float32), + data_type=onnx.TensorProto.FLOAT, + ) + reshape_out_name = "re_out" + reshape_node = onnx.helper.make_node( + name="reshape", + op_type="Reshape", + inputs=[reshape_tensor_name, reshape_shape_name], + outputs=[reshape_out_name], + ) + + matmul_out_name = "mm_out" + matmul_node = onnx.helper.make_node( + name="matmul", + op_type="MatMul", + inputs=[cast_out_name, reshape_out_name], + outputs=[matmul_out_name], + ) + + transpose_out_name = "trans_out" + transpose_node = onnx.helper.make_node( + name="transpose", + op_type="Transpose", + inputs=[matmul_out_name], + outputs=[transpose_out_name], + perm=[0, 2, 1], + ) + + concat_out_name = "concat_out" + concat_node = onnx.helper.make_node( + name="concat", + op_type="Concat", + inputs=[transpose_out_name], + outputs=[concat_out_name], + axis=-1, + ) + + sin_node = onnx.helper.make_node( + name="sin", + op_type="Sin", + inputs=[concat_out_name], + outputs=[sin_out_name], + ) + + cos_node = onnx.helper.make_node( + name="cos", + op_type="Cos", + inputs=[concat_out_name], + outputs=[cos_out_name], + ) + + graph_def = onnx.helper.make_graph( + nodes=[ + unsqueeze_node, + cast_node, + reshape_node, + matmul_node, + transpose_node, + concat_node, + sin_node, + cos_node, + ], + name="RoPEModel", + inputs=[model_in], + outputs=[sin_out, cos_out], + initializer=[ + unsqueeze_tensor, + reshape_tensor, + reshape_shape, + ], + ) + + op = onnx.OperatorSetIdProto() + op.version = OPSET_VERSION + model = onnx.helper.make_model(graph_def, opset_imports=[op]) + onnx.checker.check_model(model) + super().__init__(model, [input_shape], "rope_model.dot") diff --git a/tests/onnx/quantization/test_graphs.py b/tests/onnx/quantization/test_graphs.py index 9eccc145b42..ecdcf1dd747 100644 --- a/tests/onnx/quantization/test_graphs.py +++ b/tests/onnx/quantization/test_graphs.py @@ -11,7 +11,9 @@ import pytest +from nncf.parameters import ModelType from tests.onnx.models import ALL_SYNTHETIC_MODELS +from tests.onnx.models import RoPEModel from tests.onnx.quantization.common import compare_nncf_graph from tests.onnx.quantization.common import min_max_quantize_model from tests.onnx.quantization.common import mock_collect_statistics @@ -23,3 +25,13 @@ def test_synthetic_models_graph(model_cls_to_test, mocker): model_to_test = model_cls_to_test() quantized_model = min_max_quantize_model(model_to_test.onnx_model) compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph) + + +@pytest.mark.parametrize("model_cls_to_test", [RoPEModel]) +def test_synthetic_models_graph_transformer(model_cls_to_test, mocker): + mock_collect_statistics(mocker) + model_to_test = model_cls_to_test() + quantized_model = min_max_quantize_model( + model_to_test.onnx_model, quantization_params={"model_type": ModelType.TRANSFORMER} + ) + compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph) diff --git a/tests/onnx/test_pattern_manager.py b/tests/onnx/test_pattern_manager.py index d463aed769f..46f5a6366a8 100644 --- a/tests/onnx/test_pattern_manager.py +++ b/tests/onnx/test_pattern_manager.py @@ -53,7 +53,6 @@ IGNORING_IGNORED_PATTERN_REASONS = { IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for ONNX.", IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for ONNX.", - IgnoredPatternNames.ROPE: "Not relevant for ONNX.", } diff --git a/tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot new file mode 100644 index 00000000000..96b5a48c2e8 --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot @@ -0,0 +1,21 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 RoPEModel/unsqueeze_0" [id=1, type=unsqueeze]; +"2 RoPEModel/reshape_0" [id=2, type=reshape]; +"3 RoPEModel/matmul_0" [id=3, type=matmul]; +"4 RoPEModel/transpose_0" [id=4, type=transpose]; +"5 RoPEModel/cat_0" [id=5, type=cat]; +"6 RoPEModel/sin_0" [id=6, type=sin]; +"7 RoPEModel/cos_0" [id=7, type=cos]; +"8 /nncf_model_output_0" [id=8, type=nncf_model_output]; +"9 /nncf_model_output_1" [id=9, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 RoPEModel/unsqueeze_0"; +"1 RoPEModel/unsqueeze_0" -> "3 RoPEModel/matmul_0"; +"2 RoPEModel/reshape_0" -> "3 RoPEModel/matmul_0"; +"3 RoPEModel/matmul_0" -> "4 RoPEModel/transpose_0"; +"4 RoPEModel/transpose_0" -> "5 RoPEModel/cat_0"; +"5 RoPEModel/cat_0" -> "6 RoPEModel/sin_0"; +"5 RoPEModel/cat_0" -> "7 RoPEModel/cos_0"; +"6 RoPEModel/sin_0" -> "8 /nncf_model_output_0"; +"7 RoPEModel/cos_0" -> "9 /nncf_model_output_1"; +} diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index 2dd649712f7..eba35163c7c 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -16,6 +16,7 @@ import torch from nncf import Dataset +from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization @@ -23,6 +24,7 @@ from nncf.torch.layers import NNCF_RNN from nncf.torch.layers import LSTMCellNNCF from tests.cross_fw.test_templates.helpers import EmbeddingModel +from tests.cross_fw.test_templates.helpers import RoPEModel from tests.cross_fw.test_templates.helpers import ScaledDotProductAttentionModel from tests.torch import test_models from tests.torch.quantization.test_algo_quantization import SharedLayersModel @@ -50,6 +52,7 @@ def get_model_name(description): TEST_MODELS_DESC = [ (ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}), + (ModelDesc("rope_model", RoPEModel, [1, 10]), {"model_type": ModelType.TRANSFORMER}), ( ModelDesc( "scaled_dot_product_attention_model", diff --git a/tests/torch/test_pattern_manager.py b/tests/torch/test_pattern_manager.py index 6806f7052b6..bbf070f32f3 100644 --- a/tests/torch/test_pattern_manager.py +++ b/tests/torch/test_pattern_manager.py @@ -74,7 +74,6 @@ IGNORING_IGNORED_PATTERN_REASONS = { IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for Torch.", IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for Torch.", - IgnoredPatternNames.ROPE: "Not relevant for Torch.", }