Skip to content

Commit

Permalink
mamba flop calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef authored and Quentin-Anthony committed Sep 23, 2024
1 parent 62c9738 commit e0c6d32
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from megatron import mpu, print_rank_0
from megatron.utils import report_memory
import math


class Tee:
Expand Down Expand Up @@ -106,6 +107,23 @@ def get_flops(neox_args, iter_time_s) -> float:
+ 18 * hidden_size * hidden_size * num_layers / num_heads
)
)
elif "mamba" in neox_args.attention_config:
if neox_args.expansion_factor:
d_inner = neox_args.hidden_size * neox_args.expansion_factor
elif neox_args.intermediate_size:
d_inner = neox_args.intermediate_size
else:
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # d_state is hard coded
conv_dimension = 4 # conv_dimension is hard coded
dt_rank = math.ceil(neox_args.hidden_size / 16)
ssm_flops = ckpt_activations_factor * d_inner * seq_len * batch_size * (11 * d_state + 4 * dt_rank + 1)
mamba_projectors_flops = ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size
mamba_conv_flops = ckpt_activations_factor * seq_len * batch_size * 2 * d_inner * conv_dimension
mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops
flops_per_iteration = mamba_flops
embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size
flops_per_iteration = mamba_flops * num_layers + embedding_flops
else:
flops_per_iteration = (
24
Expand Down

0 comments on commit e0c6d32

Please sign in to comment.