-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Rope Compatibility with Cos/Sin Position Embedding for Batch Size…
… > 1 (#477) ## Summary Fix Rope Compatibility with Cos/Sin Position Embedding for Batch Size > 1 This PR addresses an issue with the compatibility of the ROPE implementation when using cosine/sine position embeddings with a batch size greater than 1. In the default behavior of transformers, position_ids is set to None during training, which results in the following computation: ```python cache_position = torch.arange(seq_len) position_ids = cache_position.unsqueeze(0) ``` This leads to the shape of the position embeddings being (1, seq_len, head_dim), which is consistent with the implementation in Liger. However, if position_ids are pre-calculated for any reason(In my experiment, I implement m-rope in another different way making position_ids pre-calculated), the current implementation fails to handle this scenario correctly. This PR introduces a fix to ensure that the ROPE implementation can accommodate pre-computed position_ids. In the unit test test_ropy.py, I have added a variable expand_position_ids to simulate this condition. The previous implementation fails under this scenario, while the new patch successfully resolves the issue. pytest details: ![image](https://github.com/user-attachments/assets/cf75debe-1048-4481-a909-1c846be760ed) ## Testing Done - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]> Co-authored-by: ByronHsu <[email protected]>
- Loading branch information
1 parent
c899cc7
commit d7c78df
Showing
3 changed files
with
55 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters