From 640efc839cde9da434dea3f2d05a7b2326757de6 Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Mon, 4 Nov 2024 14:15:17 +0100 Subject: [PATCH] Added Torch test (cherry picked from commit 28ba70545f404b2e5223346b4840298c4741b3b1) --- nncf/torch/graph/operator_metatypes.py | 12 +++++++++++ nncf/torch/quantization/ignored_patterns.py | 20 ++++++++++++++++++ tests/cross_fw/test_templates/helpers.py | 19 +++++++++++++++++ .../quantized/ptq/symmetric/rope_model.dot | 21 +++++++++++++++++++ tests/torch/ptq/test_graphs.py | 3 +++ 5 files changed, 75 insertions(+) create mode 100644 tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 8eda5611049..86c1b8f689c 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -1110,6 +1110,18 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype): target_input_ports = [0, 1] +@PT_OPERATOR_METATYPES.register() +class PTCosMetatype(PTOperatorMetatype): + name = "CosOp" + module_to_function_names = {NamespaceTarget.TORCH: ["cos"]} + + +@PT_OPERATOR_METATYPES.register() +class PTSinMetatype(PTOperatorMetatype): + name = "SinOp" + module_to_function_names = {NamespaceTarget.TORCH: ["sin"]} + + def get_operator_metatypes() -> List[Type[OperatorMetatype]]: """ Returns a list of the operator metatypes. diff --git a/nncf/torch/quantization/ignored_patterns.py b/nncf/torch/quantization/ignored_patterns.py index 895849c244e..e5bd1d93e16 100644 --- a/nncf/torch/quantization/ignored_patterns.py +++ b/nncf/torch/quantization/ignored_patterns.py @@ -227,3 +227,23 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern: main_pattern.add_pattern_alternative(get_se_block_with_reshape()) main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape()) return main_pattern + + +@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE) +def create_rope() -> GraphPattern: + pattern = GraphPattern() + matmul_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype} + ) + transpose_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.PTTransposeMetatype} + ) + concat_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype}) + cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype}) + sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype}) + + pattern.add_edge(matmul_node, transpose_node) + pattern.add_edge(transpose_node, concat_node) + pattern.add_edge(concat_node, cos_node) + pattern.add_edge(concat_node, sin_node) + return pattern diff --git a/tests/cross_fw/test_templates/helpers.py b/tests/cross_fw/test_templates/helpers.py index d578428b912..2bb21dac6d9 100644 --- a/tests/cross_fw/test_templates/helpers.py +++ b/tests/cross_fw/test_templates/helpers.py @@ -466,3 +466,22 @@ def __init__(self): def forward(self, x): return self.conv(x) + + +class RoPEModel(nn.Module): + INPUT_SIZE = [1, 10] + + def __init__(self): + super().__init__() + with set_torch_seed(): + self.data = torch.randn([5]) + + def forward(self, x): + x = torch.unsqueeze(x, dim=0) + reshape = torch.reshape(self.data, [1, 5, 1]) + x = torch.matmul(reshape, x) + x = torch.transpose(x, 2, 1) + x = torch.cat([x], dim=2) + x1 = x.sin() + x2 = x.cos() + return x1, x2 diff --git a/tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot new file mode 100644 index 00000000000..96b5a48c2e8 --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/rope_model.dot @@ -0,0 +1,21 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 RoPEModel/unsqueeze_0" [id=1, type=unsqueeze]; +"2 RoPEModel/reshape_0" [id=2, type=reshape]; +"3 RoPEModel/matmul_0" [id=3, type=matmul]; +"4 RoPEModel/transpose_0" [id=4, type=transpose]; +"5 RoPEModel/cat_0" [id=5, type=cat]; +"6 RoPEModel/sin_0" [id=6, type=sin]; +"7 RoPEModel/cos_0" [id=7, type=cos]; +"8 /nncf_model_output_0" [id=8, type=nncf_model_output]; +"9 /nncf_model_output_1" [id=9, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 RoPEModel/unsqueeze_0"; +"1 RoPEModel/unsqueeze_0" -> "3 RoPEModel/matmul_0"; +"2 RoPEModel/reshape_0" -> "3 RoPEModel/matmul_0"; +"3 RoPEModel/matmul_0" -> "4 RoPEModel/transpose_0"; +"4 RoPEModel/transpose_0" -> "5 RoPEModel/cat_0"; +"5 RoPEModel/cat_0" -> "6 RoPEModel/sin_0"; +"5 RoPEModel/cat_0" -> "7 RoPEModel/cos_0"; +"6 RoPEModel/sin_0" -> "8 /nncf_model_output_0"; +"7 RoPEModel/cos_0" -> "9 /nncf_model_output_1"; +} diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index 2dd649712f7..eba35163c7c 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -16,6 +16,7 @@ import torch from nncf import Dataset +from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization @@ -23,6 +24,7 @@ from nncf.torch.layers import NNCF_RNN from nncf.torch.layers import LSTMCellNNCF from tests.cross_fw.test_templates.helpers import EmbeddingModel +from tests.cross_fw.test_templates.helpers import RoPEModel from tests.cross_fw.test_templates.helpers import ScaledDotProductAttentionModel from tests.torch import test_models from tests.torch.quantization.test_algo_quantization import SharedLayersModel @@ -50,6 +52,7 @@ def get_model_name(description): TEST_MODELS_DESC = [ (ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}), + (ModelDesc("rope_model", RoPEModel, [1, 10]), {"model_type": ModelType.TRANSFORMER}), ( ModelDesc( "scaled_dot_product_attention_model",