-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unified Checkpoint] Support expert parallel #9055
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9055 +/- ##
===========================================
- Coverage 53.29% 53.24% -0.06%
===========================================
Files 652 652
Lines 105483 105599 +116
===========================================
+ Hits 56222 56225 +3
- Misses 49261 49374 +113 ☔ View full report in Codecov by Sentry. |
f5b02c3
to
cc5463c
Compare
…nto support_ep
ready for review? |
yes, come!!! |
try: | ||
from paddle.base import core | ||
except: | ||
core = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是为了 预测版本 paddle?
expected_keys = set() | ||
for key in model_state_dict.keys(): | ||
if getattr(model_state_dict[key], "no_sync", False): | ||
expected_keys.add(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非no_sync的QKV参数,是通过额外的广播加载? 模型,opt,master weight各是怎样?
@@ -1982,16 +2083,58 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): | |||
return index_file_list, total_size_list | |||
|
|||
|
|||
def rename_shard_file(args, shard_file, file_name): | |||
"""rename shard file when using expert_parallel.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释,解释一下?这里是还是按照原来的方式命名,修改一下? 为什么不是moe从新写一个?
def generate_base_static_name(vname): | ||
# return base static name and specific type name, like [embedding_0.w_0, moment1_0] | ||
if FP32_MASTER in vname: | ||
vname = vname.split("_" + FP32_MASTER + "_") | ||
return vname[0], vname[1] | ||
else: | ||
vname = vname.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原来的代码,我看不太明白了。等价吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
embedding_0.w_0.moment1_0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moe_gate_1_moment1_0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* [Unified Checkpoint] Support expert parallel (#9055) * update code * [Unified Checkpoint] Fix generation config save (#9223) * [Unified Checkpoint] update async_save_info in develop (#9173) * [Unified Checkpoint] update async save logic (#9274) * update async save signal * fix async save hang * bug fix --------- Co-authored-by: Weiguo Zhu <[email protected]>
* [Unified Checkpoint] Support expert parallel (#9055) * update code * [Unified Checkpoint] Fix generation config save (#9223) * [Unified Checkpoint] update async_save_info in develop (#9173) * [Unified Checkpoint] update async save logic (#9274) * update async save signal * fix async save hang * bug fix * bug fix * [Trainer] fix save_model (#9286) * bug fix * bug fix --------- Co-authored-by: Weiguo Zhu <[email protected]>
PR types
New features
PR changes
Others
Description
Support expert parallel for unified checkpoint.