diff --git a/tests/openvino/native/data/2024.1/reference_graphs/quantized/scaled_dot_product_attention.dot b/tests/openvino/native/data/2024.1/reference_graphs/quantized/scaled_dot_product_attention.dot index 9891a2b675a..7feaaaedb9d 100644 --- a/tests/openvino/native/data/2024.1/reference_graphs/quantized/scaled_dot_product_attention.dot +++ b/tests/openvino/native/data/2024.1/reference_graphs/quantized/scaled_dot_product_attention.dot @@ -1,33 +1,25 @@ strict digraph { -"0 Input_1" [id=0, type=Parameter]; -"1 Input_2" [id=1, type=Parameter]; -"2 Input_3" [id=2, type=Parameter]; -"3 Input_4" [id=3, type=Parameter]; -"4 Input_1/fq_output_0" [id=4, type=FakeQuantize]; -"5 Input_2/fq_output_0" [id=5, type=FakeQuantize]; -"6 ScaledDotProductAttention_5" [id=6, type=ScaledDotProductAttention]; -"7 Result" [id=7, type=Result]; -"8 Constant_2553" [id=8, type=Constant]; -"9 Constant_2552" [id=9, type=Constant]; -"10 Constant_2551" [id=10, type=Constant]; -"11 Constant_2550" [id=11, type=Constant]; -"12 Constant_2548" [id=12, type=Constant]; -"13 Constant_2547" [id=13, type=Constant]; -"14 Constant_2546" [id=14, type=Constant]; -"15 Constant_2545" [id=15, type=Constant]; -"0 Input_1" -> "4 Input_1/fq_output_0" [label="[1, 1, 1, 64]", style=solid]; -"1 Input_2" -> "5 Input_2/fq_output_0" [label="[1, 1, 1, 64]", style=solid]; -"2 Input_3" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid]; -"3 Input_4" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 1]", style=solid]; -"4 Input_1/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid]; -"5 Input_2/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid]; -"6 ScaledDotProductAttention_5" -> "7 Result" [label="[1, 1, 1, 64]", style=solid]; -"8 Constant_2553" -> "5 Input_2/fq_output_0" [label="[]", style=solid]; -"9 Constant_2552" -> "5 Input_2/fq_output_0" [label="[]", style=solid]; -"10 Constant_2551" -> "5 Input_2/fq_output_0" [label="[]", style=solid]; -"11 Constant_2550" -> "5 Input_2/fq_output_0" [label="[]", style=solid]; -"12 Constant_2548" -> "4 Input_1/fq_output_0" [label="[]", style=solid]; -"13 Constant_2547" -> "4 Input_1/fq_output_0" [label="[]", style=solid]; -"14 Constant_2546" -> "4 Input_1/fq_output_0" [label="[]", style=solid]; -"15 Constant_2545" -> "4 Input_1/fq_output_0" [label="[]", style=solid]; +"0 x" [id=0, type=Parameter]; +"1 x/fq_output_0" [id=1, type=FakeQuantize]; +"2 aten^^view/Reshape" [id=2, label="2 aten::view/Reshape", type=Reshape]; +"3 aten^^view/Reshape_1" [id=3, label="3 aten::view/Reshape_1", type=Reshape]; +"4 aten^^scaled_dot_product_attention/ScaledDotProductAttention" [id=4, label="4 aten::scaled_dot_product_attention/ScaledDotProductAttention", type=ScaledDotProductAttention]; +"5 Result_1784" [id=5, type=Result]; +"6 prim^^ListConstruct/Concat" [id=6, label="6 prim::ListConstruct/Concat", type=Constant]; +"7 Constant_2000" [id=7, type=Constant]; +"8 x/fq_output_0/output_high" [id=8, type=Constant]; +"9 x/fq_output_0/output_low" [id=9, type=Constant]; +"10 x/fq_output_0/input_high" [id=10, type=Constant]; +"11 x/fq_output_0/input_low" [id=11, type=Constant]; +"0 x" -> "1 x/fq_output_0" [label="[2, 1, 12]", style=solid]; +"1 x/fq_output_0" -> "2 aten^^view/Reshape" [label="[2, 1, 12]", style=solid]; +"2 aten^^view/Reshape" -> "3 aten^^view/Reshape_1" [label="[24]", style=solid]; +"3 aten^^view/Reshape_1" -> "4 aten^^scaled_dot_product_attention/ScaledDotProductAttention" [label="parallel_input_port_ids:[1, 0], shape:[2, 1, 12]", style=solid]; +"4 aten^^scaled_dot_product_attention/ScaledDotProductAttention" -> "5 Result_1784" [label="[2, 1, 12]", style=solid]; +"6 prim^^ListConstruct/Concat" -> "3 aten^^view/Reshape_1" [label="[3]", style=dashed]; +"7 Constant_2000" -> "2 aten^^view/Reshape" [label="[1]", style=dashed]; +"8 x/fq_output_0/output_high" -> "1 x/fq_output_0" [label="[]", style=solid]; +"9 x/fq_output_0/output_low" -> "1 x/fq_output_0" [label="[]", style=solid]; +"10 x/fq_output_0/input_high" -> "1 x/fq_output_0" [label="[]", style=solid]; +"11 x/fq_output_0/input_low" -> "1 x/fq_output_0" [label="[]", style=solid]; } diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index e0b87ebb6b9..d2b5665aacb 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -871,15 +871,26 @@ def _create_ov_model(self): class ScaledDotProductAttentionModel(OVReferenceModel): def _create_ov_model(self): - query = opset.parameter([1, 1, 1, 64], name="Input_1") - key = opset.parameter([1, 1, 1, 64], name="Input_2") - value = opset.parameter([1, 1, 1, 64], name="Input_3") - attn_mask = opset.parameter([1, 1, 1, 1], name="Input_4") + import openvino + import torch - attn = opset.scaled_dot_product_attention(query, key, value, attn_mask) + from tests.torch.test_models.synthetic import ScaledDotProductModel + + return openvino.convert_model( + ScaledDotProductModel(), + input=ScaledDotProductModel.INPUT_SIZES, + example_input=torch.ones(ScaledDotProductModel.INPUT_SIZES), + ) + + x = opset.parameter([1, 1, 1, 64], name="Input_1") + attn_mask = opset.parameter([1, 1, 1, 1], name="Input_2") + x = opset.reshape(x, [64], False) + x = opset.reshape(x, [1, 1, 1, 64], False) + + attn = opset.scaled_dot_product_attention(x, x, x, attn_mask) result = opset.result(attn, name="Result") result.get_output_tensor(0).set_names(set(["Result"])) - model = ov.Model([result], [query, key, value, attn_mask]) + model = ov.Model([result], [x, attn_mask]) return model diff --git a/tests/torch/data/reference_graphs/quantized/synthetic_model/ScaledDotProductModel.dot b/tests/torch/data/reference_graphs/quantized/synthetic_model/ScaledDotProductModel.dot new file mode 100644 index 00000000000..a0dc9b5b9a9 --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/synthetic_model/ScaledDotProductModel.dot @@ -0,0 +1,13 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize]; +"2 ScaledDotProductModel/view_0" [id=2, type=view]; +"3 ScaledDotProductModel/view_1" [id=3, type=view]; +"4 ScaledDotProductModel/scaled_dot_product_attention_0" [id=4, type=scaled_dot_product_attention]; +"5 /nncf_model_output_0" [id=5, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ScaledDotProductModel/view_0"; +"2 ScaledDotProductModel/view_0" -> "3 ScaledDotProductModel/view_1"; +"3 ScaledDotProductModel/view_1" -> "4 ScaledDotProductModel/scaled_dot_product_attention_0" [label="parallel_input_port_ids:[1, 2]"]; +"4 ScaledDotProductModel/scaled_dot_product_attention_0" -> "5 /nncf_model_output_0"; +} diff --git a/tests/torch/test_compressed_graph.py b/tests/torch/test_compressed_graph.py index 1bce431f64e..e49dbfcb6d0 100644 --- a/tests/torch/test_compressed_graph.py +++ b/tests/torch/test_compressed_graph.py @@ -70,6 +70,7 @@ from tests.torch.test_models.synthetic import OrdinaryModelWithRecurrentInName from tests.torch.test_models.synthetic import PoolUnPool from tests.torch.test_models.synthetic import ReshapeModel +from tests.torch.test_models.synthetic import ScaledDotProductModel from tests.torch.test_models.synthetic import ShiftScaleParametrized from tests.torch.test_models.synthetic import TransposeModel @@ -798,6 +799,7 @@ def forward(self, x): wrap_inputs_fn=partial(n_inputs_fn, nargs=3), ), GeneralModelDesc(model_builder=MHA_single_input, input_sample_sizes=(MHA_single_input.INPUT_SIZES,)), + GeneralModelDesc(model_builder=ScaledDotProductModel, input_sample_sizes=(ScaledDotProductModel.INPUT_SIZES,)), GeneralModelDesc( model_name="OrdinaryModelWithRecurrentInName", model_builder=OrdinaryModelWithRecurrentInName, diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py index 42fc1f6e2e0..e3024fc0967 100644 --- a/tests/torch/test_models/synthetic.py +++ b/tests/torch/test_models/synthetic.py @@ -325,6 +325,19 @@ def forward(self, x, y, z): return torch.baddbmm(x, y, z) +class ScaledDotProductModel(nn.Module): + # EMBED_DIM = 4 + EMBED_DIM = 4 * 3 + INPUT_SIZES = [2, 1, EMBED_DIM] + + def forward(self, x): + shape = x.shape + x = x.view(-1).view(shape) + # k, q, v = torch.split(x, 4, -1) + # return nn.functional.scaled_dot_product_attention(k, q, v) + return nn.functional.scaled_dot_product_attention(x, x, x) + + class MHA_single_input(torch.nn.Module): EMBED_DIM = 4 INPUT_SIZES = [2, 1, EMBED_DIM]