Skip to content

Commit

Permalink
Fixed buid instructions & fixed bug in kern arg preprocessor + added …
Browse files Browse the repository at this point in the history
…attn fwd autotune test
  • Loading branch information
mikex86 committed Sep 13, 2024
1 parent a7cf675 commit 801d05c
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 7 deletions.
12 changes: 10 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ add_executable(tritonc src/main.cpp
src/tinyscript.cpp
src/templating.cpp)
target_link_libraries(tritonc PRIVATE argparse)
target_link_libraries(tritonc PRIVATE cuda curand)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)

Expand Down Expand Up @@ -111,4 +110,13 @@ set(TRITON_GENERATED_INCLUDE_DIRS
)

target_include_directories(tritonc PRIVATE ${TRITON_INCLUDE_DIRS})
target_include_directories(tritonc PRIVATE ${TRITON_GENERATED_INCLUDE_DIRS})
target_include_directories(tritonc PRIVATE ${TRITON_GENERATED_INCLUDE_DIRS})

# include cuda headers
set(CMAKE_CUDA_ARCHITECTURES 75)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
enable_language(CUDA)
find_library(CUDA_LIBRARY cuda ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
find_library(CURAND_LIBRARY curand ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
target_include_directories(tritonc PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(tritonc PRIVATE ${CUDA_LIBRARY} ${CURAND_LIBRARY})
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ tritonc add_kernel.ttir --compute-capability 89 --num-stages 3 --num-warps 4 -o
#### Build LLVM

The required llvm version is a submodule of this repository under `third_party/llvm-project`.

Build-install the llvm-project system-wide.
Make sure to install `ninja-build` for fast builds!

```commandline
cd third_party/llvm-project
mkdir build
cd build
cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON
cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_INSTALL_PREFIX=/usr/local
cmake --build .
sudo make install
sudo cmake --build . --target install
```

##### Hot to build tritonc
Expand Down
252 changes: 252 additions & 0 deletions attn_fwd_autotune.ttir
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
#pragma autotune BLOCK_SIZE_M {128, 256} default 128
#pragma autotune BLOCK_SIZE_N {128, 256} default 128

#pragma autotune intrinsic num_warps {4, 8} default 4
#pragma autotune intrinsic num_stages {1, 2, 3, 4, 5, 6, 7, 8} default 2

// this is more a compile time constant, but there is no #pragma constant yet...
#pragma autotune HEAD_DIM {64} default 64


// we still have to allocate 3 * per each arg. to simulate split with jumpy stride access patterns in the kernel
// 48 = batch size
// 1024 = seq len
// 12 = num heads
// 64 head dim
// 2 = sizeof(float16)
#pragma argument 0 ptr cuMalloc(3 * 48 * 1024 * 12 * 64 * 2)
#pragma argument 1 ptr cuMalloc(3 * 48 * 1024 * 12 * 64 * 2)
#pragma argument 2 ptr cuMalloc(3 * 48 * 1024 * 12 * 64 * 2)
#pragma argument 3 ptr cuMalloc(48 * 1024 * 12 * 64 * 2)

#pragma argument 4 ptr cuMalloc(48 * 12 * 1024 * 2)

// this means 1.0f / sqrt(64*12) when interpreted as float
#pragma argument 5 i32 1024707898

// qstride 0, 2, 1
#pragma argument 6 i32 2359296
#pragma argument 7 i32 64
#pragma argument 8 i32 2304

// kstride 0, 2, 1
#pragma argument 9 i32 2359296
#pragma argument 10 i32 64
#pragma argument 11 i32 2304

// vstride 0, 2, 1
#pragma argument 12 i32 2359296
#pragma argument 13 i32 64
#pragma argument 14 i32 2304

// bias strides
#pragma argument 15 i32 0
#pragma argument 16 i32 0
#pragma argument 17 i32 0

// output strides
#pragma argument 18 i32 786432
#pragma argument 19 i32 64
#pragma argument 20 i32 768

// nheads
#pragma argument 21 i32 12

// seq length q
#pragma argument 22 i32 1024

// seq length k
#pragma argument 23 i32 1024

// seq length q rounded
#pragma argument 24 i32 1024

// n_embed
#pragma argument 25 i32 (12*64)

// cache key dummy 1
#pragma argument 26 i32 32

// cache key dummy 2
#pragma argument 27 i32 32

#pragma launch attn_fwd_kernel

// seq len / block_m
#pragma grid x (1024/${BLOCK_SIZE_M})

// batch * n_heads
#pragma grid y (48 * 12)

module {
tt.func public @attn_fwd_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg5: f32, %arg6: i32 {tt.divisibility = 16 : i32},

%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32},
%arg9: i32 {tt.divisibility = 16 : i32},

%arg10: i32 {tt.divisibility = 16 : i32},
%arg11: i32 {tt.divisibility = 16 : i32},
%arg12: i32 {tt.divisibility = 16 : i32},

%arg13: i32 {tt.divisibility = 16 : i32},
%arg14: i32 {tt.divisibility = 16 : i32},
%arg15: i32 {tt.divisibility = 16 : i32},

%arg16: i32 {tt.divisibility = 16 : i32},
%arg17: i32 {tt.divisibility = 16 : i32},
%arg18: i32 {tt.divisibility = 16 : i32},

%arg19: i32 {tt.divisibility = 16 : i32},
%arg20: i32 {tt.divisibility = 16 : i32},
%arg21: i32,
%arg22: i32 {tt.divisibility = 16 : i32},
%arg23: i32 {tt.divisibility = 16 : i32},
%arg24: i32 {tt.divisibility = 16 : i32},
%arg25: i32 {tt.divisibility = 16 : i32},
%arg26: i32 {tt.divisibility = 16 : i32},
%arg27: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0xFF800000> : tensor<${BLOCK_SIZE_M}xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%c128_i32 = arith.constant ${BLOCK_SIZE_N} : i32
%c0_i32 = arith.constant 0 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>
%cst_2 = arith.constant dense<0xFF800000> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>
%c1_i32 = arith.constant 1 : i32
%c256_i32 = arith.constant ${BLOCK_SIZE_M} : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.divsi %1, %arg21 : i32
%3 = arith.remsi %1, %arg21 : i32
%4 = arith.muli %0, %c256_i32 : i32
%5 = tt.make_range {end = ${BLOCK_SIZE_M} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_M}xi32>
%6 = tt.splat %4 : i32 -> tensor<${BLOCK_SIZE_M}xi32>
%7 = arith.addi %6, %5 : tensor<${BLOCK_SIZE_M}xi32>
%8 = tt.make_range {end = ${BLOCK_SIZE_N} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32>
%9 = tt.make_range {end = ${HEAD_DIM} : i32, start = 0 : i32} : tensor<${HEAD_DIM}xi32>
%10 = arith.muli %2, %arg6 : i32
%11 = tt.addptr %arg0, %10 : !tt.ptr<f16>, i32
%12 = arith.muli %3, %arg7 : i32
%13 = tt.addptr %11, %12 : !tt.ptr<f16>, i32
%14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xi32> -> tensor<${BLOCK_SIZE_M}x1xi32>
%15 = tt.splat %arg8 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32>
%16 = arith.muli %14, %15 : tensor<${BLOCK_SIZE_M}x1xi32>
%17 = tt.expand_dims %9 {axis = 0 : i32} : tensor<${HEAD_DIM}xi32> -> tensor<1x${HEAD_DIM}xi32>
%18 = tt.broadcast %16 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%19 = tt.broadcast %17 : tensor<1x${HEAD_DIM}xi32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%20 = arith.addi %18, %19 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%21 = tt.splat %13 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr<f16>>
%22 = tt.addptr %21, %20 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%23 = arith.muli %2, %arg9 : i32
%24 = tt.addptr %arg1, %23 : !tt.ptr<f16>, i32
%25 = arith.muli %3, %arg10 : i32
%26 = tt.addptr %24, %25 : !tt.ptr<f16>, i32
%27 = tt.expand_dims %8 {axis = 1 : i32} : tensor<${BLOCK_SIZE_N}xi32> -> tensor<${BLOCK_SIZE_N}x1xi32>
%28 = tt.splat %arg11 : i32 -> tensor<${BLOCK_SIZE_N}x1xi32>
%29 = arith.muli %27, %28 : tensor<${BLOCK_SIZE_N}x1xi32>
%30 = tt.broadcast %29 : tensor<${BLOCK_SIZE_N}x1xi32> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%31 = tt.broadcast %17 : tensor<1x${HEAD_DIM}xi32> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%32 = arith.addi %30, %31 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%33 = tt.splat %26 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>
%34 = tt.addptr %33, %32 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%35 = arith.muli %2, %arg12 : i32
%36 = tt.addptr %arg2, %35 : !tt.ptr<f16>, i32
%37 = arith.muli %3, %arg13 : i32
%38 = tt.addptr %36, %37 : !tt.ptr<f16>, i32
%39 = tt.splat %arg14 : i32 -> tensor<${BLOCK_SIZE_N}x1xi32>
%40 = arith.muli %27, %39 : tensor<${BLOCK_SIZE_N}x1xi32>
%41 = tt.broadcast %40 : tensor<${BLOCK_SIZE_N}x1xi32> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%42 = arith.addi %41, %31 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%43 = tt.splat %38 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>
%44 = tt.addptr %43, %42 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%45 = tt.load %22 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr<f16>>
%46 = arith.addi %0, %c1_i32 : i32
%47 = arith.muli %46, %c256_i32 : i32
%48 = arith.minsi %47, %arg23 : i32
%49 = tt.broadcast %14 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32>
%50 = tt.splat %arg5 : f32 -> tensor<${BLOCK_SIZE_M}xf32>
%51 = tt.splat %arg5 : f32 -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%52:3 = scf.for %arg28 = %c0_i32 to %48 step %c128_i32 iter_args(%arg29 = %cst_1, %arg30 = %cst, %arg31 = %cst) -> (tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>, tensor<${BLOCK_SIZE_M}xf32>, tensor<${BLOCK_SIZE_M}xf32>) : i32 {
%75 = arith.muli %arg28, %arg11 : i32
%76 = tt.splat %75 : i32 -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%77 = tt.addptr %34, %76 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%78 = tt.load %77 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>
%79 = tt.trans %78 {order = array<i32: 1, 0>} : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xf16> -> tensor<${HEAD_DIM}x${BLOCK_SIZE_N}xf16>
%80 = tt.dot %45, %79, %cst_3, inputPrecision = tf32 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> * tensor<${HEAD_DIM}x${BLOCK_SIZE_N}xf16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>
%81 = tt.splat %arg28 : i32 -> tensor<${BLOCK_SIZE_N}xi32>
%82 = arith.addi %81, %8 : tensor<${BLOCK_SIZE_N}xi32>
%83 = tt.expand_dims %82 {axis = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32> -> tensor<1x${BLOCK_SIZE_N}xi32>
%84 = tt.broadcast %83 : tensor<1x${BLOCK_SIZE_N}xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32>
%85 = arith.cmpi sge, %49, %84 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32>
%86 = arith.select %85, %cst_0, %cst_2 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%87 = arith.extf %80 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> to tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%88 = arith.addf %87, %86 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%89 = "tt.reduce"(%88) <{axis = 1 : i32}> ({
^bb0(%arg32: f32 loc(unknown), %arg33: f32 loc(unknown)):
%115 = arith.maxnumf %arg32, %arg33 : f32
tt.reduce.return %115 : f32
}) : (tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>) -> tensor<${BLOCK_SIZE_M}xf32>
%90 = arith.mulf %89, %50 : tensor<${BLOCK_SIZE_M}xf32>
%91 = arith.maxnumf %90, %arg31 : tensor<${BLOCK_SIZE_M}xf32>
%92 = arith.mulf %88, %51 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%93 = tt.expand_dims %91 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xf32> -> tensor<${BLOCK_SIZE_M}x1xf32>
%94 = tt.broadcast %93 : tensor<${BLOCK_SIZE_M}x1xf32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%95 = arith.subf %92, %94 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%96 = math.exp %95 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%97 = "tt.reduce"(%96) <{axis = 1 : i32}> ({
^bb0(%arg32: f32 loc(unknown), %arg33: f32 loc(unknown)):
%115 = arith.addf %arg32, %arg33 : f32
tt.reduce.return %115 : f32
}) : (tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>) -> tensor<${BLOCK_SIZE_M}xf32>
%98 = arith.subf %arg30, %91 : tensor<${BLOCK_SIZE_M}xf32>
%99 = math.exp %98 : tensor<${BLOCK_SIZE_M}xf32>
%100 = tt.expand_dims %99 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xf32> -> tensor<${BLOCK_SIZE_M}x1xf32>
%101 = arith.truncf %100 : tensor<${BLOCK_SIZE_M}x1xf32> to tensor<${BLOCK_SIZE_M}x1xf16>
%102 = tt.broadcast %101 : tensor<${BLOCK_SIZE_M}x1xf16> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>
%103 = arith.mulf %arg29, %102 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>
%104 = arith.muli %arg28, %arg14 : i32
%105 = tt.splat %104 : i32 -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%106 = tt.addptr %44, %105 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32>
%107 = tt.load %106 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr<f16>>
%108 = arith.truncf %96 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> to tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>
%109 = tt.dot %108, %107, %103, inputPrecision = tf32 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> * tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xf16> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>
%110 = arith.subf %arg31, %91 : tensor<${BLOCK_SIZE_M}xf32>
%111 = math.exp %110 : tensor<${BLOCK_SIZE_M}xf32>
%112 = arith.addf %111, %97 : tensor<${BLOCK_SIZE_M}xf32>
%113 = math.log %112 : tensor<${BLOCK_SIZE_M}xf32>
%114 = arith.addf %91, %113 : tensor<${BLOCK_SIZE_M}xf32>
scf.yield %109, %91, %114 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>, tensor<${BLOCK_SIZE_M}xf32>, tensor<${BLOCK_SIZE_M}xf32>
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
%53 = arith.subf %52#1, %52#2 : tensor<${BLOCK_SIZE_M}xf32>
%54 = math.exp %53 : tensor<${BLOCK_SIZE_M}xf32>
%55 = tt.expand_dims %54 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xf32> -> tensor<${BLOCK_SIZE_M}x1xf32>
%56 = tt.broadcast %55 : tensor<${BLOCK_SIZE_M}x1xf32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32>
%57 = arith.extf %52#0 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> to tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32>
%58 = arith.mulf %57, %56 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32>
%59 = arith.muli %1, %arg24 : i32
%60 = tt.addptr %arg4, %59 : !tt.ptr<f16>, i32
%61 = tt.splat %60 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x!tt.ptr<f16>>
%62 = tt.addptr %61, %7 : tensor<${BLOCK_SIZE_M}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}xi32>
%63 = arith.truncf %52#2 : tensor<${BLOCK_SIZE_M}xf32> to tensor<${BLOCK_SIZE_M}xf16>
tt.store %62, %63 : tensor<${BLOCK_SIZE_M}x!tt.ptr<f16>>
%64 = arith.muli %2, %arg18 : i32
%65 = tt.addptr %arg3, %64 : !tt.ptr<f16>, i32
%66 = arith.muli %3, %arg19 : i32
%67 = tt.addptr %65, %66 : !tt.ptr<f16>, i32
%68 = tt.splat %arg20 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32>
%69 = arith.muli %14, %68 : tensor<${BLOCK_SIZE_M}x1xi32>
%70 = tt.broadcast %69 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%71 = arith.addi %70, %19 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%72 = tt.splat %67 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr<f16>>
%73 = tt.addptr %72, %71 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32>
%74 = arith.truncf %58 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32> to tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>
tt.store %73, %74 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr<f16>>
tt.return
}
}
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ int main(int argc, char *argv[]) {

const std::string &dtype_str = args[2];
std::string value = args[3];
if (arg_idx <= kernel_arguments.size()) {
if (arg_idx >= kernel_arguments.size()) {
kernel_arguments.resize(arg_idx + 1);
}
KerArgDtype dtype{};
Expand Down
2 changes: 1 addition & 1 deletion third_party/triton

0 comments on commit 801d05c

Please sign in to comment.