-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] fix_outputlayer.weight_distributed #9135
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
@@ -1126,7 +1125,10 @@ def forward(self, hidden_states, return_last_logit=False): | |||
if self.config.sequence_parallel: | |||
hidden_states = GatherOp.apply(hidden_states) | |||
hidden_states = paddle.reshape_(hidden_states, [self.config.seq_length, -1, self.config.hidden_size]) | |||
logits = parallel_matmul(hidden_states, self.decoder_weight, self.config.tensor_parallel_output) | |||
if self.config.tensor_parallel_degree > 1: | |||
logits = parallel_matmul(hidden_states, self.weight, self.config.tensor_parallel_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥要在这里parallel_matmul,我看只有非 tp 的情况才会走进这里的 forward
@@ -1238,9 +1240,9 @@ def forward( | |||
lm_logits = parallel_matmul( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥要手动掉用 parallel_matmul,内部的 forward 函数我看已经支持 tp > 1的情况,具体在哪一层面做 parallel_matmul还是统一点好
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9135 +/- ##
===========================================
- Coverage 53.29% 53.28% -0.02%
===========================================
Files 652 652
Lines 105483 105579 +96
===========================================
+ Hits 56222 56254 +32
- Misses 49261 49325 +64 ☔ View full report in Codecov by Sentry. |
PR types
Bug fixes
PR changes
models
Description
手动切分参数时, 需要设置is_distributed为True, distributed_model中广播参数时就会把这些切分的参数跳过
初始loss 11->1.63