Skip to content

Commit

Permalink
Torch and OV solver bug repro
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 29, 2024
1 parent d94b93b commit d641ada
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
strict digraph {
"0 Input_1" [id=0, type=Parameter];
"1 Input_2" [id=1, type=Parameter];
"2 Input_3" [id=2, type=Parameter];
"3 Input_4" [id=3, type=Parameter];
"4 Input_1/fq_output_0" [id=4, type=FakeQuantize];
"5 Input_2/fq_output_0" [id=5, type=FakeQuantize];
"6 ScaledDotProductAttention_5" [id=6, type=ScaledDotProductAttention];
"7 Result" [id=7, type=Result];
"8 Constant_2553" [id=8, type=Constant];
"9 Constant_2552" [id=9, type=Constant];
"10 Constant_2551" [id=10, type=Constant];
"11 Constant_2550" [id=11, type=Constant];
"12 Constant_2548" [id=12, type=Constant];
"13 Constant_2547" [id=13, type=Constant];
"14 Constant_2546" [id=14, type=Constant];
"15 Constant_2545" [id=15, type=Constant];
"0 Input_1" -> "4 Input_1/fq_output_0" [label="[1, 1, 1, 64]", style=solid];
"1 Input_2" -> "5 Input_2/fq_output_0" [label="[1, 1, 1, 64]", style=solid];
"2 Input_3" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"3 Input_4" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 1]", style=solid];
"4 Input_1/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"5 Input_2/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"6 ScaledDotProductAttention_5" -> "7 Result" [label="[1, 1, 1, 64]", style=solid];
"8 Constant_2553" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"9 Constant_2552" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"10 Constant_2551" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"11 Constant_2550" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"12 Constant_2548" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"13 Constant_2547" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"14 Constant_2546" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"15 Constant_2545" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"0 x" [id=0, type=Parameter];
"1 x/fq_output_0" [id=1, type=FakeQuantize];
"2 aten^^view/Reshape" [id=2, label="2 aten::view/Reshape", type=Reshape];
"3 aten^^view/Reshape_1" [id=3, label="3 aten::view/Reshape_1", type=Reshape];
"4 aten^^scaled_dot_product_attention/ScaledDotProductAttention" [id=4, label="4 aten::scaled_dot_product_attention/ScaledDotProductAttention", type=ScaledDotProductAttention];
"5 Result_1784" [id=5, type=Result];
"6 prim^^ListConstruct/Concat" [id=6, label="6 prim::ListConstruct/Concat", type=Constant];
"7 Constant_2000" [id=7, type=Constant];
"8 x/fq_output_0/output_high" [id=8, type=Constant];
"9 x/fq_output_0/output_low" [id=9, type=Constant];
"10 x/fq_output_0/input_high" [id=10, type=Constant];
"11 x/fq_output_0/input_low" [id=11, type=Constant];
"0 x" -> "1 x/fq_output_0" [label="[2, 1, 12]", style=solid];
"1 x/fq_output_0" -> "2 aten^^view/Reshape" [label="[2, 1, 12]", style=solid];
"2 aten^^view/Reshape" -> "3 aten^^view/Reshape_1" [label="[24]", style=solid];
"3 aten^^view/Reshape_1" -> "4 aten^^scaled_dot_product_attention/ScaledDotProductAttention" [label="parallel_input_port_ids:[1, 0], shape:[2, 1, 12]", style=solid];
"4 aten^^scaled_dot_product_attention/ScaledDotProductAttention" -> "5 Result_1784" [label="[2, 1, 12]", style=solid];
"6 prim^^ListConstruct/Concat" -> "3 aten^^view/Reshape_1" [label="[3]", style=dashed];
"7 Constant_2000" -> "2 aten^^view/Reshape" [label="[1]", style=dashed];
"8 x/fq_output_0/output_high" -> "1 x/fq_output_0" [label="[]", style=solid];
"9 x/fq_output_0/output_low" -> "1 x/fq_output_0" [label="[]", style=solid];
"10 x/fq_output_0/input_high" -> "1 x/fq_output_0" [label="[]", style=solid];
"11 x/fq_output_0/input_low" -> "1 x/fq_output_0" [label="[]", style=solid];
}
23 changes: 17 additions & 6 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,15 +871,26 @@ def _create_ov_model(self):

class ScaledDotProductAttentionModel(OVReferenceModel):
def _create_ov_model(self):
query = opset.parameter([1, 1, 1, 64], name="Input_1")
key = opset.parameter([1, 1, 1, 64], name="Input_2")
value = opset.parameter([1, 1, 1, 64], name="Input_3")
attn_mask = opset.parameter([1, 1, 1, 1], name="Input_4")
import openvino
import torch

attn = opset.scaled_dot_product_attention(query, key, value, attn_mask)
from tests.torch.test_models.synthetic import ScaledDotProductModel

return openvino.convert_model(
ScaledDotProductModel(),
input=ScaledDotProductModel.INPUT_SIZES,
example_input=torch.ones(ScaledDotProductModel.INPUT_SIZES),
)

x = opset.parameter([1, 1, 1, 64], name="Input_1")
attn_mask = opset.parameter([1, 1, 1, 1], name="Input_2")
x = opset.reshape(x, [64], False)
x = opset.reshape(x, [1, 1, 1, 64], False)

attn = opset.scaled_dot_product_attention(x, x, x, attn_mask)
result = opset.result(attn, name="Result")
result.get_output_tensor(0).set_names(set(["Result"]))
model = ov.Model([result], [query, key, value, attn_mask])
model = ov.Model([result], [x, attn_mask])
return model


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize];
"2 ScaledDotProductModel/view_0" [id=2, type=view];
"3 ScaledDotProductModel/view_1" [id=3, type=view];
"4 ScaledDotProductModel/scaled_dot_product_attention_0" [id=4, type=scaled_dot_product_attention];
"5 /nncf_model_output_0" [id=5, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0";
"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ScaledDotProductModel/view_0";
"2 ScaledDotProductModel/view_0" -> "3 ScaledDotProductModel/view_1";
"3 ScaledDotProductModel/view_1" -> "4 ScaledDotProductModel/scaled_dot_product_attention_0" [label="parallel_input_port_ids:[1, 2]"];
"4 ScaledDotProductModel/scaled_dot_product_attention_0" -> "5 /nncf_model_output_0";
}
2 changes: 2 additions & 0 deletions tests/torch/test_compressed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from tests.torch.test_models.synthetic import OrdinaryModelWithRecurrentInName
from tests.torch.test_models.synthetic import PoolUnPool
from tests.torch.test_models.synthetic import ReshapeModel
from tests.torch.test_models.synthetic import ScaledDotProductModel
from tests.torch.test_models.synthetic import ShiftScaleParametrized
from tests.torch.test_models.synthetic import TransposeModel

Expand Down Expand Up @@ -798,6 +799,7 @@ def forward(self, x):
wrap_inputs_fn=partial(n_inputs_fn, nargs=3),
),
GeneralModelDesc(model_builder=MHA_single_input, input_sample_sizes=(MHA_single_input.INPUT_SIZES,)),
GeneralModelDesc(model_builder=ScaledDotProductModel, input_sample_sizes=(ScaledDotProductModel.INPUT_SIZES,)),
GeneralModelDesc(
model_name="OrdinaryModelWithRecurrentInName",
model_builder=OrdinaryModelWithRecurrentInName,
Expand Down
13 changes: 13 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,19 @@ def forward(self, x, y, z):
return torch.baddbmm(x, y, z)


class ScaledDotProductModel(nn.Module):
# EMBED_DIM = 4
EMBED_DIM = 4 * 3
INPUT_SIZES = [2, 1, EMBED_DIM]

def forward(self, x):
shape = x.shape
x = x.view(-1).view(shape)
# k, q, v = torch.split(x, 4, -1)
# return nn.functional.scaled_dot_product_attention(k, q, v)
return nn.functional.scaled_dot_product_attention(x, x, x)


class MHA_single_input(torch.nn.Module):
EMBED_DIM = 4
INPUT_SIZES = [2, 1, EMBED_DIM]
Expand Down

0 comments on commit d641ada

Please sign in to comment.