From 1e343a327c102764a458386329539d0b8ec93c06 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 25 Sep 2023 16:24:50 -0400 Subject: [PATCH 1/6] Use Megatron-DeepSpeed flops calculation Signed-off-by: Dashiell Stander --- megatron/logging.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 174ece5f0..56c124809 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -81,15 +81,24 @@ def human_readable_flops(num) -> str: def get_flops(neox_args, model, iter_time_s) -> float: world_size = torch.distributed.get_world_size() - ff = model.total_params * 6 - attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60 - flops = ( - neox_args.train_batch_size - * neox_args.seq_length - * (ff + attn) - / (iter_time_s * world_size) + vocab_size = neox_args.padded_vocab_size + batch_size = neox_args.train_batch_size + seq_len = neox_args.seq_length + hidden_size = neox_args.hidden_size + num_layers = neox_args.num_layers + ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3 + flops_calc1 = ( + 24 + * ckpt_activations_factor + * batch_size + * seq_len + * num_layers + * (hidden_size**2) + * (1.0 + (seq_len / (6.0 * hidden_size))) ) - return flops + flops_calc2 = vocab_size / (16.0 * num_layers * hidden_size) + flops_per_iteration = flops_calc1 + flops_calc2 + return flops_per_iteration / (iter_time_s * world_size) def training_log( From ace63a4b54c6679099d43884708fafac3f6f0941 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Mon, 25 Sep 2023 16:28:11 -0400 Subject: [PATCH 2/6] Use Megatron-DeepSpeed flops calculation Signed-off-by: Dashiell Stander --- megatron/logging.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 56c124809..89a4bac80 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -79,7 +79,12 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def get_flops(neox_args, model, iter_time_s) -> float: +def get_flops(neox_args, iter_time_s) -> float: + """ + Use FLOPS calculation from Megatron-DeepSpeed: + https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 + They get it from https://arxiv.org/pdf/2104.04473.pdf + """ world_size = torch.distributed.get_world_size() vocab_size = neox_args.padded_vocab_size batch_size = neox_args.train_batch_size @@ -323,9 +328,7 @@ def add_to_logging(name): ) # log tflop / gpu - flops_per_s_per_gpu = get_flops( - neox_args=neox_args, model=model, iter_time_s=iteration_time - ) + flops_per_s_per_gpu = get_flops(neox_args, iteration_time) log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" ) From d3ce33e2f1269ab8d6111b92c52e27dd7c73e54d Mon Sep 17 00:00:00 2001 From: github-actions Date: Mon, 25 Sep 2023 20:41:53 +0000 Subject: [PATCH 3/6] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 26fc841a1..0d93a732d 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 1f832c1 + Default = ace63a4 current git hash of repository From 9012e174d4486960931c434b1bfe1cde455ef469 Mon Sep 17 00:00:00 2001 From: github-actions Date: Mon, 25 Sep 2023 23:15:42 +0000 Subject: [PATCH 4/6] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index efb828799..9e522b9b3 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 3f8c63c + Default = db768fb current git hash of repository From 8feced0afa22d34187d47b1c6da0089a60feb3ec Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Wed, 27 Sep 2023 11:17:55 -0400 Subject: [PATCH 5/6] Direct comparison of FLOPS calculations Signed-off-by: Dashiell Stander --- megatron/logging.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/megatron/logging.py b/megatron/logging.py index 89a4bac80..927ee982b 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -79,6 +79,19 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") +def old_flops(neox_args, model, iter_time_s) -> float: + world_size = torch.distributed.get_world_size() + ff = model.total_params * 6 + attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60 + flops = ( + neox_args.train_batch_size + * neox_args.seq_length + * (ff + attn) + / (iter_time_s * world_size) + ) + return flops + + def get_flops(neox_args, iter_time_s) -> float: """ Use FLOPS calculation from Megatron-DeepSpeed: @@ -106,6 +119,9 @@ def get_flops(neox_args, iter_time_s) -> float: return flops_per_iteration / (iter_time_s * world_size) +import json + + def training_log( neox_args, timers, @@ -329,6 +345,13 @@ def add_to_logging(name): # log tflop / gpu flops_per_s_per_gpu = get_flops(neox_args, iteration_time) + flops_old_calc = old_flops(neox_args, model, iteration_time) + with open("flops_calc_comparison.json", mode="a") as jfile: + data = { + "new": human_readable_flops(flops_per_s_per_gpu), + "old": human_readable_flops(flops_old_calc), + } + json.dump(data, jfile) log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" ) From 73c3168bf08c39afc30bc054176305a0693a0979 Mon Sep 17 00:00:00 2001 From: Dashiell Stander Date: Wed, 27 Sep 2023 15:10:54 -0400 Subject: [PATCH 6/6] Remove test logging Signed-off-by: Dashiell Stander --- megatron/logging.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index 927ee982b..3a40864b5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -79,19 +79,6 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def old_flops(neox_args, model, iter_time_s) -> float: - world_size = torch.distributed.get_world_size() - ff = model.total_params * 6 - attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60 - flops = ( - neox_args.train_batch_size - * neox_args.seq_length - * (ff + attn) - / (iter_time_s * world_size) - ) - return flops - - def get_flops(neox_args, iter_time_s) -> float: """ Use FLOPS calculation from Megatron-DeepSpeed: @@ -119,9 +106,6 @@ def get_flops(neox_args, iter_time_s) -> float: return flops_per_iteration / (iter_time_s * world_size) -import json - - def training_log( neox_args, timers, @@ -345,13 +329,7 @@ def add_to_logging(name): # log tflop / gpu flops_per_s_per_gpu = get_flops(neox_args, iteration_time) - flops_old_calc = old_flops(neox_args, model, iteration_time) - with open("flops_calc_comparison.json", mode="a") as jfile: - data = { - "new": human_readable_flops(flops_per_s_per_gpu), - "old": human_readable_flops(flops_old_calc), - } - json.dump(data, jfile) + log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" )