Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

打开Smart schedule运行examples/transformer-xl/scripts/run_enwik8_base_moe.sh 报错 #207

Open
WhatBrain opened this issue Aug 30, 2024 · 6 comments

Comments

@WhatBrain
Copy link

Describe the bug
A clear and concise description of what the bug is.
When I use export FMOE_FASTER_SHADOW_ENABLE=1 and export FMOE_FASTER_SCHEDULE_ENABLE=1 to turn on Smart schedule,and bash examples/transformer-xl/scripts/run_enwik8_base_moe.sh train, it reports an error:
Original Traceback (most recent call last): File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 801, in forward hidden, new_mems = self._forward(data, mems=mems) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 732, in _forward core_out = layer(core_out, pos_emb, self.r_w_bias, File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 481, in forward output = self.pos_ff(output) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 403, in forward core_out = super().forward(inp) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/transformer.py", line 65, in forward output = super().forward(inp) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/layers.py", line 251, in forward fwd = _fmoe_general_global_forward( File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/fastermoe/schedule.py", line 136, in _fmoe_general_global_forward stored_models = policy_fn(local_expert_count, global_expert_count, File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/fastermoe/shadow_policy.py", line 27, in global_policy dist.all_gather(agecs, local_expert_count, group=moe_group) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1862, in all_gather default_pg = _get_default_group() File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 347, in _get_default_group raise RuntimeError("Default process group has not been initialized, " RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Do I need to add init_process_group to the code in the examples
To Reproduce

Expected behavior

Logs

Platform

  • Device: [NVIDIA A100]
  • CUDA version: [ 11.1]
  • NCCL version: [2.7.8]
  • PyTorch version: [1.8.0]

Additional context
Add any other context about the problem here.

@laekov
Copy link
Owner

laekov commented Sep 6, 2024

transformer-xl 是一个单进程 moe 训练的例子, 不支持 faster features.

@WhatBrain
Copy link
Author

非常感谢您的回答

@WhatBrain
Copy link
Author

我还有一个问题,当我使用fastergate并开启 faster features.时,expert就会变得非常不均衡
其中model使用的是原码中的MemTransformerLM,使用单机8卡
torchrun --nnodes=1 --nproc-per-node=8 ../train.py
使用了init_process_group
torch.distributed.init_process_group(backend="nccl")
使用了DistributedGroupedDataParallel
para_model = fmoeDDP(model, dim=1)
测试了第100个step第0层的64个专家处理token数依次为:
[ 77594 102442 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 2 11 13 0 0 0 0 6 0 0 0 0 0 0 0 0]

@laekov
Copy link
Owner

laekov commented Sep 12, 2024

faster gate 有个 balance loss, 有没有可能是接进最终 loss 的时候出了某些偏差?

@WhatBrain
Copy link
Author

FasterGate类中的代码
c_e = torch.scatter_add( torch.zeros(self.tot_expert, device=top1_idx.device), 0, top1_idx, torch.ones_like(top1_idx, dtype=torch.float), ) / S
self.tot_expert是每层的expert总数吗,还是expert总数乘以worldsize,比如说设--moe-num-expert 64,worldsize为8,那么self.tot_expert是512吗

@laekov
Copy link
Owner

laekov commented Sep 29, 2024

self.tot_expert = world_size * num_expert see here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants