Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ignored RoPE block #3049

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,4 @@ class IgnoredPatternNames(Enum):
SE_BLOCK = PatternDesc("se_block")
FC_BN_HSWISH_ACTIVATION = PatternDesc("fc_bn_hswish_activation")
EQUAL_LOGICALNOT = PatternDesc("equal_logicalnot")
ROPE = PatternDesc("rope", model_types=[ModelType.TRANSFORMER])
12 changes: 12 additions & 0 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,18 @@ class ONNXErfMetatype(ONNXOpMetatype):
op_names = ["Erf"]


@ONNX_OPERATION_METATYPES.register()
class ONNXCosMetatype(ONNXOpMetatype):
name = "CosOp"
op_names = ["Cos"]


@ONNX_OPERATION_METATYPES.register()
class ONNXSinMetatype(ONNXOpMetatype):
name = "SinOp"
op_names = ["Sin"]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 22 additions & 0 deletions nncf/onnx/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,25 @@ def create_se_block() -> GraphPattern:
multiply_node = node_id
pattern.add_edge(non_pattern_node, multiply_node)
return pattern


@ONNX_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
Copy link
Contributor

@alexsu52 alexsu52 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshpv, please review this pattern.

pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.ONNXMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.ONNXTransposeMetatype}
)
concat_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.ONNXConcatMetatype}
)
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.ONNXCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.ONNXSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
12 changes: 12 additions & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,18 @@ class OVScaledDotProductAttentionMetatype(OVOpMetatype):
target_input_ports = [0, 1]


@OV_OPERATOR_METATYPES.register()
class OVCosMetatype(OVOpMetatype):
name = "CosOp"
op_names = ["Cos"]


@OV_OPERATOR_METATYPES.register()
class OVSinMetatype(OVOpMetatype):
name = "SinOp"
op_names = ["Sin"]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 22 additions & 0 deletions nncf/openvino/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,25 @@ def create_se_block() -> GraphPattern:
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern


@OPENVINO_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.OVMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.OVTransposeMetatype}
)
concat_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.OVConcatMetatype}
)
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.OVCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.OVSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
12 changes: 12 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,18 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
target_input_ports = [0, 1]


@PT_OPERATOR_METATYPES.register()
class PTCosMetatype(PTOperatorMetatype):
name = "CosOp"
module_to_function_names = {NamespaceTarget.TORCH: ["cos"]}


@PT_OPERATOR_METATYPES.register()
class PTSinMetatype(PTOperatorMetatype):
name = "SinOp"
module_to_function_names = {NamespaceTarget.TORCH: ["sin"]}


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
20 changes: 20 additions & 0 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,23 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern:
main_pattern.add_pattern_alternative(get_se_block_with_reshape())
main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape())
return main_pattern


@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderDokuchaev, please review this pattern.

pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.PTTransposeMetatype}
)
concat_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype})
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
19 changes: 19 additions & 0 deletions tests/cross_fw/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,22 @@ def __init__(self):

def forward(self, x):
return self.conv(x)


class RoPEModel(nn.Module):
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
INPUT_SIZE = [1, 10]

def __init__(self):
super().__init__()
with set_torch_seed():
self.data = torch.randn([5])

