Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce set_default_sdpa #79

Merged
merged 1 commit into from
Oct 2, 2023
Merged

Introduce set_default_sdpa #79

merged 1 commit into from
Oct 2, 2023

Conversation

cbalioglu
Copy link
Contributor

This PR introduces set_default_sdpa function and sdpa context manager to switch between different attention implementations during runtime.

from fairseq2.nn.transformer import TorchSDPA, NaiveSDPA, set_default_sdpa, sdpa

set_default_sdpa(TorchSDPA)  # or None to use the library default

# Use naive SDPA for debugging (e.g. pdb)
with sdpa(NaiveSDPA)
    model = load_llama_model("llama_7b")

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2023
@cbalioglu cbalioglu merged commit 4f5f8b2 into main Oct 2, 2023
@cbalioglu cbalioglu deleted the sdpa branch October 2, 2023 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants