Skip to content

Commit

Permalink
Added Torch test
Browse files Browse the repository at this point in the history
(cherry picked from commit 28ba705)
  • Loading branch information
KodiaqQ committed Nov 5, 2024
1 parent 82c19f1 commit 640efc8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,18 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
target_input_ports = [0, 1]


@PT_OPERATOR_METATYPES.register()
class PTCosMetatype(PTOperatorMetatype):
name = "CosOp"
module_to_function_names = {NamespaceTarget.TORCH: ["cos"]}


@PT_OPERATOR_METATYPES.register()
class PTSinMetatype(PTOperatorMetatype):
name = "SinOp"
module_to_function_names = {NamespaceTarget.TORCH: ["sin"]}


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
20 changes: 20 additions & 0 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,23 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern:
main_pattern.add_pattern_alternative(get_se_block_with_reshape())
main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape())
return main_pattern


@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.PTTransposeMetatype}
)
concat_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype})
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
19 changes: 19 additions & 0 deletions tests/cross_fw/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,22 @@ def __init__(self):

def forward(self, x):
return self.conv(x)


class RoPEModel(nn.Module):
INPUT_SIZE = [1, 10]

def __init__(self):
super().__init__()
with set_torch_seed():
self.data = torch.randn([5])

def forward(self, x):
x = torch.unsqueeze(x, dim=0)
reshape = torch.reshape(self.data, [1, 5, 1])
x = torch.matmul(reshape, x)
x = torch.transpose(x, 2, 1)
x = torch.cat([x], dim=2)
x1 = x.sin()
x2 = x.cos()
return x1, x2
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 RoPEModel/unsqueeze_0" [id=1, type=unsqueeze];
"2 RoPEModel/reshape_0" [id=2, type=reshape];
"3 RoPEModel/matmul_0" [id=3, type=matmul];
"4 RoPEModel/transpose_0" [id=4, type=transpose];
"5 RoPEModel/cat_0" [id=5, type=cat];
"6 RoPEModel/sin_0" [id=6, type=sin];
"7 RoPEModel/cos_0" [id=7, type=cos];
"8 /nncf_model_output_0" [id=8, type=nncf_model_output];
"9 /nncf_model_output_1" [id=9, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 RoPEModel/unsqueeze_0";
"1 RoPEModel/unsqueeze_0" -> "3 RoPEModel/matmul_0";
"2 RoPEModel/reshape_0" -> "3 RoPEModel/matmul_0";
"3 RoPEModel/matmul_0" -> "4 RoPEModel/transpose_0";
"4 RoPEModel/transpose_0" -> "5 RoPEModel/cat_0";
"5 RoPEModel/cat_0" -> "6 RoPEModel/sin_0";
"5 RoPEModel/cat_0" -> "7 RoPEModel/cos_0";
"6 RoPEModel/sin_0" -> "8 /nncf_model_output_0";
"7 RoPEModel/cos_0" -> "9 /nncf_model_output_1";
}
3 changes: 3 additions & 0 deletions tests/torch/ptq/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import torch

from nncf import Dataset
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.torch import wrap_model
from nncf.torch.layers import NNCF_RNN
from nncf.torch.layers import LSTMCellNNCF
from tests.cross_fw.test_templates.helpers import EmbeddingModel
from tests.cross_fw.test_templates.helpers import RoPEModel
from tests.cross_fw.test_templates.helpers import ScaledDotProductAttentionModel
from tests.torch import test_models
from tests.torch.quantization.test_algo_quantization import SharedLayersModel
Expand Down Expand Up @@ -50,6 +52,7 @@ def get_model_name(description):

TEST_MODELS_DESC = [
(ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}),
(ModelDesc("rope_model", RoPEModel, [1, 10]), {"model_type": ModelType.TRANSFORMER}),
(
ModelDesc(
"scaled_dot_product_attention_model",
Expand Down

0 comments on commit 640efc8

Please sign in to comment.