diff --git a/megatron/logging.py b/megatron/logging.py index 174ece5f0..3a40864b5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -79,17 +79,31 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def get_flops(neox_args, model, iter_time_s) -> float: +def get_flops(neox_args, iter_time_s) -> float: + """ + Use FLOPS calculation from Megatron-DeepSpeed: + https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 + They get it from https://arxiv.org/pdf/2104.04473.pdf + """ world_size = torch.distributed.get_world_size() - ff = model.total_params * 6 - attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60 - flops = ( - neox_args.train_batch_size - * neox_args.seq_length - * (ff + attn) - / (iter_time_s * world_size) + vocab_size = neox_args.padded_vocab_size + batch_size = neox_args.train_batch_size + seq_len = neox_args.seq_length + hidden_size = neox_args.hidden_size + num_layers = neox_args.num_layers + ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3 + flops_calc1 = ( + 24 + * ckpt_activations_factor + * batch_size + * seq_len + * num_layers + * (hidden_size**2) + * (1.0 + (seq_len / (6.0 * hidden_size))) ) - return flops + flops_calc2 = vocab_size / (16.0 * num_layers * hidden_size) + flops_per_iteration = flops_calc1 + flops_calc2 + return flops_per_iteration / (iter_time_s * world_size) def training_log( @@ -314,9 +328,8 @@ def add_to_logging(name): ) # log tflop / gpu - flops_per_s_per_gpu = get_flops( - neox_args=neox_args, model=model, iter_time_s=iteration_time - ) + flops_per_s_per_gpu = get_flops(neox_args, iteration_time) + log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" )