Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention support. #20152

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

hazemessamm
Copy link
Contributor

I added support for flash attention for PyTorch.

Let me know what do you think about this current implementation so I can add support for JAX and maybe will try for TF.

@codecov-commenter
Copy link

codecov-commenter commented Aug 22, 2024

Codecov Report

Attention: Patch coverage is 8.00000% with 23 lines in your changes missing coverage. Please review.

Project coverage is 79.31%. Comparing base (f4a4725) to head (7e99f06).
Report is 1 commits behind head on master.

Files Patch % Lines
keras/src/backend/torch/nn.py 8.00% 23 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20152      +/-   ##
==========================================
- Coverage   79.35%   79.31%   -0.04%     
==========================================
  Files         501      501              
  Lines       47311    47336      +25     
  Branches     8689     8695       +6     
==========================================
+ Hits        37542    37544       +2     
- Misses       8014     8037      +23     
  Partials     1755     1755              
Flag Coverage Δ
keras 79.16% <8.00%> (-0.04%) ⬇️
keras-jax 62.44% <8.00%> (-0.03%) ⬇️
keras-numpy 57.55% <8.00%> (-0.03%) ⬇️
keras-tensorflow 63.83% <8.00%> (-0.03%) ⬇️
keras-torch 62.47% <8.00%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR -- the code looks good! Please add a unit test.

For the JAX version, I think we'd want to rely on a Pallas kernel. We can get help from the JAX team.

Copy link

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Sep 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

4 participants