From 62c9738aad50cfbcd4c70b52d4e78746f9abfb47 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sun, 22 Sep 2024 14:53:18 -0700 Subject: [PATCH 1/4] fix typo from https://github.com/EleutherAI/gpt-neox/pull/1244/files#diff-383134de6f3512484e20625419bd5fb6b1675a922f47aeb1a6bd3cff6185a754R126 (#1290) --- megatron/data/gpt2_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index e37c558d2..c4729cc3e 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -123,7 +123,7 @@ def __getitem__(self, idx): samples.append( dataset.get( self.doc_idx[doc_index_f], - offset=offset_l, + offset=offset_f, length=offset_l - offset_f + 1, ) ) From 4765384414d919d06250356f5d1d41b0b04d7446 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Mon, 23 Sep 2024 15:33:19 -0700 Subject: [PATCH 2/4] update args docs (#1293) * update args docs * undo pre-commit change --- configs/neox_arguments.md | 339 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 326 insertions(+), 13 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index d24b2b60a..698e28697 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 217b4c5 + Default = 62c9738a current git hash of repository @@ -133,6 +133,54 @@ Logging Arguments +- **use_comet**: bool + + Default = None + + Flag indicating if comet is to be used. + + + +- **comet_workspace**: Optional + + Default = None + + Comet workspace name, if not configured Comet Experiments will be created in the user configured default workspace. + + + +- **comet_project**: Optional + + Default = None + + Comet project name, if not configured Comet Experiments will be created in the Uncategorized Experiments project. + + + +- **comet_experiment_name**: Optional + + Default = None + + Custom name for the Comet experiment. If not provided, a random name is used. + + + +- **comet_tags**: Optional + + Default = None + + List of tags to attach to the created Comet Experiment. + + + +- **comet_others**: Optional + + Default = None + + Custom metadata to attach to the created Comet Experiment. + + + - **log_interval**: int Default = 100 @@ -281,9 +329,23 @@ Model Arguments Default = None - Transformer intermediate size. Currently only used for "mlp_type": "llama". + Transformer intermediate size. Default = 4h + + + +- **mlp_multiple_of**: int + + Default = 1 + + force mlp size to be a multiple of this value + + + +- **expansion_factor**: float - If not passed, will be set to a reasonable default. + Default = None + + Transformer intermediate size. Default = 4 @@ -351,6 +413,14 @@ Model Arguments +- **rmsnorm_fusion**: bool + + Default = False + + Use fused RMS norm kernel (if `norm` is `rmsnorm`). + + + - **use_qk_layernorm**: bool Default = False @@ -497,11 +567,19 @@ Model Arguments -- **activation**: typing.Literal['gelu', 'geglu', 'relu', 'softsign', 'swish', 'mish', 'silu'] +- **activation**: typing.Literal['gelu', 'geglu', 'relu', 'softsign', 'swish', 'mish', 'silu', 'reglu', 'swiglu', 'bilinear', 'glu'] Default = gelu - Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"] + Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] + + + +- **use_flashattn_swiglu**: bool + + Default = False + + Use flash attention's version of swiglu @@ -681,13 +759,11 @@ Model Arguments -- **mlp_type**: str +- **use_bias_in_mlp**: bool - Default = regular + Default = True - Types: - regular: Megatron implementation - llama: LLaMA MLP (SiLU-gated MLP) + If false, mlps will not have bias terms @@ -1091,7 +1167,15 @@ Text Generation arguments Default = None How to generate text/sample the model. - Options: `unconditional`, `input-file`, `interactive` + Options: `unconditional`, `input-file`, `interactive`, `precompute` + + + +- **precompute_model_name**: str + + Default = None + + Model name to use for saving precomputed logprobs @@ -1378,11 +1462,19 @@ Training Arguments -- **label_data_paths**: list +- **train_label_data_paths**: list Default = None - List of paths to label datasets (not shifted by 1 yet!). + List of paths to train label datasets (not shifted by 1 yet!). + + + +- **train_reward_data_paths**: list + + Default = None + + List of paths to train reward datasets @@ -1394,6 +1486,22 @@ Training Arguments +- **test_label_data_paths**: list + + Default = None + + List of paths to test label datasets (not shifted by 1 yet!). + + + +- **test_reward_data_paths**: list + + Default = None + + List of paths to test reward datasets + + + - **valid_data_paths**: list Default = None @@ -1402,6 +1510,118 @@ Training Arguments +- **valid_label_data_paths**: list + + Default = None + + List of paths to validation label datasets (not shifted by 1 yet!). + + + +- **valid_reward_data_paths**: list + + Default = None + + List of paths to validation reward datasets + + + +- **pos_train_data_paths**: list + + Default = None + + + + + +- **neg_train_data_paths**: list + + Default = None + + List of paths to positive and negative training datasets. + + + +- **pos_train_label_data_paths**: list + + Default = None + + + + + +- **neg_train_label_data_paths**: list + + Default = None + + List of paths to positive and negative training label datasets (not shifted by 1 yet!). + + + +- **pos_valid_data_paths**: list + + Default = None + + + + + +- **neg_valid_data_paths**: list + + Default = None + + List of paths to positive and negative validation datasets. + + + +- **pos_valid_label_data_paths**: list + + Default = None + + + + + +- **neg_valid_label_data_paths**: list + + Default = None + + List of paths to positive and negative validation label datasets (not shifted by 1 yet!). + + + +- **pos_test_data_paths**: list + + Default = None + + + + + +- **neg_test_data_paths**: list + + Default = None + + List of paths to positive and negative test datasets. + + + +- **pos_test_label_data_paths**: list + + Default = None + + + + + +- **neg_test_label_data_paths**: list + + Default = None + + List of paths to positive and negative test label datasets (not shifted by 1 yet!). + + + - **train_data_weights**: list Default = None @@ -1469,6 +1689,99 @@ Training Arguments +- **pack_impl**: typing.Literal['packed', 'pack_until_overflow', 'unpacked'] + + Default = packed + + Packing implementation, can be one of "packed", "pack_until_overflow", or "unpacked". + + warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets + + + +- **dataset_impl**: typing.Literal['gpt2', 'pairwise'] + + Default = gpt2 + + Dataset implementation, can be one of "gpt2" or "pairwise" + + + +- **train_impl**: typing.Literal['normal', 'dpo', 'rm', 'kto'] + + Default = normal + + Training implementation, can be one of "normal", "dpo", "kto", or "rm" + + + +- **dpo_fp32**: bool + + Default = True + + Whether to cast logits to fp32 for DPO loss calculation. + + + +- **dpo_reference_free**: bool + + Default = False + + Whether to use reference-free DPO. + + + +- **dpo_beta**: float + + Default = 0.1 + + Beta value for DPO + + + +- **kto_fp32**: bool + + Default = True + + Whether to cast logits to fp32 for KTO loss calculation. + + + +- **kto_desirable_weight**: float + + Default = 1.0 + + Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + + + +- **kto_undesirable_weight**: float + + Default = 1.0 + + Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + + + +- **kto_beta**: float + + Default = 0.1 + + Beta value for KTO + + + +- **allow_chopped**: bool + + Default = True + + WARNING: if your packing impl is packed, this is ignored. + + Allow chopped samples in the dataset. + (e.g if your sequence length is 1024 and you have a sample of length 1026, it will be chopped to 1024) + + + - **mmap_warmup**: bool Default = False From 1bce90c1ac60f4bad6c45d630d9aa7b41591f358 Mon Sep 17 00:00:00 2001 From: Jacob Hatef <74274091+jahatef@users.noreply.github.com> Date: Mon, 23 Sep 2024 18:51:45 -0400 Subject: [PATCH 3/4] mamba flop calculations (#1291) * mamba flop calculations * mamba flop calculations * beef up comments and remove useless line * undo precommit change --------- Co-authored-by: Quentin Anthony --- megatron/logging.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/megatron/logging.py b/megatron/logging.py index 05945fdda..af8a41fe5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -23,6 +23,7 @@ from megatron import mpu, print_rank_0 from megatron.utils import report_memory +import math class Tee: @@ -106,6 +107,38 @@ def get_flops(neox_args, iter_time_s) -> float: + 18 * hidden_size * hidden_size * num_layers / num_heads ) ) + elif "mamba" in neox_args.attention_config: + # from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py + if neox_args.expansion_factor: + d_inner = neox_args.hidden_size * neox_args.expansion_factor + elif neox_args.intermediate_size: + d_inner = neox_args.intermediate_size + else: + d_inner = neox_args.hidden_size * 2 # default expansion factor + d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here + conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here + dt_rank = math.ceil(neox_args.hidden_size / 16) + ssm_flops = ( + ckpt_activations_factor + * d_inner + * seq_len + * batch_size + * (11 * d_state + 4 * dt_rank + 1) + ) + mamba_projectors_flops = ( + ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size + ) + mamba_conv_flops = ( + ckpt_activations_factor + * seq_len + * batch_size + * 2 + * d_inner + * conv_dimension + ) + mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops + embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size + flops_per_iteration = mamba_flops * num_layers + embedding_flops else: flops_per_iteration = ( 24 From f5d7ff9f0fcf90e710081adcf14b75e49233c4db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20Caillaut?= <2855911+gcaillaut@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:21:27 +0200 Subject: [PATCH 4/4] Do not fail when git is not installed (#1280) --- megatron/neox_arguments/neox_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 5194047d5..df7e51da6 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -46,7 +46,7 @@ def get_git_commit_hash(): try: git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() git_hash = git_hash.decode() - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_hash = None return git_hash