Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FLOPS Calculation #1044

Merged
merged 8 commits into from
Sep 27, 2023
37 changes: 25 additions & 12 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,31 @@ def human_readable_flops(num) -> str:
return "%.1f%s" % (num, "Yi")


def get_flops(neox_args, model, iter_time_s) -> float:
def get_flops(neox_args, iter_time_s) -> float:
"""
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
"""
world_size = torch.distributed.get_world_size()
ff = model.total_params * 6
attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60
flops = (
neox_args.train_batch_size
* neox_args.seq_length
* (ff + attn)
/ (iter_time_s * world_size)
vocab_size = neox_args.padded_vocab_size
batch_size = neox_args.train_batch_size
seq_len = neox_args.seq_length
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3
flops_calc1 = (
24
* ckpt_activations_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (1.0 + (seq_len / (6.0 * hidden_size)))
)
return flops
flops_calc2 = vocab_size / (16.0 * num_layers * hidden_size)
flops_per_iteration = flops_calc1 + flops_calc2
return flops_per_iteration / (iter_time_s * world_size)


def training_log(
Expand Down Expand Up @@ -314,9 +328,8 @@ def add_to_logging(name):
)

# log tflop / gpu
flops_per_s_per_gpu = get_flops(
neox_args=neox_args, model=model, iter_time_s=iteration_time
)
flops_per_s_per_gpu = get_flops(neox_args, iteration_time)

log_string += (
f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |"
)
Expand Down
Loading