Skip to content

Commit

Permalink
feat(common.py): update flops func
Browse files Browse the repository at this point in the history
  • Loading branch information
BingyangWu committed Jan 2, 2025
1 parent 526503b commit 6826338
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions internlm/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,25 +212,18 @@ def get_megatron_flops(
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
"""

checkpoint_activations_factor = 4 if checkpoint else 3
if checkpoint:
attn_checkpoint_activation_factor = 3 if selective_checkpoint else 4
else:
attn_checkpoint_activation_factor = 3

if use_swiglu:
mlp_ratio = mlp_ratio * 3 / 2
checkpoint_activations_factor = 3
attn_checkpoint_activation_factor = 3

flops_per_iteration = (
# wqkv wo mlp
(checkpoint_activations_factor * ((8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2))
(checkpoint_activations_factor * ((8 + mlp_ratio * 6) * global_batch_size * seq_len * hidden_size**2))
* num_layers
# attn
+ attn_checkpoint_activation_factor * (4 * global_batch_size * seq_len**2 * hidden_size) * num_layers
+ attn_checkpoint_activation_factor * (4 * global_batch_size * seq_len**2 * hidden_size) * num_layers / 2
# head
+ 6 * global_batch_size * seq_len * hidden_size * vocab_size
)

tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops

Expand Down

0 comments on commit 6826338

Please sign in to comment.