Skip to content

Commit

Permalink
Add MLIR conv3d support (#3138)
Browse files Browse the repository at this point in the history
  • Loading branch information
turneram authored Jun 7, 2024
1 parent 2690ecf commit fd9af84
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,14 @@ auto is_mlir_conv(mlir_mode mode)
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
// Avoid MLIR assertion: Index < Length && "Invalid index!"
#ifdef _WIN32
// Temporarily make it available only on Windows
if(ins->get_shape().lens().size() != 4 and group > 1)
return false;
#else
if(ins->get_shape().lens().size() != 4)
return false;
#endif
if(contains({shape::fp8e4m3fnuz_type, shape::int8_type}, input.type()))
return true;
if(mode == mlir_mode::all)
Expand Down

0 comments on commit fd9af84

Please sign in to comment.