-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Issue]: RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. #41
Comments
q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
print(q.shape)
print(k.shape)
print(v.shape)
attn = flash_attn_func(q, k, v, causal=is_causal)
attn = rearrange(attn, 'b l h d -> (b h) l d') The error message: torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
Traceback (most recent call last):
File "/tmp/amlt-code-download/fairseq/train.py", line 14, in <module>
cli_main()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 543, in cli_main
distributed_utils.call_main(cfg, main)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 365, in call_main
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 339, in distributed_main
main(cfg, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 191, in main
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 307, in train
log_output = trainer.train_step(samples)
File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 850, in train_step
raise e
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 818, in train_step
loss, sample_size_i, logging_output = self.task.train_step(
File "/tmp/amlt-code-download/fairseq/tasks/gpt.py", line 253, in train_step
optimizer.backward(loss)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward
loss.backward()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward
_flash_attn_backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. |
The log shows the program runs well for steps and triggers the bug, rather than encountering the error at the first calling. |
I tried both rocm-5.7/6.0 dockers. |
The bug is related to the qkv shape: [1, 2048, 48, 64]: works well [2, 2048, 48, 64]: triggers the bug [4, 2048, 48, 64]: triggers the bug [1, 2048, 24, 128]: triggers the bug [2, 2048, 24, 128]: triggers the bug [2, 2048, 25, 128]: triggers the bug [2, 2048, 24, 124]: works well [2, 2048, 48, 62]: works well |
@donglixp Can I have the script you are running? |
@howiejayz Forward and backward with the following shapes: [2, 2048, 48, 64]: triggers the bug [4, 2048, 48, 64]: triggers the bug [1, 2048, 24, 128]: triggers the bug [2, 2048, 24, 128]: triggers the bug [2, 2048, 25, 128]: triggers the bug |
q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
attn = flash_attn_func(q, k, v, causal=is_causal)
attn = rearrange(attn, 'b l h d -> (b h) l d') |
Although using [2, 2048, 48, 62] didn't trigger the above error. I found that the job encountered loss divergence issues, while a similar recipe ran successfully before (when the VM ROCM is 5.4 and docker is 5.7). |
The error seems to be triggered when dout is not contiguous. May I ask how do you generate the dout when passing to the backward? |
@howiejayz Yes, they were. The contiguous() was also handled at
|
@donglixp. Thanks. Can I also have the information of the repo and branch you are testing? So I can reproduce your result and see what step goes wrong. |
We change the backend of flash attention 2in the branch of ck_tile |
Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team. |
|
cc @danyao12 |
Problem Description
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward
loss.backward()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward
_flash_attn_backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Operating System
20.04.6 LTS (Focal Fossa)
CPU
AMD EPYC 7V12 64-Core Processor
GPU
AMD Instinct MI250X
ROCm Version
ROCm 6.0.0, ROCm 5.7.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
The text was updated successfully, but these errors were encountered: