Skip to content

Commit

Permalink
Improve FLOPS Calculation (#1044)
Browse files Browse the repository at this point in the history
* Use Megatron-DeepSpeed flops calculation

Signed-off-by: Dashiell Stander <[email protected]>

* Use Megatron-DeepSpeed flops calculation

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* Direct comparison of FLOPS calculations

Signed-off-by: Dashiell Stander <[email protected]>

* Remove test logging

Signed-off-by: Dashiell Stander <[email protected]>

---------

Signed-off-by: Dashiell Stander <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2023
1 parent ba51ca0 commit 5f36401
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,31 @@ def human_readable_flops(num) -> str:
return "%.1f%s" % (num, "Yi")


def get_flops(neox_args, model, iter_time_s) -> float:
def get_flops(neox_args, iter_time_s) -> float:
"""
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
"""
world_size = torch.distributed.get_world_size()
ff = model.total_params * 6
attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60
flops = (
neox_args.train_batch_size
* neox_args.seq_length
* (ff + attn)
/ (iter_time_s * world_size)
vocab_size = neox_args.padded_vocab_size
batch_size = neox_args.train_batch_size
seq_len = neox_args.seq_length
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3
flops_calc1 = (
24
* ckpt_activations_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (1.0 + (seq_len / (6.0 * hidden_size)))
)
return flops
flops_calc2 = vocab_size / (16.0 * num_layers * hidden_size)
flops_per_iteration = flops_calc1 + flops_calc2
return flops_per_iteration / (iter_time_s * world_size)


def training_log(
Expand Down Expand Up @@ -314,9 +328,8 @@ def add_to_logging(name):
)

# log tflop / gpu
flops_per_s_per_gpu = get_flops(
neox_args=neox_args, model=model, iter_time_s=iteration_time
)
flops_per_s_per_gpu = get_flops(neox_args, iteration_time)

log_string += (
f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |"
)
Expand Down

0 comments on commit 5f36401

Please sign in to comment.