Skip to content

Commit

Permalink
Add new tensor operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Gerum committed Feb 14, 2024
2 parents ecb7f46 + 2b813ec commit 3136140
Show file tree
Hide file tree
Showing 1,329 changed files with 290,072 additions and 9,948 deletions.
2 changes: 1 addition & 1 deletion .asf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ github:
main:
required_status_checks:
contexts:
# Require passing runs from Jenkins for all platforms
- unity/pr-head
- arm/pr-head
- cortexm/pr-head
- cpu/pr-head
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@
[submodule "3rdparty/libflash_attn"]
path = 3rdparty/libflash_attn
url = https://github.com/tlc-pack/libflash_attn
[submodule "3rdparty/flashinfer"]
path = 3rdparty/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 1946 files
2 changes: 1 addition & 1 deletion 3rdparty/cutlass_fpA_intB_gemm
Submodule cutlass_fpA_intB_gemm updated 74 files
+78 −0 .clang-format
+6 −0 CMakeLists.txt
+66 −0 cmake/utils/Utils.cmake
+78 −4 cutlass_extensions/include/cutlass_extensions/arch/mma.h
+23 −12 cutlass_extensions/include/cutlass_extensions/compute_occupancy.h
+14 −10 cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h
+0 −390 cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
+65 −55 cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
+67 −48 cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
+57 −42 cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
+253 −176 cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
+501 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h
+86 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h
+58 −33 cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
+534 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h
+358 −0 cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h
+44 −25 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h
+148 −197 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h
+106 −172 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h
+71 −207 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h
+89 −263 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
+56 −35 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h
+16 −505 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
+691 −0 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
+636 −0 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h
+81 −69 cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
+19 −39 cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
+57 −69 cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
+274 −98 cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
+24 −10 cutlass_extensions/include/cutlass_extensions/gemm_configs.h
+73 −55 cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+16 −11 cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h
+248 −0 cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h
+23 −13 cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h
+19 −2 cutlass_kernels/CMakeLists.txt
+79 −73 cutlass_kernels/cutlass_heuristic.cc
+12 −14 cutlass_kernels/cutlass_heuristic.h
+517 −403 cutlass_kernels/cutlass_preprocessors.cc
+5 −4 cutlass_kernels/cutlass_preprocessors.h
+0 −33 cutlass_kernels/fpA_intB_gemm.cu
+21 −17 cutlass_kernels/fpA_intB_gemm.h
+30 −92 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h
+26 −0 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_finegrained.cu
+0 −21 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_fp16_int8.cu
+117 −0 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_impl.h
+26 −0 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_per_col.cu
+544 −389 cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h
+68 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels.h
+32 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
+32 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
+32 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
+561 −0 cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
+179 −0 cutlass_kernels/moe_gemm/moe_gemv_kernels.cu
+75 −0 cutlass_kernels/moe_gemm/moe_gemv_kernels.h
+36 −0 tvm_binding/CMakeLists.txt
+140 −0 tvm_binding/tvm_binding.cu
+20 −4 utils/activation_types.h
+17 −12 utils/cuda_utils.h
+30 −22 utils/logger.h
+13 −11 utils/string_utils.h
+84 −0 weightOnlyBatchedGemv/common.h
+91 −0 weightOnlyBatchedGemv/enabled.h
+440 −0 weightOnlyBatchedGemv/kernel.h
+224 −0 weightOnlyBatchedGemv/kernelLauncher.cu
+11 −5 weightOnlyBatchedGemv/kernelLauncher.h
+99 −0 weightOnlyBatchedGemv/utility.h
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu
+97 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu
+97 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu
+97 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu
+98 −0 weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu
1 change: 1 addition & 0 deletions 3rdparty/flashinfer
Submodule flashinfer added at 47686e
2 changes: 1 addition & 1 deletion 3rdparty/libflash_attn
5 changes: 4 additions & 1 deletion 3rdparty/picojson/picojson.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
* POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#ifndef PICOJSON_USE_INT64
#define PICOJSON_USE_INT64
#define __STDC_FORMAT_MACROS 1
#endif

#include <algorithm>
#include <cstddef>
Expand Down Expand Up @@ -76,7 +80,6 @@ extern "C" {

// experimental support for int64_t (see README.mkdn for detail)
#ifdef PICOJSON_USE_INT64
#define __STDC_FORMAT_MACROS
#include <errno.h>
#include <inttypes.h>
#endif
Expand Down
57 changes: 54 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_NVTX "Build with NVTX" OFF)
tvm_option(USE_CUTLASS "Build with CUTLASS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_CURAND "Build with cuRAND" OFF)
Expand Down Expand Up @@ -126,6 +127,7 @@ tvm_option(USE_CLML "Build with CLML Codegen support" OFF)
tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF)
tvm_option(USE_UMA "Build with UMA support" OFF)
tvm_option(USE_VERILATOR "Build with Verilator support" OFF)
tvm_option(USE_MSC "Enable Multi-System Compiler" OFF)

# include directories
include_directories(${CMAKE_INCLUDE_PATH})
Expand Down Expand Up @@ -303,6 +305,18 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/driver/*.cc
src/support/*.cc
src/script/*.cc
src/relax/ir/*.cc
src/relax/op/*.cc
src/relax/analysis/*.cc
src/relax/transform/*.cc
src/relax/backend/vm/*.cc
src/relax/backend/task_extraction.cc
src/relax/backend/pattern_registry.cc
src/relax/utils.cc
src/relax/distributed/*.cc
src/relax/distributed/transform/*.cc
src/relax/op/distributed/*.cc
src/relax/testing/*.cc
)

tvm_file_glob(GLOB CODEGEN_SRCS
Expand Down Expand Up @@ -351,6 +365,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/memory/*.cc
src/runtime/disco/*.cc
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
)

if(BUILD_FOR_HEXAGON)
Expand Down Expand Up @@ -437,13 +452,15 @@ if(USE_CUDA AND USE_NCCL)
message(STATUS "Build with NCCL...")
find_nccl(${USE_NCCL})
tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
endif()

if(USE_ROCM AND USE_RCCL)
message(STATUS "Build with RCCL...")
find_rccl(${USE_RCCL})
tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/rccl/*.cc)
tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/nccl/*.cc)
set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
list(APPEND RUNTIME_SRCS ${RUNTIME_RCCL_SRC})
endif()

Expand Down Expand Up @@ -559,6 +576,8 @@ include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/VitisAI.cmake)
include(cmake/modules/contrib/Verilator.cmake)
include(cmake/modules/contrib/UMA.cmake)
include(cmake/modules/contrib/MSC.cmake)
include(cmake/modules/contrib/vllm.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
include(cmake/modules/RustExt.cmake)
Expand Down Expand Up @@ -873,14 +892,46 @@ if(USE_CUDA AND USE_CUTLASS)
install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX})
target_link_libraries(tvm PRIVATE fpA_intB_gemm)
target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm)
target_link_libraries(tvm PRIVATE fpA_intB_gemm_tvm)
target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm_tvm)

install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX})
target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn)
target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
endif()

if(USE_CUDA AND USE_NVTX)
set_source_files_properties(src/runtime/nvtx.cc PROPERTIES COMPILE_DEFINITIONS "TVM_NVTX_ENABLED=1")
endif()

if(USE_CUDA AND USE_NCCL)
target_link_libraries(tvm PRIVATE nccl)
target_link_libraries(tvm_runtime PRIVATE nccl)
find_library(LIBRT rt)
target_link_libraries(tvm PRIVATE nccl ${LIBRT})
target_link_libraries(tvm_runtime PRIVATE nccl ${LIBRT})
endif()

if(USE_ROCM AND USE_RCCL)
target_link_libraries(tvm PRIVATE rccl)
target_link_libraries(tvm_runtime PRIVATE rccl)
endif()


option(USE_FLASHINFER "Build TVM with FlashInfer" OFF)
if (USE_FLASHINFER STREQUAL "ON")
message(STATUS "Build with FlashInfer")
set(FLASHINFER_TVM_BINDING ON)
set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR})
set(FLASHINFER_ENABLE_FP8 OFF)
set(FLASHINFER_PREFILL OFF)
set(FLASHINFER_DECODE OFF)
set(FLASHINFER_PAGE OFF)
add_subdirectory(3rdparty/flashinfer)
else ()
message(STATUS "Build without FlashInfer")
endif ()


if (USE_FLASHINFER STREQUAL "ON")
target_link_libraries(tvm PRIVATE flashinfer_tvm)
target_link_libraries(tvm_runtime PRIVATE flashinfer_tvm)
endif ()
59 changes: 59 additions & 0 deletions KEYS
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,62 @@ WTgrESErlqNLN5ZTTW/1jBELJCfJKxgHUip+Yo6qNZoWwNLP1BaIcoA3miSG3DXf
wS/UuN04NxDy7V6mPXE=
=MTba
-----END PGP PUBLIC KEY BLOCK-----
pub rsa4096 2024-01-15 [SC]
A4D9228E55761E665BF01CBB5CE869CB7DEC048C
uid [ultimate] Star Yuan (CODE SIGNING KEY) <[email protected]>
sig 3 5CE869CB7DEC048C 2024-01-15 Star Yuan (CODE SIGNING KEY) <[email protected]>
sub rsa4096 2024-01-15 [E]
sig 5CE869CB7DEC048C 2024-01-15 Star Yuan (CODE SIGNING KEY) <[email protected]>

-----BEGIN PGP PUBLIC KEY BLOCK-----

mQINBGWlStcBEADaslyfbUNARhWftJoRAChoak0cFU6NxahhvyZfyTGtSuwuHNDD
2eyvhnDIaYXVClxoNgikiQ5Nkd1jtbA4rFCw6Pdbq+98fkpcr8N4o+jlbpu6Ff3j
dJ2Qu000MV5qe9FZ4QasdfglJElvizgfNbJv/Fz1ERl/BS1U0c7lyQF9jGGh7EY2
1y+JFp5OMG6A9SpfaOd+iOw5/cfCQk8+sHQC4dp3hOJPK4NLvjotK+hlOhRsF7gU
goYYT2IP56kPQb6U/Uiv4/R6HbKugzqSMl6BMwAb9uG6UX0xUfAA8ciHoaITCJCQ
9e/jGWnDnqYlAMNqLkHEmW7THxJ3hHXcac/Z1C3PeLDJU0rpTxDcjuYkM5jFCu7H
TgT7lWBP/PyAAVSsLqMQbLJOWm0a14tb/oRoeYr/B2prIbJY5qJBM1nherKGMg0G
7Oqugo6A1VqgUxg7Chj73PledaNwvm5Lxpl6D+wPDSifhlz0vnwOCMoOon0pTjK4
DXDEXnEXZtzkZgXI6g7AkVyt0gkqyUi+01ibmlBfcVHh3PVvU4oNdkaywQd5s29R
DsA4WOqt9cLv+iqIzM1juygfR6ooA1jHDIyIPmmC/kOrcxKXEFvIGXDDCbXAvdXc
uXgZeZqI3pbKjQaU3fF8HwJ956HTM8rywtVGH9BWRl/i6qn5sq9CcukcuQARAQAB
tDBTdGFyIFl1YW4gKENPREUgU0lHTklORyBLRVkpIDx5c2gzMjlAYXBhY2hlLm9y
Zz6JAk4EEwEKADgWIQSk2SKOVXYeZlvwHLtc6GnLfewEjAUCZaVK1wIbAwULCQgH
AgYVCgkICwIEFgIDAQIeAQIXgAAKCRBc6GnLfewEjBAiD/0cfaYfQ0DL7CPsP0lS
yezPDDTnDPIo//G1cuSYG0gnXQ1SpbJSzDE7deew+P506/sWFneOY5Kuv6DuSE8J
nM6vv1EYR4/9x/XstA4F04lQPngKKBV+UKrWj8zIA2Drn345Ece1150bWvrUD7mT
+ps1gfe8SGYpOmR/kRc8qra2zizcWBC1Dl4qd+RcY7Ac6Cu3G/JG2KvZnrUSVev9
nzSl2V0JtFVIla2odSJqv0Zdj5E2vLvQd3Dxbf3BODCdL3iQqxrQhj+0T3QLEhPg
y2XOtqW7a96XosoQ44wUiHaS5LwFViG8LoiPADtSdXYb8m4FtMfB8t4mzXVqBjpz
2csMqOnNvo7bctfpJkjM14UKib39MR2wUv9fD6Qa+OAAIeXGTQH+wlXmlYjji9+A
4tgq/+d75qUC/tyHSgbZLNXobHF8v77g60cBvFXVL02W53xhVDZP4gwu5iSSN8BJ
a2hqwo4UO53mRUNkwFZONYxJE7MhLl22r08eu0xNYhoGtpHzDVoyHg26+2FUgFDd
TNsdqjMyJ+3GXEE3PdKVDTj9To+RoHLuCczk5uvtFYGhseRwIWbVhmTLKUL+wgSa
+b90slkv+CBJvLjvKbVCmCLXwiH8Cx+MZSu0oM5v8fbHuWOhkb7bJd1V+U7qV/OA
CCqBICt64F+ooQ0oEdC0oLvr2LkCDQRlpUrXARAA1DKsF2ZNUdPIn4VcsjRk/+qF
13VC9SaqMp+J+8m1XTIeXdr27uUa2vT4j8pAM4gwMVkpEqE0rmHK+S1SeEAlcizC
Bvp7vvso/glcOg9Sgt9PXvvEDPL/Hnsn1+3YX+Gye4cOTiDDgVW1RKcgGj9Xsir+
5BS9Secj5CGo92cuaqIo/mMjxGlsuW/LvTU5qQhz7aOaBibe5EHPlGMqM6XJN0BZ
MHRfBiGDs2n/egMnTPL0JcTlAeird+yxDPULKzhQWkd8rfQKpwcRiY6IcYFHlWdM
VhZkXNRrxh6+q3rR7FKmxlvG/12YyT6Y1BocGLgROzKIeoEp+6vsU5LJ90jy82ig
oGSHwNjm2RRukjV3eebovl1dCo6IaI/j4idCv7NlcBnln/Unk4YOZbneMT5r+3Zy
Q4azLB8KHfHOrUwAxRAGPygdLtqbjs4mF45HDe6h3IOVoiOQlZNpesrwEumlK+Il
taU0T8hfxyMpIcTLUZpIddSxo0sVby2XZ+z00En3JvtqbpRcfA87thxpsE7uHxwT
YT8mPPDxo1R4I4LSzsDnekD8EB/7woz4n5I1RBoPB1LSoo0B2os+4vHGkiwZ0TN0
ICcUYdM623Bv2wJQbVKEDvwjHZTkotjLx7R2lyqMRwFYrMXHxevOfbARJQCqrcY2
ouLzQme9rE5MPQbKj2cAEQEAAYkCNgQYAQoAIBYhBKTZIo5Vdh5mW/Acu1zoact9
7ASMBQJlpUrXAhsMAAoJEFzoact97ASMNsIP/3tlsvwUVfy19lUjxWT4rPw2GGz8
lbPiaetgigK1F1rlzYnIVo32Fcj/GNNwWEdxxEzeaQR/AJmZLWB8sBDThoTGeSDK
fjKXeDjZh+ElpIKWyk7f3ddHN2TpBz698kZ7fYCciRE9T4d3xgbqx2rCfupxUFSj
lxLFRkasByJnLdAZI50NZjW838IHMaGsvgbWEqRuvKZOES6gFhrK1NTSxj5iuiHk
Uxj1KzMhOW+m1eZ0pQcCVXJDY6KYhmrZzw9q6kzSO9ukmS5yRf0EnD7Fsca4iIXP
Y28xs3zBxYHV4IGU1PtcIwNewmTnjnEy0apHPz0zDplHi1meXuhA7bBMjs/AouJg
6FIDNSQqDuFXufqvVQ6LZZgob+LklMAoGcka4/5ZLPjipj5SWNeZZunJujSqWK7f
KJaIfn7ILXqxjaTFrjBN3cm60rO1+zEektrjtWMmSBn0L76pY2ucenrqewruYYdD
12VQra/6QAS5R0HG8gzOfsZcrHaiIuLoTbsOgnqLVcdb9lO7f3oMbKPwejZ5yhyz
SraXHvmixlhf4uUYwsWyhw3UgHrv1psB8Z9NfdH9/T2BvRg0qy6ZmI0n0OagPNgz
v+SZrqrWkSjyPdl6j7x8EmePfNidqw/CnncYI2rEVSmP28W0Uhg5JLgroGYmycv6
HeZaRpYvkV8UNmnE
=BtHq
-----END PGP PUBLIC KEY BLOCK-----
Loading

0 comments on commit 3136140

Please sign in to comment.