From e2bae91e04b8bce8d204ed0d9e9a64410e56f126 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 26 Sep 2023 19:38:03 +0000 Subject: [PATCH 1/2] [mlir] Apply is_mlir_conv predicate in standalone MLIr offloading Currently, the is_mlir_conv predicate wasn't being used when offloading standalone convolutions to MLIR on Navi3x, which caused failures relating to being unable to construct the MLIR program when a 3D convlolution was passed in. Fixes https://github.com/ROCmSoftwarePlatform/rocMLIR-internal/issues/1153 This commit amends the standalone lowering to use said predicate, as well as to include quant_convolution and quant_dot into the set of operations that get a standalone lowering. --- src/targets/gpu/fuse_mlir.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f2f2ccc8015..91b22b8593f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -327,12 +327,12 @@ struct find_mlir_standalone_op struct find_mlir_standalone_convolution_op : find_mlir_standalone_op { - auto matcher() const { return match::name("convolution"); } + auto matcher() const { return is_mlir_conv; } }; struct find_mlir_standalone_dot_op : find_mlir_standalone_op { - auto matcher() const { return match::name("dot"); } + auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); } }; /** @@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx) { return true; } - else if(op_name == "convolution") + else if(op_name == "convolution" || op_name == "quant_convolution") { if(ctx == nullptr) { From d8d417cddec1fa7a92e8e5e6ac15bc4116e1be40 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 26 Sep 2023 20:55:12 +0000 Subject: [PATCH 2/2] cppcheck fix --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 91b22b8593f..e40b31ddd1d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx) { return true; } - else if(op_name == "convolution" || op_name == "quant_convolution") + else if(op_name == "convolution" or op_name == "quant_convolution") { if(ctx == nullptr) {