Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistent version of Flash Attention #2407

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Aug 2, 2024

Added two more variants: triton_tutorial_flash_v2_persistent and triton_tutorial_flash_v2_persistent_tma
The variants handle non-causal only. For causal, it has 2 invocations to attn_fwd_inner, which means we will have an outerloop and 2 inner loops
for ... # persistent loop
for ...
for ...
It is not clear how to flatten it into a 1D loop.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren marked this pull request as draft August 2, 2024 18:41
@manman-ren manman-ren temporarily deployed to docker-s3-upload August 2, 2024 18:41 — with GitHub Actions Inactive
@manman-ren manman-ren temporarily deployed to docker-s3-upload August 2, 2024 18:41 — with GitHub Actions Inactive
@manman-ren manman-ren requested review from embg and xuzhao9 August 2, 2024 19:07

@triton.autotune(list(filter(keep, configs)), key=["N_CTX"])
@triton.jit
def _attn_fwd_persistent_tma(Q, Out, desc_q, desc_k, desc_v, sm_scale, M, desc_o, #
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a copy of _attn_fwd_persistent but with TMA changes?

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants