Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement in ROCM fmha-backward #1082

Merged
merged 696 commits into from
Aug 22, 2024
Merged

Conversation

qianfengz
Copy link
Contributor

@qianfengz qianfengz commented Aug 17, 2024

This PR is mostly for providing update with regards to ROCM FMHA Backward. Specifically:

  1. Improved the performance of FMHA Backward for generally all input shapes
  2. Added support for Headdim256 (all headdim sizes bigger than 128 and less/equal to 256 currently supported)
  3. Changed to use unpadded LSE layout for var-len sequence situation
  4. Improved accuracy of output grad_query for both fp16 and bf16 input type by using fp32 for AtomicAdd based accumulation
  5. Adapt to the kernel API changes in ck_tile fmha fwd/bwd kernel (to support the requirements from TriDao FlashAttention)
  6. Add environment variable to control the number of compiled instances
  7. Add environment variable to select the using of RTZ or RTN rounding method for fp32 to bf16 conversion to balance performance and accuracy
  8. Bug fixing/stability enhancement

To test/verify, using the following command/scripts

#> pytest tests/test_mem_eff_attention.py::test_forward   -k "not triton" -k "not flshattF" -k "not "fa2F"
#> pytest tests/test_mem_eff_attention.py::test_backward -k "not flshattB" -k "not fa2B"
#> pytest tests/test_mem_eff_attention.py::test_dropout
#> pytest tests/test_mem_eff_attention.py::test_dropout_backward_ck

To benchmark performance, using

#> python xformers/benchmark/benchmark_mem_eff_attention.py --omit-forward

qianfengz and others added 30 commits February 4, 2024 15:24
ensure ck_decoder does not dispatch in test_attn_bias_padded
Apply the existing linters (1/n)
@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm labels Aug 17, 2024
Copy link
Member

@jianyuh jianyuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the depending Composable Kernel, is it reliable to use the current CK trunk development branch?

static void Run(BatchedBackwardParams& param, hipStream_t stream) {
{
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockSize = 64;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the reason on decreasing block size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just keep consistent with ck_tile codes

fmha_bwd.py #L524

which is well tested for performance consideration

@@ -9,6 +9,46 @@
#include <ck_tile/core.hpp>
#include <stdexcept>

#ifndef FMHA_SUPPORT_MAX_HEADDIM_128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Checking if we want to only have head dim = 128 support (to save compile time), not 64, 32, 256, any easy way to configure this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use unset MAX_JOBS, the compiling is very fast. Even though it is easy to only build for dim == 128, we don't like do this, since we are not very confident with our building since the scripts provided for verifying under tests/ are not specifically prepared for dim128. You know, for any change in the codes, we always try to run the following scripts to verify that every-thing is correctly running:

#> pytest tests/test_mem_eff_attention.py::test_forward
#> pytest tests/test_mem_eff_attention.py::test_backward
#> pytest tests/test_mem_eff_attention.py::test_dropout
#> pytest tests/test_mem_eff_attention.py::test_dropout_backward_ck

tests/test_mem_eff_attention.py Show resolved Hide resolved
@@ -1003,6 +998,38 @@ def test_dropout_backward_ck(dt, q_len, kv_len, batch_size, k, p):
)


@cuda_only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test got here as merge conflict resolution gone bad?

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on the xFormers side (didn't review the generated files/generator but I trust you on that)

@jianyuh jianyuh merged commit e3900ba into facebookresearch:main Aug 22, 2024
23 checks passed
@qianfengz qianfengz deleted the upstream_pr branch September 5, 2024 06:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants