Skip to content

Commit

Permalink
Update fused_rotary_positional_embedding.h
Browse files Browse the repository at this point in the history
  • Loading branch information
StellaAthena authored Nov 27, 2023
1 parent d86c399 commit 417f55c
Showing 1 changed file with 68 additions and 47 deletions.
115 changes: 68 additions & 47 deletions megatron/fused_kernels/fused_rotary_positional_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,70 +25,83 @@
namespace {

template <typename scalar_t>
__global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2,
__global__ void fused_rope_forward(int h, int d, int d2, int stride_s,
int stride_b, int stride_h, int stride_d,
int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* src, const scalar_t* cos,
const scalar_t* sin, scalar_t* dst) {
int sq_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = sq_id * b * np * hn + b_id * np * hn;
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
#pragma unroll
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
scalar_t v_sin = sin[sq_id * hn2 + hn_id];
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cos[s_id * d2 + d_id];
scalar_t v_sin = sin[s_id * d2 + d_id];
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_src_dst = offset_block + head_id * hn + hn_id;
scalar_t v_src = src[offset_src_dst];
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
? -src[offset_src_dst + hn2 / 2]
: src[offset_src_dst + hn2 / 2 - hn2];
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? -src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}

// copy the rest
if (hn > hn2) {
if (d > d2) {
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_head = offset_block + head_id * hn;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
dst[offset_head + hn_id] = src[offset_head + hn_id];
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] =
src[offset_head + d_id * stride_d];
}
}
}
}

template <typename scalar_t>
__global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
__global__ void fused_rope_backward(int h, int d, int d2, int stride_s,
int stride_b, int stride_h, int stride_d,
int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* src, const scalar_t* cos,
const scalar_t* sin, scalar_t* dst) {
int sq_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = sq_id * b * np * hn + b_id * np * hn;
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
#pragma unroll
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
scalar_t v_sin = (hn_id + hn2 / 2 < hn2)
? sin[sq_id * hn2 + hn_id + hn2 / 2]
: -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2];
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cos[s_id * d2 + d_id];
scalar_t v_sin = (d_id + d2 / 2 < d2)
? sin[s_id * d2 + d_id + d2 / 2]
: -sin[s_id * d2 + d_id + d2 / 2 - d2];
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_src_dst = offset_block + head_id * hn + hn_id;
scalar_t v_src = src[offset_src_dst];
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
? src[offset_src_dst + hn2 / 2]
: src[offset_src_dst + hn2 / 2 - hn2];
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}

// handle the tail
if (hn > hn2) {
if (d > d2) {
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_head = offset_block + head_id * hn;
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
dst[offset_head + hn_id] = src[offset_head + hn_id];
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
}
}
}
Expand All @@ -97,32 +110,40 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
} // end of anonymous namespace

template <typename scalar_t>
void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2,
void dispatch_fused_rope_forward(int s, int b, int h, int d, int d2,
int stride_s, int stride_b, int stride_h,
int stride_d, int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* input, const scalar_t* cos,
const scalar_t* sin, scalar_t* output) {
auto stream = at::cuda::getCurrentCUDAStream();

int warps_per_block = np < 16 ? 4 : 8;
dim3 blocks(sq, b);
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(C10_WARP_SIZE, warps_per_block);

fused_rope_forward<<<blocks, threads, 0, stream>>>(sq, b, np, hn, hn2, input,
cos, sin, output);
fused_rope_forward<<<blocks, threads, 0, stream>>>(
h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, input, cos, sin, output);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename scalar_t>
void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2,
void dispatch_fused_rope_backward(int s, int b, int h, int d, int d2,
int stride_s, int stride_b, int stride_h,
int stride_d, int o_stride_s, int o_stride_b,
int o_stride_h, int o_stride_d,
const scalar_t* output_grads,
const scalar_t* cos, const scalar_t* sin,
scalar_t* input_grads) {
auto stream = at::cuda::getCurrentCUDAStream();

int warps_per_block = np < 16 ? 4 : 8;
dim3 blocks(sq, b);
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(C10_WARP_SIZE, warps_per_block);

fused_rope_backward<<<blocks, threads, 0, stream>>>(
sq, b, np, hn, hn2, output_grads, cos, sin, input_grads);
h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, output_grads, cos, sin, input_grads);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

0 comments on commit 417f55c

Please sign in to comment.