From 06f50f89eed5c760f4a946a1de81e7eccd232e3f Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Thu, 30 May 2024 10:50:15 -0700 Subject: [PATCH 1/2] Add mlir conv3d support --- requirements.txt | 2 +- src/targets/gpu/fuse_mlir.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 915e644c4c7..19dab3c0c42 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@3612396bca1139abf25e2ed0085fe481d275af89 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@c6182c64fd4cd7988d80fa7622b8142ae92b0c5f -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 70163a61365..a1f6cd6993c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -257,7 +257,7 @@ auto is_mlir_conv(mlir_mode mode) value v = ins->get_operator().to_value(); auto group = v.at("group").to(); // Avoid MLIR assertion: Index < Length && "Invalid index!" - if(ins->get_shape().lens().size() != 4) + if(ins->get_shape().lens().size() != 4 and group > 1) return false; if(contains({shape::fp8e4m3fnuz_type, shape::int8_type}, input.type())) return true; From 28ce3f562a43de01c1f2d351c0d265197b606aa2 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Fri, 7 Jun 2024 09:38:43 +0200 Subject: [PATCH 2/2] temporarily enable only on Windows --- src/targets/gpu/fuse_mlir.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a1f6cd6993c..d8a76be8641 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -257,8 +257,14 @@ auto is_mlir_conv(mlir_mode mode) value v = ins->get_operator().to_value(); auto group = v.at("group").to(); // Avoid MLIR assertion: Index < Length && "Invalid index!" +#ifdef _WIN32 + // Temporarily make it available only on Windows if(ins->get_shape().lens().size() != 4 and group > 1) return false; +#else + if(ins->get_shape().lens().size() != 4) + return false; +#endif if(contains({shape::fp8e4m3fnuz_type, shape::int8_type}, input.type())) return true; if(mode == mlir_mode::all)