Skip to content

Commit

Permalink
beef up comments and remove useless line
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Sep 23, 2024
1 parent e0c6d32 commit 4ad542a
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,35 @@ def get_flops(neox_args, iter_time_s) -> float:
)
)
elif "mamba" in neox_args.attention_config:
# from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py
if neox_args.expansion_factor:
d_inner = neox_args.hidden_size * neox_args.expansion_factor
elif neox_args.intermediate_size:
d_inner = neox_args.intermediate_size
else:
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # d_state is hard coded
conv_dimension = 4 # conv_dimension is hard coded
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here
conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here
dt_rank = math.ceil(neox_args.hidden_size / 16)
ssm_flops = ckpt_activations_factor * d_inner * seq_len * batch_size * (11 * d_state + 4 * dt_rank + 1)
mamba_projectors_flops = ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size
mamba_conv_flops = ckpt_activations_factor * seq_len * batch_size * 2 * d_inner * conv_dimension
ssm_flops = (
ckpt_activations_factor
* d_inner
* seq_len
* batch_size
* (11 * d_state + 4 * dt_rank + 1)
)
mamba_projectors_flops = (
ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size
)
mamba_conv_flops = (
ckpt_activations_factor
* seq_len
* batch_size
* 2
* d_inner
* conv_dimension
)
mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops
flops_per_iteration = mamba_flops
embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size
flops_per_iteration = mamba_flops * num_layers + embedding_flops
else:
Expand Down

0 comments on commit 4ad542a

Please sign in to comment.