Skip to content

Commit

Permalink
Update fused_rotary_positional_embedding.h
Browse files Browse the repository at this point in the history
Ports the fix from NVIDIA/apex#1750 into this branch.
  • Loading branch information
StellaAthena authored Nov 25, 2023
1 parent e63242d commit aaa000c
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions megatron/fused_kernels/fused_rotary_positional_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2,
int offset_head = offset_block + head_id * hn;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
int offset_src_dst = offset_head + hn_id;
dst[offset_src_dst] = src[offset_src_dst];
dst[offset_head + hn_id] = src[offset_head + hn_id];
}
}
}
Expand Down Expand Up @@ -89,7 +88,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
int offset_head = offset_block + head_id * hn;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
dst[offset_head + hn_id] = 1.0;
dst[offset_head + hn_id] = src[offset_head + hn_id];
}
}
}
Expand Down

0 comments on commit aaa000c

Please sign in to comment.