-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch distributed: add support for user-specified parameter synchronization #1612
base: master
Are you sure you want to change the base?
Conversation
5e029dd
to
532a4c4
Compare
532a4c4
to
e852396
Compare
What specific synchronization schemes do you have in mind? While this API looks flexible, I think it's actually not very flexible and has lots of implicit assumptions:
For most of the extensions I have in mind on distributed training, this API would not work. I'm also not sure there is a good way to design such API because there are basically infinite many ways how distributed training could be done, and we don't really know what we want to do yet, or what other people might want. So I think we should really only implement a new API for some of the parts where we exactly know that we need this right now. That brings me to my initial question: What specific synchronization schemes do you have in mind? What API would you need for exactly that synchronization scheme, that it would be possible to implement it in user space? Or: instead of trying to implement such flexible API, implement directly the specific method you have in mind? The flexible API might actually only really work well now for this specific method you have in mind, but not really work for anything else. In that case, we just have added complexity without real value. |
Thank you for your comment, that makes sense to me. I was thinking of a scheme where you do parameter averaging after different steps depending on whether you are averaging within one node or across nodes. But I think the comment, especially wrt. GPUs leaving and joining the computation makes a lot of sense why such an abstract API does not make that much sense, at least not until you can always update the local/global rank indices when you have GPUs joining and leaving the cluster. |
Ah yea, that's a good idea. Ok, this is very specific, so let's think about what API we need to allow for flexible user-defined schemes for this. I think a flexible way would be if the user just defines a custom def custom_step_after_param_update(*, module: torch.nn.Module, epoch_step_idx: int, **_kwargs):
...
torch_distributed = {
"reduce_type": "custom_step_after_param_update",
"custom_step_after_param_update": custom_step_after_param_update,
} That's all what is needed, right? If you need to know the global/local rank/size inside that custom step func, or any other environment information, that would be up to the user. E.g. the user can always do: from returnn.torch import distributed
rank = distributed.get_ctx().rank |
Hmm, when would you set up the sub process groups needed for synchronization? E.g. on the first invocation of the function? In the class approach this is quite easy because you can initialize the class (and any sub process groups) right after the global process group is initialized, the moment is just very defined. I'm not sure it's feasible to initialize them when the RETURNN config is parsed (i.e. by initializing a callable class like e.g. |
Yes that should work just fine, right? |
Some update: In the future, I want to implement sth similar as this: https://github.com/PrimeIntellect-ai/prime (or maybe just reuse the existing code there). Specifically, this includes This is just an example for what to keep in mind when making this more flexible here. |
To allow extending the existing param_avg and gradient_sync strategies w/ custom user-defined ones for easier experimentation.