From 4ad542a47793f7daca4ce7217c8230510fcc9bc5 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Mon, 23 Sep 2024 22:46:06 +0000 Subject: [PATCH] beef up comments and remove useless line --- megatron/logging.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 4942cfb04..af8a41fe5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -108,20 +108,35 @@ def get_flops(neox_args, iter_time_s) -> float: ) ) elif "mamba" in neox_args.attention_config: + # from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py if neox_args.expansion_factor: d_inner = neox_args.hidden_size * neox_args.expansion_factor elif neox_args.intermediate_size: d_inner = neox_args.intermediate_size else: - d_inner = neox_args.hidden_size * 2 # default expansion factor - d_state = 16 # d_state is hard coded - conv_dimension = 4 # conv_dimension is hard coded + d_inner = neox_args.hidden_size * 2 # default expansion factor + d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here + conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here dt_rank = math.ceil(neox_args.hidden_size / 16) - ssm_flops = ckpt_activations_factor * d_inner * seq_len * batch_size * (11 * d_state + 4 * dt_rank + 1) - mamba_projectors_flops = ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size - mamba_conv_flops = ckpt_activations_factor * seq_len * batch_size * 2 * d_inner * conv_dimension + ssm_flops = ( + ckpt_activations_factor + * d_inner + * seq_len + * batch_size + * (11 * d_state + 4 * dt_rank + 1) + ) + mamba_projectors_flops = ( + ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size + ) + mamba_conv_flops = ( + ckpt_activations_factor + * seq_len + * batch_size + * 2 + * d_inner + * conv_dimension + ) mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops - flops_per_iteration = mamba_flops embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size flops_per_iteration = mamba_flops * num_layers + embedding_flops else: