Skip to content

Commit

Permalink
Dont use standalone convolution on fp by default
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Oct 14, 2023
1 parent 91c1fd1 commit dd6ff56
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
enum class mlir_mode
{
all,
int8,
fast,
int8,
none
};

Expand Down Expand Up @@ -192,7 +192,9 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode != mlir_mode::fast)
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
Expand Down Expand Up @@ -416,6 +418,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});

match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
Expand Down

0 comments on commit dd6ff56

Please sign in to comment.