forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Triton] Modify disabled mixed-precision mmav2 swizzling to be enable…
…d with different configuration. PiperOrigin-RevId: 619236546
- Loading branch information
1 parent
5289b17
commit f5a03f1
Showing
4 changed files
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
This patch can be removed once this commit is included: | ||
https://github.com/openai/triton/commit/6ea5b56015db9e0bcff45ec7116cfcbfa729a516 | ||
|
||
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | ||
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | ||
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | ||
@@ -305,9 +305,10 @@ compared to 1*64 when the hasLeadingOffs | ||
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); | ||
perPhase = std::max<int>(perPhase, 1); | ||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; | ||
- // for now, disable swizzle when using transposed int8 tensor cores | ||
- if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner) | ||
- return get(context, 1, 1, 1, order, CTALayout); | ||
+ int vecWidth = 32 / typeWidthInBit; | ||
+ if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) { | ||
+ perPhase = std::max<int>(perPhase, 2 * vecWidth); | ||
+ } | ||
int rank = order.size(); | ||
// --- handle A operand --- | ||
if (opIdx == 0) { // compute swizzling for A operand | ||
diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir | ||
new file mode 100644 | ||
--- /dev/null | ||
+++ b/test/TritonGPU/reduce-data-duplication.mlir | ||
@@ -0,0 +1,14 @@ | ||
+// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s | ||
+ | ||
+// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} | ||
+// CHECK: apply_swizzle | ||
+// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]> | ||
+ | ||
+#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> | ||
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> | ||
+module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { | ||
+ tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) { | ||
+ %0 = triton_gpu.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> | ||
+ tt.return | ||
+ } | ||
+} | ||
\ No newline at end of file | ||
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp | ||
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp | ||
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp | ||
@@ -541,8 +541,6 @@ getLoadMatrixFn(MemDescType descTy, cons | ||
const int elemBytes = descTy.getElementTypeBitWidth() / 8; | ||
auto order = sharedLayout.getOrder(); | ||
|
||
- if (kWidth != (4 / elemBytes)) | ||
- assert(vecPhase == 1 || vecPhase == 4 * kWidth); | ||
int nPerWarp = | ||
std::max<int>(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
This patch can be removed once this commit is included: | ||
https://github.com/openai/triton/commit/6ea5b56015db9e0bcff45ec7116cfcbfa729a516 | ||
|
||
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | ||
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | ||
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | ||
@@ -305,9 +305,10 @@ compared to 1*64 when the hasLeadingOffs | ||
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); | ||
perPhase = std::max<int>(perPhase, 1); | ||
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; | ||
- // for now, disable swizzle when using transposed int8 tensor cores | ||
- if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner) | ||
- return get(context, 1, 1, 1, order, CTALayout); | ||
+ int vecWidth = 32 / typeWidthInBit; | ||
+ if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) { | ||
+ perPhase = std::max<int>(perPhase, 2 * vecWidth); | ||
+ } | ||
int rank = order.size(); | ||
// --- handle A operand --- | ||
if (opIdx == 0) { // compute swizzling for A operand | ||
diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir | ||
new file mode 100644 | ||
--- /dev/null | ||
+++ b/test/TritonGPU/reduce-data-duplication.mlir | ||
@@ -0,0 +1,14 @@ | ||
+// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s | ||
+ | ||
+// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} | ||
+// CHECK: apply_swizzle | ||
+// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]> | ||
+ | ||
+#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> | ||
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> | ||
+module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { | ||
+ tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) { | ||
+ %0 = triton_gpu.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> | ||
+ tt.return | ||
+ } | ||
+} | ||
\ No newline at end of file | ||
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp | ||
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp | ||
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp | ||
@@ -541,8 +541,6 @@ getLoadMatrixFn(MemDescType descTy, cons | ||
const int elemBytes = descTy.getElementTypeBitWidth() / 8; | ||
auto order = sharedLayout.getOrder(); | ||
|
||
- if (kWidth != (4 / elemBytes)) | ||
- assert(vecPhase == 1 || vecPhase == 4 * kWidth); | ||
int nPerWarp = | ||
std::max<int>(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters