Skip to content

Commit

Permalink
hotfix: Fix sm75 compilation issue for bf16 on cuda 11.8 & 12.1 (#472)
Browse files Browse the repository at this point in the history
For CUDA 11.8 and 12.1, functions such as `make_bfloat162` are not
declared for `__CUDA_ARCH__ < 800`, and this PR fixes the issue.
  • Loading branch information
yzh119 authored Aug 27, 2024
1 parent a23979b commit a836e7e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 10 deletions.
4 changes: 2 additions & 2 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.fill(DTypeOut(0.f));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
Expand Down Expand Up @@ -325,7 +325,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.fill(DTypeOut(0.f));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
Expand Down
16 changes: 8 additions & 8 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0);
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
Expand Down Expand Up @@ -79,8 +79,8 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0);
weight_vec.fill(0);
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down Expand Up @@ -130,9 +130,9 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0);
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down Expand Up @@ -174,9 +174,9 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0);
weight_vec.fill(0);
residual_vec.fill(0);
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down
3 changes: 3 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#define FLASHINFER_UTILS_CUH_
#include <cuda_device_runtime_api.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include <iostream>
#include <sstream>
Expand Down
29 changes: 29 additions & 0 deletions include/flashinfer/vec_dtypes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,35 @@ namespace flashinfer {

#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__

#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
// CUDA version < 12.4 and GPU architecture < 80
FLASHINFER_INLINE __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) {
__nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
}

FLASHINFER_INLINE __nv_bfloat162 __floats2bfloat162_rn(const float a, const float b) {
__nv_bfloat162 val;
val = __nv_bfloat162(__float2bfloat16_rn(a), __float2bfloat16_rn(b));
return val;
}

FLASHINFER_INLINE __nv_bfloat162 __float22bfloat162_rn(const float2 a) {
__nv_bfloat162 val = __floats2bfloat162_rn(a.x, a.y);
return val;
}
FLASHINFER_INLINE float2 __bfloat1622float2(const __nv_bfloat162 a) {
float hi_float;
float lo_float;
lo_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).x);
hi_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).y);
return make_float2(lo_float, hi_float);
}
#endif

/******************* vec_t type cast *******************/

template <typename dst_t, typename src_t>
Expand Down

0 comments on commit a836e7e

Please sign in to comment.