Skip to content

Commit

Permalink
Added ignored RoPE for ONNX/PT (#3059)
Browse files Browse the repository at this point in the history
On top of #3049

### Changes

- As stated in the title.
- Added Sin and Cos ONNX/PT operations.

### Reason for changes

- Accuracy improvement for models with ModelType.TRANSFORMERS type.

### Related tickets

- 155511

### Tests

- Added graph test for ONNX/PT.
  • Loading branch information
KodiaqQ authored Nov 7, 2024
1 parent 20ab35d commit 80df08b
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 2 deletions.
12 changes: 12 additions & 0 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,18 @@ class ONNXErfMetatype(ONNXOpMetatype):
op_names = ["Erf"]


@ONNX_OPERATION_METATYPES.register()
class ONNXCosMetatype(ONNXOpMetatype):
name = "CosOp"
op_names = ["Cos"]


@ONNX_OPERATION_METATYPES.register()
class ONNXSinMetatype(ONNXOpMetatype):
name = "SinOp"
op_names = ["Sin"]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 22 additions & 0 deletions nncf/onnx/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,25 @@ def create_se_block() -> GraphPattern:
multiply_node = node_id
pattern.add_edge(non_pattern_node, multiply_node)
return pattern


@ONNX_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.ONNXMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.ONNXTransposeMetatype}
)
concat_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.ONNXConcatMetatype}
)
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.ONNXCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.ONNXSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
12 changes: 12 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,18 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
target_input_ports = [0, 1]


@PT_OPERATOR_METATYPES.register()
class PTCosMetatype(PTOperatorMetatype):
name = "CosOp"
module_to_function_names = {NamespaceTarget.TORCH: ["cos"]}


@PT_OPERATOR_METATYPES.register()
class PTSinMetatype(PTOperatorMetatype):
name = "SinOp"
module_to_function_names = {NamespaceTarget.TORCH: ["sin"]}


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
20 changes: 20 additions & 0 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,23 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern:
main_pattern.add_pattern_alternative(get_se_block_with_reshape())
main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape())
return main_pattern


@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.PTTransposeMetatype}
)
concat_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype})
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
19 changes: 19 additions & 0 deletions tests/cross_fw/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,22 @@ def __init__(self):

def forward(self, x):
return self.conv(x)


class RoPEModel(nn.Module):
INPUT_SIZE = [1, 10]

def __init__(self):
super().__init__()
with set_torch_seed():
self.data = torch.randn([5])

def forward(self, x):
x = torch.unsqueeze(x, dim=0)
reshape = torch.reshape(self.data, [1, 5, 1])
x = torch.matmul(reshape, x)
x = torch.transpose(x, 2, 1)
x = torch.cat([x], dim=2)
x1 = x.sin()
x2 = x.cos()
return x1, x2
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
strict digraph {
"0 unsqueeze" [id=0, type=Unsqueeze];
"1 cast" [id=1, type=Cast];
"2 reshape" [id=2, type=Reshape];
"3 matmul" [id=3, type=MatMul];
"4 transpose" [id=4, type=Transpose];
"5 concat" [id=5, type=Concat];
"6 sin" [id=6, type=Sin];
"7 cos" [id=7, type=Cos];
"8 nncf_model_input_0" [id=8, type=nncf_model_input];
"9 nncf_model_output_0" [id=9, type=nncf_model_output];
"10 nncf_model_output_1" [id=10, type=nncf_model_output];
"0 unsqueeze" -> "1 cast" [label="[1, 10, 1]", style=dashed];
"1 cast" -> "3 matmul" [label="[1, 10, 1]", style=solid];
"2 reshape" -> "3 matmul" [label="[1, 5]", style=solid];
"3 matmul" -> "4 transpose" [label="[1, 10, 5]", style=solid];
"4 transpose" -> "5 concat" [label="[1, 5, 10]", style=solid];
"5 concat" -> "6 sin" [label="[1, 5, 10]", style=solid];
"5 concat" -> "7 cos" [label="[1, 5, 10]", style=solid];
"6 sin" -> "9 nncf_model_output_0" [label="[1, 5, 10]", style=solid];
"7 cos" -> "10 nncf_model_output_1" [label="[1, 5, 10]", style=solid];
"8 nncf_model_input_0" -> "0 unsqueeze" [label="[1, 10]", style=dashed];
}
123 changes: 123 additions & 0 deletions tests/onnx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,3 +1757,126 @@ def __init__(self):
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "unified_embedding_model.dot")


class RoPEModel(ONNXReferenceModel):
def __init__(self):
rng = np.random.default_rng(seed=0)

input_shape = [1, 10]
input_name = "model_in"
model_in = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT64, input_shape)

output_shape = [1, 5, 10]
cos_out_name = "cos_out"
cos_out = onnx.helper.make_tensor_value_info(cos_out_name, onnx.TensorProto.FLOAT, output_shape)
sin_out_name = "sin_out"
sin_out = onnx.helper.make_tensor_value_info(sin_out_name, onnx.TensorProto.FLOAT, output_shape)

unsqueeze_out_name = "un_out"
unsqueeze_tensor_name = "un_tensor"
unsqueeze_tensor = create_initializer_tensor(
name=unsqueeze_tensor_name, tensor_array=np.int64([2]), data_type=onnx.TensorProto.INT64
)
unsqueeze_node = onnx.helper.make_node(
name="unsqueeze",
op_type="Unsqueeze",
inputs=[input_name, unsqueeze_tensor_name],
outputs=[unsqueeze_out_name],
)

cast_out_name = "cast_out"
cast_node = onnx.helper.make_node(
name="cast",
op_type="Cast",
to=onnx.TensorProto.FLOAT,
inputs=[unsqueeze_out_name],
outputs=[cast_out_name],
)

reshape_shape_name = "re_shape"
reshape_shape = create_initializer_tensor(
name=reshape_shape_name,
tensor_array=np.array([1, 5]).astype(np.int64),
data_type=onnx.TensorProto.INT64,
)
reshape_tensor_name = "re_tensor"
reshape_tensor = create_initializer_tensor(
name=reshape_tensor_name,
tensor_array=rng.uniform(0, 1, (5)).astype(np.float32),
data_type=onnx.TensorProto.FLOAT,
)
reshape_out_name = "re_out"
reshape_node = onnx.helper.make_node(
name="reshape",
op_type="Reshape",
inputs=[reshape_tensor_name, reshape_shape_name],
outputs=[reshape_out_name],
)

matmul_out_name = "mm_out"
matmul_node = onnx.helper.make_node(
name="matmul",
op_type="MatMul",
inputs=[cast_out_name, reshape_out_name],
outputs=[matmul_out_name],
)

transpose_out_name = "trans_out"
transpose_node = onnx.helper.make_node(
name="transpose",
op_type="Transpose",
inputs=[matmul_out_name],
outputs=[transpose_out_name],
perm=[0, 2, 1],
)

concat_out_name = "concat_out"
concat_node = onnx.helper.make_node(
name="concat",
op_type="Concat",
inputs=[transpose_out_name],
outputs=[concat_out_name],
axis=-1,
)

sin_node = onnx.helper.make_node(
name="sin",
op_type="Sin",
inputs=[concat_out_name],
outputs=[sin_out_name],
)

cos_node = onnx.helper.make_node(
name="cos",
op_type="Cos",
inputs=[concat_out_name],
outputs=[cos_out_name],
)

graph_def = onnx.helper.make_graph(
nodes=[
unsqueeze_node,
cast_node,
reshape_node,
matmul_node,
transpose_node,
concat_node,
sin_node,
cos_node,
],
name="RoPEModel",
inputs=[model_in],
outputs=[sin_out, cos_out],
initializer=[
unsqueeze_tensor,
reshape_tensor,
reshape_shape,
],
)

