From aaa000cbc337abcee6a8049141b4486c15faef0d Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Sat, 25 Nov 2023 11:08:00 -0500 Subject: [PATCH] Update fused_rotary_positional_embedding.h Ports the fix from https://github.com/NVIDIA/apex/pull/1750 into this branch. --- megatron/fused_kernels/fused_rotary_positional_embedding.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/megatron/fused_kernels/fused_rotary_positional_embedding.h b/megatron/fused_kernels/fused_rotary_positional_embedding.h index 7ac13932d..28dca70a5 100644 --- a/megatron/fused_kernels/fused_rotary_positional_embedding.h +++ b/megatron/fused_kernels/fused_rotary_positional_embedding.h @@ -52,8 +52,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - int offset_src_dst = offset_head + hn_id; - dst[offset_src_dst] = src[offset_src_dst]; + dst[offset_head + hn_id] = src[offset_head + hn_id]; } } } @@ -89,7 +88,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - dst[offset_head + hn_id] = 1.0; + dst[offset_head + hn_id] = src[offset_head + hn_id]; } } }