-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No overlapping observed when enabling Smart Scheduling #168
Comments
I will check it out. However, it looks like you missed to set the FMOE_FASTER_GROUP_SIZE variable. |
This issue is found to be caused by using default cuda stream which synchronizes all other streams. Simply using another stream in smgr for nccl can solve the problem. Credits to @Harry-Chen for finding the point. Looking forward to a pull request. |
Hi @chenyu-jiang , I finally found some bugs. I've fixed them in this branch; maybe you can retrace your program on it? |
Hi @zms1999, extremely sorry for the (very) delayed response.. After the fix, now I can see overlapping in the example program. Thanks a lot for the fix! It is tremendously helpful. |
Sorry for bothering again, but I am still running into problems when running the above example code with SwitchGate (i.e., add The error message is:
While if the code is run with |
I guess that you're right because I've already fixed some synchronization bugs. There could be more, I will check next week. |
Describe the bug
I am trying to create a minimal run-able example of Smart Scheduling proposed by the FasterMoE paper. However, when I profile the example using Nsight Systems, it seems that there is no overlapping between the all-to-all communication and expert computation.
Example of the profile result (one of the forward passes):
By looking at the CUDA API stack trace, it seems that it is indeed running the smart schedule code path:
The code I used can be found below. Could you let me know if this is caused by my misusing FastMoE or other issues? Thanks.
To Reproduce
The test is done on 2 nodes, each with 8 V100 GPUs.
The code I used for the tests: (
example.py
)Steps to reproduce the behavior:
pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
, install FastMoE.FMOE_FASTER_SCHEDULE_ENABLE=1 torchrun --nnodes=2 --nproc-per-node=8 --rdzv-id=0 --rdzv-backend=c10d --rdzv-endpoint=xxx.xxx.xx.xx example.py
Expected behavior
Overlapping expert computation and all-to-all.
Logs
N/A
Platform
Additional context
N/A
The text was updated successfully, but these errors were encountered: