From b2c1b014401fbe5049357e2fd03188f809c6ca77 Mon Sep 17 00:00:00 2001 From: jahatef Date: Sun, 22 Sep 2024 19:49:38 -0400 Subject: [PATCH] mamba flop calculations --- megatron/logging.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/megatron/logging.py b/megatron/logging.py index 05945fdda..4942cfb04 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -23,6 +23,7 @@ from megatron import mpu, print_rank_0 from megatron.utils import report_memory +import math class Tee: @@ -106,6 +107,23 @@ def get_flops(neox_args, iter_time_s) -> float: + 18 * hidden_size * hidden_size * num_layers / num_heads ) ) + elif "mamba" in neox_args.attention_config: + if neox_args.expansion_factor: + d_inner = neox_args.hidden_size * neox_args.expansion_factor + elif neox_args.intermediate_size: + d_inner = neox_args.intermediate_size + else: + d_inner = neox_args.hidden_size * 2 # default expansion factor + d_state = 16 # d_state is hard coded + conv_dimension = 4 # conv_dimension is hard coded + dt_rank = math.ceil(neox_args.hidden_size / 16) + ssm_flops = ckpt_activations_factor * d_inner * seq_len * batch_size * (11 * d_state + 4 * dt_rank + 1) + mamba_projectors_flops = ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size + mamba_conv_flops = ckpt_activations_factor * seq_len * batch_size * 2 * d_inner * conv_dimension + mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops + flops_per_iteration = mamba_flops + embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size + flops_per_iteration = mamba_flops * num_layers + embedding_flops else: flops_per_iteration = ( 24