def forward(self, x):
x = torch.unsqueeze(x, dim=0)
reshape = torch.reshape(self.data, [1, 5, 1])
x = torch.matmul(reshape, x)
x = torch.transpose(x, 2, 1)
x = torch.cat([x], dim=2)
x1 = x.sin()
x2 = x.cos()
return x1, x2
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
strict digraph {
"0 unsqueeze" [id=0, type=Unsqueeze];
"1 cast" [id=1, type=Cast];
"2 reshape" [id=2, type=Reshape];
"3 matmul" [id=3, type=MatMul];
"4 transpose" [id=4, type=Transpose];
"5 concat" [id=5, type=Concat];
"6 sin" [id=6, type=Sin];
"7 cos" [id=7, type=Cos];
"8 nncf_model_input_0" [id=8, type=nncf_model_input];
"9 nncf_model_output_0" [id=9, type=nncf_model_output];
"10 nncf_model_output_1" [id=10, type=nncf_model_output];
"0 unsqueeze" -> "1 cast" [label="[1, 10, 1]", style=dashed];
"1 cast" -> "3 matmul" [label="[1, 10, 1]", style=solid];
"2 reshape" -> "3 matmul" [label="[1, 5]", style=solid];
"3 matmul" -> "4 transpose" [label="[1, 10, 5]", style=solid];
"4 transpose" -> "5 concat" [label="[1, 5, 10]", style=solid];
"5 concat" -> "6 sin" [label="[1, 5, 10]", style=solid];
"5 concat" -> "7 cos" [label="[1, 5, 10]", style=solid];
"6 sin" -> "9 nncf_model_output_0" [label="[1, 5, 10]", style=solid];
"7 cos" -> "10 nncf_model_output_1" [label="[1, 5, 10]", style=solid];
"8 nncf_model_input_0" -> "0 unsqueeze" [label="[1, 10]", style=dashed];
}
123 changes: 123 additions & 0 deletions tests/onnx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,3 +1757,126 @@ def __init__(self):
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "unified_embedding_model.dot")


class RoPEModel(ONNXReferenceModel):
def __init__(self):
rng = np.random.default_rng(seed=0)

input_shape = [1, 10]
input_name = "model_in"
model_in = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT64, input_shape)

output_shape = [1, 5, 10]
cos_out_name = "cos_out"
cos_out = onnx.helper.make_tensor_value_info(cos_out_name, onnx.TensorProto.FLOAT, output_shape)
sin_out_name = "sin_out"
sin_out = onnx.helper.make_tensor_value_info(sin_out_name, onnx.TensorProto.FLOAT, output_shape)

unsqueeze_out_name = "un_out"
unsqueeze_tensor_name = "un_tensor"
unsqueeze_tensor = create_initializer_tensor(
name=unsqueeze_tensor_name, tensor_array=np.int64([2]), data_type=onnx.TensorProto.INT64
)
unsqueeze_node = onnx.helper.make_node(
name="unsqueeze",
op_type="Unsqueeze",
inputs=[input_name, unsqueeze_tensor_name],
outputs=[unsqueeze_out_name],
)

cast_out_name = "cast_out"
cast_node = onnx.helper.make_node(
name="cast",
op_type="Cast",
to=onnx.TensorProto.FLOAT,
inputs=[unsqueeze_out_name],
outputs=[cast_out_name],
)

reshape_shape_name = "re_shape"
reshape_shape = create_initializer_tensor(
name=reshape_shape_name,
tensor_array=np.array([1, 5]).astype(np.int64),
data_type=onnx.TensorProto.INT64,
)
reshape_tensor_name = "re_tensor"
reshape_tensor = create_initializer_tensor(
name=reshape_tensor_name,
tensor_array=rng.uniform(0, 1, (5)).astype(np.float32),
data_type=onnx.TensorProto.FLOAT,
)
reshape_out_name = "re_out"
reshape_node = onnx.helper.make_node(
name="reshape",
op_type="Reshape",
inputs=[reshape_tensor_name, reshape_shape_name],
outputs=[reshape_out_name],
)

matmul_out_name = "mm_out"
matmul_node = onnx.helper.make_node(
name="matmul",
op_type="MatMul",
inputs=[cast_out_name, reshape_out_name],
outputs=[matmul_out_name],
)

transpose_out_name = "trans_out"
transpose_node = onnx.helper.make_node(
name="transpose",
op_type="Transpose",
inputs=[matmul_out_name],
outputs=[transpose_out_name],
perm=[0, 2, 1],
)

concat_out_name = "concat_out"
concat_node = onnx.helper.make_node(
name="concat",
op_type="Concat",
inputs=[transpose_out_name],
outputs=[concat_out_name],
axis=-1,
)

sin_node = onnx.helper.make_node(
name="sin",
op_type="Sin",
inputs=[concat_out_name],
outputs=[sin_out_name],
)

cos_node = onnx.helper.make_node(
name="cos",
op_type="Cos",
inputs=[concat_out_name],
outputs=[cos_out_name],
)

graph_def = onnx.helper.make_graph(
nodes=[
unsqueeze_node,
cast_node,
reshape_node,
matmul_node,
transpose_node,
concat_node,
sin_node,
cos_node,
],
name="RoPEModel",
inputs=[model_in],
outputs=[sin_out, cos_out],
initializer=[
unsqueeze_tensor,
reshape_tensor,
reshape_shape,
],
)

op = onnx.OperatorSetIdProto()
op.version = OPSET_VERSION
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "rope_model.dot")
12 changes: 12 additions & 0 deletions tests/onnx/quantization/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import pytest

from nncf.parameters import ModelType
from tests.onnx.models import ALL_SYNTHETIC_MODELS
from tests.onnx.models import RoPEModel
from tests.onnx.quantization.common import compare_nncf_graph
from tests.onnx.quantization.common import min_max_quantize_model
from tests.onnx.quantization.common import mock_collect_statistics
Expand All @@ -23,3 +25,13 @@ def test_synthetic_models_graph(model_cls_to_test, mocker):
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(model_to_test.onnx_model)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)


@pytest.mark.parametrize("model_cls_to_test", [RoPEModel])
def test_synthetic_models_graph_transformer(model_cls_to_test, mocker):
mock_collect_statistics(mocker)
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(
model_to_test.onnx_model, quantization_params={"model_type": ModelType.TRANSFORMER}
)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)
1 change: 1 addition & 0 deletions tests/onnx/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
IGNORING_IGNORED_PATTERN_REASONS = {
IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for ONNX.",
IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for ONNX.",
IgnoredPatternNames.ROPE: "Not relevant for ONNX.",
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
strict digraph {
"0 position_ids" [id=0, type=Parameter];
"1 Unsqueeze_3697" [id=1, type=Unsqueeze];
"2 convert" [id=2, type=Convert];
"3 MatMul" [id=3, type=MatMul];
"4 transpose" [id=4, type=Transpose];
"5 concat" [id=5, type=Concat];
"6 cos" [id=6, type=Cos];
"7 sin" [id=7, type=Sin];
"8 cos_result" [id=8, type=Result];
"9 sin_result" [id=9, type=Result];
"10 transpose/Constant_3703" [id=10, type=Constant];
"11 unsqueeze/Constant_3696" [id=11, type=Constant];
"12 broadcast" [id=12, type=Broadcast];
"13 broadcast/Constant_3700" [id=13, type=Constant];
"14 broadcast/Constant_3699" [id=14, type=Constant];
"0 position_ids" -> "1 Unsqueeze_3697" [label="[1, 10]", style=solid];
"1 Unsqueeze_3697" -> "2 convert" [label="[1, 1, 10]", style=solid];
"2 convert" -> "3 MatMul" [label="[1, 1, 10]", style=solid];
"3 MatMul" -> "4 transpose" [label="[1, 5, 10]", style=solid];
"4 transpose" -> "5 concat" [label="[1, 10, 5]", style=solid];
"5 concat" -> "6 cos" [label="[1, 10, 5]", style=solid];
"5 concat" -> "7 sin" [label="[1, 10, 5]", style=solid];
"6 cos" -> "8 cos_result" [label="[1, 10, 5]", style=solid];
"7 sin" -> "9 sin_result" [label="[1, 10, 5]", style=solid];
"10 transpose/Constant_3703" -> "4 transpose" [label="[3]", style=dashed];
"11 unsqueeze/Constant_3696" -> "1 Unsqueeze_3697" [label="[]", style=dashed];
"12 broadcast" -> "3 MatMul" [label="[1, 5, 1]", style=solid];
"13 broadcast/Constant_3700" -> "12 broadcast" [label="[3]", style=dashed];
"14 broadcast/Constant_3699" -> "12 broadcast" [label="[1, 5, 1]", style=solid];
}
Loading
Loading