Skip to content

Commit

Permalink
Added ignored RoPE block
Browse files Browse the repository at this point in the history
  • Loading branch information
KodiaqQ committed Oct 31, 2024
1 parent 51a7fb6 commit 1ed567e
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 1 deletion.
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,4 @@ class IgnoredPatternNames(Enum):
SE_BLOCK = PatternDesc("se_block")
FC_BN_HSWISH_ACTIVATION = PatternDesc("fc_bn_hswish_activation")
EQUAL_LOGICALNOT = PatternDesc("equal_logicalnot")
ROPE = PatternDesc("rope", model_types=[ModelType.TRANSFORMER])
12 changes: 12 additions & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,18 @@ class OVScaledDotProductAttentionMetatype(OVOpMetatype):
target_input_ports = [0, 1]


@OV_OPERATOR_METATYPES.register()
class OVCosMetatype(OVOpMetatype):
name = "CosOp"
op_names = ["Cos"]


@OV_OPERATOR_METATYPES.register()
class OVSinMetatype(OVOpMetatype):
name = "SinOp"
op_names = ["Sin"]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 22 additions & 0 deletions nncf/openvino/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,25 @@ def create_se_block() -> GraphPattern:
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
return pattern


@OPENVINO_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.OVMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.OVTransposeMetatype}
)
concat_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.OVConcatMetatype}
)
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.OVCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.OVSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
strict digraph {
"0 position_ids" [id=0, type=Parameter];
"1 Unsqueeze_3697" [id=1, type=Unsqueeze];
"2 convert" [id=2, type=Convert];
"3 MatMul" [id=3, type=MatMul];
"4 transpose" [id=4, type=Transpose];
"5 concat" [id=5, type=Concat];
"6 cos" [id=6, type=Cos];
"7 sin" [id=7, type=Sin];
"8 cos_result" [id=8, type=Result];
"9 sin_result" [id=9, type=Result];
"10 transpose/Constant_3703" [id=10, type=Constant];
"11 unsqueeze/Constant_3696" [id=11, type=Constant];
"12 broadcast" [id=12, type=Broadcast];
"13 broadcast/Constant_3700" [id=13, type=Constant];
"14 broadcast/Constant_3699" [id=14, type=Constant];
"0 position_ids" -> "1 Unsqueeze_3697" [label="[1, 10]", style=solid];
"1 Unsqueeze_3697" -> "2 convert" [label="[1, 1, 10]", style=solid];
"2 convert" -> "3 MatMul" [label="[1, 1, 10]", style=solid];
"3 MatMul" -> "4 transpose" [label="[1, 5, 10]", style=solid];
"4 transpose" -> "5 concat" [label="[1, 10, 5]", style=solid];
"5 concat" -> "6 cos" [label="[1, 10, 5]", style=solid];
"5 concat" -> "7 sin" [label="[1, 10, 5]", style=solid];
"6 cos" -> "8 cos_result" [label="[1, 10, 5]", style=solid];
"7 sin" -> "9 sin_result" [label="[1, 10, 5]", style=solid];
"10 transpose/Constant_3703" -> "4 transpose" [label="[3]", style=dashed];
"11 unsqueeze/Constant_3696" -> "1 Unsqueeze_3697" [label="[]", style=dashed];
"12 broadcast" -> "3 MatMul" [label="[1, 5, 1]", style=solid];
"13 broadcast/Constant_3700" -> "12 broadcast" [label="[3]", style=dashed];
"14 broadcast/Constant_3699" -> "12 broadcast" [label="[1, 5, 1]", style=solid];
}
23 changes: 23 additions & 0 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,3 +1162,26 @@ def _create_ov_model(self):
result = opset.result(conv, name="Result")
model = ov.Model([result], [input])
return model


class RoPEModel(OVReferenceModel):
def _create_ov_model(self):
position_ids = opset.parameter([1, 10], name="position_ids")

unsqueeze = opset.unsqueeze(position_ids, 0, name="unsqueeze")
convert = opset.convert(unsqueeze, ov.Type.f32, name="convert")

broadcast_data = self._rng.random((1, 5, 1)).astype(np.float32)
broadcast_shape = [1, 5, 1]
broadcast = opset.broadcast(broadcast_data, broadcast_shape, name="broadcast")

matmul = opset.matmul(broadcast, convert, transpose_a=False, transpose_b=False, name="MatMul")
transpose = opset.transpose(matmul, [0, 2, 1], name="transpose")
concat = opset.concat([transpose], axis=0, name="concat")
sin = opset.sin(concat, name="sin")
cos = opset.cos(concat, name="cos")
sin_result = opset.result(sin, name="sin_result")
cos_result = opset.result(cos, name="cos_result")

model = ov.Model([sin_result, cos_result], [position_ids])
return model
3 changes: 2 additions & 1 deletion tests/openvino/native/quantization/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tests.openvino.native.models import IfModel
from tests.openvino.native.models import IfModel_2
from tests.openvino.native.models import MatmulSoftmaxMatmulBlock
from tests.openvino.native.models import RoPEModel
from tests.openvino.native.models import ScaledDotProductAttentionModel
from tests.openvino.native.models import get_torch_model_info
from tests.openvino.native.quantization.test_fq_params_calculation import quantize_model
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_real_models_fq_placement(model_name_params, tmp_path):
compare_nncf_graphs(quantized_model, path_ref_graph)


@pytest.mark.parametrize("model_creator_func", [MatmulSoftmaxMatmulBlock])
@pytest.mark.parametrize("model_creator_func", [MatmulSoftmaxMatmulBlock, RoPEModel])
def test_transformer_models_fq_placement(model_creator_func, tmp_path):
model = model_creator_func()
quantized_model = quantize_model(
Expand Down

0 comments on commit 1ed567e

Please sign in to comment.