diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index bb3074d54a1..190fee46012 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -150,8 +150,8 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) enum class mlir_mode { all, - int8, fast, + int8, none }; @@ -192,7 +192,9 @@ auto is_mlir_conv(mlir_mode mode) return false; if(ins->get_shape().type() == shape::int8_type) return true; - if(mode != mlir_mode::fast) + if(mode == mlir_mode::int8) + return false; + if(mode == mlir_mode::all) return true; auto w = ins->inputs().at(1)->get_shape(); if(w.lens().size() != 4) @@ -416,6 +418,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast), .dot_mode = get_mode("fused", mode)}); + match::find_matches( mpm, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},