From ace63a4b54c6679099d43884708fafac3f6f0941 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 25 Sep 2023 16:28:11 -0400 Subject: [PATCH] Use Megatron-DeepSpeed flops calculation Signed-off-by: Dashiell Stander --- megatron/logging.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 56c124809..89a4bac80 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -79,7 +79,12 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def get_flops(neox_args, model, iter_time_s) -> float: +def get_flops(neox_args, iter_time_s) -> float: + """ + Use FLOPS calculation from Megatron-DeepSpeed: + https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 + They get it from https://arxiv.org/pdf/2104.04473.pdf + """ world_size = torch.distributed.get_world_size() vocab_size = neox_args.padded_vocab_size batch_size = neox_args.train_batch_size @@ -323,9 +328,7 @@ def add_to_logging(name): ) # log tflop / gpu - flops_per_s_per_gpu = get_flops( - neox_args=neox_args, model=model, iter_time_s=iteration_time - ) + flops_per_s_per_gpu = get_flops(neox_args, iteration_time) log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" )