From 682633802062602737caf2783969f7c6c59c520f Mon Sep 17 00:00:00 2001 From: BingyangWu Date: Thu, 2 Jan 2025 11:56:38 +0800 Subject: [PATCH] feat(common.py): update flops func --- internlm/utils/common.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 1ecb7ea0..08fd09b6 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -212,25 +212,18 @@ def get_megatron_flops( Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf """ - checkpoint_activations_factor = 4 if checkpoint else 3 - if checkpoint: - attn_checkpoint_activation_factor = 3 if selective_checkpoint else 4 - else: - attn_checkpoint_activation_factor = 3 - - if use_swiglu: - mlp_ratio = mlp_ratio * 3 / 2 + checkpoint_activations_factor = 3 + attn_checkpoint_activation_factor = 3 flops_per_iteration = ( # wqkv wo mlp - (checkpoint_activations_factor * ((8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2)) + (checkpoint_activations_factor * ((8 + mlp_ratio * 6) * global_batch_size * seq_len * hidden_size**2)) * num_layers # attn - + attn_checkpoint_activation_factor * (4 * global_batch_size * seq_len**2 * hidden_size) * num_layers + + attn_checkpoint_activation_factor * (4 * global_batch_size * seq_len**2 * hidden_size) * num_layers / 2 # head + 6 * global_batch_size * seq_len * hidden_size * vocab_size ) - tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12)) return tflops