diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 2e678ec6..9d71e7bf 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -160,7 +160,7 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S if (num_index_sets == 0) { vec_t v; - v.fill(DTypeOut(0)); + v.fill(DTypeOut(0.f)); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { s_merged[pos * num_heads + head_idx] = -5e4; @@ -325,7 +325,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ if (num_index_sets == 0) { vec_t v; - v.fill(DTypeOut(0)); + v.fill(DTypeOut(0.f)); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { s_merged[pos * num_heads + head_idx] = -5e4; diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 02c24a7e..6fa7b317 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -44,7 +44,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* for (uint32_t i = 0; i < rounds; i++) { vec_t input_vec; - input_vec.fill(0); + input_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } @@ -79,8 +79,8 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* vec_t input_vec; vec_t weight_vec; vec_t output_vec; - input_vec.fill(0); - weight_vec.fill(0); + input_vec.fill(0.f); + weight_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); @@ -130,9 +130,9 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res for (uint32_t i = 0; i < rounds; i++) { vec_t input_vec; - input_vec.fill(0); + input_vec.fill(0.f); vec_t residual_vec; - residual_vec.fill(0); + residual_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); @@ -174,9 +174,9 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res vec_t input_vec; vec_t weight_vec; vec_t residual_vec; - input_vec.fill(0); - weight_vec.fill(0); - residual_vec.fill(0); + input_vec.fill(0.f); + weight_vec.fill(0.f); + residual_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index a7a8dfd3..ce7a6015 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -17,6 +17,9 @@ #define FLASHINFER_UTILS_CUH_ #include #include +#include +#include +#include #include #include diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index a40b4575..87f08581 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -31,6 +31,35 @@ namespace flashinfer { #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) +// CUDA version < 12.4 and GPU architecture < 80 +FLASHINFER_INLINE __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) { + __nv_bfloat162 t; + t.x = x; + t.y = y; + return t; +} + +FLASHINFER_INLINE __nv_bfloat162 __floats2bfloat162_rn(const float a, const float b) { + __nv_bfloat162 val; + val = __nv_bfloat162(__float2bfloat16_rn(a), __float2bfloat16_rn(b)); + return val; +} + +FLASHINFER_INLINE __nv_bfloat162 __float22bfloat162_rn(const float2 a) { + __nv_bfloat162 val = __floats2bfloat162_rn(a.x, a.y); + return val; +} +FLASHINFER_INLINE float2 __bfloat1622float2(const __nv_bfloat162 a) { + float hi_float; + float lo_float; + lo_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).x); + hi_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).y); + return make_float2(lo_float, hi_float); +} +#endif + /******************* vec_t type cast *******************/ template