op = onnx.OperatorSetIdProto()
op.version = OPSET_VERSION
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "rope_model.dot")
12 changes: 12 additions & 0 deletions tests/onnx/quantization/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import pytest

from nncf.parameters import ModelType
from tests.onnx.models import ALL_SYNTHETIC_MODELS
from tests.onnx.models import RoPEModel
from tests.onnx.quantization.common import compare_nncf_graph
from tests.onnx.quantization.common import min_max_quantize_model
from tests.onnx.quantization.common import mock_collect_statistics
Expand All @@ -23,3 +25,13 @@ def test_synthetic_models_graph(model_cls_to_test, mocker):
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(model_to_test.onnx_model)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)


@pytest.mark.parametrize("model_cls_to_test", [RoPEModel])
def test_synthetic_models_graph_transformer(model_cls_to_test, mocker):
mock_collect_statistics(mocker)
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(
model_to_test.onnx_model, quantization_params={"model_type": ModelType.TRANSFORMER}
)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)
1 change: 0 additions & 1 deletion tests/onnx/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
IGNORING_IGNORED_PATTERN_REASONS = {
IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for ONNX.",
IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for ONNX.",
IgnoredPatternNames.ROPE: "Not relevant for ONNX.",
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 RoPEModel/unsqueeze_0" [id=1, type=unsqueeze];
"2 RoPEModel/reshape_0" [id=2, type=reshape];
"3 RoPEModel/matmul_0" [id=3, type=matmul];
"4 RoPEModel/transpose_0" [id=4, type=transpose];
"5 RoPEModel/cat_0" [id=5, type=cat];
"6 RoPEModel/sin_0" [id=6, type=sin];
"7 RoPEModel/cos_0" [id=7, type=cos];
"8 /nncf_model_output_0" [id=8, type=nncf_model_output];
"9 /nncf_model_output_1" [id=9, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 RoPEModel/unsqueeze_0";
"1 RoPEModel/unsqueeze_0" -> "3 RoPEModel/matmul_0";
"2 RoPEModel/reshape_0" -> "3 RoPEModel/matmul_0";
"3 RoPEModel/matmul_0" -> "4 RoPEModel/transpose_0";
"4 RoPEModel/transpose_0" -> "5 RoPEModel/cat_0";
"5 RoPEModel/cat_0" -> "6 RoPEModel/sin_0";
"5 RoPEModel/cat_0" -> "7 RoPEModel/cos_0";
"6 RoPEModel/sin_0" -> "8 /nncf_model_output_0";
"7 RoPEModel/cos_0" -> "9 /nncf_model_output_1";
}
3 changes: 3 additions & 0 deletions tests/torch/ptq/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import torch

from nncf import Dataset
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.torch import wrap_model
from nncf.torch.layers import NNCF_RNN
from nncf.torch.layers import LSTMCellNNCF
from tests.cross_fw.test_templates.helpers import EmbeddingModel
from tests.cross_fw.test_templates.helpers import RoPEModel
from tests.cross_fw.test_templates.helpers import ScaledDotProductAttentionModel
from tests.torch import test_models
from tests.torch.quantization.test_algo_quantization import SharedLayersModel
Expand Down Expand Up @@ -50,6 +52,7 @@ def get_model_name(description):

TEST_MODELS_DESC = [
(ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}),
(ModelDesc("rope_model", RoPEModel, [1, 10]), {"model_type": ModelType.TRANSFORMER}),
(
ModelDesc(
"scaled_dot_product_attention_model",
Expand Down
1 change: 0 additions & 1 deletion tests/torch/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
IGNORING_IGNORED_PATTERN_REASONS = {
IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for Torch.",
IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for Torch.",
IgnoredPatternNames.ROPE: "Not relevant for Torch.",
}


Expand Down

0 comments on commit 80df08b

Please sign in to comment.