Skip to content

Commit

Permalink
Fix Rope Compatibility with Cos/Sin Position Embedding for Batch Size…
Browse files Browse the repository at this point in the history
… > 1 (#477)

## Summary
Fix Rope Compatibility with Cos/Sin Position Embedding for Batch Size >
1

This PR addresses an issue with the compatibility of the ROPE
implementation when using cosine/sine position embeddings with a batch
size greater than 1.
In the default behavior of transformers, position_ids is set to None
during training, which results in the following computation:
```python
cache_position = torch.arange(seq_len)
position_ids = cache_position.unsqueeze(0)
```
This leads to the shape of the position embeddings being (1, seq_len,
head_dim), which is consistent with the implementation in Liger.
However, if position_ids are pre-calculated for any reason(In my
experiment, I implement m-rope in another different way making
position_ids pre-calculated), the current implementation fails to handle
this scenario correctly. This PR introduces a fix to ensure that the
ROPE implementation can accommodate pre-computed position_ids.
In the unit test test_ropy.py, I have added a variable
expand_position_ids to simulate this condition. The previous
implementation fails under this scenario, while the new patch
successfully resolves the issue.

pytest details:

![image](https://github.com/user-attachments/assets/cf75debe-1048-4481-a909-1c846be760ed)

## Testing Done

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <[email protected]>
Co-authored-by: ByronHsu <[email protected]>
  • Loading branch information
3 people authored Dec 23, 2024
1 parent c899cc7 commit d7c78df
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
32 changes: 23 additions & 9 deletions src/liger_kernel/ops/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _triton_rope(
sin_row_stride,
sl,
bs: tl.constexpr,
cos_bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
Expand All @@ -29,7 +30,7 @@ def _triton_rope(
# k size: (bsz, seq_len, num_kv_heads, head_dim)
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)

# cos size: (1, seq_len, head_dim)
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
# stride: (seq_len * head_dim, head_dim, 1)
pid = tl.program_id(0)

Expand All @@ -48,9 +49,19 @@ def _triton_rope(
# and pid % sl to get the sequence index.
# 2. We only need the left half of cos and sin matrix because the right half is just
# a clone of the left half.
cos_row_idx = pid % (sl)
cos = cos + cos_row_idx * cos_row_stride
sin = sin + cos_row_idx * sin_row_stride
batch_idx = pid // sl
cos_row_idx = pid % sl
cos = cos + tl.where(
cos_bs == 1,
cos_row_idx * cos_row_stride,
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
)
sin = sin + tl.where(
cos_bs == 1,
cos_row_idx * sin_row_stride,
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
)

cos_offsets = tl.arange(0, pad_hd // 2)
cos_mask = cos_offsets < hd // 2
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
Expand Down Expand Up @@ -118,7 +129,6 @@ def _triton_rope(


def rope_forward(q, k, cos, sin):

# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
q = q.transpose(1, 2)
Expand All @@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
cos_batch_size = cos.shape[0]

_triton_rope[(n_row,)](
q,
Expand All @@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand All @@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
dk = dk.transpose(1, 2)

batch_size, seq_len, n_q_head, head_dim = dq.shape
cos_batch_size = cos.shape[0]
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
Expand All @@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -221,8 +235,8 @@ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim)
sin size: (1, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
q, k, cos, sin = rope_forward(q, k, cos, sin)
ctx.save_for_backward(cos, sin)
Expand All @@ -232,8 +246,8 @@ def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim)
sin size: (1, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""

cos, sin = ctx.saved_tensors
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/transformers/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
Expand Down
32 changes: 30 additions & 2 deletions test/transformers/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,20 @@
),
],
)
@pytest.mark.parametrize(
"expand_position_ids",
[True, False],
)
def test_correctness(
bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
bsz,
seq_len,
num_q_heads,
num_kv_heads,
head_dim,
dtype,
expand_position_ids,
atol,
rtol,
):
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

Expand All @@ -70,6 +82,8 @@ def test_correctness(
k2 = _tensor_k.clone().requires_grad_(True)

pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
if expand_position_ids:
pos_ids = pos_ids.expand(bsz, -1)
cos, sin = rotary_emb(k1, pos_ids)

# validate forward pass
Expand Down Expand Up @@ -111,8 +125,20 @@ def test_correctness(
(torch.bfloat16, 1e-1, 1e-5),
],
)
@pytest.mark.parametrize(
"expand_position_ids",
[True, False],
)
def test_functional_correctness(
bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
bsz,
seq_len,
num_q_heads,
num_kv_heads,
head_dim,
expand_position_ids,
dtype,
atol,
rtol,
):
_q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype)
_k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype)
Expand All @@ -126,6 +152,8 @@ def test_functional_correctness(
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
if expand_position_ids:
pos_ids = pos_ids.expand(bsz, -1)
cos, sin = rotary_emb(k1, pos_ids)

functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)
Expand Down

0 comments on commit d7c78df

Please sign in to comment.