From c1ea2a1ff1c062ed89ede27780cecf0122ae3f42 Mon Sep 17 00:00:00 2001 From: dmahan93 <44207705+dmahan93@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:14:27 -0500 Subject: [PATCH 1/4] Add hf llama to neox conversion (#1247) * - Add conversion of HF llama models to NeoX * - Add conversion of HF llama models to NeoX * - minor fix * pre-commit --------- Co-authored-by: Quentin Anthony --- tools/ckpts/README.md | 17 ++ tools/ckpts/convert_hf_llama_to_neox.py | 219 ++++++++++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 tools/ckpts/convert_hf_llama_to_neox.py diff --git a/tools/ckpts/README.md b/tools/ckpts/README.md index 24d5cf31c..770cfb9c6 100644 --- a/tools/ckpts/README.md +++ b/tools/ckpts/README.md @@ -131,3 +131,20 @@ options: --num_output_shards NUM_OUTPUT_SHARDS --pipeline_parallel Only use if PP>1 ``` + +### `convert_hf_llama_to_neox.py` +Takes an HF Llama checkpoint and puts it into a NeoX-compatible format. + +Note that this does not support pipeline parallelism! + +``` +usage: convert_hf_llama_to_neox.py [-h] [--tp TP] [--pp PP] [--model MODEL] [--model_path MODEL_PATH] + +options: + -h, --help show this help message and exit + --tp TP Number of tensor parallelism ranks + --pp PP Number of pipeline parallelism stages + --model MODEL HF model name + --model_path MODEL_PATH + Path to save model +``` diff --git a/tools/ckpts/convert_hf_llama_to_neox.py b/tools/ckpts/convert_hf_llama_to_neox.py new file mode 100644 index 000000000..2adddb19d --- /dev/null +++ b/tools/ckpts/convert_hf_llama_to_neox.py @@ -0,0 +1,219 @@ +import torch +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import os +import tqdm + + +def convert_model(hf_state_dict, hf_config, tp_ranks): + conv_state_dicts = [{} for _ in range(tp_ranks)] + # get embeddings... + for i, chunk in enumerate( + torch.chunk(hf_state_dict["model.embed_tokens.weight"], tp_ranks, dim=0) + ): + conv_state_dicts[i][ + "sequential.0.word_embeddings.weight" + ] = chunk.clone().detach() + print( + "model.embed_tokens.weight", + hf_state_dict["model.embed_tokens.weight"].shape, + "sequential.0.word_embeddings.weight", + conv_state_dicts[0]["sequential.0.word_embeddings.weight"].shape, + ) + # Get config data... + num_kv_heads = hf_config.num_key_value_heads + num_q_heads = hf_config.num_attention_heads + head_dim = hf_config.hidden_size // num_q_heads + # do layers... + for layer_num in tqdm.tqdm(range(model.model.config.num_hidden_layers)): + # --- attention --- + # Output first since it's a simple row parallel... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"], + tp_ranks, + dim=1, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.attention.dense.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.self_attn.o_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"].shape, + f"sequential.{layer_num+2}.attention.dense.weight", + conv_state_dicts[0][ + f"sequential.{layer_num+2}.attention.dense.weight" + ].shape, + ) + # Now for attention... + # Split into heads... + q = hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"] + k = hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"] + v = hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"] + # The GQA code splits the heads by the num_q_heads so we also do that + # here to ensure it matches... + q = q.view(num_q_heads, -1, q.shape[-1]) + k = k.view(num_q_heads, -1, q.shape[-1]) + v = v.view(num_q_heads, -1, q.shape[-1]) + # Chunk for tensor parallelism... + for i, q_chunk, k_chunk, v_chunk in zip( + range(tp_ranks), + torch.chunk(q, tp_ranks, dim=0), + torch.chunk(k, tp_ranks, dim=0), + torch.chunk(v, tp_ranks, dim=0), + ): + # Need to join the heads across q, k, v... + conv_state_dicts[i][ + f"sequential.{layer_num+2}.attention.query_key_value.weight" + ] = ( + torch.cat([q_chunk, k_chunk, v_chunk], dim=1) + .view(-1, q.shape[-1]) + .clone() + .detach() + ) + print( + f"model.layers.{layer_num}.self_attn.(q/k/v)_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"].shape, + hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"].shape, + hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"].shape, + f"sequential.{layer_num+2}.attention.query_key_value.weight", + conv_state_dicts[0][ + f"sequential.{layer_num+2}.attention.query_key_value.weight" + ].shape, + ) + # --- mlp --- + # Do SwiGLU weights... + # w1... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"], + tp_ranks, + dim=0, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.mlp.w1.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.mlp.gate_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"].shape, + f"sequential.{layer_num+2}.mlp.w1.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w1.weight"].shape, + ) + # w3... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"], + tp_ranks, + dim=0, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.mlp.w3.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.mlp.up_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"].shape, + f"sequential.{layer_num+2}.mlp.w3.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w3.weight"].shape, + ) + # w2 (output)... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"], + tp_ranks, + dim=1, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.mlp.w2.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.mlp.down_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"].shape, + f"sequential.{layer_num+2}.mlp.w2.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w2.weight"].shape, + ) + # --- norm --- + for i in range(tp_ranks): + conv_state_dicts[i][f"sequential.{layer_num+2}.input_layernorm.scale"] = ( + hf_state_dict[f"model.layers.{layer_num}.input_layernorm.weight"] + .clone() + .detach() + ) + conv_state_dicts[i][ + f"sequential.{layer_num+2}.post_attention_layernorm.scale" + ] = ( + hf_state_dict[ + f"model.layers.{layer_num}.post_attention_layernorm.weight" + ] + .clone() + .detach() + ) + + # Get final ln/linear.... + index = model.model.config.num_hidden_layers + 3 + for i in range(tp_ranks): + conv_state_dicts[i][f"sequential.{index}.norm.scale"] = ( + hf_state_dict["model.norm.weight"].clone().detach() + ) + index += 1 + # do output... + for i, chunk in enumerate( + torch.chunk(hf_state_dict["lm_head.weight"], tp_ranks, dim=0) + ): + conv_state_dicts[i][ + f"sequential.{index}.final_linear.weight" + ] = chunk.clone().detach() + print( + "lm_head.weight", + hf_state_dict["lm_head.weight"].shape, + f"sequential.{index}.final_linear.weight", + conv_state_dicts[0][f"sequential.{index}.final_linear.weight"].shape, + ) + return conv_state_dicts + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--tp", type=int, default=1, help="Number of tensor parallelism ranks" + ) + parser.add_argument( + "--pp", type=int, default=0, help="Number of pipeline parallelism stages" + ) + parser.add_argument("--model", type=str, default="gpt2", help="HF model name") + parser.add_argument( + "--model_path", type=str, default=None, help="Path to save model" + ) + args = parser.parse_args() + assert args.pp == 0, "Pipeline parallelism not supported yet" + tokenizer = AutoTokenizer.from_pretrained(args.model).save_pretrained( + args.model_path + "/tokenizer" + ) + model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto") + state_dict = model.state_dict() + for key in state_dict.keys(): + print(key, state_dict[key].shape) + os.makedirs(args.model_path, exist_ok=True) + # Setup model directory... + os.makedirs(f"{args.model_path}/0", exist_ok=True) + # Save the latest file so neox can figure out where to grab the weights... + with open(f"{args.model_path}/latest", "w") as f: + f.write("0") + # Convert the model... + tp_state_dicts = convert_model(state_dict, model.model.config, args.tp) + for i in range(args.tp): + torch.save( + { + "dp_world_size": 1, + "mp_world_size": args.tp, + "optimizer": {}, + "global_steps": 1, + "skipped_steps": 1, + "iteration": 1, + "module": tp_state_dicts[i], + }, + f"{args.model_path}/0/mp_rank_{i:02d}_model_states.pt", + ) From 0ef2c074ac03c2b888e9003e7ce4c166cb78cc82 Mon Sep 17 00:00:00 2001 From: dmahan93 <44207705+dmahan93@users.noreply.github.com> Date: Thu, 15 Aug 2024 16:26:15 -0500 Subject: [PATCH 2/4] bugfix: chat turns instead of repeating the conversation in preprocess_data_with_chat_template.py (#1258) * bugfix: chat turns instead of repeating the conversation * pre-commit --- tools/datasets/preprocess_data_with_chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 81770deff..55623b303 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -104,7 +104,7 @@ def build_chat( ) chat_tokens = tokenizer.apply_chat_template( chat[: i + 1], add_generation_prompt=add_gen - ) + )[len(tokens) :] # remove previous stuff... tokens.extend(chat_tokens) if only_last_turn and (i != len(chat) - 1): From f8c9e68c4984a0b6f7f5f276b563d2612a6dce9f Mon Sep 17 00:00:00 2001 From: jaimemcc <99298642+jaimemcc-intel@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:57:02 -0700 Subject: [PATCH 3/4] Conversion for CI from self-hosted hardware (#1245) * changing from self-hosted runners to Github's ubuntu-22.04 runner environment * adding warning about not using 'self-hosted' runner labels and using Github runners instead * updated some guidance in comments for coverity scan CI * moving CPU tests to workflow_dispatch only --- .github/workflows/{cpu_ci_on_pr.yml => .cpu_ci_on_pr.yml} | 4 +++- .github/workflows/coverity_scan.yml | 5 +++-- .github/workflows/cpu_ci.yml | 2 +- .github/workflows/cpu_ci_dispatch.yml | 2 +- .github/workflows/pull_request.yml | 5 +++-- tests/README.md | 2 ++ 6 files changed, 13 insertions(+), 7 deletions(-) rename .github/workflows/{cpu_ci_on_pr.yml => .cpu_ci_on_pr.yml} (58%) diff --git a/.github/workflows/cpu_ci_on_pr.yml b/.github/workflows/.cpu_ci_on_pr.yml similarity index 58% rename from .github/workflows/cpu_ci_on_pr.yml rename to .github/workflows/.cpu_ci_on_pr.yml index 971640c18..43ce025c0 100644 --- a/.github/workflows/cpu_ci_on_pr.yml +++ b/.github/workflows/.cpu_ci_on_pr.yml @@ -1,3 +1,5 @@ +# This file is hidden (.cpu_cpi_on_pr.yml) to minimize the number of runner minutes consumed. + name: "Pull Request CPU Tests" on: @@ -7,7 +9,7 @@ on: jobs: run-tests: - runs-on: [ 'test', 'self-hosted' ] + runs-on: ubuntu-22.04 # ubuntu-latest currently points to ubuntu-22.04 but 24.04 is in beta - recommend testing on 24.04 and then changing instead of using ubuntu-latest steps: - name: Checkout Repository uses: actions/checkout@v4 diff --git a/.github/workflows/coverity_scan.yml b/.github/workflows/coverity_scan.yml index a79d0d8fb..128d279cc 100644 --- a/.github/workflows/coverity_scan.yml +++ b/.github/workflows/coverity_scan.yml @@ -17,9 +17,10 @@ jobs: runs-on: ubuntu-latest env: - COV_USER: ${{ secrets.COV_USER }} + COV_USER: ${{ secrets.COV_USER }} # needs to be an email with access to the Coverity stream - add to secrets/actions COVERITY_PROJECT: ${{ secrets.COVERITY_PROJECT }} - COVERITY_TOKEN: ${{ secrets.COVERITY_TOKEN }} + COVERITY_TOKEN: ${{ secrets.COVERITY_TOKEN }} # you can get this token from Coverity stream dashboard: + # https://scan.coverity.com/projects/?tab=project_settings steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 9160fccab..6910b8a1c 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -5,7 +5,7 @@ on: "push" jobs: run-tests: #runs-on: ubuntu-latest - runs-on: [ 'test', 'self-hosted' ] + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/cpu_ci_dispatch.yml b/.github/workflows/cpu_ci_dispatch.yml index b1d108b3b..38485d6a6 100644 --- a/.github/workflows/cpu_ci_dispatch.yml +++ b/.github/workflows/cpu_ci_dispatch.yml @@ -10,7 +10,7 @@ on: jobs: run-tests: - runs-on: [ 'test', 'self-hosted' ] + runs-on: ubuntu-22.04 steps: - name: Checkout Repository uses: actions/checkout@v4 diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 53be528ae..7b06256bf 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -1,6 +1,7 @@ name: Pull Request -on: [pull_request, workflow_dispatch] +#on: [pull_request, workflow_dispatch] +on: workflow_dispatch jobs: pre-commit: @@ -40,7 +41,7 @@ jobs: git commit -m "Update NeoXArgs docs automatically" git push run-tests: - runs-on: self-hosted + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v4 diff --git a/tests/README.md b/tests/README.md index 390a52898..f5ba5e560 100644 --- a/tests/README.md +++ b/tests/README.md @@ -57,6 +57,8 @@ Tests can be run against physical CPUs through GitHub Actions. To have tests run ### runs-on +#### NOTE: These BKMs were written to work with CI infrastructure that is no longer in place. To use the Github runners (ubuntu-22.04 / ubuntu-latest), skip the 'runs-on' section. + The CI needs to be written to target the CPU Github Action runner. The jobs that need to run on CPU should use the hardware runner's labels: ```yaml jobs: From 8b43196fbd832b797be9f3d88d54481171010507 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Fri, 23 Aug 2024 14:02:59 -0400 Subject: [PATCH 4/4] Megatron-LM style Sequence Parallel (#1257) * first draft (shape errors occurring) * training works (but poor convergence) * debugging progress: current commit works if we do regular TP via impl-ing AR in rowparallel as RS then AG * Update NeoXArgs docs automatically * push most recent code (updated mark_norms fn, back to 'real' sequence parallel) * Update NeoXArgs docs automatically * Fix LayerNorm all reduce gradient hook * Sum instead of average for LayerNorm gradient all reduce * Update NeoXArgs docs automatically * Update NeoXArgs docs automatically * Fix gather and reduce scatter ops on sequence dimension * Fix sequence parallel with tied weight embeddings * Update NeoXArgs docs automatically * cleanup pass + add MoE arguments.py guard * pre-commit and clean up comments * remove vestigial debug code * remove unused debugging code * remove dummy test config * update fp32_allreduce to handle fp16 ; don't cast to fp32 for gathers * run linter on the rest of the files * Improve performance of sequence parallel gather, scatter, and reduce * Add comment * Update NeoXArgs docs automatically --------- Co-authored-by: github-actions Co-authored-by: Brandon Yang Co-authored-by: Quentin Anthony --- configs/neox_arguments.md | 12 +- megatron/model/__init__.py | 5 +- megatron/model/gpt2_model.py | 5 +- megatron/model/transformer.py | 29 ++++- megatron/model/utils.py | 56 ++++++-- megatron/model/word_embeddings.py | 10 ++ megatron/mpu/__init__.py | 3 + megatron/mpu/layers.py | 39 +++++- megatron/mpu/mappings.py | 187 +++++++++++++++++++++++++-- megatron/mpu/utils.py | 22 ++++ megatron/neox_arguments/arguments.py | 4 + megatron/neox_arguments/neox_args.py | 7 + megatron/training.py | 3 + 13 files changed, 349 insertions(+), 33 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 1e67685ed..413138597 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 455446c + Default = 53d0ae8 current git hash of repository @@ -1056,6 +1056,16 @@ Parallelism Arguments +- **sequence_parallel**: bool + + Default = False + + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** + + + - **expert_interval**: int Default = 2 diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 619b4c33d..23be28936 100755 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -16,5 +16,8 @@ # limitations under the License. from .gpt2_model import GPT2ModelPipe -from .utils import get_params_for_weight_decay_optimization +from .utils import ( + get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, +) from .word_embeddings import SoftEmbedding diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9e643874a..7899048db 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output): ) logits = parallel_lm_logits( - lm_output, embedding.word_embeddings_weight, self.parallel_output + lm_output, + embedding.word_embeddings_weight, + self.parallel_output, + seq_parallel=self.neox_args.sequence_parallel, ) return logits diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..62e7d3a9c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -254,6 +254,7 @@ def __init__( gather_output=not parallel_output, skip_bias_add=False, mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here + seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 ) # else: @@ -1024,7 +1025,14 @@ def __init__( self.moe_type = neox_args.moe_type if self.gpt_j_residual: - self.reduce = mpu.mappings.reduce_from_model_parallel_region + # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. + # the reduction we use is a simple allreduce for pure Tensor Parallel, + # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) + self.reduce = ( + mpu.mappings.reduce_from_model_parallel_region + if not neox_args.sequence_parallel + else mpu.mappings.reduce_scatter_to_sequence_parallel_region + ) # Self attention. self.attention = ParallelSelfAttention( @@ -1339,10 +1347,25 @@ def forward(self, args): return self.norm(args) -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): +def parallel_lm_logits( + input_, + word_embeddings_weight, + parallel_output, + seq_parallel=False, + seq_dim=1, + bias=None, +): """LM logits using word embedding weights.""" # Parallel logits. - input_parallel = mpu.copy_to_model_parallel_region(input_) + if seq_parallel: + # if using Sequence Parallelism, our logits are sharded along the sequence dimension. + # gather them here. (backward pass: reduce-scatter) + input_parallel = mpu.gather_from_sequence_parallel_region( + input_, seq_dim=seq_dim + ) + else: + # Set up backprop all-reduce. + input_parallel = mpu.copy_to_model_parallel_region(input_) # Matrix multiply. if bias is None: diff --git a/megatron/model/utils.py b/megatron/model/utils.py index c3da2ce8b..97b409c1d 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -18,8 +18,8 @@ """Utilities for models.""" import torch -from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm from megatron.model.fused_softmax import SoftmaxFusionTypes +from megatron import mpu from types import GeneratorType import torch.distributed as dist @@ -35,15 +35,9 @@ def get_params_for_weight_decay_optimization(module, neox_args): "name": "no_weight_decay_params", } for module_ in module.modules(): - if any( - [ - isinstance(module_, LayerNorm), - isinstance(module_, RMSNorm), - isinstance(module_, ScaleNorm), - ] - ) or ( - neox_args.weight_decay == 0.0 - ): # also include all parameters here if no weight decay is being done + # apply weight decay to any "...Norm" modules. + if "norm" in type(module_).__name__.lower() or neox_args.weight_decay == 0.0: + # also include all parameters here if no weight decay is being done no_weight_decay_params["params"].extend( [p for p in list(module_._parameters.values()) if p is not None] ) @@ -359,3 +353,45 @@ def get_fusion_type(neox_args): elif neox_args.scaled_masked_softmax_fusion: fusion_type = SoftmaxFusionTypes.general return fusion_type + + +def reduce_weight_grads_from_model_parallel_region(input_): + """A hook that can be applied to any weight tensor via .register_hook(). + Allreduces grads for e.g. LN weights across the model parallel group. + Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel. + """ + # Bypass the function if no TP -> no comm needed. + if mpu.get_model_parallel_world_size() == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + input_ = input_.float() + + # All-reduce. + torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) + + # Bf16 convert + if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + input_ = input_.bfloat16() + + return input_ + + +def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): + """Iterate through the modules in our model, and for any "...Norm" classnames, + register a hook on each of that module's parameters which will allreduce norms' weights' grads across + the model (sequence) parallel region. + """ + + if not neox_args.sequence_parallel: + # if we aren't using sequence parallelism, this is a no-op + return + + for module_ in module.modules(): + if "norm" in type(module_).__name__.lower(): + # this is a norm, we want to allreduce its weight grads across sequence parallel region + for name, param in module_.named_parameters(): + if param.requires_grad: + param.register_hook(reduce_weight_grads_from_model_parallel_region) diff --git a/megatron/model/word_embeddings.py b/megatron/model/word_embeddings.py index f7372bc55..ce3c1117e 100644 --- a/megatron/model/word_embeddings.py +++ b/megatron/model/word_embeddings.py @@ -50,6 +50,11 @@ def __init__( self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes + + self.sequence_parallel = ( + neox_args.sequence_parallel + ) # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks + self.use_mup = neox_args.use_mup self.mup_embedding_mult = neox_args.mup_embedding_mult self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult @@ -159,6 +164,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): with torch.no_grad(): embeddings.mul_(self.mup_embedding_mult) + if self.sequence_parallel: + # TODO: megatron-lm does dropout using the scattered embs. This would save a tiny bit of time, perhaps? + # Not a priority since we don't often use dropout + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + return embeddings diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 2365507d9..780fb33e8 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -47,6 +47,9 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region +from .mappings import scatter_to_sequence_parallel_region from .random import checkpoint from .random import get_cuda_rng_tracker diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 0d14806ac..d59edab94 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -33,6 +33,8 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region from .random import get_cuda_rng_tracker from .utils import divide from .utils import VocabUtility @@ -416,6 +418,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, + seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. ): super(ColumnParallelLinear, self).__init__() @@ -427,6 +430,10 @@ def __init__( world_size = MoE_mp_size if MOE else get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + self.init_method = init_method self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters @@ -551,14 +558,29 @@ def set_parallel_output(self, value: bool): def forward(self, input_): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) + + if self.sequence_parallel: + input_parallel = input_ + else: + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. + if self.sequence_parallel: + # do an AG in the fwd pass, RS in bwd pass. + # gather / scatter portion happens across the sequence dim (self.seq_dim)-- + # almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h] + input_parallel = gather_from_sequence_parallel_region( + input_parallel, seq_dim=self.seq_dim + ) + bias = self.bias if not self.skip_bias_add else None output_parallel = F.linear(input_parallel, self.weight, bias) if self.gather_output: # All-gather across the partitions. + assert ( + not self.sequence_parallel + ), "sequence_parallel=True and gather_output=True are incompatible!" output = gather_from_model_parallel_region(output_parallel) else: output = output_parallel @@ -623,6 +645,12 @@ def __init__( self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.parallel_output = parallel_output + + self.sequence_parallel = neox_args.sequence_parallel + assert not ( + self.sequence_parallel and not self.input_is_parallel + ), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True." + self.init_method = init_method self.stride = stride self.keep_master_weight_for_test = keep_master_weight_for_test @@ -748,7 +776,12 @@ def forward(self, input_): # Matrix multiply. output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. - if not self.parallel_output: + if self.sequence_parallel and not self.parallel_output: + # do an RS in the fwd pass, AG in bwd pass. + # skip in the gpt-j parallel sublayer case (self.parallel_output=True) + # (user responsible for calling reduce-scatter) + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + elif not self.parallel_output: output_ = reduce_from_model_parallel_region(output_parallel) else: output_ = output_parallel diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 535fe6255..f11d9e6ab 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -23,7 +23,7 @@ get_model_parallel_rank, get_fp32_allreduce, ) -from .utils import split_tensor_along_last_dim +from .utils import split_tensor_along_last_dim, split_tensor_along_any_dim def _reduce(input_): @@ -33,17 +33,17 @@ def _reduce(input_): if get_model_parallel_world_size() == 1: return input_ - # Bf16 convert + # upcast to fp32 if using fp32 allreduce dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): + if get_fp32_allreduce(): input_ = input_.float() # All-reduce. torch.distributed.all_reduce(input_, group=get_model_parallel_group()) - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.bfloat16() + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + input_ = input_.to(dt) return input_ @@ -75,11 +75,6 @@ def _gather(input_): if world_size == 1: return input_ - # Bf16 convert - dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.float() - # Size and dimension. last_dim = input_.dim() - 1 rank = get_model_parallel_rank() @@ -91,9 +86,100 @@ def _gather(input_): # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - output = output.bfloat16() + return output + + +def _reduce_scatter_along_seq_dim(input_, seq_dim): + """Reduce-scatter the input tensor across model parallel group, scattering across sequence dim.""" + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # upcast to fp32 if using fp32 allreduce + dt = input_.dtype + if get_fp32_allreduce(): + input_ = input_.float() + + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + assert dim_size[seq_dim] % world_size == 0 + + if seq_dim == 0: + # reduce_scatter_tensor is faster but only works correctly on dimension 0 + dim_size[seq_dim] = dim_size[seq_dim] // world_size + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + tensor_list = list( + torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) + ) + output = torch.empty_like(tensor_list[0]) + torch.distributed.reduce_scatter(output, tensor_list) + + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + output = output.to(dt) + + return output + + +def _gather_along_seq_dim(input_, seq_dim): + """Gather tensors and concatinate along the (manually-specified) sequence dimension.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + dim_size[seq_dim] = dim_size[seq_dim] * world_size + + if seq_dim == 0: + # reduce_gather_tensor is faster but only works correctly on dimension 0 + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + input_ = input_.contiguous() + rank = get_model_parallel_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_model_parallel_group() + ) + output = torch.cat(tensor_list, dim=seq_dim) + + return output + + +def _split_along_seq_dim(input_, seq_dim): + """Split the tensor along the sequence dimension (as manually selected) and keep the + corresponding slice.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along second dimension. + input_list = split_tensor_along_any_dim(input_, world_size, seq_dim) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_model_parallel_rank() + output = input_list[rank].contiguous() return output @@ -162,6 +248,65 @@ def backward(ctx, grad_output): return _split(grad_output) +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce-Scatter across sequence parallel region (same as model parallel region.) + Note: same region as model parallel region + """ + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return _gather_along_seq_dim(grad_output, seq_dim=seq_dim), None + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """All-Gather across sequence parallel region (same region as model parallel region.)""" + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _gather_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _gather_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return _reduce_scatter_along_seq_dim(grad_output, seq_dim=seq_dim), None + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Scatter (split) sequence length across sequence parallel region (=> same region as model parallel.)""" + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _split_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _split_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return ( + _gather_along_seq_dim(grad_output, seq_dim=seq_dim), + None, + ) + + # ----------------- # Helper functions. # ----------------- @@ -181,3 +326,17 @@ def scatter_to_model_parallel_region(input_): def gather_from_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + + +def reduce_scatter_to_sequence_parallel_region(input_, seq_dim=0): + return _ReduceScatterToSequenceParallelRegion.apply(input_, seq_dim) + + +def gather_from_sequence_parallel_region(input_, seq_dim=0): + return _GatherFromSequenceParallelRegion.apply(input_, seq_dim) + + +def scatter_to_sequence_parallel_region( + input_, seq_dim=1 +): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h] instead of the usual [s, b, h] + return _ScatterToSequenceParallelRegion.apply(input_, seq_dim) diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 13941dc29..1f97e0e76 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -53,6 +53,28 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= return tensor_list +def split_tensor_along_any_dim( + tensor, num_partitions, seq_dim, contiguous_split_chunks=False +): + """Split a tensor along a user-specified dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + seq_dim: dimension along which to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + seq_dim_size = divide(tensor.size()[seq_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, seq_dim_size, dim=seq_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the first and last index of the vocabulary belonging to the `rank` diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 054689eda..6a84df6c7 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1038,6 +1038,10 @@ def calculate_derived(self): assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3" assert self.mlp_type == "regular", "MoE not compatible with LLaMA" + assert ( + self.sequence_parallel is False + ), "MoE not compatible with Sequence Parallel" + # Attention config if self.attention_config is None: self.update_value("attention_config", [[["global"], self.num_layers]]) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index febefb3c2..7993f785f 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -85,6 +85,13 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ + sequence_parallel: bool = False + """ + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** + """ + expert_interval: int = 2 """ Have one MoE layer every expert_interval layers diff --git a/megatron/training.py b/megatron/training.py index 3265680c5..ce59b242a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -43,6 +43,7 @@ GPT2ModelPipe, SoftEmbedding, get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, ) from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.data.data_utils import build_train_valid_test_data_iterators @@ -765,6 +766,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + mark_norms_for_sequence_parallel_grad_sync(model, neox_args) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -891,6 +893,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and neox_args.iteration <= neox_args.profile_step_stop ): torch.cuda.nvtx.range_push(f"Optimizer step") + timers("optimizer").start() if neox_args.deepspeed: model.step()