From 801d05c71c0855fe0639594298bea02dcb6537bf Mon Sep 17 00:00:00 2001 From: mike Date: Fri, 13 Sep 2024 05:10:04 +0200 Subject: [PATCH] Fixed buid instructions & fixed bug in kern arg preprocessor + added attn fwd autotune test --- CMakeLists.txt | 12 +- README.md | 6 +- attn_fwd_autotune.ttir | 252 +++++++++++++++++++++++++++++++++++++++++ src/main.cpp | 2 +- third_party/triton | 2 +- 5 files changed, 267 insertions(+), 7 deletions(-) create mode 100644 attn_fwd_autotune.ttir diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ec32e1..cb1c1e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,6 @@ add_executable(tritonc src/main.cpp src/tinyscript.cpp src/templating.cpp) target_link_libraries(tritonc PRIVATE argparse) -target_link_libraries(tritonc PRIVATE cuda curand) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) @@ -111,4 +110,13 @@ set(TRITON_GENERATED_INCLUDE_DIRS ) target_include_directories(tritonc PRIVATE ${TRITON_INCLUDE_DIRS}) -target_include_directories(tritonc PRIVATE ${TRITON_GENERATED_INCLUDE_DIRS}) \ No newline at end of file +target_include_directories(tritonc PRIVATE ${TRITON_GENERATED_INCLUDE_DIRS}) + +# include cuda headers +set(CMAKE_CUDA_ARCHITECTURES 75) +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) +enable_language(CUDA) +find_library(CUDA_LIBRARY cuda ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}) +find_library(CURAND_LIBRARY curand ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}) +target_include_directories(tritonc PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_link_libraries(tritonc PRIVATE ${CUDA_LIBRARY} ${CURAND_LIBRARY}) \ No newline at end of file diff --git a/README.md b/README.md index f7bb5de..9a07bf6 100644 --- a/README.md +++ b/README.md @@ -46,16 +46,16 @@ tritonc add_kernel.ttir --compute-capability 89 --num-stages 3 --num-warps 4 -o #### Build LLVM The required llvm version is a submodule of this repository under `third_party/llvm-project`. - Build-install the llvm-project system-wide. +Make sure to install `ninja-build` for fast builds! ```commandline cd third_party/llvm-project mkdir build cd build -cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON +cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_INSTALL_PREFIX=/usr/local cmake --build . -sudo make install +sudo cmake --build . --target install ``` ##### Hot to build tritonc diff --git a/attn_fwd_autotune.ttir b/attn_fwd_autotune.ttir new file mode 100644 index 0000000..ad5be15 --- /dev/null +++ b/attn_fwd_autotune.ttir @@ -0,0 +1,252 @@ +#pragma autotune BLOCK_SIZE_M {128, 256} default 128 +#pragma autotune BLOCK_SIZE_N {128, 256} default 128 + +#pragma autotune intrinsic num_warps {4, 8} default 4 +#pragma autotune intrinsic num_stages {1, 2, 3, 4, 5, 6, 7, 8} default 2 + +// this is more a compile time constant, but there is no #pragma constant yet... +#pragma autotune HEAD_DIM {64} default 64 + + +// we still have to allocate 3 * per each arg. to simulate split with jumpy stride access patterns in the kernel +// 48 = batch size +// 1024 = seq len +// 12 = num heads +// 64 head dim +// 2 = sizeof(float16) +#pragma argument 0 ptr cuMalloc(3 * 48 * 1024 * 12 * 64 * 2) +#pragma argument 1 ptr cuMalloc(3 * 48 * 1024 * 12 * 64 * 2) +#pragma argument 2 ptr cuMalloc(3 * 48 * 1024 * 12 * 64 * 2) +#pragma argument 3 ptr cuMalloc(48 * 1024 * 12 * 64 * 2) + +#pragma argument 4 ptr cuMalloc(48 * 12 * 1024 * 2) + +// this means 1.0f / sqrt(64*12) when interpreted as float +#pragma argument 5 i32 1024707898 + +// qstride 0, 2, 1 +#pragma argument 6 i32 2359296 +#pragma argument 7 i32 64 +#pragma argument 8 i32 2304 + +// kstride 0, 2, 1 +#pragma argument 9 i32 2359296 +#pragma argument 10 i32 64 +#pragma argument 11 i32 2304 + +// vstride 0, 2, 1 +#pragma argument 12 i32 2359296 +#pragma argument 13 i32 64 +#pragma argument 14 i32 2304 + +// bias strides +#pragma argument 15 i32 0 +#pragma argument 16 i32 0 +#pragma argument 17 i32 0 + +// output strides +#pragma argument 18 i32 786432 +#pragma argument 19 i32 64 +#pragma argument 20 i32 768 + +// nheads +#pragma argument 21 i32 12 + +// seq length q +#pragma argument 22 i32 1024 + +// seq length k +#pragma argument 23 i32 1024 + +// seq length q rounded +#pragma argument 24 i32 1024 + +// n_embed +#pragma argument 25 i32 (12*64) + +// cache key dummy 1 +#pragma argument 26 i32 32 + +// cache key dummy 2 +#pragma argument 27 i32 32 + +#pragma launch attn_fwd_kernel + +// seq len / block_m +#pragma grid x (1024/${BLOCK_SIZE_M}) + +// batch * n_heads +#pragma grid y (48 * 12) + +module { + tt.func public @attn_fwd_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: !tt.ptr {tt.divisibility = 16 : i32}, + %arg4: !tt.ptr {tt.divisibility = 16 : i32}, + %arg5: f32, %arg6: i32 {tt.divisibility = 16 : i32}, + + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}, + %arg9: i32 {tt.divisibility = 16 : i32}, + + %arg10: i32 {tt.divisibility = 16 : i32}, + %arg11: i32 {tt.divisibility = 16 : i32}, + %arg12: i32 {tt.divisibility = 16 : i32}, + + %arg13: i32 {tt.divisibility = 16 : i32}, + %arg14: i32 {tt.divisibility = 16 : i32}, + %arg15: i32 {tt.divisibility = 16 : i32}, + + %arg16: i32 {tt.divisibility = 16 : i32}, + %arg17: i32 {tt.divisibility = 16 : i32}, + %arg18: i32 {tt.divisibility = 16 : i32}, + + %arg19: i32 {tt.divisibility = 16 : i32}, + %arg20: i32 {tt.divisibility = 16 : i32}, + %arg21: i32, + %arg22: i32 {tt.divisibility = 16 : i32}, + %arg23: i32 {tt.divisibility = 16 : i32}, + %arg24: i32 {tt.divisibility = 16 : i32}, + %arg25: i32 {tt.divisibility = 16 : i32}, + %arg26: i32 {tt.divisibility = 16 : i32}, + %arg27: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<${BLOCK_SIZE_M}xf32> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %c128_i32 = arith.constant ${BLOCK_SIZE_N} : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> + %cst_2 = arith.constant dense<0xFF800000> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> + %c1_i32 = arith.constant 1 : i32 + %c256_i32 = arith.constant ${BLOCK_SIZE_M} : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.divsi %1, %arg21 : i32 + %3 = arith.remsi %1, %arg21 : i32 + %4 = arith.muli %0, %c256_i32 : i32 + %5 = tt.make_range {end = ${BLOCK_SIZE_M} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_M}xi32> + %6 = tt.splat %4 : i32 -> tensor<${BLOCK_SIZE_M}xi32> + %7 = arith.addi %6, %5 : tensor<${BLOCK_SIZE_M}xi32> + %8 = tt.make_range {end = ${BLOCK_SIZE_N} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32> + %9 = tt.make_range {end = ${HEAD_DIM} : i32, start = 0 : i32} : tensor<${HEAD_DIM}xi32> + %10 = arith.muli %2, %arg6 : i32 + %11 = tt.addptr %arg0, %10 : !tt.ptr, i32 + %12 = arith.muli %3, %arg7 : i32 + %13 = tt.addptr %11, %12 : !tt.ptr, i32 + %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xi32> -> tensor<${BLOCK_SIZE_M}x1xi32> + %15 = tt.splat %arg8 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32> + %16 = arith.muli %14, %15 : tensor<${BLOCK_SIZE_M}x1xi32> + %17 = tt.expand_dims %9 {axis = 0 : i32} : tensor<${HEAD_DIM}xi32> -> tensor<1x${HEAD_DIM}xi32> + %18 = tt.broadcast %16 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %19 = tt.broadcast %17 : tensor<1x${HEAD_DIM}xi32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %20 = arith.addi %18, %19 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %21 = tt.splat %13 : !tt.ptr -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr> + %22 = tt.addptr %21, %20 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr>, tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %23 = arith.muli %2, %arg9 : i32 + %24 = tt.addptr %arg1, %23 : !tt.ptr, i32 + %25 = arith.muli %3, %arg10 : i32 + %26 = tt.addptr %24, %25 : !tt.ptr, i32 + %27 = tt.expand_dims %8 {axis = 1 : i32} : tensor<${BLOCK_SIZE_N}xi32> -> tensor<${BLOCK_SIZE_N}x1xi32> + %28 = tt.splat %arg11 : i32 -> tensor<${BLOCK_SIZE_N}x1xi32> + %29 = arith.muli %27, %28 : tensor<${BLOCK_SIZE_N}x1xi32> + %30 = tt.broadcast %29 : tensor<${BLOCK_SIZE_N}x1xi32> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %31 = tt.broadcast %17 : tensor<1x${HEAD_DIM}xi32> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %32 = arith.addi %30, %31 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %33 = tt.splat %26 : !tt.ptr -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr> + %34 = tt.addptr %33, %32 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %35 = arith.muli %2, %arg12 : i32 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i32 + %37 = arith.muli %3, %arg13 : i32 + %38 = tt.addptr %36, %37 : !tt.ptr, i32 + %39 = tt.splat %arg14 : i32 -> tensor<${BLOCK_SIZE_N}x1xi32> + %40 = arith.muli %27, %39 : tensor<${BLOCK_SIZE_N}x1xi32> + %41 = tt.broadcast %40 : tensor<${BLOCK_SIZE_N}x1xi32> -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %42 = arith.addi %41, %31 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %43 = tt.splat %38 : !tt.ptr -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr> + %44 = tt.addptr %43, %42 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %45 = tt.load %22 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr> + %46 = arith.addi %0, %c1_i32 : i32 + %47 = arith.muli %46, %c256_i32 : i32 + %48 = arith.minsi %47, %arg23 : i32 + %49 = tt.broadcast %14 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32> + %50 = tt.splat %arg5 : f32 -> tensor<${BLOCK_SIZE_M}xf32> + %51 = tt.splat %arg5 : f32 -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %52:3 = scf.for %arg28 = %c0_i32 to %48 step %c128_i32 iter_args(%arg29 = %cst_1, %arg30 = %cst, %arg31 = %cst) -> (tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>, tensor<${BLOCK_SIZE_M}xf32>, tensor<${BLOCK_SIZE_M}xf32>) : i32 { + %75 = arith.muli %arg28, %arg11 : i32 + %76 = tt.splat %75 : i32 -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %77 = tt.addptr %34, %76 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %78 = tt.load %77 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr> + %79 = tt.trans %78 {order = array} : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xf16> -> tensor<${HEAD_DIM}x${BLOCK_SIZE_N}xf16> + %80 = tt.dot %45, %79, %cst_3, inputPrecision = tf32 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> * tensor<${HEAD_DIM}x${BLOCK_SIZE_N}xf16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> + %81 = tt.splat %arg28 : i32 -> tensor<${BLOCK_SIZE_N}xi32> + %82 = arith.addi %81, %8 : tensor<${BLOCK_SIZE_N}xi32> + %83 = tt.expand_dims %82 {axis = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32> -> tensor<1x${BLOCK_SIZE_N}xi32> + %84 = tt.broadcast %83 : tensor<1x${BLOCK_SIZE_N}xi32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32> + %85 = arith.cmpi sge, %49, %84 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi32> + %86 = arith.select %85, %cst_0, %cst_2 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %87 = arith.extf %80 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> to tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %88 = arith.addf %87, %86 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %89 = "tt.reduce"(%88) <{axis = 1 : i32}> ({ + ^bb0(%arg32: f32 loc(unknown), %arg33: f32 loc(unknown)): + %115 = arith.maxnumf %arg32, %arg33 : f32 + tt.reduce.return %115 : f32 + }) : (tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>) -> tensor<${BLOCK_SIZE_M}xf32> + %90 = arith.mulf %89, %50 : tensor<${BLOCK_SIZE_M}xf32> + %91 = arith.maxnumf %90, %arg31 : tensor<${BLOCK_SIZE_M}xf32> + %92 = arith.mulf %88, %51 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %93 = tt.expand_dims %91 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xf32> -> tensor<${BLOCK_SIZE_M}x1xf32> + %94 = tt.broadcast %93 : tensor<${BLOCK_SIZE_M}x1xf32> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %95 = arith.subf %92, %94 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %96 = math.exp %95 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> + %97 = "tt.reduce"(%96) <{axis = 1 : i32}> ({ + ^bb0(%arg32: f32 loc(unknown), %arg33: f32 loc(unknown)): + %115 = arith.addf %arg32, %arg33 : f32 + tt.reduce.return %115 : f32 + }) : (tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>) -> tensor<${BLOCK_SIZE_M}xf32> + %98 = arith.subf %arg30, %91 : tensor<${BLOCK_SIZE_M}xf32> + %99 = math.exp %98 : tensor<${BLOCK_SIZE_M}xf32> + %100 = tt.expand_dims %99 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xf32> -> tensor<${BLOCK_SIZE_M}x1xf32> + %101 = arith.truncf %100 : tensor<${BLOCK_SIZE_M}x1xf32> to tensor<${BLOCK_SIZE_M}x1xf16> + %102 = tt.broadcast %101 : tensor<${BLOCK_SIZE_M}x1xf16> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> + %103 = arith.mulf %arg29, %102 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> + %104 = arith.muli %arg28, %arg14 : i32 + %105 = tt.splat %104 : i32 -> tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %106 = tt.addptr %44, %105 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr>, tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xi32> + %107 = tt.load %106 : tensor<${BLOCK_SIZE_N}x${HEAD_DIM}x!tt.ptr> + %108 = arith.truncf %96 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> to tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> + %109 = tt.dot %108, %107, %103, inputPrecision = tf32 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16> * tensor<${BLOCK_SIZE_N}x${HEAD_DIM}xf16> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> + %110 = arith.subf %arg31, %91 : tensor<${BLOCK_SIZE_M}xf32> + %111 = math.exp %110 : tensor<${BLOCK_SIZE_M}xf32> + %112 = arith.addf %111, %97 : tensor<${BLOCK_SIZE_M}xf32> + %113 = math.log %112 : tensor<${BLOCK_SIZE_M}xf32> + %114 = arith.addf %91, %113 : tensor<${BLOCK_SIZE_M}xf32> + scf.yield %109, %91, %114 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16>, tensor<${BLOCK_SIZE_M}xf32>, tensor<${BLOCK_SIZE_M}xf32> + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + %53 = arith.subf %52#1, %52#2 : tensor<${BLOCK_SIZE_M}xf32> + %54 = math.exp %53 : tensor<${BLOCK_SIZE_M}xf32> + %55 = tt.expand_dims %54 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xf32> -> tensor<${BLOCK_SIZE_M}x1xf32> + %56 = tt.broadcast %55 : tensor<${BLOCK_SIZE_M}x1xf32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32> + %57 = arith.extf %52#0 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> to tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32> + %58 = arith.mulf %57, %56 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32> + %59 = arith.muli %1, %arg24 : i32 + %60 = tt.addptr %arg4, %59 : !tt.ptr, i32 + %61 = tt.splat %60 : !tt.ptr -> tensor<${BLOCK_SIZE_M}x!tt.ptr> + %62 = tt.addptr %61, %7 : tensor<${BLOCK_SIZE_M}x!tt.ptr>, tensor<${BLOCK_SIZE_M}xi32> + %63 = arith.truncf %52#2 : tensor<${BLOCK_SIZE_M}xf32> to tensor<${BLOCK_SIZE_M}xf16> + tt.store %62, %63 : tensor<${BLOCK_SIZE_M}x!tt.ptr> + %64 = arith.muli %2, %arg18 : i32 + %65 = tt.addptr %arg3, %64 : !tt.ptr, i32 + %66 = arith.muli %3, %arg19 : i32 + %67 = tt.addptr %65, %66 : !tt.ptr, i32 + %68 = tt.splat %arg20 : i32 -> tensor<${BLOCK_SIZE_M}x1xi32> + %69 = arith.muli %14, %68 : tensor<${BLOCK_SIZE_M}x1xi32> + %70 = tt.broadcast %69 : tensor<${BLOCK_SIZE_M}x1xi32> -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %71 = arith.addi %70, %19 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %72 = tt.splat %67 : !tt.ptr -> tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr> + %73 = tt.addptr %72, %71 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr>, tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xi32> + %74 = arith.truncf %58 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf32> to tensor<${BLOCK_SIZE_M}x${HEAD_DIM}xf16> + tt.store %73, %74 : tensor<${BLOCK_SIZE_M}x${HEAD_DIM}x!tt.ptr> + tt.return + } +} \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 8ad620b..683f12b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -289,7 +289,7 @@ int main(int argc, char *argv[]) { const std::string &dtype_str = args[2]; std::string value = args[3]; - if (arg_idx <= kernel_arguments.size()) { + if (arg_idx >= kernel_arguments.size()) { kernel_arguments.resize(arg_idx + 1); } KerArgDtype dtype{}; diff --git a/third_party/triton b/third_party/triton index 7697cbe..df26ec6 160000 --- a/third_party/triton +++ b/third_party/triton @@ -1 +1 @@ -Subproject commit 7697cbe71cfb48e0320a59a2c6454be061d1a50a +Subproject commit df26ec64aedd222af7f1988006a2b709d9e84f07