diff --git a/src/common/transformations/src/transformations/mlir/op/matmul.cpp b/src/common/transformations/src/transformations/mlir/op/matmul.cpp index 635779c0bf52fe..b8c50db1ed4f31 100644 --- a/src/common/transformations/src/transformations/mlir/op/matmul.cpp +++ b/src/common/transformations/src/transformations/mlir/op/matmul.cpp @@ -29,8 +29,25 @@ struct ConvertMatMul { auto empty = builder.create(loc, outType, dynamic_dimensions); auto zero = getConstant(builder, ov_output_element_type, 0); auto fill = builder.create(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty}); - // TODO: Add other variants of transpose_a/transpose_b - auto matmul = builder.create(loc, mlir::ValueRange{inputs[0], inputs[1]}, mlir::ValueRange{fill.getResult(0)}); + + mlir::SmallVector ins{inputs[0], inputs[1]}; + mlir::SmallVector outs{fill.getResult(0)}; + + auto matmul_node = std::dynamic_pointer_cast(node); + assert(matmul_node); + bool isTransposedA = matmul_node->get_transpose_a(); + bool isTransposedB = matmul_node->get_transpose_b(); + assert(!(isTransposedA && isTransposedB)); + + Operation* matmul; + if (isTransposedA) { + matmul = builder.create(loc, ins, outs); + } else if (isTransposedB) { + matmul = builder.create(loc, ins, outs); + } else { + matmul = builder.create(loc, ins, outs); + } + context.addOutputs(node, matmul); } }; @@ -48,11 +65,9 @@ MatMulPattern::MatMulPattern() : MarkPattern( auto node = std::dynamic_pointer_cast(output.get_node_shared_ptr()); assert(node); // FIXME: current code limitation - return - !has_dynamic_rank(node) && - !node->get_transpose_a() && node->get_transpose_b() && - node->get_input_partial_shape(0).rank().get_length() == 2 && - node->get_input_partial_shape(1).rank().get_length() == 2; + return !has_dynamic_rank(node) && !(node->get_transpose_a() && node->get_transpose_b()) && + node->get_input_partial_shape(0).rank().get_length() == 2 && + node->get_input_partial_shape(1).rank().get_length() == 2; }), ConvertMatMul()) { }