Skip to content

Commit

Permalink
mamba flop calculations (#1291)
Browse files Browse the repository at this point in the history
* mamba flop calculations

* mamba flop calculations

* beef up comments and remove useless line

* undo precommit change

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
jahatef and Quentin-Anthony authored Sep 23, 2024
1 parent 4765384 commit 1bce90c
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from megatron import mpu, print_rank_0
from megatron.utils import report_memory
import math


class Tee:
Expand Down Expand Up @@ -106,6 +107,38 @@ def get_flops(neox_args, iter_time_s) -> float:
+ 18 * hidden_size * hidden_size * num_layers / num_heads
)
)
elif "mamba" in neox_args.attention_config:
# from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py
if neox_args.expansion_factor:
d_inner = neox_args.hidden_size * neox_args.expansion_factor
elif neox_args.intermediate_size:
d_inner = neox_args.intermediate_size
else:
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here
conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here
dt_rank = math.ceil(neox_args.hidden_size / 16)
ssm_flops = (
ckpt_activations_factor
* d_inner
* seq_len
* batch_size
* (11 * d_state + 4 * dt_rank + 1)
)
mamba_projectors_flops = (
ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size
)
mamba_conv_flops = (
ckpt_activations_factor
* seq_len
* batch_size
* 2
* d_inner
* conv_dimension
)
mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops
embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size
flops_per_iteration = mamba_flops * num_layers + embedding_flops
else:
flops_per_iteration = (
24
Expand Down

0 comments on commit 1bce90c

Please sign in to comment.