Skip to content

Commit

Permalink
[Triton] Modify disabled mixed-precision mmav2 swizzling to be enable…
Browse files Browse the repository at this point in the history
…d with different configuration.

PiperOrigin-RevId: 619236546
  • Loading branch information
Moerafaat authored and tensorflower-gardener committed Mar 26, 2024
1 parent 5289b17 commit f5a03f1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 0 deletions.
52 changes: 52 additions & 0 deletions third_party/triton/cl619146327.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
This patch can be removed once this commit is included:
https://github.com/openai/triton/commit/6ea5b56015db9e0bcff45ec7116cfcbfa729a516

diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
@@ -305,9 +305,10 @@ compared to 1*64 when the hasLeadingOffs
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
- // for now, disable swizzle when using transposed int8 tensor cores
- if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
- return get(context, 1, 1, 1, order, CTALayout);
+ int vecWidth = 32 / typeWidthInBit;
+ if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) {
+ perPhase = std::max<int>(perPhase, 2 * vecWidth);
+ }
int rank = order.size();
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir
new file mode 100644
--- /dev/null
+++ b/test/TritonGPU/reduce-data-duplication.mlir
@@ -0,0 +1,14 @@
+// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s
+
+// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}
+// CHECK: apply_swizzle
+// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]>
+
+#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}>
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
+module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) {
+ %0 = triton_gpu.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
+ tt.return
+ }
+}
\ No newline at end of file
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
@@ -541,8 +541,6 @@ getLoadMatrixFn(MemDescType descTy, cons
const int elemBytes = descTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();

- if (kWidth != (4 / elemBytes))
- assert(vecPhase == 1 || vecPhase == 4 * kWidth);
int nPerWarp =
std::max<int>(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8);

1 change: 1 addition & 0 deletions third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ def repo():
patch_file = [
"//third_party/triton:cl607293980.patch", # long standing :(
"//third_party/triton:cl617812302.patch",
"//third_party/triton:cl619146327.patch",
],
)
52 changes: 52 additions & 0 deletions third_party/xla/third_party/triton/cl619146327.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
This patch can be removed once this commit is included:
https://github.com/openai/triton/commit/6ea5b56015db9e0bcff45ec7116cfcbfa729a516

diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
@@ -305,9 +305,10 @@ compared to 1*64 when the hasLeadingOffs
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
- // for now, disable swizzle when using transposed int8 tensor cores
- if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
- return get(context, 1, 1, 1, order, CTALayout);
+ int vecWidth = 32 / typeWidthInBit;
+ if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) {
+ perPhase = std::max<int>(perPhase, 2 * vecWidth);
+ }
int rank = order.size();
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir
new file mode 100644
--- /dev/null
+++ b/test/TritonGPU/reduce-data-duplication.mlir
@@ -0,0 +1,14 @@
+// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s
+
+// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}
+// CHECK: apply_swizzle
+// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]>
+
+#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}>
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
+module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) {
+ %0 = triton_gpu.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
+ tt.return
+ }
+}
\ No newline at end of file
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
@@ -541,8 +541,6 @@ getLoadMatrixFn(MemDescType descTy, cons
const int elemBytes = descTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();

- if (kWidth != (4 / elemBytes))
- assert(vecPhase == 1 || vecPhase == 4 * kWidth);
int nPerWarp =
std::max<int>(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8);

1 change: 1 addition & 0 deletions third_party/xla/third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ def repo():
patch_file = [
"//third_party/triton:cl607293980.patch", # long standing :(
"//third_party/triton:cl617812302.patch",
"//third_party/triton:cl619146327.patch",
],
)

0 comments on commit f5a03f1

Please sign in to comment.