Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] fused attention and cu_seqlens #1259

Open
Marks101 opened this issue Oct 16, 2024 · 2 comments
Open

[PyTorch] fused attention and cu_seqlens #1259

Marks101 opened this issue Oct 16, 2024 · 2 comments
Assignees

Comments

@Marks101
Copy link
Contributor

Hi team,

we are currently adapting our training environment to use the fused attention functions. In one of our training setups, we work with batch size one and concaternate multiple documents along the sequence dimension (sbhd format). We set cu_seqlens_q and cu_seqlens_kv so that these documents cannot attend on each other. This is actually not a padding use case, because we always fill up the whole sequence and there is no packing and unpacking with pack_tensors() and unpack_tensors() required. With the flash attention backend this worked perfectly fine and produces the results that we intended. With the fused attention functions we get device side assertions for this input. Here is a small sample code:

import os
import torch

from transformer_engine.pytorch.attention import DotProductAttention, _attention_backends

seqlen, batch_size, heads, kv_channels = 2048, 1, 16, 64

q, k, v = [torch.randn(seqlen, batch_size, heads, kv_channels, dtype=torch.float16, device="cuda", requires_grad=True) for _ in range(3)]

cu_seqlens_q = cu_seqlens_kv = torch.tensor([0, 300, 1100, 2048], device="cuda", dtype=torch.int32)

attention_kernel = DotProductAttention(heads, kv_channels)

os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = attention_kernel(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = attention_kernel(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

torch.testing.assert_close(output_fused, output_flash, atol=1e-2, rtol=1e-2)

Was the use case we have been working with ever intended? Or is there just some assertion missing that forbids to use cu_seqlens without setting a padding mode?

@cyanguwa cyanguwa self-assigned this Oct 19, 2024
@cyanguwa
Copy link
Collaborator

cyanguwa commented Oct 29, 2024

Hey Markus!

I think what you wanted to do is in line with the thd format. I've tweaked your script a little and it seems to work for both FlashAttention and FusedAttention. In this case, Transformer Engine is treating your batch as [t=2048, h, d], and batch size b=3 (inferred from cu_seqlens's shape [b+1]). Your original script did run, for FlashAttention, but it wasn't running as intended if I understand your use case correctly. If you turn on NVTE_DEBUG_LEVEL=2, you'll see that it's treating the batch as sbhd format, i.e. all 2048 tokens were in 1 sequence. It's also applying causal mask (because that's the default), and not using cu_seqlens tensors (because we go through the flash_attn_func path, which doesn't take cu_seqlens).

I have a little blurb over here to explain the use cases of sbhd + padding and thd + padding. Hope that helps!

import os
import torch

from transformer_engine.pytorch.attention import DotProductAttention, _attention_backends

seqlen, batch_size, heads, kv_channels = 2048, 1, 16, 64

q, k, v = [torch.randn(seqlen * batch_size, heads, kv_channels, dtype=torch.float16, device="cuda", requires_grad=True) for _ in range(3)]

cu_seqlens_q = cu_seqlens_kv = torch.tensor([0, 300, 1100, 2048], device="cuda", dtype=torch.int32)

attention_kernel = DotProductAttention(heads, kv_channels)

os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
output_flash = attention_kernel(q, k, v, qkv_format='thd', attn_mask_type='padding', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = attention_kernel(q, k, v, qkv_format='thd', attn_mask_type='padding', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)

torch.testing.assert_close(output_fused, output_flash, atol=1e-2, rtol=1e-2)

Run:

NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py 
/code/pr-thd-int64/TransformerEngine/transformer_engine/pytorch/attention.py:5162: UserWarning: window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=padding
  warnings.warn(
[INFO     | DotProductAttention]: Running with FlashAttention backend (version 2.4.2)
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)

@Marks101
Copy link
Contributor Author

Hey Hey,

oh, okay, then my use case was simply not correct 🙈
thank you so much for your explanation and for extending the documentation. That was exactly the information that I was looking for!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants