Skip to content

Commit

Permalink
fix bug for non first microbatch in te
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Feb 21, 2024
1 parent 0a28e0f commit eadca11
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion msamp/te/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def set_fp8_weights(self):
weight_cast_attr = f'weight{i}_fp8'
weight_transpose_attr = f'weight{i}_t_fp8'

if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr).shape == shape):
if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr)._data.shape == shape):
return

setattr(
Expand Down
5 changes: 4 additions & 1 deletion tests/te/test_te_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,8 @@ def test_fp8_ddp_with_te(self):
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(x, attention_mask=None)
output = model(x, attention_mask=None, is_first_microbatch=True)
output.sum().backward()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(x, attention_mask=None, is_first_microbatch=False)
output.sum().backward()

0 comments on commit eadca11

Please sign in to comment.