From e86c9fc352daf46cbaa126a2b6f65c1099143b79 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Thu, 25 Jul 2024 17:05:22 +0000 Subject: [PATCH 1/2] Check MatMul expected attributes in a predicate instead of assert in the transformation callback. Fixes PyTorch addmm layer tests. --- .../src/transformations/mlir/op/matmul.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index 229b4734358cfd..0e5c3ca68c84ec 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -18,12 +18,6 @@ using namespace ov::mlir; struct ConvertMatMul { void operator()(ConversionContext& context, NodePtr node) { - auto matmul_node = std::dynamic_pointer_cast(node); - assert(matmul_node); - - // FIXME: current code limitation - assert(!matmul_node->get_transpose_a() && matmul_node->get_transpose_b()); - auto loc = createLocation(context.context, node); auto& builder = context.builder(); // TODO: Support broadcasts @@ -50,7 +44,12 @@ using namespace ov::pass::pattern; using namespace ov::op; MatMulPattern::MatMulPattern() : MarkPattern( - wrap_type({any_input(), any_input()}), + wrap_type({any_input(), any_input()}, [](const Output& output) { + auto matmul_node = std::dynamic_pointer_cast(output.get_node_shared_ptr()); + assert(matmul_node); + // FIXME: current code limitation + return !matmul_node->get_transpose_a() && matmul_node->get_transpose_b(); + }), ConvertMatMul()) { } From cd88a94691e524fb8d940e9403201a9f8b3ece70 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Thu, 25 Jul 2024 17:26:55 +0000 Subject: [PATCH 2/2] Put rank == 2 restriction on MatMul conversion. Fixes PyTorch linear layer tests. --- .../src/transformations/mlir/op/matmul.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index 0e5c3ca68c84ec..635779c0bf52fe 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -45,10 +45,14 @@ using namespace ov::op; MatMulPattern::MatMulPattern() : MarkPattern( wrap_type({any_input(), any_input()}, [](const Output& output) { - auto matmul_node = std::dynamic_pointer_cast(output.get_node_shared_ptr()); - assert(matmul_node); + auto node = std::dynamic_pointer_cast(output.get_node_shared_ptr()); + assert(node); // FIXME: current code limitation - return !matmul_node->get_transpose_a() && matmul_node->get_transpose_b(); + return + !has_dynamic_rank(node) && + !node->get_transpose_a() && node->get_transpose_b() && + node->get_input_partial_shape(0).rank().get_length() == 2 && + node->get_input_partial_shape(1).rank().get_length() == 2; }), ConvertMatMul()) { }