diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index d9f1cd2..32866c4 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -36,6 +36,10 @@ def config_parser(): type=float, default=1.0, help='Ratio of kv heads to query heads used in model. 1.0 for MHA') + parser.add_argument("--ffn-expansion-factor", "-ff", + type=int, + default=4, + help='How much the MLP hidden size expands') parser.add_argument("--moe", action="store_true", help='Whether our model is MoE') @@ -102,9 +106,9 @@ def calc_flops(args): attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size attention_over_values_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size linear_projection_flops = iter_factor * 2 * args.num_layers * args.tokens * args.hidden_size * args.hidden_size - ffn_flops = iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor * args.num_layers * args.tokens * args.hidden_size * args.hidden_size + ffn_flops = int(iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size if args.swiglu: - ffn_flops = 3/2 * ffn_flops + ffn_flops = int(3/2 * ffn_flops) # no activation checkpointing for embeddings embedding_flops = 6 * args.tokens * args.hidden_size * args.vocab_size