diff --git a/megatron/logging.py b/megatron/logging.py index 56c124809..89a4bac80 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -79,7 +79,12 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def get_flops(neox_args, model, iter_time_s) -> float: +def get_flops(neox_args, iter_time_s) -> float: + """ + Use FLOPS calculation from Megatron-DeepSpeed: + https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 + They get it from https://arxiv.org/pdf/2104.04473.pdf + """ world_size = torch.distributed.get_world_size() vocab_size = neox_args.padded_vocab_size batch_size = neox_args.train_batch_size @@ -323,9 +328,7 @@ def add_to_logging(name): ) # log tflop / gpu - flops_per_s_per_gpu = get_flops( - neox_args=neox_args, model=model, iter_time_s=iteration_time - ) + flops_per_s_per_gpu = get_flops(neox_args, iteration_time) log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" )