-
Notifications
You must be signed in to change notification settings - Fork 9.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) #1201
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass, I think we can do things a lot simpler. please see inline comments
# while for SP, input can be different across all ranks. | ||
# We will use dp_rank for setting the random seed | ||
# to mimic the behavior of the dataloader. | ||
dp_rank = dist.get_rank(dp_pg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, this needs to be consolidate to a device mesh API, cc @wz337
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better! wondering what's the reason to keep original.py and also some inline comments about imports, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Only have some nits for logging, thanks for addressing the comments!
This PR updates the three distributed examples for Tensor Parallel, Sequence Parallel and 2D with the following main changes:
(note - internal reference - task [T170073014] Rewrite TensorParalell/SequenceParallel Examples using our new UX)
1 - move to torchrun launching (see run_.sh files) and relevant world topology introspection in the setup instead of mp.spawn.
2 - move device mesh creation to new api, init_device_mesh
3 - use custom parallelization plans (ColwiseParallel and RowwiseParallel) rather than the previous prebuilt PairwiseParallel() and SequenceParallel()
4 - For the 2D example - used a more relevant swiglu MLP model to showcase applying 2D to a more sophisticated/llama style situation.
5 - Adds more interactive UI for the user (start, per iter, and completion feedback).