Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any suggestion for Llama-3.1-70b(128k seq len) deploy mesh with torchtian? #678

Open
medivh-xp opened this issue Nov 15, 2024 · 8 comments
Labels
enhancement New feature or request question Further information is requested

Comments

@medivh-xp
Copy link

medivh-xp commented Nov 15, 2024

Under the 128k long sequence, the activation value memory increases significantly.
CP8 + TP8 seems necessary (they reduce the activation value memory almost linearly), but there is still as much as 50G of activation value memory.
Reccompute the activations of the MLP can reduce it by about 9G, while the recalculation of the ATTENTION layer or MLP up linear seems rather costly.I noticed that the article at https://arxiv.org/pdf/2410.06511 mentioned Full checkpoint was applied to address the activation memory issue,which seems to significantly increase the execution time of recomputation?
Does TorchTitan plan to offload the activation values and reload them during the backward calculation to reduce the activation value memory?

@gnadathur
Copy link
Contributor

cc: @XilunWu

@XilunWu
Copy link
Contributor

XilunWu commented Nov 15, 2024

The PR #592 enables CP in torchtitan. You can change context_parallel_degree (for example 8 for Cp8) in the toml file. See detail in the PR description.

CP8 is enough for 128K on H100 and A100. If you still encounter OOM, you can change selective checkpoint to "full" to further reduce peak memory usage.

@gnadathur
Copy link
Contributor

cc: @lessw2020

@medivh-xp
Copy link
Author

medivh-xp commented Nov 18, 2024

The PR #592 enables CP in torchtitan. You can change context_parallel_degree (for example 8 for Cp8) in the toml file. See detail in the PR description.

CP8 is enough for 128K on H100 and A100. If you still encounter OOM, you can change selective checkpoint to "full" to further reduce peak memory usage.

@XilunWu Thank you for your reply! I noticed that in PR #467, the activation values are reduced through activations offload. If a balance can be struck among computation, memory, and H2D bandwidth, it seems that Full-AC might not be necessary (I'm not sure if my understanding is correct. Full-AC recomputation will significantly reduce the MFU). So how should I choose between full-AC and activations offload? It seems that activations offload could theoretically achieve a higher MFU?

@tianyu-l
Copy link
Contributor

@awgu can you share a bit more on the status of the activation offloading PR? E.g. is it ready to be used, and its performance vs. using full AC on llama models.

@tianyu-l tianyu-l added question Further information is requested enhancement New feature or request labels Nov 18, 2024
@awgu
Copy link
Contributor

awgu commented Nov 18, 2024

The PR is meant as a way to add activation offloading to your model with intrusive changes. The main concern is that for current gen Nvidia GPUs, the offloading may contend with inter-node collectives for PCIe bandwidth.

If you apply full activation checkpointing to each transformer block and then further apply activation offloading to the transformer block input, then you can accumulate no extra GPU memory per transformer block, which can help unblock long-sequence use cases.

There probably needs to be some extra work on the PR for that though.

@XilunWu
Copy link
Contributor

XilunWu commented Nov 19, 2024

@medivh-xp I think the general logic is:

  1. try larger context parallel degree possible to see if it unblocks your long sequence use case. 128k works fine with llama3-8B model on H100 with 8 GPUs (dp_shard_degree=2 and context_parallel_degree=4). I haven't tested on llama3-70B model but you can easily try it out by changing the context_parallel_degree to a larger number and see if it works or not within 20 steps.
  2. if not, you can try activation checkpointing to see if this helps.

@XilunWu
Copy link
Contributor

XilunWu commented Nov 20, 2024

I just realize that we have a bug in torchtitan if you want to use CP without combining DP. The consequence would be high memory usage and maybe diverging loss.

#685 is the fix cc @fegin @tianyu-l

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants