Skip to content

Commit

Permalink
[CUDA] upgrade cutlass to 3.5.0 (microsoft#20940)
Browse files Browse the repository at this point in the history
### Description
Upgrade cutlass to 3.5 to fix build errors using CUDA 12.4 or 12.5 in
Windows
- [x] Upgrade cutlass to 3.5.0.
- [x] Fix flash attention build error with latest cutlass header files
and APIs. This fix is provided by @wangyems.
- [x] Update efficient attention to use new cutlass fmha interface.
- [x] Patch cutlass to fix `hrsqrt` not found error for sm < 53.
- [x] Disable TF32 Staged Accumulation to fix blkq4_fp16_gemm_sm80_test
build error for cuda 11.8 to 12.3.
- [x] Disable TRT 10 deprecate warnings. 

The following are not included in this PR:
* TRT provider replaces the deprecated APIs.
* Fix blkq4_fp16_gemm_sm80_test build error for cuda 12.4 or 12.5. This
test is not built by default unless you add `--cmake_extra_defines
onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON` in build command.

To integrate to rel-1.18.1: Either bring in other changes (like onnx
1.16.1), or generate manifest and upload a new ONNX Runtime Build Time
Deps artifact based on rel-1.18.1.

### Motivation and Context
microsoft#19891
microsoft#20924
microsoft#20953
  • Loading branch information
tianleiwu authored Jun 11, 2024
1 parent dd805ff commit b3fc9b5
Show file tree
Hide file tree
Showing 22 changed files with 124 additions and 60 deletions.
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "6f47420213f757831fae65c686aa471749fa8d60",
"commitHash": "7d49e6c7e2f8896c47f586706e67e1fb215529dc",
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
},
"comments": "cutlass"
Expand Down
5 changes: 5 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ else()
set(CMAKE_CXX_STANDARD 17)
endif()

if (MSVC)
# Make sure Visual Studio sets __cplusplus macro correctly: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
endif()

set_property(GLOBAL PROPERTY USE_FOLDERS ON)
# NOTE: POSITION INDEPENDENT CODE hurts performance, and it only make sense on POSIX systems
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b
re2;https://github.com/google/re2/archive/refs/tags/2024-05-01.tar.gz;206cfee5ee0b4c6844680ba66275e9e8faa77405
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.0.zip;ae038931b9fc2c416c17d9cda91d9706b343f56d
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
Expand Down
1 change: 1 addition & 0 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_3.5.0.patch
)

FetchContent_GetProperties(cutlass)
Expand Down
4 changes: 4 additions & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@
endif()
endif()

if(MSVC)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Zc:__cplusplus>")
endif()

onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers)
if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(${target} onnxruntime_training)
Expand Down
25 changes: 25 additions & 0 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"

#include <cuda_runtime.h>
+#include <cuda_fp16.h>

#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
CUTLASS_HOST_DEVICE
half_t operator()(half_t const &lhs) const {
#if defined(__CUDA_ARCH__)
+#if (__CUDA_ARCH__ >= 530)
auto result = hrsqrt(reinterpret_cast<__half const &>(lhs));
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
4 changes: 1 addition & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
!disable_memory_efficient_attention_ &&
nullptr == past &&
nullptr == present &&
(parameters.head_size & 7) == 0 &&
(parameters.v_head_size & 7) == 0 &&
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);

if (use_memory_efficient_attention) {
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,8 @@ struct RightPaddingBatchHook {
batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start;
}

// Custom masking
if (p.causal_diagonal_ptr) {
p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
}
if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
p.causal_diagonal_offset += p.num_keys - p.num_queries;
p.causal_diagonal_offset = p.num_keys - p.num_queries;
}
if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft ||
p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
Expand Down Expand Up @@ -143,9 +139,10 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
AK::attention_kernel(p);
}

template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, int max_head_size>
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
constexpr bool dropout = false;
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, max_head_size, dropout>;
typename Attention::Params p;
{ // set parameters
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
Expand Down Expand Up @@ -220,6 +217,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;

if (params.has_custom_right_padding) {
kernel_fn = attention_kernel_batched_impl_right_padding<Attention, queries_per_block>;
}
Expand All @@ -237,20 +235,23 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, params.stream>>>(p);
}

template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, bool single_value_iteration>
template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, int max_head_size>
void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, max_head_size>;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect
#endif

// Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;

DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, max_head_size>(params);
}));

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
Expand All @@ -259,11 +260,11 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
template <typename T, typename ArchTag>
void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
if (params.v_head_size <= 64) {
DispatchIsAligned<T, ArchTag, 64, 64, true>(params);
DispatchIsAligned<T, ArchTag, 64, 64, 64>(params);
} else if (params.v_head_size <= 128) {
DispatchIsAligned<T, ArchTag, 32, 128, true>(params);
DispatchIsAligned<T, ArchTag, 32, 128, 128>(params);
} else {
DispatchIsAligned<T, ArchTag, 32, 128, false>(params);
DispatchIsAligned<T, ArchTag, 32, 128, kEfficientAttentionMaxHeadSize>(params);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

constexpr int kEfficientAttentionMaxHeadSize = 1024;

struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
Expand Down Expand Up @@ -49,8 +51,11 @@ struct MemoryEfficientAttentionParams {

void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);

inline bool has_memory_efficient_attention(int32_t sm, bool is_half) {
return sm >= (is_half ? 53 : 50);
inline bool has_memory_efficient_attention(int32_t sm, bool is_half, int qk_head_size, int v_head_size) {
return sm >= (is_half ? 53 : 50) &&
(qk_head_size & 7) == 0 &&
(v_head_size & 7) == 0 &&
qk_head_size <= kEfficientAttentionMaxHeadSize && v_head_size <= kEfficientAttentionMaxHeadSize;
}

void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
#endif

#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>
Expand Down Expand Up @@ -98,7 +97,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;

const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
******************************************************************************/
#pragma once

#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
Expand Down Expand Up @@ -32,10 +32,8 @@ struct Flash_kernel_traits {
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>>;
using ValLayoutMNK = cute::Layout<cute::Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = cute::Layout<cute::Shape<_1, _2, _2>>;
#endif

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
Expand Down Expand Up @@ -77,7 +75,7 @@ struct Flash_fwd_kernel_traits : public Base {
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>, _1, _1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * kNWarps>, _16, _16>>;

using SmemLayoutAtomQ = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Expand Down Expand Up @@ -208,17 +206,17 @@ struct Flash_bwd_kernel_traits : public Base {
using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch,
cute::Layout<cute::Shape<cute::Int<AtomLayoutMSdP>, cute::Int<kNWarps / AtomLayoutMSdP>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;

using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch,
cute::Layout<cute::Shape<cute::Int<AtomLayoutNdKV>, cute::Int<kNWarps / AtomLayoutNdKV>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;

using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch,
cute::Layout<cute::Shape<cute::Int<AtomLayoutMdQ>, cute::Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;

using SmemLayoutAtomQdO = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
cute::Layout<cute::Shape<_8, cute::Int<kBlockKSmem>>,
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
#include <cuda_bf16.h>
#endif

#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>

#include <cutlass/array.h>
#include <cutlass/cutlass.h>
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
(parameters.head_size & 7) == 0 &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size);
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Local attention UNSUPPORTED for sm < 80 on CUDA.");
Expand Down
21 changes: 10 additions & 11 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,16 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {

bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;

bool use_memory_efficient_attention = !use_flash_attention &&
fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
(parameters.head_size & 7) == 0 &&
(parameters.v_head_size & 7) == 0 &&
is_long_sequence &&
!past_no_bias &&
(relative_position_bias == nullptr || is_good_for_rpb) &&
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
bool use_memory_efficient_attention =
!use_flash_attention &&
fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
!past_no_bias &&
(relative_position_bias == nullptr || is_good_for_rpb) &&
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
#else
constexpr bool use_memory_efficient_attention = false;
#endif
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,10 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (nullptr == fused_runner) {
int sm = device_prop.major * 10 + device_prop.minor;
bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0;
use_memory_efficient_attention = is_good_for_rpb &&
sizeof(T) == 2 && // only enable for fp16
(parameters.head_size & 7) == 0 &&
(parameters.v_head_size & 7) == 0 &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
use_memory_efficient_attention =
is_good_for_rpb &&
sizeof(T) == 2 && // only enable for fp16
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
}
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
use_memory_efficient_attention =
is_good_for_rpb &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(parameters.head_size & 7) == 0 &&
(parameters.v_head_size & 7) == 0 &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
}
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,10 @@ class QuantBMmaMultistage :
// accuracy, where each mainloop iteration first accumulates into a temporary
// set of freshly-cleared accumulators, which are subsequently added to the
// final accumulator set.
static bool const kStagedAccumulation = arch::UseStagedAccumulation<typename Operator::MathOperator>::value;

// Change the following to avoid build error: class "cutlass::arch::OpMultiplyAdd" has no member "ElementA".
// kStagedAccumulation = arch::detail::UseStagedAccumulation<typename Operator::MathOperator>::value;
static bool const kStagedAccumulation = false;
};

private:
Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/core/providers/tensorrt/nv_includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
// Licensed under the MIT License.
#pragma once

// File to include the required TRT headers with workarounds for warnings we can't fix.

// Ignore warning C4100: unreferenced formal parameter
// File to include the required TRT headers with workarounds for warnings we can't fix or not fixed yet.
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced formal parameter
#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::IPluginV2' was declared deprecated
#endif

#include <NvInfer.h>
Expand Down
Loading

0 comments on commit b3fc9b5

Please sign in to comment.