Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama GPU model with dcn fsdp + ici tp + cudnn flash attention broken #1093

Open
wang2yn84 opened this issue Dec 10, 2024 · 2 comments
Open

Comments

@wang2yn84
Copy link
Collaborator

wang2yn84 commented Dec 10, 2024

I'm using 7b as an example, the following config doesn't work even on 2 nodes setup:

dcn fsdp = number of nodes
ici tp = 8
attention = cudnn_flash_te

It works with dot_product attention. Here is a snippet of the error message:
ERROR 2024-12-09T12:37:43.202101549Z [resource.labels.containerName: gpu-image] 2024-12-09 12:37:43.201690: E external/xla/xla/service/rendezvous.cc:55] This thread has been waiting for first call to collective operation 5688; run_id=1895556971 for 20 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
ERROR 2024-12-09T12:37:46.679868113Z [resource.labels.containerName: gpu-image] 2024-12-09 12:37:46.679466: F external/xla/xla/service/rendezvous.cc:77] Termination timeout for first call to collective operation 5688; run_id=1895556971 of 40 seconds exceeded. Exiting to ensure a consistent program state. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
ERROR 2024-12-09T12:37:46.679903343Z [resource.labels.containerName: gpu-image] Fatal Python error: Aborted

I had a working image dating back to Oct 8th. Not sure if Oct 8th is the exact date that it broke, but images after that doesn't work with this config. This is the script I use: https://github.com/AI-Hypercomputer/maxtext/blob/lance-nv/mt_jon_pgle.sh. It has nothing to do with PGLE tho.

@wang2yn84 wang2yn84 changed the title llama model dcn fsdp + ici tp + cudnn flash attention broken llama GPU model with dcn fsdp + ici tp + cudnn flash attention broken Dec 10, 2024
@abhinavgoel95
Copy link
Contributor

What are the XLA flags that you used?

@abhinavgoel95
Copy link
Contributor

Found it:
'--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=33554432 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization --xla_gpu_graph_level=0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants