Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate the repetition pattern to expand and reshape #3645

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 25, 2024

In RoPE, this repeat pattern shows up commonly:

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136

torch.cat((freqs, freqs), dim=-1)

This pattern shows up as PadOp followed by CatOp in nvFuser.

This preseg pass translates this pattern to expand and reshape ops. For example, given a pattern like:

t0 = [i0];
t1 = cat({t0, t0}, -1);

It will be translated to:

t0 = [i0]
t2 = broadcast(t0, {true, false});
t3 = expand(t2, {2, i0});
t4 = reshape(t3, {2 * i0});

And all uses of t1 will be replaced by t4.

While the pattern can be handled by the resize scheduler, it's currently limited to segments with pointwise ops only, and its scheduling heuristics are not tuned yet. Specifically, I experimentally observed a significant perf gain with the Mistral backward function.

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 25, 2024

!test --diff

@naoyam naoyam added the rope label Dec 25, 2024
@naoyam
Copy link
Collaborator Author

naoyam commented Dec 25, 2024

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 25, 2024

!test --diff

@naoyam naoyam marked this pull request as ready for review December 25, 2024 04:20
@naoyam naoyam requested a review from jjsjann123 December 25, 2024 04:20
@naoyam naoyam changed the title Translate the repetition pattern with expand and reshape Translate the repetition pattern to expand and reshape Dec 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant