Skip to content

Commit

Permalink
fix rmsnorm template function invocation problem(template function pa…
Browse files Browse the repository at this point in the history
…rtial specialization is not allowed in Cpp) and luckily pass e2e precision test
  • Loading branch information
SunflowerAries committed Mar 13, 2024
1 parent 6fd355a commit c14eede
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 35 deletions.
100 changes: 70 additions & 30 deletions extensions/csrc/cuda/rms_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@
#include "../common/micros.h"
#include "../common/cuda_type_utils.h"

#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} else { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
general_##__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} \

// optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void rms_layernorm_kernel(
Expand Down Expand Up @@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel(
}
}

template<int unroll_factor>
__global__ void rms_layernorm_kernel(
float* __restrict__ out, // [..., hidden_size]
const float* __restrict__ input, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
template<typename scalar_t, int unroll_factor>
__global__ void general_rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
Expand All @@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] = (float) input[id];
variance += x_local[cnt] * x_local[cnt];
}
variance = blockReduceSum<float>(variance);
Expand All @@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}

Expand Down Expand Up @@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel(
}
}

template<int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel(
float* __restrict__ input, // [..., hidden_size]
float* __restrict__ residual, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
template<typename scalar_t, int unroll_factor>
__global__ void general_fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
Expand All @@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] += residual[id];
x_local[cnt] = (float) input[id];
x_local[cnt] += (float) residual[id];
variance += x_local[cnt] * x_local[cnt];
residual[id] = x_local[cnt];
residual[id] = (scalar_t) x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
Expand All @@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}

Expand All @@ -190,7 +218,8 @@ void rms_layernorm(

if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
Expand All @@ -201,7 +230,8 @@ void rms_layernorm(
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
Expand All @@ -216,11 +246,12 @@ void rms_layernorm(
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
Expand All @@ -232,7 +263,8 @@ void rms_layernorm(
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
Expand All @@ -244,7 +276,8 @@ void rms_layernorm(
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
Expand All @@ -256,7 +289,8 @@ void rms_layernorm(
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
Expand Down Expand Up @@ -288,7 +322,8 @@ void fused_add_rms_layernorm(

if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
Expand All @@ -299,7 +334,8 @@ void fused_add_rms_layernorm(
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
Expand All @@ -314,11 +350,12 @@ void fused_add_rms_layernorm(
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
Expand All @@ -330,7 +367,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
Expand All @@ -342,7 +380,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
Expand All @@ -354,7 +393,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
Expand Down
14 changes: 9 additions & 5 deletions tests/test_infer/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
).cuda()
.cuda()
.half()
)
model = model.eval()

inputs = [
Expand All @@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50

if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
Expand Down

0 comments on commit c14eede

Please sign in to comment.