Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.nn.dot_product_attention #20286

Merged
merged 1 commit into from
Sep 25, 2024

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 25, 2024

I haven't check the performance yet, but I believe adding this operation will be beneficial since both torch and jax have optimized it in their codebases.

From a performance perspective, we should be able to replace _compute_attention in MultiHeadAttention with this op if the input shapes are strictly 4D.

EDITED:
It seems that CI is using jax<=0.4.30 due to python version limitation. I implemented a pure numpy version of dot_product_attention for the unit tests.

For backends:

  • jax: Uses jax.nn.dot_product_attention if available. Otherwise, adapts the impl from jax==0.4.33.
  • numpy: Adapts the impl from jax==0.4.31 (no customizable vmap)
  • tensorflow: Adapts the impl from jax==0.4.31 (no customizable vmap)
  • torch: Uses torch.nn.functional.scaled_dot_product_attention

@codecov-commenter
Copy link

codecov-commenter commented Sep 25, 2024

Codecov Report

Attention: Patch coverage is 88.95349% with 19 lines in your changes missing coverage. Please review.

Project coverage is 78.91%. Comparing base (577ef63) to head (fb303a3).

Files with missing lines Patch % Lines
keras/src/backend/jax/nn.py 88.23% 3 Missing and 3 partials ⚠️
keras/src/backend/torch/nn.py 76.00% 3 Missing and 3 partials ⚠️
keras/src/backend/numpy/nn.py 95.23% 1 Missing and 1 partial ⚠️
keras/src/backend/tensorflow/nn.py 94.59% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
keras/api/_tf_keras/keras/ops/nn/__init__.py 0.00% 1 Missing ⚠️
keras/src/ops/nn.py 92.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20286      +/-   ##
==========================================
+ Coverage   78.88%   78.91%   +0.03%     
==========================================
  Files         511      511              
  Lines       48735    48907     +172     
  Branches     8982     9006      +24     
==========================================
+ Hits        38443    38596     +153     
- Misses       8437     8448      +11     
- Partials     1855     1863       +8     
Flag Coverage Δ
keras 78.77% <88.95%> (+0.03%) ⬆️
keras-jax 62.35% <37.79%> (-0.09%) ⬇️
keras-numpy 57.50% <36.62%> (-0.08%) ⬇️
keras-tensorflow 63.63% <31.39%> (-0.12%) ⬇️
keras-torch 62.34% <23.25%> (-0.14%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, thank you for the contribution! It seems we need to update our JAX version number on GPU CI.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Sep 25, 2024
@fchollet fchollet merged commit 3bc3aef into keras-team:master Sep 25, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Sep 25, 2024
@james77777778 james77777778 deleted the add_dot_product_attention branch September 26, 2024 01:46
@james77777778
Copy link
Contributor Author

The main barrier should be the python version.
jax>=0.4.31 only supports python>=3.10.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants