diff --git a/megatron/fused_kernels/fused_rotary_positional_embedding.h b/megatron/fused_kernels/fused_rotary_positional_embedding.h index 7ac13932d..28dca70a5 100644 --- a/megatron/fused_kernels/fused_rotary_positional_embedding.h +++ b/megatron/fused_kernels/fused_rotary_positional_embedding.h @@ -52,8 +52,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - int offset_src_dst = offset_head + hn_id; - dst[offset_src_dst] = src[offset_src_dst]; + dst[offset_head + hn_id] = src[offset_head + hn_id]; } } } @@ -89,7 +88,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - dst[offset_head + hn_id] = 1.0; + dst[offset_head + hn_id] = src[offset_head + hn_id]; } } }