Skip to content

Commit

Permalink
fp16 acc autotune example
Browse files Browse the repository at this point in the history
  • Loading branch information
mikex86 committed Sep 10, 2024
1 parent 7cecfcb commit f85e333
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 4 deletions.
6 changes: 3 additions & 3 deletions matmul_autotune.ttir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma autotune BLOCK_SIZE_M {128, 256} default 256
#pragma autotune BLOCK_SIZE_N {128, 256} default 128
#pragma autotune BLOCK_SIZE_M {128, 256} default 128
#pragma autotune BLOCK_SIZE_N {128, 256} default 256
#pragma autotune BLOCK_SIZE_K {32, 64} default 32
#pragma autotune GROUP_SIZE_M {8, 12, 16, 20, 24} default 32
#pragma autotune GROUP_SIZE_M {8, 12, 16, 20, 24} default 8
#pragma autotune intrinsic num_warps {4, 8} default 8
#pragma autotune intrinsic num_stages {3, 4, 5} default 3

Expand Down
131 changes: 131 additions & 0 deletions matmul_autotune_fp16acc.ttir
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#pragma autotune BLOCK_SIZE_M {128, 256} default 128
#pragma autotune BLOCK_SIZE_N {128, 256} default 256
#pragma autotune BLOCK_SIZE_K {32, 64} default 64
#pragma autotune GROUP_SIZE_M {8, 12, 16, 20, 24} default 8
#pragma autotune intrinsic num_warps {4, 8} default 8
#pragma autotune intrinsic num_stages {3, 4, 5} default 3

#pragma argument 0 ptr cuMalloc(8192 * 8192 * 2)
#pragma argument 1 ptr cuMalloc(8192 * 8192 * 2)
#pragma argument 2 ptr cuMalloc(8192 * 8192 * 2)

#pragma argument 3 i32 8192
#pragma argument 4 i32 8192
#pragma argument 5 i32 8192

#pragma argument 6 i32 8192
#pragma argument 7 i32 8192
#pragma argument 8 i32 8192

#pragma grid x ((8192 / ${BLOCK_SIZE_M}) * (8192 / ${BLOCK_SIZE_N}))

#pragma launch matmul_kernel

module {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c31_i32 = arith.constant 31 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>
%c255_i32 = arith.constant ${${BLOCK_SIZE_N} - 1} : i32
%c127_i32 = arith.constant ${${BLOCK_SIZE_M} - 1} : i32
%cst_0 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xf16>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xf16>
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%cst_2 = arith.constant dense<${BLOCK_SIZE_K}> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi32>
%c32_i32 = arith.constant ${BLOCK_SIZE_K} : i32
%c256_i32 = arith.constant ${BLOCK_SIZE_N} : i32
%c128_i32 = arith.constant ${BLOCK_SIZE_M} : i32
%c8_i32 = arith.constant ${GROUP_SIZE_M} : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c127_i32 : i32
%2 = arith.divsi %1, %c128_i32 : i32
%3 = arith.addi %arg4, %c255_i32 : i32
%4 = arith.divsi %3, %c256_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.minsi %8, %c8_i32 : i32
%10 = arith.remsi %0, %9 : i32
%11 = arith.addi %7, %10 : i32
%12 = arith.remsi %0, %5 : i32
%13 = arith.divsi %12, %9 : i32
%14 = arith.muli %11, %c128_i32 : i32
%15 = tt.make_range {end = ${BLOCK_SIZE_M} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_M}xi32>
%16 = tt.splat %14 : i32 -> tensor<${BLOCK_SIZE_M}xi32>
%17 = arith.addi %16, %15 : tensor<${BLOCK_SIZE_M}xi32>
%18 = tt.splat %arg3 : i32 -> tensor<${BLOCK_SIZE_M}xi32>
%19 = arith.remsi %17, %18 : tensor<${BLOCK_SIZE_M}xi32>
%20 = arith.muli %13, %c256_i32 : i32
%21 = tt.make_range {end = ${BLOCK_SIZE_N} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32>
%22 = tt.splat %20 : i32 -> tensor<${BLOCK_SIZE_N}xi32>
%23 = arith.addi %22, %21 : tensor<${BLOCK_SIZE_N}xi32>
%24 = tt.splat %arg4 : i32 -> tensor<${BLOCK_SIZE_N}xi32>
%25 = arith.remsi %23, %24 : tensor<${BLOCK_SIZE_N}xi32>
%26 = tt.make_range {end = ${BLOCK_SIZE_K} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_K}xi32>
%27 = tt.expand_dims %19 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xi32> -> tensor<${BLOCK_SIZE_M}x1xi32>
%28 = tt.splat %arg6 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32>
%29 = arith.muli %27, %28 : tensor<${BLOCK_SIZE_M}x1xi32>
%30 = tt.expand_dims %26 {axis = 0 : i32} : tensor<${BLOCK_SIZE_K}xi32> -> tensor<1x${BLOCK_SIZE_K}xi32>
%31 = tt.broadcast %29 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi32>
%32 = tt.broadcast %30 : tensor<1x${BLOCK_SIZE_K}xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi32>
%33 = arith.addi %31, %32 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi32>
%34 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>
%35 = tt.addptr %34, %33 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi32>
%36 = tt.expand_dims %26 {axis = 1 : i32} : tensor<${BLOCK_SIZE_K}xi32> -> tensor<${BLOCK_SIZE_K}x1xi32>
%37 = tt.splat %arg7 : i32 -> tensor<${BLOCK_SIZE_K}x1xi32>
%38 = arith.muli %36, %37 : tensor<${BLOCK_SIZE_K}x1xi32>
%39 = tt.expand_dims %25 {axis = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32> -> tensor<1x${BLOCK_SIZE_N}xi32>
%40 = tt.broadcast %38 : tensor<${BLOCK_SIZE_K}x1xi32> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi32>
%41 = tt.broadcast %39 : tensor<1x${BLOCK_SIZE_N}xi32> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi32>
%42 = arith.addi %40, %41 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi32>
%43 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
%44 = tt.addptr %43, %42 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi32>
%45 = arith.addi %arg5, %c31_i32 : i32
%46 = arith.divsi %45, %c32_i32 : i32
%47 = arith.muli %arg7, %c32_i32 : i32
%48 = tt.splat %47 : i32 -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi32>
%49:3 = scf.for %arg9 = %c0_i32 to %46 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %35, %arg12 = %44) -> (tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>) : i32 {
%66 = arith.muli %arg9, %c32_i32 : i32
%67 = arith.subi %arg5, %66 : i32
%68 = tt.splat %67 : i32 -> tensor<1x${BLOCK_SIZE_K}xi32>
%69 = arith.cmpi slt, %30, %68 : tensor<1x${BLOCK_SIZE_K}xi32>
%70 = tt.broadcast %69 : tensor<1x${BLOCK_SIZE_K}xi1> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi1>
%71 = tt.load %arg11, %70, %cst_1 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>
%72 = tt.splat %67 : i32 -> tensor<${BLOCK_SIZE_K}x1xi32>
%73 = arith.cmpi slt, %36, %72 : tensor<${BLOCK_SIZE_K}x1xi32>
%74 = tt.broadcast %73 : tensor<${BLOCK_SIZE_K}x1xi1> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi1>
%75 = tt.load %arg12, %74, %cst_0 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
%76 = tt.dot %71, %75, %arg10, inputPrecision = tf32 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xf16> * tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xf16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>
%77 = tt.addptr %arg11, %cst_2 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi32>
%78 = tt.addptr %arg12, %48 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi32>
scf.yield %76, %77, %78 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
}
%50 = tt.expand_dims %17 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xi32> -> tensor<${BLOCK_SIZE_M}x1xi32>
%51 = tt.splat %arg8 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32>
%52 = arith.muli %51, %50 : tensor<${BLOCK_SIZE_M}x1xi32>
%53 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x1x!tt.ptr<f16>>
%54 = tt.addptr %53, %52 : tensor<${BLOCK_SIZE_M}x1x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x1xi32>
%55 = tt.expand_dims %23 {axis = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32> -> tensor<1x${BLOCK_SIZE_N}xi32>
%56 = tt.broadcast %54 : tensor<${BLOCK_SIZE_M}x1x!tt.ptr<f16>> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
%57 = tt.broadcast %55 : tensor<1x${BLOCK_SIZE_N}xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32>
%58 = tt.addptr %56, %57 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32>
%59 = tt.splat %arg3 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32>
%60 = arith.cmpi slt, %50, %59 : tensor<${BLOCK_SIZE_M}x1xi32>
%61 = tt.splat %arg4 : i32 -> tensor<1x${BLOCK_SIZE_N}xi32>
%62 = arith.cmpi slt, %55, %61 : tensor<1x${BLOCK_SIZE_N}xi32>
%63 = tt.broadcast %60 : tensor<${BLOCK_SIZE_M}x1xi1> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>
%64 = tt.broadcast %62 : tensor<1x${BLOCK_SIZE_N}xi1> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>
%65 = arith.andi %63, %64 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>
tt.store %58, %49#0, %65 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
tt.return
}
}
1 change: 0 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ int main(int argc, char *argv[]) {
enableFpFusion,
libraryNames,
libraryPaths, "ptx");

std::ofstream ostream(outputPath);
ostream << "// Shared memory requirements: " << std::to_string(shared_mem_bytes) << '\n';
ostream << ptxCode;
Expand Down

0 comments on commit f85e333

Please sign in to comment.