diff --git a/.github/workflows/cpu_ci_on_pr.yml b/.github/workflows/.cpu_ci_on_pr.yml similarity index 58% rename from .github/workflows/cpu_ci_on_pr.yml rename to .github/workflows/.cpu_ci_on_pr.yml index 971640c18..43ce025c0 100644 --- a/.github/workflows/cpu_ci_on_pr.yml +++ b/.github/workflows/.cpu_ci_on_pr.yml @@ -1,3 +1,5 @@ +# This file is hidden (.cpu_cpi_on_pr.yml) to minimize the number of runner minutes consumed. + name: "Pull Request CPU Tests" on: @@ -7,7 +9,7 @@ on: jobs: run-tests: - runs-on: [ 'test', 'self-hosted' ] + runs-on: ubuntu-22.04 # ubuntu-latest currently points to ubuntu-22.04 but 24.04 is in beta - recommend testing on 24.04 and then changing instead of using ubuntu-latest steps: - name: Checkout Repository uses: actions/checkout@v4 diff --git a/.github/workflows/coverity_scan.yml b/.github/workflows/coverity_scan.yml index a79d0d8fb..128d279cc 100644 --- a/.github/workflows/coverity_scan.yml +++ b/.github/workflows/coverity_scan.yml @@ -17,9 +17,10 @@ jobs: runs-on: ubuntu-latest env: - COV_USER: ${{ secrets.COV_USER }} + COV_USER: ${{ secrets.COV_USER }} # needs to be an email with access to the Coverity stream - add to secrets/actions COVERITY_PROJECT: ${{ secrets.COVERITY_PROJECT }} - COVERITY_TOKEN: ${{ secrets.COVERITY_TOKEN }} + COVERITY_TOKEN: ${{ secrets.COVERITY_TOKEN }} # you can get this token from Coverity stream dashboard: + # https://scan.coverity.com/projects/?tab=project_settings steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 9160fccab..6910b8a1c 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -5,7 +5,7 @@ on: "push" jobs: run-tests: #runs-on: ubuntu-latest - runs-on: [ 'test', 'self-hosted' ] + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/cpu_ci_dispatch.yml b/.github/workflows/cpu_ci_dispatch.yml index b1d108b3b..38485d6a6 100644 --- a/.github/workflows/cpu_ci_dispatch.yml +++ b/.github/workflows/cpu_ci_dispatch.yml @@ -10,7 +10,7 @@ on: jobs: run-tests: - runs-on: [ 'test', 'self-hosted' ] + runs-on: ubuntu-22.04 steps: - name: Checkout Repository uses: actions/checkout@v4 diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 3213718df..7b06256bf 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -1,6 +1,7 @@ name: Pull Request -on: [pull_request] +#on: [pull_request, workflow_dispatch] +on: workflow_dispatch jobs: pre-commit: @@ -9,7 +10,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v4 with: - python-version: 3.10 + python-version: "3.10.14" cache: "pip" cache-dependency-path: "**/requirements*.txt" # Need the right version of clang-format @@ -40,10 +41,20 @@ jobs: git commit -m "Update NeoXArgs docs automatically" git push run-tests: - runs-on: self-hosted + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v4 + with: + python-version: "3.10.13" + cache-dependency-path: "**/requirements*.txt" - name: prepare data - run: python prepare_data.py + run: python3 prepare_data.py + - name: install pytest + run: python3 -m pip install pytest pytest-forked pyyaml requests wandb + - name: install torch + run: python3 -m pip install torch + - name: install requirements + run: pip install -r requirements/requirements.txt - name: Run Tests run: pytest --forked tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7de35027a..249255306 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: hooks: - id: codespell args: [ - '--ignore-words-list=reord,dout', # Word used in error messages that need rewording + '--ignore-words-list=reord,dout,te', # Word used in error messages that need rewording. te --> transformerengine --check-filenames, --check-hidden, ] diff --git a/README.md b/README.md index e7f61bf20..b5fc0d877 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,6 @@ To install the remaining basic dependencies, run: pip install -r requirements/requirements.txt pip install -r requirements/requirements-wandb.txt # optional, if logging using WandB pip install -r requirements/requirements-tensorboard.txt # optional, if logging via tensorboard -python ./megatron/fused_kernels/setup.py install # optional, if using fused kernels ``` from the repository root. @@ -106,6 +105,16 @@ from the repository root. +### Fused Kernels +We now support AMD GPUs (MI100, MI250X) through JIT fused-kernel compilation. Fused kernels will be built and loaded as needed. To avoid waiting during job launching, you can also do the following for manual pre-build: + +```python +python +from megatron.fused_kernels import load +load() +``` +This will automatically adapts building process over different GPU vendors (AMD, NVIDIA) without platform specific code changes. To further test fused kernels using `pytest`, use `pytest tests/model/test_fused_kernels.py` + ### Flash Attention To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. @@ -640,7 +649,7 @@ If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher, # Profiling -We support profiling with Nsight Systems and PyTorch Memory Profiling. +We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memory Profiling. ## Nsight Systems Profiling @@ -656,6 +665,15 @@ The generated output file can then by viewed with the Nsight Systems GUI: ![Alt text](images/nsight_profiling.png) +## PyTorch Profiling + +To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`. + +The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within +TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). + +![Alt text](images/pytorch_profiling.png) + ## PyTorch Memory Profiling To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`. @@ -718,7 +736,7 @@ The following publications by other research groups use this library: The following models were trained using this library: ### English LLMs -- EleutherAI's [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b), [Pythia (70M through 13B)](https://github.com/EleutherAI/pythia), and [LLeMMA (34B)](https://arxiv.org/abs/2310.10631) +- EleutherAI's [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b) and [Pythia (70M through 13B)](https://github.com/EleutherAI/pythia) - CarperAI's [FIM-NeoX-1.3B](https://huggingface.co/CarperAI/FIM-NeoX-1.3B) - StabilityAI's [StableLM (3B and 7B)](https://github.com/Stability-AI/StableLM) - Together.ai's [RedPajama-INCITE (3B and 7B)](https://together.ai/blog/redpajama-models-v1) @@ -729,13 +747,15 @@ The following models were trained using this library: ### Non-English LLMs - EleutherAI's [Polyglot-Ko (1.3B through 12.8B)](https://github.com/EleutherAI/polyglot) (Korean) - Korea University's [KULLM-Polyglot (5.8B and 12.8B)](https://github.com/nlpai-lab/KULLM) (Korean) -- Stability AI's [Japanese Stable LM (7B)](https://huggingface.co/stabilityai/japanese-stablelm-base-alpha-7b) +- Stability AI's [Japanese Stable LM (7B)](https://huggingface.co/stabilityai/japanese-stablelm-base-alpha-7b) (Japanese) - LearnItAnyway's [LLaVA-Polyglot-Ko (1.3B)](https://huggingface.co/LearnItAnyway/llava-polyglot-ko-1.3b-hf) (Korean) - Rinna Co.'s [japanese-gpt-neox-3.6b](https://huggingface.co/rinna/japanese-gpt-neox-3.6b) (Japanese) and [bilingual-gpt-neox-4b](https://huggingface.co/rinna/bilingual-gpt-neox-4b) (English / Japanese) - CyberAgent's [Open-CLM (125M through 7B)](https://huggingface.co/cyberagent/open-calm-7b) (Japanese) - The Hungarian Research Centre for Linguistics's [PULI GPTrio (6.7B)](https://huggingface.co/NYTK/PULI-GPTrio) (Hungarian / English / Chinese) - The University of Tokyo's [weblab-10b](https://huggingface.co/Kojima777/weblab-10b) and [weblab-10b-instruct](https://huggingface.co/Kojima777/weblab-10b-instruction-sft) (Japanese) - nolando.ai's [Hi-NOLIN (9B)](https://blog.nolano.ai/Hi-NOLIN/) (English, Hindi) +- Renmin University of China's [YuLan (12B)](https://huggingface.co/yulan-team/YuLan-Base-12b) (English, Chinese) +- The Basque Center for Language Technology's [Latixna (70B)](https://huggingface.co/HiTZ/latxa-70b-v1.2) (Basque) ### Code Models - Carnegie Mellon University's [PolyCoder (160M through 2.7B)](https://github.com/VHellendoorn/Code-LMs) and [CAT-LM (2.7B)](https://huggingface.co/nikitharao/catlm) @@ -743,11 +763,13 @@ The following models were trained using this library: - CodeFuse AI's [CodeFuse (13B)](https://huggingface.co/codefuse-ai/CodeFuse-13B) ### AI for Science +- EleutherAI's [LLeMMA (34B)](https://arxiv.org/abs/2310.10631) - Oak Ridge National Lab's [FORGE (26B)](https://github.com/at-aaims/forge) -- Oak Ridge National Lab and EleutherAI's [Unnamed Material Science Domain Models (7B)](https://github.com/at-aaims/forge) +- Oak Ridge National Lab's [Unnamed Material Science Domain Models (7B)](https://arxiv.org/abs/2402.00691) - Pacific Northwest National Lab's [MolJet (undisclosed size)](https://openreview.net/pdf?id=7UudBVsIrr) ### Other Modalities +- Rinna Co.'s [PSLM (7B)](https://arxiv.org/abs/2406.12428) (speech / text) - University College London's [ChessGPT-3B](https://huggingface.co/Waterhorse/chessgpt-base-v1) - Gretel's [Text-to-Table (3B)](https://huggingface.co/gretelai/text2table) diff --git a/configs/README.md b/configs/README.md index d8ae81739..3102a34d1 100644 --- a/configs/README.md +++ b/configs/README.md @@ -9,7 +9,7 @@ Below is an example configuration `.yaml` to train a ~160M parameter GPT model. For a detailed list of all the arguments available for neox, see [neox_arguments.md](neox_arguments.md) -Note: yaml arguments may be formatted with either '-' or '_'. The standard separator used is a '_' as shown in the example configurations below. However, the use of '-' as a separator may be deprecated in the future. +Note: yaml arguments may be formatted with either '-' or '\_'. The standard separator used is a '\_' as shown in the example configurations below. However, the use of '-' as a separator may be deprecated in the future. ```yaml # GPT-3 pretraining setup { @@ -235,6 +235,33 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in "eval_iters": 10, ``` +However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g. + +```yaml + "dataset_impl": "pairwise", + "train_impl": "dpo", + "pack_impl": "unpacked", + "dpo_beta": 0.1, + "dpo_fp32": true, + "pos_train_data_path": "data/enwik8/enwik8_text_pos_document", + "pos_valid_data_path": "data/enwik8/enwik8_text_pos_document", + "pos_test_data_path": "data/enwik8/enwik8_text_pos_document", + "neg_train_data_path": "data/enwik8/enwik8_text_neg_document", + "neg_valid_data_path": "data/enwik8/enwik8_text_neg_document", + "neg_test_data_path": "data/enwik8/enwik8_text_neg_document", + ## If you have labels... (likely to mask out user turns) + "pos_train_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "pos_valid_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "pos_test_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "neg_train_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + "neg_valid_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + "neg_test_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + ## If you want to precompute the logits over your dataset... + "precompute_model_name": "gpt2", + ## Needed for the generation.py step, if precomputing + "text_gen_type": "precompute" +``` + ### LR Scheduler settings ```yaml diff --git a/configs/llama/13B.yml b/configs/llama/13B.yml index 305567be1..7a823a43c 100644 --- a/configs/llama/13B.yml +++ b/configs/llama/13B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/llama/30B.yml b/configs/llama/30B.yml index 450f8da38..2c356cea2 100644 --- a/configs/llama/30B.yml +++ b/configs/llama/30B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/llama/65B.yml b/configs/llama/65B.yml index 85f199ce2..cc22d3734 100644 --- a/configs/llama/65B.yml +++ b/configs/llama/65B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/llama/7B.yml b/configs/llama/7B.yml index ecbf187a8..0b134ae27 100644 --- a/configs/llama/7B.yml +++ b/configs/llama/7B.yml @@ -22,5 +22,5 @@ "use_bias_in_norms": false, "use_bias_in_attn_linear": false, "mlp_type": "llama", - "activation": "silu", + "activation": "swiglu", } diff --git a/configs/mamba/mamba-1.4B.yml b/configs/mamba/mamba-1.4B.yml index 2898a72fd..eae467d0e 100644 --- a/configs/mamba/mamba-1.4B.yml +++ b/configs/mamba/mamba-1.4B.yml @@ -19,5 +19,71 @@ "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00002, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 1, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-130M.yml b/configs/mamba/mamba-130M.yml index d9a6ab92e..bd05723b2 100644 --- a/configs/mamba/mamba-130M.yml +++ b/configs/mamba/mamba-130M.yml @@ -19,5 +19,71 @@ "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0006, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00006, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0.0, + "attention_dropout": 0.0, + + # precision settings + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-2.8B.yml b/configs/mamba/mamba-2.8B.yml index 1aacb264b..d5afef368 100644 --- a/configs/mamba/mamba-2.8B.yml +++ b/configs/mamba/mamba-2.8B.yml @@ -19,5 +19,71 @@ "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00016, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.000016, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-370M.yml b/configs/mamba/mamba-370M.yml index 5e5a78cca..0058f1c0e 100644 --- a/configs/mamba/mamba-370M.yml +++ b/configs/mamba/mamba-370M.yml @@ -12,12 +12,77 @@ "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "attention_config": [[["mamba"], 64]], + "attention_config": [[["mamba"], 48]], "mamba_selective_scan_fusion": true, "mamba_causal_conv_fusion": true, "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0003, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00003, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-790M.yml b/configs/mamba/mamba-790M.yml index fcd324d9d..4aef7e813 100644 --- a/configs/mamba/mamba-790M.yml +++ b/configs/mamba/mamba-790M.yml @@ -12,12 +12,78 @@ "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "attention_config": [[["mamba"], 64]], + "attention_config": [[["mamba"], 48]], "mamba_selective_scan_fusion": true, "mamba_causal_conv_fusion": true, "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00025, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + } + }, + "min_lr": 0.000025, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 24313b68d..d24b2b60a 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = ad8417c + Default = 217b4c5 current git hash of repository @@ -335,11 +335,11 @@ Model Arguments -- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm'] +- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm', 'te_rmsnorm', 'te_layernorm'] Default = layernorm - Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm". + Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm". @@ -1056,6 +1056,16 @@ Parallelism Arguments +- **sequence_parallel**: bool + + Default = False + + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** + + + - **expert_interval**: int Default = 2 diff --git a/generate.py b/generate.py index 743e350d0..e19ef2e0e 100755 --- a/generate.py +++ b/generate.py @@ -23,6 +23,7 @@ generate_samples_from_prompt, generate_samples_unconditional, generate_samples_interactive, + precompute_logits, ) @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None): top_p=neox_args.top_p, ) + elif neox_args.text_gen_type == "precompute": + precompute_logits(neox_args=neox_args, model=model) else: raise ValueError( f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}" diff --git a/images/pytorch_profiling.png b/images/pytorch_profiling.png new file mode 100644 index 000000000..e85324dc6 Binary files /dev/null and b/images/pytorch_profiling.png differ diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..7c13131ad 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -23,6 +23,7 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset +from megatron.data.pairwise_dataset import PairwiseDataset from megatron.data.samplers import DistributedBatchSampler @@ -53,39 +54,113 @@ def make_data_loader(dataset, neox_args): def build_the_dataset( data_prefix, + pos_data_prefix, + neg_data_prefix, name, data_impl, + pack_impl, + dataset_impl, + allow_chopped, num_samples, seq_length, seed, skip_warmup, build_index_mappings=True, label_prefix=None, + pos_label_prefix=None, + neg_label_prefix=None, + precompute_model_name=None, ): """Build train/valid/test datasets.""" - - indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - if label_prefix is None: - label_dataset = None + if dataset_impl == "gpt2": + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + if label_prefix is None: + label_dataset = None + else: + label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + if precompute_model_name is not None: + # If we have the name, assume it exists. If it doesn't, it will just be None which is fine. + precompute_indexed_dataset = make_indexed_dataset( + data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) + precompute_indexed_dataset = precompute_indexed_dataset + elif dataset_impl == "pairwise": + pos_indexed_dataset = make_indexed_dataset( + pos_data_prefix, data_impl, skip_warmup + ) + neg_indexed_dataset = make_indexed_dataset( + neg_data_prefix, data_impl, skip_warmup + ) + if pos_label_prefix is None: + pos_label_dataset = None + # Also do neg here since they both must be the same + assert neg_label_prefix is None + neg_label_dataset = None + else: + pos_label_dataset = make_indexed_dataset( + pos_label_prefix, data_impl, skip_warmup + ) + # Also do neg here since they both must be the same + assert neg_label_prefix is not None + neg_label_dataset = make_indexed_dataset( + neg_label_prefix, data_impl, skip_warmup + ) + if precompute_model_name is None: + pos_ref_dataset = None + neg_ref_dataset = None + else: + pos_ref_dataset = make_indexed_dataset( + pos_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) + neg_ref_dataset = make_indexed_dataset( + neg_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) else: - label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented") - total_num_of_documents = indexed_dataset.sizes.shape[0] + total_num_of_documents = ( + indexed_dataset.sizes.shape[0] + if dataset_impl == "gpt2" + else pos_indexed_dataset.sizes.shape[0] + ) print_rank_0(" {}:".format(name)) print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPT2Dataset( - name, - data_prefix, - documents, - indexed_dataset, - num_samples, - seq_length, - seed, - build_index_mappings=build_index_mappings, - label_dataset=label_dataset, - ) + + if dataset_impl == "gpt2": + dataset = GPT2Dataset( + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + label_dataset=label_dataset, + ) + elif dataset_impl == "pairwise": + dataset = PairwiseDataset( + name, + pos_data_prefix, + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + pos_label_dataset=pos_label_dataset, + neg_label_dataset=neg_label_dataset, + pos_ref_dataset=pos_ref_dataset, + neg_ref_dataset=neg_ref_dataset, + ) + return dataset @@ -93,6 +168,8 @@ def build_train_valid_test_datasets( data_prefix, use_shared_fs, data_impl, + pack_impl, + allow_chopped, splits_string, train_valid_test_num_samples, seq_length, @@ -129,7 +206,6 @@ def build_dataset(index, name): documents = np.arange( start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 ) - dataset = GPT2Dataset( name, data_prefix, @@ -138,6 +214,8 @@ def build_dataset(index, name): train_valid_test_num_samples[index], seq_length, seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, use_shared_fs=use_shared_fs, ) return dataset @@ -204,54 +282,129 @@ def build_weighted_datasets( ): # build individual datasets train_datasets, valid_datasets, test_datasets = [], [], [] - for i, (train_path, label_path, valid_path, test_path) in enumerate( + for i, ( + train_path, + train_label_path, + valid_path, + valid_label_path, + test_path, + test_label_path, + pos_train_path, + neg_train_path, + pos_train_label_path, + neg_train_label_path, + pos_valid_path, + neg_valid_path, + pos_valid_label_path, + neg_valid_label_path, + pos_test_path, + neg_test_path, + pos_test_label_path, + neg_test_label_path, + ) in enumerate( zip_longest( - neox_args.train_data_paths, - neox_args.label_data_paths if neox_args.label_data_paths else [], - neox_args.valid_data_paths, - neox_args.test_data_paths, + neox_args.train_data_paths if neox_args.train_data_paths else [], + neox_args.train_label_data_paths + if neox_args.train_label_data_paths + else [], + neox_args.valid_data_paths if neox_args.valid_data_paths else [], + neox_args.valid_label_data_paths + if neox_args.valid_label_data_paths + else [], + neox_args.test_data_paths if neox_args.test_data_paths else [], + neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], + neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], + neox_args.pos_train_label_data_paths + if neox_args.pos_train_label_data_paths + else [], + neox_args.neg_train_label_data_paths + if neox_args.neg_train_label_data_paths + else [], + neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [], + neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [], + neox_args.pos_valid_label_data_paths + if neox_args.pos_valid_label_data_paths + else [], + neox_args.neg_valid_label_data_paths + if neox_args.neg_valid_label_data_paths + else [], + neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [], + neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [], + neox_args.pos_test_label_data_paths + if neox_args.pos_test_label_data_paths + else [], + neox_args.neg_test_label_data_paths + if neox_args.neg_test_label_data_paths + else [], ) ): - if train_path: + if train_path or pos_train_path: train_datasets.append( build_the_dataset( data_prefix=train_path, name=f"train_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=train_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, - label_prefix=label_path, + label_prefix=train_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_train_path, + neg_data_prefix=neg_train_path, + pos_label_prefix=pos_train_label_path, + neg_label_prefix=neg_train_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) - if valid_path: + if valid_path or pos_valid_path: valid_datasets.append( build_the_dataset( data_prefix=valid_path, name=f"valid_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=valid_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + label_prefix=valid_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_valid_path, + neg_data_prefix=neg_valid_path, + pos_label_prefix=pos_valid_label_path, + neg_label_prefix=neg_valid_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) - if test_path: + if test_path or pos_test_path: test_datasets.append( build_the_dataset( data_prefix=test_path, name=f"test_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=test_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + label_prefix=test_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_test_path, + neg_data_prefix=neg_test_path, + pos_label_prefix=pos_test_label_path, + neg_label_prefix=neg_test_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) return train_datasets, valid_datasets, test_datasets @@ -323,7 +476,7 @@ def build_train_valid_test_data_iterators(neox_args): test_iters * neox_args.train_batch_size, ] - if neox_args.train_data_paths: + if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths): # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset train_weights, train_num_samples = get_normalized_weights_and_num_samples( @@ -414,6 +567,8 @@ def build_train_valid_test_data_iterators(neox_args): seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, ) # Build dataloders. diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 75e601fda..edba57df2 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -36,14 +36,19 @@ def __init__( num_samples, seq_length, seed, + pack_impl="packed", + allow_chopped=True, build_index_mappings=True, use_shared_fs=True, label_dataset=None, ): self.name = name + self.pack_impl = pack_impl + self.allow_chopped = allow_chopped self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.seq_length = seq_length # Checks assert np.min(documents) >= 0 @@ -56,10 +61,13 @@ def __init__( data_prefix, documents, self.indexed_dataset.sizes, + self.label_dataset, num_samples, seq_length, seed, + self.pack_impl, use_shared_fs=use_shared_fs, + allow_chopped=self.allow_chopped, ) self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 self.sample_idx_len = self.sample_idx.shape[0] - 1 @@ -113,8 +121,38 @@ def __getitem__(self, idx): samples.append(np.concatenate(sample_list)) if len(datasets) == 1: + if len(samples[0]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] return {"text": np.array(samples[0], dtype=np.int64)} else: + if len(samples[0]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + samples[1] = samples[1][: (self.seq_length + 1)] return { "text": np.array(samples[0], dtype=np.int64), "label": np.array(samples[1], dtype=np.int64), @@ -132,10 +170,13 @@ def _build_index_mappings( data_prefix, documents, sizes, + label_dataset, num_samples, seq_length, seed, + packing_impl, use_shared_fs=True, + allow_chopped=True, ): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. @@ -155,6 +196,9 @@ def _build_index_mappings( _filename += "_{}ns".format(num_samples) _filename += "_{}sl".format(seq_length) _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + if allow_chopped: + _filename += "_ac" doc_idx_filename = _filename + "_doc_idx.npy" sample_idx_filename = _filename + "_sample_idx.npy" shuffle_idx_filename = _filename + "_shuffle_idx.npy" @@ -177,44 +221,116 @@ def _build_index_mappings( ) # doc-idx. start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng) - np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save doc-idx mapping " - "(seconds): {:4f}".format(time.time() - start_time) - ) - # sample-idx. - start_time = time.time() - # Use C++ implementation for speed. - from megatron.data import helpers - - assert doc_idx.dtype == np.int32 - assert sizes.dtype == np.int32 - - num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length - if 2 * (num_samples + 1) < np.iinfo(np.int32).max: - sample_idx = helpers.build_sample_idx_int32( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + if packing_impl == "packed": + doc_idx = _build_doc_idx(documents, num_epochs, np_rng) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) ) - else: - sample_idx = helpers.build_sample_idx_int64( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + from megatron.data import helpers + + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + + num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length + if 2 * (num_samples + 1) < np.iinfo(np.int32).max: + sample_idx = helpers.build_sample_idx_int32( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + else: + sample_idx = helpers.build_sample_idx_int64( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) ) - np.save(sample_idx_filename, sample_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save sample-idx mapping " - "(seconds): {:4f}".format(time.time() - start_time) - ) - # shuffle-idx. - start_time = time.time() - # -1 is due to data structure used to retrieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) - np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save shuffle-idx mapping" - " (seconds): {:4f}".format(time.time() - start_time) - ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + elif packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + if not allow_chopped: + # +1 since we shift left/right by 1 + if sizes[temp_shuffle_idx[curr_shuffle_idx]] > seq_length + 1: + curr_shuffle_idx += 1 + continue + # First, check if we need to skip this item... + if label_dataset is not None: + if np.all( + label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = sizes[temp_shuffle_idx[curr_shuffle_idx]] + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + if not allow_chopped: + # +1 since we shift left/right by 1 + if sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # Just in case we have bad data in the loop... + if np.all(label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model diff --git a/megatron/data/pairwise_dataset.py b/megatron/data/pairwise_dataset.py new file mode 100644 index 000000000..e39b4d626 --- /dev/null +++ b/megatron/data/pairwise_dataset.py @@ -0,0 +1,457 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pairwise style dataset.""" + +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0 + + +class PairwiseDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + pos_data_prefix, # Don't need neg since it's assumed you have paired the data already. + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl="unpacked", + build_index_mappings=True, + use_shared_fs=True, + pos_label_dataset=None, + pos_ref_dataset=None, + neg_label_dataset=None, + neg_ref_dataset=None, + allow_chopped=True, + ): + + self.name = name + self.pos_indexed_dataset = pos_indexed_dataset + self.pos_label_dataset = pos_label_dataset + self.pos_ref_dataset = pos_ref_dataset + self.neg_indexed_dataset = neg_indexed_dataset + self.neg_label_dataset = neg_label_dataset + self.neg_ref_dataset = neg_ref_dataset + self.pack_impl = pack_impl + self.seq_length = seq_length + # Checks + assert np.min(documents) >= 0 + assert (neg_label_dataset is not None and pos_label_dataset is not None) or ( + neg_label_dataset is None and pos_label_dataset is None + ), "Label datasets must be both None or both not None" + assert np.max(documents) < pos_indexed_dataset.sizes.shape[0] + assert pos_indexed_dataset.sizes.shape[0] == neg_indexed_dataset.sizes.shape[0] + assert ( + pack_impl != "packed" + ), "Packed implementation not supported for pairwise dataset" + + if build_index_mappings: + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + pos_data_prefix, + documents, + self.pos_indexed_dataset.sizes, + self.neg_indexed_dataset.sizes, + self.pos_label_dataset, + self.neg_label_dataset, + num_samples, + seq_length, + seed, + pack_impl, + use_shared_fs=use_shared_fs, + allow_chopped=allow_chopped, + ) + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len - 1: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) + + def __len__(self): + return min(self.shuffle_idx_len, self.sample_idx_len) + + def __getitem__(self, idx): + try: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # Labels and texts are supposed to be fully in sync. + datasets = [self.pos_indexed_dataset, self.neg_indexed_dataset] + + if self.pos_label_dataset is not None: + datasets += [ + self.pos_label_dataset, + self.neg_label_dataset, + ] + if self.pos_ref_dataset is not None: + datasets += [ + self.pos_ref_dataset, + self.neg_ref_dataset, + ] + samples = [] + pos_ref_samples = [] + neg_ref_samples = [] + # If we are within the same document, just extract the chunk. + for n, dataset in enumerate(datasets): + if doc_index_f == doc_index_l: + samples.append( + dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + else: + # Otherwise, get the rest of the initial document. + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + samples.append(np.concatenate(sample_list)) + for i in range(len(samples)): + if len(samples[i]) < (self.seq_length + 1): + if ((i == 2) or (i == 3)) and self.pos_label_dataset is not None: + # Labels... So pad with -100 + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), + mode="constant", + constant_values=-100, + ) + else: + # Pad with 0s, can use any number since it's masked. + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), + mode="constant", + constant_values=0, + ) + elif len(samples[i]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[i] = samples[i][: (self.seq_length + 1)] + ret = {} + ret["pos"] = np.array(samples[0], dtype=np.int64) + ret["neg"] = np.array(samples[1], dtype=np.int64) + if self.pos_label_dataset is not None: + ret["pos_label"] = np.array(samples[2], dtype=np.int64) + ret["neg_label"] = np.array(samples[3], dtype=np.int64) + if self.pos_ref_dataset is not None: + ret["pos_ref"] = np.array(samples[4], dtype=np.float32) + ret["neg_ref"] = np.array(samples[5], dtype=np.float32) + elif self.pos_ref_dataset is not None: + # Don't have labels... + ret["pos_ref"] = np.array(samples[2], dtype=np.float32) + ret["neg_ref"] = np.array(samples[3], dtype=np.float32) + return ret + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] + + +def _build_index_mappings( + name, + pos_data_prefix, + documents, + pos_sizes, + neg_sizes, + pos_label_dataset, + neg_label_dataset, + num_samples, + seq_length, + seed, + packing_impl, + use_shared_fs=True, + allow_chopped=True, +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, pos_sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = pos_data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + print_rank_0( + " > WARNING: could not find index map files, building " + "the indices on rank 0 ..." + ) + # doc-idx. + start_time = time.time() + if packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + # If not allow_chopped, skip this item if it's chopped. + if not allow_chopped: + if ( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + if ( + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + # Then, check if we need to skip this item... + if pos_label_dataset is not None: + if np.all( + pos_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + if np.all( + neg_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = max( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]], + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]], + ) + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.array([i % len(documents) for i in range(num_samples)]) + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + # Check if we need to skip this item... + if not allow_chopped: + # +1 since we shift left/right by 1 + if pos_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + if neg_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # In theory if we don't allow chopped we should be able to skip it, but the warm fuzzies I get + # from this are worth the extra bool check + if np.all(pos_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + if np.all(neg_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_io_parallel_group() + ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng): + """Build an array with length = number-of-epochs * number-of-documents. + Each index is mapped to a corresponding document.""" + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(size, np_rng): + """Build the range [0, size) and shuffle.""" + dtype_ = np.uint32 + if size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx) + return shuffle_idx diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index 1e4c9efac..3694e964b 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -135,8 +135,8 @@ def _cpp_extention_load_helper( srcpath / "fused_rotary_positional_embedding.cpp", srcpath / "fused_rotary_positional_embedding_cuda.cu", ] - fused_rotary_positional_embedding_cuda = _cpp_extention_load_helper( - "fused_rotary_positional_embedding_cuda", + fused_rotary_positional_embedding = _cpp_extention_load_helper( + "fused_rotary_positional_embedding", sources, extra_cuda_flags, extra_include_paths, @@ -174,7 +174,7 @@ def load_fused_kernels(): print(e) print("=" * 100) print( - f"ERROR: Fused kernels configured but not properly installed. Please run `pip install {str(srcpath)}` to install them" + f"ERROR: Fused kernels configured but not properly installed. Please run `from megatron.fused_kernels import load()` then `load()` to load them correctly" ) print("=" * 100) exit() diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 619b4c33d..23be28936 100755 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -16,5 +16,8 @@ # limitations under the License. from .gpt2_model import GPT2ModelPipe -from .utils import get_params_for_weight_decay_optimization +from .utils import ( + get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, +) from .word_embeddings import SoftEmbedding diff --git a/megatron/model/activations.py b/megatron/model/activations.py index 7a29b0716..c0b825261 100644 --- a/megatron/model/activations.py +++ b/megatron/model/activations.py @@ -25,9 +25,23 @@ def get_activation(neox_args): - """retrieves the activation function specified in neox_args""" + """retrieves the activation function specified in neox_args and whether or not the activation is gated""" + is_gated = False if neox_args.activation == "geglu": - activation_func = GEGLU(neox_args=neox_args) + is_gated = True + activation_func = F.gelu + elif neox_args.activation == "reglu": + is_gated = True + activation_func = F.relu + elif neox_args.activation == "bilinear": + is_gated = True + activation_func = lambda x: x + elif neox_args.activation == "swiglu": + is_gated = True + activation_func = swish + elif neox_args.activation == "glu": + is_gated = True + activation_func = F.sigmoid elif neox_args.activation == "gelu": if neox_args.onnx_safe and neox_args.bias_gelu_fusion: raise ValueError("onnx_safe + bias_gelu_fusion not compatible") @@ -49,7 +63,7 @@ def get_activation(neox_args): activation_func = F.silu else: raise ValueError(f"Activation function {neox_args.activation} not recognized") - return activation_func + return activation_func, is_gated ###### BIAS GELU FUSION/ NO AUTOGRAD ################ @@ -119,21 +133,3 @@ def swish(x, beta: float = 1.0): @torch.jit.script def mish(x): return x * torch.tanh(F.softplus(x)) - - -class GEGLU(torch.nn.Module): - def __init__(self, neox_args): - super(GEGLU, self).__init__() - if neox_args.onnx_safe: - self.activation_func = erf_gelu - else: - self.activation_func = F.gelu - - def forward(self, x, bias=None): - x, gate = x.chunk(2, dim=-1) - if bias is not None: - bias_1, bias_2 = bias.chunk(2, dim=-1) - x = x + bias_1 - gate = gate + bias_2 - intermediate_parallel = self.activation_func(gate) - return intermediate_parallel * x diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index d33ded506..3fd251147 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -21,7 +21,10 @@ except: HAVE_PERSIST_LAYER_NORM = False -from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction +from apex.normalization.fused_layer_norm import ( + FusedLayerNormAffineFunction, + FusedRMSNormAffineFunction, +) global fused_layer_norm_cuda @@ -148,3 +151,112 @@ def forward(self, input): ) return output + + +class MixedFusedRMSNorm(torch.nn.Module): + def __init__( + self, + normalized_shape, + eps=1e-5, + no_persist_layer_norm=True, + sequence_parallel=False, + apply_rmsnorm_1p=False, + mem_efficient_rms=True, + ): + super(MixedFusedRMSNorm, self).__init__() + + self.apply_rmsnorm_1p = apply_rmsnorm_1p + self.mem_efficient_rms = mem_efficient_rms + self.norm_fn = FusedRMSNormAffineFunction + + global fused_layer_norm_cuda + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + # List of hiddens sizes supported in the persistent layer norm kernel + # If the hidden size is not supported, fall back to the non-persistent + # kernel. + persist_ln_hidden_sizes = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, + ] + if ( + normalized_shape not in persist_ln_hidden_sizes + or not HAVE_PERSIST_LAYER_NORM + ): + no_persist_layer_norm = True + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.scale = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + self.no_persist_layer_norm = no_persist_layer_norm + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.scale, "sequence_parallel", self.sequence_parallel) + + def reset_parameters(self): + + if self.apply_rmsnorm_1p: + init.zeros_(self.scale) + else: + init.ones_(self.scale) + + def forward(self, input): + + weight = self.scale + 1 if self.apply_rmsnorm_1p else self.scale + # CPU path is here for unittest sake. + if not input.is_cuda: + print( + "WARNING! The input of FusedLayerNorm should be on the GPU." + "This warning should only be triggered in the FusedRMSNorm unit tests." + ) + # Latest pytorch actually supports F.rms_norm but I don't want to break builds so... + return F.layer_norm(input, self.normalized_shape, weight, None, self.eps) + + # Apex does not have versions yet (https://github.com/NVIDIA/apex/pull/1648), so we need to inspect + # the function manually on whether the extra arg introduced in https://github.com/NVIDIA/apex/pull/1715 exists yet + if "memory_efficient" in inspect.getfullargspec(self.norm_fn.forward).args: + return self.norm_fn.apply( + input, + weight, + self.normalized_shape, + self.eps, + self.mem_efficient_rms, + ) + else: + return self.norm_fn.apply(input, weight, self.normalized_shape, self.eps) + + # Apex's fast layer norm function outputs a 'view' tensor (i.e., has + # a populated '_base' field). This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + output = make_viewless_tensor( + inp=output, requires_grad=input.requires_grad, keep_graph=True + ) + + return output diff --git a/megatron/model/gmlp.py b/megatron/model/gmlp.py index c3462c651..6400640bd 100644 --- a/megatron/model/gmlp.py +++ b/megatron/model/gmlp.py @@ -112,7 +112,7 @@ def __init__( init_method=init_method, skip_bias_add=True, ) - self.activation_func = get_activation(neox_args) + self.activation_func, _ = get_activation(neox_args) ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size()) if neox_args.attention_config[layer_number] == "amlp": d_attn = neox_args.gmlp_attn_dim diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9e643874a..7899048db 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output): ) logits = parallel_lm_logits( - lm_output, embedding.word_embeddings_weight, self.parallel_output + lm_output, + embedding.word_embeddings_weight, + self.parallel_output, + seq_parallel=self.neox_args.sequence_parallel, ) return logits diff --git a/megatron/model/init_functions.py b/megatron/model/init_functions.py index 86a003dbd..8a0b8e251 100644 --- a/megatron/model/init_functions.py +++ b/megatron/model/init_functions.py @@ -145,7 +145,7 @@ def init_(tensor, use_mup=use_mup_outer): def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0): """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving - the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution.""" + the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.""" std = math.sqrt(2 / (5 * dim)) def init_(tensor, use_mup=use_mup_outer): diff --git a/megatron/model/mamba/mamba.py b/megatron/model/mamba/mamba.py index d5d6b336f..950e36fed 100644 --- a/megatron/model/mamba/mamba.py +++ b/megatron/model/mamba/mamba.py @@ -14,7 +14,8 @@ import einops except ModuleNotFoundError: print( - "Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, or directly from https://github.com/state-spaces/mamba" + "Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, \ + or directly from https://github.com/state-spaces/mamba" ) pass @@ -45,12 +46,21 @@ def __init__( neox_args.mamba_use_bias_in_linears and neox_args.mamba_inner_func_fusion ), "Mamba fused inner fn and bias in x_proj not compatible!" + assert ( + neox_args.intermediate_size == None or neox_args.expansion_factor == None + ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" + # set variables, mostly following mamba defaults self.d_model = neox_args.hidden_size self.d_state = 16 # state dimensions per channel self.d_conv = 4 # convolution width - self.expand = 2 # linear projection expansion factors - self.d_inner = int(self.expand * self.d_model) + if neox_args.intermediate_size: + self.d_inner = neox_args.intermediate_size + else: + self.expand = ( + neox_args.expansion_factor if neox_args.expansion_factor else 2 + ) + self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter self.dt_scale = 1.0 diff --git a/megatron/model/norms.py b/megatron/model/norms.py index 8b06b177c..ba175d3eb 100644 --- a/megatron/model/norms.py +++ b/megatron/model/norms.py @@ -14,19 +14,38 @@ import torch from torch.nn import LayerNorm as LayerNorm -from .fused_layer_norm import MixedFusedLayerNorm def get_norm(neox_args): if neox_args.norm == "rmsnorm": - norm = RMSNorm eps = neox_args.rms_norm_epsilon + if neox_args.rmsnorm_fusion: + from .fused_layer_norm import MixedFusedRMSNorm + + norm = MixedFusedRMSNorm + else: + norm = RMSNorm elif neox_args.norm == "layernorm": eps = neox_args.layernorm_epsilon - norm = MixedFusedLayerNorm if neox_args.layernorm_fusion else LayerNorm + if neox_args.layernorm_fusion: + from .fused_layer_norm import MixedFusedLayerNorm + + norm = MixedFusedLayerNorm + else: + norm = LayerNorm elif neox_args.norm == "scalenorm": eps = neox_args.scalenorm_epsilon norm = ScaleNorm + elif neox_args.norm == "te_rmsnorm": + from .transformer_engine import TERMSNorm + + norm = TERMSNorm + eps = neox_args.rms_norm_epsilon + elif neox_args.norm == "te_layernorm": + from .transformer_engine import TELayerNorm + + norm = TELayerNorm + eps = neox_args.layernorm_epsilon else: raise ValueError(f"norm {neox_args.norm} not recognized") return norm, eps diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 5d4e0d144..b3741a3fc 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) self.receptance = nn.Linear( neox_args.hidden_size, neox_args.hidden_size, bias=False ) - self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) def forward(self, x): xx = self.time_shift(x) - x @@ -275,14 +275,23 @@ def __init__(self, neox_args, layer_number): self.layer_number = layer_number self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" + assert ( + neox_args.intermediate_size == None or neox_args.expansion_factor == None + ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" if not hasattr(neox_args, "dim_att"): neox_args.dim_att = neox_args.hidden_size - if not hasattr(neox_args, "dim_ffn"): - # Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic - neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32) + if neox_args.intermediate_size: + neox_args.ffn_dim = neox_args.intermediate_size + else: + self.expand = ( + neox_args.expansion_factor if neox_args.expansion_factor else 3.5 + ) + neox_args.ffn_dim = int(self.expand * neox_args.hidden_size) + # Make hidden size 3.5x by default. Round to nearest multiple of 32 until we add hdim rounding logic + neox_args.ffn_dim = int(neox_args.ffn_dim // 32 * 32) assert neox_args.hidden_size % 32 == 0 assert neox_args.dim_att % 32 == 0 - assert neox_args.dim_ffn % 32 == 0 + assert neox_args.ffn_dim % 32 == 0 self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads self.head_size = self.neox_args.head_size self.num_attention_heads = neox_args.num_attention_heads diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..d2b93eb06 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -93,37 +93,57 @@ def __init__( init_method, output_layer_init_method, parallel_output=False, + multiple_of=256, MOE=False, MoE_mp_size=1, ): super().__init__() + assert ( + neox_args.intermediate_size == None or neox_args.expansion_factor == None + ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - self.activation_func = get_activation(neox_args) + self.activation_func, self.is_gated = get_activation(neox_args) self.activation_type = neox_args.activation self.bias_gelu_fusion = neox_args.bias_gelu_fusion + self.multiple_of = multiple_of - # auto scale so geglu has equal parameters - ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4 - ff_dim = ( - int(ff_mult * neox_args.hidden_size) * 2 - if self.activation_type == "geglu" - else ff_mult * neox_args.hidden_size + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation(self.activation_func) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - self.dense_h_to_4h = mpu.ColumnParallelLinear( + + self.linear1 = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, - output_size=ff_dim, + output_size=ffn_dim, gather_output=False, init_method=init_method, skip_bias_add=True, MOE=MOE, MoE_mp_size=MoE_mp_size, ) - ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. - self.dense_4h_to_h = mpu.RowParallelLinear( + self.linear2 = mpu.RowParallelLinear( neox_args=neox_args, - input_size=ff_dim_in, + input_size=ffn_dim_in, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, @@ -134,13 +154,10 @@ def __init__( ) def forward(self, hidden_states): + # [s, b, intermediate_size] + intermediate_parallel, bias_parallel = self.linear1(hidden_states) - # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - - if ( - self.activation_type == "gelu" and self.bias_gelu_fusion - ) or self.activation_type == "geglu": + if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel ) @@ -150,84 +167,23 @@ def forward(self, hidden_states): ) # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) + output, output_bias = self.linear2(intermediate_parallel) return output, output_bias -class LLaMAParallelMLP(nn.Module): - """LLaMA's MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. At the end, dropout is also - applied. - - Note: multiple_of is used to compute the hidden dimension of the MLP - """ - - def __init__( - self, - neox_args, - init_method, - output_layer_init_method, - parallel_output=False, - multiple_of=256, - MOE=False, - MoE_mp_size=1, - ): +class Gated_Activation(torch.nn.Module): + def __init__(self, activation_func): super().__init__() + self.activation_func = activation_func - self.activation_func = get_activation(neox_args) - self.activation_type = neox_args.activation - - self.multiple_of = multiple_of - - # Allow custom intermediate size, e.g. for Mistral - if neox_args.intermediate_size is not None: - ff_dim = neox_args.intermediate_size - else: - ff_dim = int(2 * neox_args.hidden_size * 4 / 3) - ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) - - self.w1 = mpu.ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ff_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - ) - self.w3 = mpu.ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ff_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - ) - self.w2 = mpu.RowParallelLinear( - neox_args=neox_args, - input_size=ff_dim, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - parallel_output=parallel_output, - bias=False, - MOE=MOE, - MoE_mp_size=MoE_mp_size, - ) - - def forward(self, hidden_states): - w1_out, _ = self.w1(hidden_states) - w3_out, _ = self.w3(hidden_states) - return self.w2(self.activation_func(w1_out) * w3_out) + def forward(self, x, bias=None): + x, gate = x.chunk(2, dim=-1) + if bias is not None: + bias_1, bias_2 = bias.chunk(2, dim=-1) + x = x + bias_1 + gate = gate + bias_2 + intermediate_parallel = self.activation_func(gate) + return intermediate_parallel * x class ParallelLinear(nn.Module): @@ -254,6 +210,7 @@ def __init__( gather_output=not parallel_output, skip_bias_add=False, mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here + seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 ) # else: @@ -1024,7 +981,14 @@ def __init__( self.moe_type = neox_args.moe_type if self.gpt_j_residual: - self.reduce = mpu.mappings.reduce_from_model_parallel_region + # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. + # the reduction we use is a simple allreduce for pure Tensor Parallel, + # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) + self.reduce = ( + mpu.mappings.reduce_from_model_parallel_region + if not neox_args.sequence_parallel + else mpu.mappings.reduce_scatter_to_sequence_parallel_region + ) # Self attention. self.attention = ParallelSelfAttention( @@ -1046,24 +1010,13 @@ def __init__( # MLP def get_mlp(mlp_type, **kw): - if mlp_type == "regular": - return ParallelMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - **kw, - ) - elif mlp_type == "llama": - return LLaMAParallelMLP( - neox_args=neox_args, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - parallel_output=self.gpt_j_residual, - **kw, - ) - else: - raise KeyError(mlp_type) + return ParallelMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + **kw, + ) self.num_experts = ( neox_args.moe_num_experts @@ -1280,7 +1233,7 @@ def forward(self, x, attention_mask, layer_past=None): with torch.enable_grad(): if ( - self.mlp_type == "llama" + self.activation == "swiglu" or self.num_experts > 1 and self.moe_type == "deepspeed" ): @@ -1339,10 +1292,25 @@ def forward(self, args): return self.norm(args) -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): +def parallel_lm_logits( + input_, + word_embeddings_weight, + parallel_output, + seq_parallel=False, + seq_dim=1, + bias=None, +): """LM logits using word embedding weights.""" # Parallel logits. - input_parallel = mpu.copy_to_model_parallel_region(input_) + if seq_parallel: + # if using Sequence Parallelism, our logits are sharded along the sequence dimension. + # gather them here. (backward pass: reduce-scatter) + input_parallel = mpu.gather_from_sequence_parallel_region( + input_, seq_dim=seq_dim + ) + else: + # Set up backprop all-reduce. + input_parallel = mpu.copy_to_model_parallel_region(input_) # Matrix multiply. if bias is None: diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 8e3d0d527..338513a97 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -1,19 +1,56 @@ -import transformer_engine as te import torch -from pkg_resources import packaging -_te_version = packaging.version.Version(version("transformer-engine")) +try: + import transformer_engine as te +except ImportError: + raise ImportError( + "Unable to import transformer-engine. Please refer to " + "https://github.com/NVIDIA/TransformerEngine for installation instructions." + ) + + +class TERMSNorm(torch.nn.Module): + def __init__(self, dim, eps=1e-8, **kwargs): + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `RMSNorm` based on input + :param dim: model size + :param eps: epsilon value, default 1e-8 + """ + super(TERMSNorm, self).__init__() + + self.d = dim + self.eps = eps + self.norm = te.pytorch.RMSNorm( + hidden_size=self.d, + eps=self.eps, + **kwargs, + ) + def forward(self, x): + return self.norm(x) + + +class TELayerNorm(torch.nn.Module): + def __init__(self, dim, eps=1.0e-5, **kwargs): + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `LayerNorm` based on input + :param dim: model size + :param eps: epsilon value, default 1.0e-5 + """ + super(TELayerNorm, self).__init__() + + self.d = dim + self.eps = eps + self.norm = te.pytorch.LayerNorm( + hidden_size=self.d, + eps=self.eps, + **kwargs, + ) -class TENorm: - """ - A conditional wrapper to initialize an instance of Transformer-Engine's - `LayerNorm` or `RMSNorm` based on input - """ - - def __new__(): - return - # TODO ??? + def forward(self, x): + return self.norm(x) class TELinear(te.pytorch.Linear): @@ -22,12 +59,12 @@ class TELinear(te.pytorch.Linear): """ def __init__(self): + # TODO return - # TODO: Nick def forward(self, x): + # TODO return - # TODO: Nick class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): @@ -37,12 +74,12 @@ class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): """ def __init__(self): + # TODO return - # TODO: Nick def forward(self, x): + # TODO return - # TODO: Nick class TEColumnParallelLinear(TELinear): @@ -52,12 +89,12 @@ class TEColumnParallelLinear(TELinear): """ def __init__(self): - # TODO: Nick + # TODO return def forward(self, x): + # TODO return - # TODO: Nick class TERowParallelLinear(TELinear): @@ -67,11 +104,11 @@ class TERowParallelLinear(TELinear): """ def __init__(self): - # TODO: Nick + # TODO return def forward(self, x): - # TODO: Nick + # TODO return @@ -82,11 +119,11 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): """ def __init__(self): - # TODO: tfidia + # TODO return def forward(self, x): - # TODO: tfidia + # TODO return @@ -96,5 +133,5 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): """ def __init__(self): - # TODO: ??? + # TODO return diff --git a/megatron/model/utils.py b/megatron/model/utils.py index c3da2ce8b..77e7f521d 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -20,6 +20,7 @@ import torch from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm from megatron.model.fused_softmax import SoftmaxFusionTypes +from megatron import mpu from types import GeneratorType import torch.distributed as dist @@ -39,6 +40,8 @@ def get_params_for_weight_decay_optimization(module, neox_args): [ isinstance(module_, LayerNorm), isinstance(module_, RMSNorm), + isinstance(module_, TELayerNorm), + isinstance(module_, TERMSNorm), isinstance(module_, ScaleNorm), ] ) or ( @@ -359,3 +362,45 @@ def get_fusion_type(neox_args): elif neox_args.scaled_masked_softmax_fusion: fusion_type = SoftmaxFusionTypes.general return fusion_type + + +def reduce_weight_grads_from_model_parallel_region(input_): + """A hook that can be applied to any weight tensor via .register_hook(). + Allreduces grads for e.g. LN weights across the model parallel group. + Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel. + """ + # Bypass the function if no TP -> no comm needed. + if mpu.get_model_parallel_world_size() == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + input_ = input_.float() + + # All-reduce. + torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) + + # Bf16 convert + if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + input_ = input_.bfloat16() + + return input_ + + +def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): + """Iterate through the modules in our model, and for any "...Norm" classnames, + register a hook on each of that module's parameters which will allreduce norms' weights' grads across + the model (sequence) parallel region. + """ + + if not neox_args.sequence_parallel: + # if we aren't using sequence parallelism, this is a no-op + return + + for module_ in module.modules(): + if "norm" in type(module_).__name__.lower(): + # this is a norm, we want to allreduce its weight grads across sequence parallel region + for name, param in module_.named_parameters(): + if param.requires_grad: + param.register_hook(reduce_weight_grads_from_model_parallel_region) diff --git a/megatron/model/word_embeddings.py b/megatron/model/word_embeddings.py index f7372bc55..ce3c1117e 100644 --- a/megatron/model/word_embeddings.py +++ b/megatron/model/word_embeddings.py @@ -50,6 +50,11 @@ def __init__( self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes + + self.sequence_parallel = ( + neox_args.sequence_parallel + ) # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks + self.use_mup = neox_args.use_mup self.mup_embedding_mult = neox_args.mup_embedding_mult self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult @@ -159,6 +164,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): with torch.no_grad(): embeddings.mul_(self.mup_embedding_mult) + if self.sequence_parallel: + # TODO: megatron-lm does dropout using the scattered embs. This would save a tiny bit of time, perhaps? + # Not a priority since we don't often use dropout + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + return embeddings diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 2365507d9..780fb33e8 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -47,6 +47,9 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region +from .mappings import scatter_to_sequence_parallel_region from .random import checkpoint from .random import get_cuda_rng_tracker diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 0d14806ac..d59edab94 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -33,6 +33,8 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region from .random import get_cuda_rng_tracker from .utils import divide from .utils import VocabUtility @@ -416,6 +418,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, + seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. ): super(ColumnParallelLinear, self).__init__() @@ -427,6 +430,10 @@ def __init__( world_size = MoE_mp_size if MOE else get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + self.init_method = init_method self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters @@ -551,14 +558,29 @@ def set_parallel_output(self, value: bool): def forward(self, input_): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) + + if self.sequence_parallel: + input_parallel = input_ + else: + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. + if self.sequence_parallel: + # do an AG in the fwd pass, RS in bwd pass. + # gather / scatter portion happens across the sequence dim (self.seq_dim)-- + # almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h] + input_parallel = gather_from_sequence_parallel_region( + input_parallel, seq_dim=self.seq_dim + ) + bias = self.bias if not self.skip_bias_add else None output_parallel = F.linear(input_parallel, self.weight, bias) if self.gather_output: # All-gather across the partitions. + assert ( + not self.sequence_parallel + ), "sequence_parallel=True and gather_output=True are incompatible!" output = gather_from_model_parallel_region(output_parallel) else: output = output_parallel @@ -623,6 +645,12 @@ def __init__( self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.parallel_output = parallel_output + + self.sequence_parallel = neox_args.sequence_parallel + assert not ( + self.sequence_parallel and not self.input_is_parallel + ), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True." + self.init_method = init_method self.stride = stride self.keep_master_weight_for_test = keep_master_weight_for_test @@ -748,7 +776,12 @@ def forward(self, input_): # Matrix multiply. output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. - if not self.parallel_output: + if self.sequence_parallel and not self.parallel_output: + # do an RS in the fwd pass, AG in bwd pass. + # skip in the gpt-j parallel sublayer case (self.parallel_output=True) + # (user responsible for calling reduce-scatter) + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + elif not self.parallel_output: output_ = reduce_from_model_parallel_region(output_parallel) else: output_ = output_parallel diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 535fe6255..f11d9e6ab 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -23,7 +23,7 @@ get_model_parallel_rank, get_fp32_allreduce, ) -from .utils import split_tensor_along_last_dim +from .utils import split_tensor_along_last_dim, split_tensor_along_any_dim def _reduce(input_): @@ -33,17 +33,17 @@ def _reduce(input_): if get_model_parallel_world_size() == 1: return input_ - # Bf16 convert + # upcast to fp32 if using fp32 allreduce dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): + if get_fp32_allreduce(): input_ = input_.float() # All-reduce. torch.distributed.all_reduce(input_, group=get_model_parallel_group()) - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.bfloat16() + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + input_ = input_.to(dt) return input_ @@ -75,11 +75,6 @@ def _gather(input_): if world_size == 1: return input_ - # Bf16 convert - dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.float() - # Size and dimension. last_dim = input_.dim() - 1 rank = get_model_parallel_rank() @@ -91,9 +86,100 @@ def _gather(input_): # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - output = output.bfloat16() + return output + + +def _reduce_scatter_along_seq_dim(input_, seq_dim): + """Reduce-scatter the input tensor across model parallel group, scattering across sequence dim.""" + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # upcast to fp32 if using fp32 allreduce + dt = input_.dtype + if get_fp32_allreduce(): + input_ = input_.float() + + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + assert dim_size[seq_dim] % world_size == 0 + + if seq_dim == 0: + # reduce_scatter_tensor is faster but only works correctly on dimension 0 + dim_size[seq_dim] = dim_size[seq_dim] // world_size + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + tensor_list = list( + torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) + ) + output = torch.empty_like(tensor_list[0]) + torch.distributed.reduce_scatter(output, tensor_list) + + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + output = output.to(dt) + + return output + + +def _gather_along_seq_dim(input_, seq_dim): + """Gather tensors and concatinate along the (manually-specified) sequence dimension.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + dim_size[seq_dim] = dim_size[seq_dim] * world_size + + if seq_dim == 0: + # reduce_gather_tensor is faster but only works correctly on dimension 0 + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + input_ = input_.contiguous() + rank = get_model_parallel_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_model_parallel_group() + ) + output = torch.cat(tensor_list, dim=seq_dim) + + return output + + +def _split_along_seq_dim(input_, seq_dim): + """Split the tensor along the sequence dimension (as manually selected) and keep the + corresponding slice.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along second dimension. + input_list = split_tensor_along_any_dim(input_, world_size, seq_dim) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_model_parallel_rank() + output = input_list[rank].contiguous() return output @@ -162,6 +248,65 @@ def backward(ctx, grad_output): return _split(grad_output) +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce-Scatter across sequence parallel region (same as model parallel region.) + Note: same region as model parallel region + """ + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return _gather_along_seq_dim(grad_output, seq_dim=seq_dim), None + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """All-Gather across sequence parallel region (same region as model parallel region.)""" + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _gather_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _gather_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return _reduce_scatter_along_seq_dim(grad_output, seq_dim=seq_dim), None + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Scatter (split) sequence length across sequence parallel region (=> same region as model parallel.)""" + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _split_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _split_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return ( + _gather_along_seq_dim(grad_output, seq_dim=seq_dim), + None, + ) + + # ----------------- # Helper functions. # ----------------- @@ -181,3 +326,17 @@ def scatter_to_model_parallel_region(input_): def gather_from_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + + +def reduce_scatter_to_sequence_parallel_region(input_, seq_dim=0): + return _ReduceScatterToSequenceParallelRegion.apply(input_, seq_dim) + + +def gather_from_sequence_parallel_region(input_, seq_dim=0): + return _GatherFromSequenceParallelRegion.apply(input_, seq_dim) + + +def scatter_to_sequence_parallel_region( + input_, seq_dim=1 +): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h] instead of the usual [s, b, h] + return _ScatterToSequenceParallelRegion.apply(input_, seq_dim) diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 13941dc29..1f97e0e76 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -53,6 +53,28 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= return tensor_list +def split_tensor_along_any_dim( + tensor, num_partitions, seq_dim, contiguous_split_chunks=False +): + """Split a tensor along a user-specified dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + seq_dim: dimension along which to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + seq_dim_size = divide(tensor.size()[seq_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, seq_dim_size, dim=seq_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the first and last index of the vocabulary belonging to the `rank` diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index ff4f4bc21..1677bf072 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -180,7 +180,6 @@ def from_ymls(cls, paths_to_yml_files: List[str], overwrite_values: Dict = None) config_files = dict() # iterate of all to be loaded yaml files for conf_file_name in paths_to_yml_files: - # load file with open(conf_file_name) as conf_file: conf = yaml.load(conf_file, Loader=yaml.FullLoader) @@ -477,7 +476,6 @@ def get_extra_deepspeed_args(self): return extra_ds_args def get_deepspeed_main_args(self): - args_list = list() if self.autotuning_run is not None: @@ -803,7 +801,6 @@ def calculate_batch_parameters( @staticmethod def check_batch_parameters(dp_world_size, train_batch, micro_batch, grad_acc): - assert ( train_batch > 0 ), f"Train batch size: {train_batch} has to be greater than 0" @@ -1033,10 +1030,7 @@ def calculate_derived(self): # Update 'is pipe parallel' flag # if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs - self.update_value( - "is_pipe_parallel", - self.pipe_parallel_size > 1 and self.moe_num_experts == 1, - ) + self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) if self.moe_num_experts > 1: assert not ( self.is_pipe_parallel or self.pipe_parallel_size > 1 @@ -1044,6 +1038,10 @@ def calculate_derived(self): assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3" assert self.mlp_type == "regular", "MoE not compatible with LLaMA" + assert ( + self.sequence_parallel is False + ), "MoE not compatible with Sequence Parallel" + # Attention config if self.attention_config is None: self.update_value("attention_config", [[["global"], self.num_layers]]) @@ -1070,8 +1068,8 @@ def calculate_derived(self): ), "Mamba does not yet have dropout implemented" if "rwkv" in self.attention_config: assert ( - not self.is_pipe_parallel and self.model_parallel_size == 1 - ), "RWKV not currently compatible with parallelism" + self.model_parallel_size == 1 + ), "RWKV not currently compatible with model parallelism" if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV" assert ( @@ -1118,15 +1116,19 @@ def calculate_derived(self): # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) + elif self.pos_train_data_paths and (self.train_data_weights is None): + self.train_data_weights = [1.0] * len(self.pos_train_data_paths) if self.valid_data_paths and (self.valid_data_weights is None): self.valid_data_weights = [1.0] * len(self.valid_data_paths) + elif self.pos_valid_data_paths and (self.valid_data_weights is None): + self.valid_data_weights = [1.0] * len(self.pos_valid_data_paths) if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) + elif self.pos_test_data_paths and (self.test_data_weights is None): + self.test_data_weights = [1.0] * len(self.pos_test_data_paths) - if self.label_data_paths: - err_str = ( - "Must use `label_data_paths` with `train_data_paths`, not `data_path`" - ) + if self.train_label_data_paths: + err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" assert self.train_data_paths and not self.data_path, err_str # if a sample input file is provided, default text_gen_type type to input-file @@ -1191,7 +1193,9 @@ def validate_values(self): return False # Checks. - if self.hidden_size % self.num_attention_heads != 0: + if self.hidden_size % self.num_attention_heads != 0 and not ( + "mamba" in self.attention_config + ): error_message = ( self.__class__.__name__ + ".validate_values() hidden_size must be divisible by num_attention_heads" @@ -1234,7 +1238,6 @@ def validate_values(self): # Parameters sharing does not work with torch DDP. if (self.num_unique_layers is not None) and (self.num_layers is not None): - if not (self.num_unique_layers <= self.num_layers): error_message = ( self.__class__.__name__ diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index febefb3c2..b5e7a619d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -85,6 +85,13 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ + sequence_parallel: bool = False + """ + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** + """ + expert_interval: int = 2 """ Have one MoE layer every expert_interval layers @@ -114,9 +121,12 @@ class NeoXArgsModel(NeoXArgsTemplate): intermediate_size: int = None """ - Transformer intermediate size. Currently only used for "mlp_type": "llama". + Transformer intermediate size. Default = 4h + """ - If not passed, will be set to a reasonable default. + expansion_factor: float = None + """ + Transformer intermediate size. Default = 4 """ num_attention_heads: int = None @@ -152,9 +162,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Maximum number of position embeddings to use. This is the size of position embedding. """ - norm: Literal["layernorm", "rmsnorm", "scalenorm"] = "layernorm" + norm: Literal[ + "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm" + ] = "layernorm" """ - Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm". + Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm". """ layernorm_fusion: bool = False @@ -162,6 +174,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Use fused layer norm kernel (if `norm` is `layernorm`). """ + rmsnorm_fusion: bool = False + """ + Use fused RMS norm kernel (if `norm` is `rmsnorm`). + """ + use_qk_layernorm: bool = False """ Use QK Normalization @@ -271,10 +288,20 @@ class NeoXArgsModel(NeoXArgsTemplate): """ activation: Literal[ - "gelu", "geglu", "relu", "softsign", "swish", "mish", "silu" + "gelu", + "geglu", + "relu", + "softsign", + "swish", + "mish", + "silu", + "reglu", + "swiglu", + "bilinear", + "glu", ] = "gelu" """ - Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"] + Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] """ scaled_upper_triang_masked_softmax_fusion: bool = False @@ -414,9 +441,9 @@ class NeoXArgsModel(NeoXArgsTemplate): mlp_type: str = "regular" """ + Currently, the only mlp_type is "regular." This behavior is currently deprecated. Types: regular: Megatron implementation - llama: LLaMA MLP (SiLU-gated MLP) """ soft_prompt_tuning: dict = None @@ -848,9 +875,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to train datasets. """ - label_data_paths: list = None + train_label_data_paths: list = None """ - List of paths to label datasets (not shifted by 1 yet!). + List of paths to train label datasets (not shifted by 1 yet!). """ test_data_paths: list = None @@ -858,11 +885,57 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to test datasets. """ + test_label_data_paths: list = None + """ + List of paths to test label datasets (not shifted by 1 yet!). + """ + valid_data_paths: list = None """ List of paths to validation datasets. """ + valid_label_data_paths: list = None + """ + List of paths to validation label datasets (not shifted by 1 yet!). + """ + + pos_train_data_paths: list = None + neg_train_data_paths: list = None + """ + List of paths to positive and negative training datasets. + """ + + pos_train_label_data_paths: list = None + neg_train_label_data_paths: list = None + """ + List of paths to positive and negative training label datasets (not shifted by 1 yet!). + """ + + pos_valid_data_paths: list = None + neg_valid_data_paths: list = None + """ + List of paths to positive and negative validation datasets. + """ + + pos_valid_label_data_paths: list = None + neg_valid_label_data_paths: list = None + """ + List of paths to positive and negative validation label datasets (not shifted by 1 yet!). + """ + + pos_test_data_paths: list = None + neg_test_data_paths: list = None + """ + List of paths to positive and negative test datasets. + """ + + pos_test_label_data_paths: list = None + neg_test_label_data_paths: list = None + """ + List of paths to positive and negative test label datasets (not shifted by 1 yet!). + """ + train_data_weights: list = None """ List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting. @@ -912,6 +985,41 @@ class NeoXArgsTraining(NeoXArgsTemplate): Implementation of indexed datasets, can be one of "infer", "cached", or "mmap" """ + pack_impl: Literal["packed", "pack_until_overflow", "unpacked"] = "packed" + """ + Packing implementation, can be one of "packed", "pack_until_overflow", or "unpacked". + + warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets + """ + + dataset_impl: Literal["gpt2", "pairwise"] = "gpt2" + """ + Dataset implementation, can be one of "gpt2" or "pairwise" + """ + + train_impl: Literal["normal", "dpo"] = "normal" + """ + Training implementation, can be one of "normal" or "dpo" + """ + + dpo_fp32: bool = True + """ + Whether to cast logits to fp32 for DPO loss calculation. + """ + + dpo_beta: float = 0.1 + """ + Beta value for DPO + """ + + allow_chopped: bool = True + """ + WARNING: if your packing impl is packed, this is ignored. + + Allow chopped samples in the dataset. + (e.g if your sequence length is 1024 and you have a sample of length 1026, it will be chopped to 1024) + """ + mmap_warmup: bool = False """ Warm up mmap files. @@ -1200,7 +1308,12 @@ class NeoXArgsTextgen(NeoXArgsTemplate): text_gen_type: str = None """ How to generate text/sample the model. - Options: `unconditional`, `input-file`, `interactive` + Options: `unconditional`, `input-file`, `interactive`, `precompute` + """ + + precompute_model_name: str = None + """ + Model name to use for saving precomputed logprobs """ temperature: float = 0.0 diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 7b7a390ab..02926c2c3 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -23,12 +23,15 @@ import time from typing import List, Union +import numpy as np import torch import torch.nn.functional as F from megatron import print_rank_0 from megatron import mpu from megatron.utils import get_ltor_masks_and_position_ids, is_mp_rank_0 +from megatron.data.indexed_dataset import make_builder, make_dataset +from megatron.mpu.mappings import gather_from_model_parallel_region def get_batch(neox_args, context_tokens: torch.Tensor): @@ -52,7 +55,9 @@ def get_batch(neox_args, context_tokens: torch.Tensor): return tokens, attention_mask, position_ids -def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): +def pad_batch( + context_tokens: List[List[int]], pad_id: int, pad_len: int, truncate: bool = False +): """ pads context lengths in context_tokens with pad_id to equal neox_args.seq_length, and returns the padded batch and the new lengths. @@ -60,17 +65,21 @@ def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): context_tokens: list of lists of tokens pad_id: int, integer to use as padding token pad_len: int, context length to be padded; all batch items will be padded to the same length + truncate: bool, if True, truncate context tokens to pad_len if they are longer than pad_len returns: tuple of padded context tokens and a list of unpadded token count """ context_lengths = [] - for tokens in context_tokens: + for i, tokens in enumerate(context_tokens): context_length = len(tokens) if context_length < pad_len: tokens.extend([pad_id] * (pad_len - context_length)) elif context_length > pad_len: - raise ValueError("context_length is bigger than to be padded length") + if not truncate: + raise ValueError("context_length is bigger than to be padded length") + context_tokens[i] = tokens[:pad_len] + context_length = pad_len context_lengths.append(context_length) return context_tokens, context_lengths @@ -807,3 +816,180 @@ def generate_samples_interactive( print_rank_0("Generated Text: " + generated_text) if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: _ = input("\n") + + +def get_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def precompute_logits(neox_args, model): + """ + Precomputes logprobs from training/testing/validation datasets + + Saves it to the same directory as the dataset with the model name appended to it + + neox_args: NeoXArgs. + model: a Megatron model + + """ + if neox_args.precompute_model_name is None: + mdl_name = str(hash(neox_args.load)) + else: + mdl_name = neox_args.precompute_model_name + print_rank_0("Precomputing logprobs...") + model.eval() + data_paths = list() + if neox_args.train_data_paths is not None: + for path in neox_args.train_data_paths: + data_paths.append(path) + for path in neox_args.test_data_paths: + data_paths.append(path) + for path in neox_args.valid_data_paths: + data_paths.append(path) + elif neox_args.pos_train_data_paths is not None: + # Pairwise data... + for path in neox_args.pos_train_data_paths: + data_paths.append(path) + for path in neox_args.neg_train_data_paths: + data_paths.append(path) + for path in neox_args.pos_valid_data_paths: + data_paths.append(path) + for path in neox_args.neg_valid_data_paths: + data_paths.append(path) + for path in neox_args.pos_test_data_paths: + data_paths.append(path) + for path in neox_args.neg_test_data_paths: + data_paths.append(path) + for path in data_paths: + print_rank_0(f"Precomputing logits for {path}") + # Add hash to path... + out_path = path + f"_{mdl_name}" + if os.path.exists(out_path + ".idx"): + continue + dataset = make_dataset(path, neox_args.data_impl, not neox_args.mmap_warmup) + if is_mp_rank_0(): + out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) + out_dataset._dtype = np.float32 + i = 0 + while i < len(dataset): + start = time.time() + model.module.clear_cache() # clear kv cache between batches + if is_mp_rank_0(): + offset = ( + mpu.get_data_parallel_rank() + * neox_args.train_micro_batch_size_per_gpu + ) + context_tokens = [ + [int(x) for x in dataset.get(j % len(dataset)).tolist()] + for j in range( + i + offset, + i + (neox_args.train_micro_batch_size_per_gpu + offset), + ) + ] + # grab microbatch + # pad batch in order to allow conversion to tensor + context_tokens, context_lengths = pad_batch( + copy.deepcopy(context_tokens), + pad_id=0, + pad_len=neox_args.seq_length + 1, + truncate=True, + ) + # print(context_tokens) + label_tokens = [tokens[1:] for tokens in context_tokens] + context_tokens = [tokens[:-1] for tokens in context_tokens] + else: + context_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + label_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + context_lengths = [0 for _ in range(neox_args.batch_size)] + i += ( + neox_args.train_micro_batch_size_per_gpu + * mpu.get_data_parallel_world_size() + ) + # print(context_tokens) + # convert to tensor and broadcast + context_tokens = torch.cuda.LongTensor(context_tokens) + label_tokens = torch.cuda.LongTensor(label_tokens) + # Make sure context tokens + start tokens are the same across all ranks + token_generation_start_index = torch.cuda.LongTensor(context_lengths) + torch.distributed.broadcast( + context_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + token_generation_start_index, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + label_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + # context_tokens = context_tokens[:, :chop_len].contiguous() + # label_tokens = label_tokens[:, :chop_len].contiguous() + with torch.no_grad(): + # get attention mask / position ids + context_tokens, attention_mask, position_ids = get_batch( + neox_args, context_tokens + ) + model_inputs = ( + context_tokens, + position_ids, + attention_mask, + ) + maybe_tuple = forward_model( + model, model_inputs, neox_args.is_pipe_parallel + ) + if isinstance(maybe_tuple, tuple): + logits, _ = maybe_tuple + else: + logits = maybe_tuple + if logits is not None: # if pipe parallel, not all ranks return logits + logits = gather_from_model_parallel_region(logits) + logp = get_logp(logits, label_tokens, True).squeeze() + if neox_args.is_pipe_parallel: + # broadcast generated tokens to pipe parallel group + src_rank = model.grid.stage_to_global(model.num_stages - 1) + logp = ( + logp + if logits is not None + else torch.zeros( + neox_args.batch_size, dtype=torch.float32 + ).cuda() + ) + torch.distributed.broadcast( + tensor=logp, + src=src_rank, + group=mpu.get_pipe_parallel_group(), + ) + logp = logp.squeeze() + logp_list = [ + torch.zeros_like(logp) + for _ in range(mpu.get_data_parallel_world_size()) + ] + torch.distributed.all_gather( + logp_list, logp, group=mpu.get_data_parallel_group() + ) + logp = torch.cat(logp_list, dim=0).cpu().numpy() + if (mpu.get_model_parallel_rank() == 0) and ( + mpu.get_data_parallel_rank() == 0 + ): + for j in range(logp.shape[0]): + out_dataset.add_item(logp[j]) + out_dataset.end_document() + print_rank_0(f"Processed {i} / {len(dataset)} in {time.time() - start}") + if is_mp_rank_0(): + out_dataset.finalize( + out_path + ".idx", + ) + torch.distributed.barrier() diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 348c7cefe..d39e18243 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -32,6 +32,10 @@ def build_tokenizer(args): if args.rank == 0: print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True) + assert ( + args.tokenizer_type is not None + ), "tokenizer_type must be specified in the .yml config" + # Select and instantiate the tokenizer. if args.tokenizer_type.lower() == "GPT2BPETokenizer".lower(): assert args.vocab_file is not None diff --git a/megatron/training.py b/megatron/training.py index 6a4e843ab..d9932483a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -21,12 +21,14 @@ """Pretrain utilities.""" from datetime import datetime from functools import partial +from collections import defaultdict import math import sys from contextlib import nullcontext import torch +import torch.nn.functional as F import deepspeed from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler import numpy as np @@ -43,7 +45,9 @@ GPT2ModelPipe, SoftEmbedding, get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, ) +from megatron.mpu.mappings import gather_from_model_parallel_region from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.data.data_utils import build_train_valid_test_data_iterators from megatron.initialize import initialize_megatron @@ -136,7 +140,7 @@ def gen(): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = hidden_size - model, optimizer, _ = setup_model_and_optimizer( + model, optimizer, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=False ) @@ -192,7 +196,7 @@ def pretrain(neox_args): # Model, optimizer, and learning rate. timers("model and optimizer").start() - model, optimizer, lr_scheduler = setup_model_and_optimizer( + model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( neox_args=neox_args, use_cache=False, iteration=neox_args.iteration ) timers("model and optimizer").stop() @@ -230,6 +234,7 @@ def pretrain(neox_args): neox_args=neox_args, timers=timers, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_data_iterator=train_data_iterator, @@ -277,16 +282,19 @@ def pretrain(neox_args): def _get_batch(neox_args, tokenizer, keys, data, datatype): """Support function for get_batch / get_batch pipe (to avoid code repetition)""" data_b = mpu.broadcast_data(keys, data, datatype) - + token_key = keys[0] + label_key = keys[1] if len(keys) > 1 else None # Unpack. - tokens_ = data_b["text"].long() - if "label" in data_b: + tokens_ = data_b[token_key].long() + if label_key in data_b: + label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() labels = torch.where( - data_b["label"].long() >= 0, - data_b["label"].long(), - torch.zeros_like(data_b["label"].long()), + data_b[label_key].long() >= 0, + data_b[label_key].long(), + torch.zeros_like(data_b[label_key].long()), )[:, 1:].contiguous() else: + label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() @@ -297,9 +305,9 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): eod_mask_loss=neox_args.eod_mask_loss, sliding_window_width=neox_args.sliding_window_width, ) - # If `label` is present, any token < 0 (e.g., -100, the default for torch) skips the loss computation - if "label" in data_b: - loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype) + + # combine loss masks from get_ltor_masks_and_position_ids with loss masks from data + loss_mask = label_mask.to(loss_mask.dtype) * loss_mask return tokens, labels, loss_mask, attention_mask, position_ids @@ -307,7 +315,14 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + if neox_args.train_impl == "normal": + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] + elif neox_args.train_impl == "dpo": + keys = ( + [["pos", "pos_label"], ["neg", "neg_label"]] + if neox_args.pos_train_label_data_paths + else [["pos"], ["neg"]] + ) datatype = torch.int64 # Broadcast data. @@ -315,19 +330,49 @@ def get_batch(neox_args, data_iterator): data = next(data_iterator) else: data = None - return _get_batch( - neox_args=neox_args, - tokenizer=neox_args.tokenizer, - keys=keys, - data=data, - datatype=datatype, - ) + if neox_args.train_impl == "normal": + return _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + elif neox_args.train_impl == "dpo": + pos_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[0], + data=data, + datatype=datatype, + ) + neg_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[1], + data=data, + datatype=datatype, + ) + if neox_args.precompute_model_name: + ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) + else: + ref_data = {"pos_ref": None} + return [ + torch.cat((pos_item, neg_item), dim=0) + for pos_item, neg_item in zip(pos_tup, neg_tup) + ] + [ + torch.cat((ref_data["pos_ref"], ref_data["neg_ref"]), dim=0)[ + :, :-1 + ].contiguous() + if ref_data["pos_ref"] is not None + else None + ] def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" # Items and their type. - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 tokens, labels, loss_mask, attention_mask, position_ids = _get_batch( @@ -415,8 +460,23 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict +def get_pos_neg_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # Split to pos/neg... + return torch.chunk(per_token_logp, 2, 0) + + def forward_step( - data_iterator, model, neox_args, timers, return_logits=False, is_train=False + data_iterator, + model, + neox_args, + timers, + return_logits=False, + is_train=False, + reference_model=None, ): """Forward step.""" if neox_args.is_pipe_parallel: @@ -427,9 +487,14 @@ def forward_step( torch.cuda.nvtx.range_push(f"Get batch") if timers is not None: timers("batch generator").start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - neox_args=neox_args, data_iterator=data_iterator - ) + if neox_args.train_impl == "normal": + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + neox_args=neox_args, data_iterator=data_iterator + ) + if neox_args.train_impl == "dpo": + tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( + neox_args=neox_args, data_iterator=data_iterator + ) if timers is not None: timers("batch generator").stop() @@ -438,38 +503,100 @@ def forward_step( if neox_args.memory_profiling: torch.cuda.nvtx.range_push(f"Forward pass") - # Sequential returns moe_losses, but this is not yet supported by pipe parallel - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, moe_losses = maybe_tuple - else: - outputs = maybe_tuple - moe_losses = [] - if ( - is_train - and neox_args.curriculum_learning - and neox_args.curriculum_seqlen < neox_args.seq_length - ): - loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() - labels = labels[:, : neox_args.curriculum_seqlen].contiguous() - main_loss = cross_entropy( - outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy - ) - if neox_args.moe_num_experts > 1: - if neox_args.moe_type == "deepspeed": - moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) - elif neox_args.moe_type == "megablocks": - moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + metrics = {} + if neox_args.train_impl == "normal": + # Sequential returns moe_losses, but this is not yet supported by pipe parallel + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, moe_losses = maybe_tuple else: - raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") - else: - moe_loss = 0.0 - loss = main_loss + moe_loss + outputs = maybe_tuple + moe_losses = [] + if ( + is_train + and neox_args.curriculum_learning + and neox_args.curriculum_seqlen < neox_args.seq_length + ): + loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() + labels = labels[:, : neox_args.curriculum_seqlen].contiguous() + main_loss = cross_entropy( + outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy + ) + if neox_args.moe_num_experts > 1: + if neox_args.moe_type == "deepspeed": + moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) + elif neox_args.moe_type == "megablocks": + moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + else: + raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") + else: + moe_loss = 0.0 + loss = main_loss + moe_loss + elif neox_args.train_impl == "dpo": + # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + if ref_logp is None: + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + ref_pos, ref_neg = get_pos_neg_logp( + ref_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + else: + ref_pos, ref_neg = torch.chunk(ref_logp, 2, 0) + ref_pos = (ref_pos * pos_loss_mask).sum(-1) + ref_neg = (ref_neg * neg_loss_mask).sum(-1) + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_pos, chosen_neg = get_pos_neg_logp( + chosen_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + chosen_pos = (chosen_pos * pos_loss_mask).sum(-1) + chosen_neg = (chosen_neg * neg_loss_mask).sum(-1) + with torch.no_grad(): + # Collect metrics... + metrics["ref_neg"] = ref_neg.clone().detach().mean() + metrics["ref_pos"] = ref_pos.clone().detach().mean() + metrics["chosen_neg"] = chosen_neg.clone().detach().mean() + metrics["chosen_pos"] = chosen_pos.clone().detach().mean() + chosen_rewards = neox_args.dpo_beta * ( + chosen_pos.clone().detach() - ref_pos.clone().detach() + ) + rejected_rewards = neox_args.dpo_beta * ( + chosen_neg.clone().detach() - ref_neg.clone().detach() + ) + reward_acc = (chosen_rewards > rejected_rewards).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["chosen_rewards"] = chosen_rewards.mean() + metrics["rejected_rewards"] = rejected_rewards.mean() + metrics["margins"] = (chosen_rewards - rejected_rewards).mean() + pi_logrations = chosen_pos - chosen_neg + ref_logrations = ref_pos - ref_neg + logits = pi_logrations - ref_logrations + loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: - return loss, outputs - return loss + return loss, outputs, metrics + return loss, metrics def get_model(neox_args, use_cache=False): @@ -544,9 +671,14 @@ def get_model(neox_args, use_cache=False): raise ValueError("Must be using deepspeed to run neox") -def get_optimizer(model, neox_args): +def get_optimizer(model, neox_args, dummy=False): """Set up the optimizer.""" - if neox_args.no_load_optim: + if neox_args.no_load_optim and neox_args.deepspeed: + # Required to have something so... + dummy = True + neox_args.optimizer = {"params": {"lr": 0.0}} + neox_args.optimizer_type = "adam" + elif neox_args.no_load_optim: return None, None if neox_args.optimizer is None: @@ -580,8 +712,13 @@ def get_optimizer(model, neox_args): _param_groups = [] for param_group in param_groups: trainable_params = [p for p in param_group["params"] if p.requires_grad] + if dummy: + trainable_params = [trainable_params[0]] # just take the first one param_group["params"] = trainable_params _param_groups.append(param_group) + if dummy: + # Only need one. + break param_groups = _param_groups # If we're using mup, then the optimizer must be adam or sgd @@ -695,7 +832,7 @@ def get_optimizer(model, neox_args): def get_learning_rate_scheduler(optimizer, neox_args): """Build the learning rate scheduler.""" - if neox_args.no_load_optim: + if (neox_args.no_load_optim) and not neox_args.deepspeed: # TODO: this should be configured as a separate arg return None if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam": @@ -740,19 +877,30 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" + needs_reference_model = (neox_args.train_impl == "dpo") and ( + neox_args.precompute_model_name is None + ) model = get_model(neox_args=neox_args, use_cache=use_cache) + if needs_reference_model: + reference_model = get_model(neox_args=neox_args, use_cache=use_cache) + else: + reference_model = None optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args) - + if neox_args.deepspeed and needs_reference_model: + # Need an optimizer & lr_scheduler so make a very small one to keep deepspeed happy... + ref_optimizer, ref_param_groups = get_optimizer( + model=reference_model, neox_args=neox_args, dummy=True + ) + ref_lr_scheduler = get_learning_rate_scheduler( + optimizer=ref_optimizer, neox_args=neox_args + ) + else: + ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None if neox_args.deepspeed: print_rank_0("DeepSpeed is enabled.") - if neox_args.no_load_optim: - assert optimizer is None - _model_params = None - _lr_scheduler = None - else: - _model_params = param_groups if optimizer is None else None - _lr_scheduler = lr_scheduler + _model_params = param_groups if optimizer is None else None + _lr_scheduler = lr_scheduler model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, @@ -765,6 +913,17 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + if needs_reference_model: + reference_model, _, _, _ = deepspeed.initialize( + model=reference_model, + optimizer=ref_optimizer, + args=neox_args, + lr_scheduler=ref_lr_scheduler, + dist_init_required=False, + model_parameters=ref_param_groups, + mpu=mpu if not neox_args.is_pipe_parallel else None, + ) + mark_norms_for_sequence_parallel_grad_sync(model, neox_args) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -800,6 +959,14 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): lr_scheduler=lr_scheduler, iteration=iteration, ) + if needs_reference_model: + _ = load_checkpoint( + neox_args=neox_args, + model=reference_model, + optimizer=ref_optimizer, + lr_scheduler=ref_lr_scheduler, + iteration=iteration, + ) print_rank_0( f"Loading checkpoint and starting from iteration {neox_args.iteration}" ) @@ -811,7 +978,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): if lr_scheduler is not None: lr_scheduler.optimizer = model.optimizer - return model, optimizer, lr_scheduler + return model, optimizer, lr_scheduler, reference_model def backward_step(neox_args, timers, optimizer, model, loss): @@ -833,7 +1000,15 @@ def backward_step(neox_args, timers, optimizer, model, loss): raise ValueError("Must be using deepspeed to run neox") -def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): +def train_step( + neox_args, + timers, + data_iterator, + model, + optimizer, + lr_scheduler, + reference_model=None, +): """Single training step.""" # Pipeline parallelism schedules forward/backward/step @@ -841,6 +1016,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) reduced_loss = train_step_pipe( neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator ) + reduce_metrics = reduced_loss if ( neox_args.memory_profiling and neox_args.iteration >= neox_args.profile_step_start @@ -850,18 +1026,22 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) save_snapshot(neox_args) else: losses = [] + metric_dicts = defaultdict(list) for _ in range(neox_args.gradient_accumulation_steps): # Forward model for one step. timers("forward").start() - loss = forward_step( + loss, metric_dict = forward_step( neox_args=neox_args, timers=timers, data_iterator=data_iterator, model=model, is_train=True, + reference_model=reference_model, ) timers("forward").stop() losses.append(loss) + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # Calculate gradients, reduce across processes, and clip. if ( neox_args.profile @@ -891,6 +1071,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and neox_args.iteration <= neox_args.profile_step_stop ): torch.cuda.nvtx.range_push(f"Optimizer step") + timers("optimizer").start() if neox_args.deepspeed: model.step() @@ -910,17 +1091,19 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and torch.distributed.get_rank() == 0 ): save_snapshot(neox_args) - reduced_loss = { - "lm_loss": reduce_losses(losses).mean() - } # reduces losses across machines for logging + # reduces metrics across machines for logging + reduce_metrics = { + key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys() + } + reduce_metrics["lm_loss"] = reduce_losses(losses).mean() if neox_args.precision == "fp16" and model.optimizer.overflow: skipped_iter = 1 else: skipped_iter = 0 - collect_loss_for_unit_test(reduced_loss["lm_loss"]) - return reduced_loss, skipped_iter + collect_loss_for_unit_test(reduce_metrics["lm_loss"]) + return reduce_metrics, skipped_iter def train_step_pipe(neox_args, timers, model, data_iterator): @@ -946,6 +1129,7 @@ def train( neox_args, timers, model, + reference_model, optimizer, lr_scheduler, train_data_iterator, @@ -970,7 +1154,28 @@ def train( # to monitor if we've skipped many iterations in a row and trigger an early exit overflow_monitor = OverflowMonitor(optimizer) + + if neox_args.profile: + schedule = torch.profiler.schedule( + wait=neox_args.profile_step_start, + warmup=1, + active=neox_args.profile_step_stop - neox_args.profile_step_start, + ) + prof = torch.profiler.profile( + schedule=schedule, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + neox_args.tensorboard_dir + ), + record_shapes=True, + profile_memory=True, + with_flops=True, + with_modules=True, + with_stack=True, + ) + prof.start() while iteration < neox_args.train_iters: + if neox_args.profile: + prof.step() if neox_args.profile and iteration == neox_args.profile_step_start: torch.cuda.cudart().cudaProfilerStart() loss_dict, skipped_iter = train_step( @@ -980,9 +1185,11 @@ def train( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + reference_model=reference_model, ) if neox_args.profile and iteration == neox_args.profile_step_stop: torch.cuda.cudart().cudaProfilerStop() + prof.stop() iteration += 1 neox_args.iteration = iteration if neox_args.precision == "fp16": @@ -1069,6 +1276,7 @@ def evaluate( # Turn on evaluation mode which disables dropout. model.eval() losses = [] + metric_dicts = defaultdict(list) if neox_args.char_level_ppl: data_iterator = CharCounter(data_iterator, neox_args.tokenizer) @@ -1090,14 +1298,15 @@ def evaluate( else neox_args.gradient_accumulation_steps ): # Forward evaluation - loss = forward_step_fn( + loss, metric_dict = forward_step_fn( model=model, data_iterator=data_iterator, neox_args=neox_args, timers=timers, ) losses.append(loss) - + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # When contiguous memory optimizations are enabled, the buffers # allocated by the optimizations are deallocated during backward pass # in the absence of backward pass the buffers should be reset after each @@ -1107,6 +1316,8 @@ def evaluate( # reduces losses across processes for logging & run eval harness tasks eval_results = {"lm_loss": reduce_losses(losses).mean().item()} + for key in metric_dicts.keys(): + eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) if neox_args.char_level_ppl: diff --git a/megatron/utils.py b/megatron/utils.py index 26b4439bd..a64a8ba6c 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -449,7 +449,7 @@ def setup_for_inference_or_eval(use_cache=True, overwrite_values=None, input_arg initialize_megatron(neox_args) # set up model and load checkpoint. - model, _, _ = setup_model_and_optimizer( + model, _, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=use_cache, iteration=neox_args.iteration, diff --git a/requirements/requirements-transformerengine.txt b/requirements/requirements-transformerengine.txt new file mode 100644 index 000000000..2050d7566 --- /dev/null +++ b/requirements/requirements-transformerengine.txt @@ -0,0 +1 @@ +pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a051200b5..b5a84674b 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,11 +1,11 @@ -git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742#egg=deepspeed +deepspeed@git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742#egg=deepspeed ftfy>=6.0.1 -git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836 huggingface_hub>=0.11.0 -jinja2==3.1.3 +jinja2==3.1.4 +lm_dataformat@git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836 lm_eval>=0.4.0,<=0.4.1 mpi4py>=3.0.3 -numpy>=1.22.0 +numpy<2.0 pybind11>=2.6.2 regex sentencepiece diff --git a/tests/README.md b/tests/README.md index c1fac0f81..32618d757 100644 --- a/tests/README.md +++ b/tests/README.md @@ -33,7 +33,7 @@ pytest --forked tests/model/test_model_generation.py Some tests can run on cpu only. These are marked with the decorator @pytest.mark.cpu. The test cases for cpu can be run with: -```` +``` pytest tests -m cpu ``` @@ -50,3 +50,80 @@ if You see this kind of error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method ``` It usually means that you used some pytorch.cuda function before the test creates the processes. However just importing `from torch.utils import cpp_extension` can also trigger this. + + +## CPU Test Integration + +Tests can be run against physical CPUs through GitHub Actions. To have tests run on the physical CPU test, here is generally how the CI should be written: + +### runs-on + +#### NOTE: These BKMs were written to work with CI infrastructure that is no longer in place. To use the Github runners (ubuntu-22.04 / ubuntu-latest), skip the 'runs-on' section. + +The CI needs to be written to target the CPU Github Action runner. The jobs that need to run on CPU should use the hardware runner's labels: +```yaml +jobs: + cpu-test-job: + runs-on: [ 'self-hosted', 'aws', 'test'] # these labels tell GitHub to execute on the runner with the 'aws' and 'test' labels +``` + +### Software dependencies + +Hardware tests that need python and docker should install them as part of the test execution to make sure the tests run as expected: +```yaml +steps: + # sample syntax to setup python with pip + - uses: actions/setup-python@v4 + with: + python-version: "3.8" + cache: "pip" + + # sample setup of docker (there's no official Docker setup action) + - name: Docker setup + run: | # taken from Docker's installation page: https://docs.docker.com/engine/install/ubuntu/ + # Add Docker's official GPG key: + sudo apt-get update + sudo apt-get install ca-certificates curl + sudo install -m 0755 -d /etc/apt/keyrings + sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc + sudo chmod a+r /etc/apt/keyrings/docker.asc + # Add the repository to Apt sources: + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ + $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + sudo apt-get update + sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin -y +``` + +Any other software dependencies should be assumed to be missing and installed as part of the CI. + +### Using Docker image + +Using the Docker image and running tests in a container is recommended to resolve environment issues. There is a modified docker-compose.yml in tests/cpu_tests directory that is recommended to be used for CPU tests: + +```bash +cp tests/cpu_tests/docker-compose.yml . +# export any env variables here that should be used: +export NEOX_DATA_PATH='./data/enwik8' +docker compose run -d --build --name $CONTAINER gpt-neox tail -f /dev/null +# then can set up and run tests in the container using docker exec +docker exec $CONTAINER pip install -r /workspace/requirements-dev.txt +# etc. +# please clean up the container as part of the CI: +docker rm $CONTAINER +``` + +At the time of writing there is no built-in method to provide an offline-built Docker image to `jobs..container`. + +### Using existing CPU test CI + +There is an existing CPU test workflow that can be included in existing CI: + +```yaml +steps: + - name: Run CPU Tests + uses: + target_test_ref: $GITHUB_REF # replace with the ref/SHA that the tests should be run on + # have a look at the reusable workflow here: https://github.com/EleutherAI/gpt-neox/blob/main/tests/cpu_tests/action.yml +``` diff --git a/tests/model/test_fused_kernels.py b/tests/model/test_fused_kernels.py index cc458bf4a..125eb6c52 100644 --- a/tests/model/test_fused_kernels.py +++ b/tests/model/test_fused_kernels.py @@ -30,9 +30,7 @@ ) -@pytest.mark.xfail( - reason="ModuleNotFoundError: No module named 'scaled_masked_softmax_cuda'" -) +@pytest.mark.xfail(reason="SystemExit: None") def test_load_fused_kernels(): load() try: diff --git a/tools/ckpts/README.md b/tools/ckpts/README.md index 24d5cf31c..770cfb9c6 100644 --- a/tools/ckpts/README.md +++ b/tools/ckpts/README.md @@ -131,3 +131,20 @@ options: --num_output_shards NUM_OUTPUT_SHARDS --pipeline_parallel Only use if PP>1 ``` + +### `convert_hf_llama_to_neox.py` +Takes an HF Llama checkpoint and puts it into a NeoX-compatible format. + +Note that this does not support pipeline parallelism! + +``` +usage: convert_hf_llama_to_neox.py [-h] [--tp TP] [--pp PP] [--model MODEL] [--model_path MODEL_PATH] + +options: + -h, --help show this help message and exit + --tp TP Number of tensor parallelism ranks + --pp PP Number of pipeline parallelism stages + --model MODEL HF model name + --model_path MODEL_PATH + Path to save model +``` diff --git a/tools/ckpts/convert_hf_llama_to_neox.py b/tools/ckpts/convert_hf_llama_to_neox.py new file mode 100644 index 000000000..2adddb19d --- /dev/null +++ b/tools/ckpts/convert_hf_llama_to_neox.py @@ -0,0 +1,219 @@ +import torch +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM +import os +import tqdm + + +def convert_model(hf_state_dict, hf_config, tp_ranks): + conv_state_dicts = [{} for _ in range(tp_ranks)] + # get embeddings... + for i, chunk in enumerate( + torch.chunk(hf_state_dict["model.embed_tokens.weight"], tp_ranks, dim=0) + ): + conv_state_dicts[i][ + "sequential.0.word_embeddings.weight" + ] = chunk.clone().detach() + print( + "model.embed_tokens.weight", + hf_state_dict["model.embed_tokens.weight"].shape, + "sequential.0.word_embeddings.weight", + conv_state_dicts[0]["sequential.0.word_embeddings.weight"].shape, + ) + # Get config data... + num_kv_heads = hf_config.num_key_value_heads + num_q_heads = hf_config.num_attention_heads + head_dim = hf_config.hidden_size // num_q_heads + # do layers... + for layer_num in tqdm.tqdm(range(model.model.config.num_hidden_layers)): + # --- attention --- + # Output first since it's a simple row parallel... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"], + tp_ranks, + dim=1, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.attention.dense.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.self_attn.o_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.self_attn.o_proj.weight"].shape, + f"sequential.{layer_num+2}.attention.dense.weight", + conv_state_dicts[0][ + f"sequential.{layer_num+2}.attention.dense.weight" + ].shape, + ) + # Now for attention... + # Split into heads... + q = hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"] + k = hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"] + v = hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"] + # The GQA code splits the heads by the num_q_heads so we also do that + # here to ensure it matches... + q = q.view(num_q_heads, -1, q.shape[-1]) + k = k.view(num_q_heads, -1, q.shape[-1]) + v = v.view(num_q_heads, -1, q.shape[-1]) + # Chunk for tensor parallelism... + for i, q_chunk, k_chunk, v_chunk in zip( + range(tp_ranks), + torch.chunk(q, tp_ranks, dim=0), + torch.chunk(k, tp_ranks, dim=0), + torch.chunk(v, tp_ranks, dim=0), + ): + # Need to join the heads across q, k, v... + conv_state_dicts[i][ + f"sequential.{layer_num+2}.attention.query_key_value.weight" + ] = ( + torch.cat([q_chunk, k_chunk, v_chunk], dim=1) + .view(-1, q.shape[-1]) + .clone() + .detach() + ) + print( + f"model.layers.{layer_num}.self_attn.(q/k/v)_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"].shape, + hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"].shape, + hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"].shape, + f"sequential.{layer_num+2}.attention.query_key_value.weight", + conv_state_dicts[0][ + f"sequential.{layer_num+2}.attention.query_key_value.weight" + ].shape, + ) + # --- mlp --- + # Do SwiGLU weights... + # w1... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"], + tp_ranks, + dim=0, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.mlp.w1.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.mlp.gate_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"].shape, + f"sequential.{layer_num+2}.mlp.w1.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w1.weight"].shape, + ) + # w3... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"], + tp_ranks, + dim=0, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.mlp.w3.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.mlp.up_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"].shape, + f"sequential.{layer_num+2}.mlp.w3.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w3.weight"].shape, + ) + # w2 (output)... + for i, chunk in enumerate( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"], + tp_ranks, + dim=1, + ) + ): + conv_state_dicts[i][ + f"sequential.{layer_num+2}.mlp.w2.weight" + ] = chunk.clone().detach() + print( + f"model.layers.{layer_num}.mlp.down_proj.weight", + hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"].shape, + f"sequential.{layer_num+2}.mlp.w2.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w2.weight"].shape, + ) + # --- norm --- + for i in range(tp_ranks): + conv_state_dicts[i][f"sequential.{layer_num+2}.input_layernorm.scale"] = ( + hf_state_dict[f"model.layers.{layer_num}.input_layernorm.weight"] + .clone() + .detach() + ) + conv_state_dicts[i][ + f"sequential.{layer_num+2}.post_attention_layernorm.scale" + ] = ( + hf_state_dict[ + f"model.layers.{layer_num}.post_attention_layernorm.weight" + ] + .clone() + .detach() + ) + + # Get final ln/linear.... + index = model.model.config.num_hidden_layers + 3 + for i in range(tp_ranks): + conv_state_dicts[i][f"sequential.{index}.norm.scale"] = ( + hf_state_dict["model.norm.weight"].clone().detach() + ) + index += 1 + # do output... + for i, chunk in enumerate( + torch.chunk(hf_state_dict["lm_head.weight"], tp_ranks, dim=0) + ): + conv_state_dicts[i][ + f"sequential.{index}.final_linear.weight" + ] = chunk.clone().detach() + print( + "lm_head.weight", + hf_state_dict["lm_head.weight"].shape, + f"sequential.{index}.final_linear.weight", + conv_state_dicts[0][f"sequential.{index}.final_linear.weight"].shape, + ) + return conv_state_dicts + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--tp", type=int, default=1, help="Number of tensor parallelism ranks" + ) + parser.add_argument( + "--pp", type=int, default=0, help="Number of pipeline parallelism stages" + ) + parser.add_argument("--model", type=str, default="gpt2", help="HF model name") + parser.add_argument( + "--model_path", type=str, default=None, help="Path to save model" + ) + args = parser.parse_args() + assert args.pp == 0, "Pipeline parallelism not supported yet" + tokenizer = AutoTokenizer.from_pretrained(args.model).save_pretrained( + args.model_path + "/tokenizer" + ) + model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto") + state_dict = model.state_dict() + for key in state_dict.keys(): + print(key, state_dict[key].shape) + os.makedirs(args.model_path, exist_ok=True) + # Setup model directory... + os.makedirs(f"{args.model_path}/0", exist_ok=True) + # Save the latest file so neox can figure out where to grab the weights... + with open(f"{args.model_path}/latest", "w") as f: + f.write("0") + # Convert the model... + tp_state_dicts = convert_model(state_dict, model.model.config, args.tp) + for i in range(args.tp): + torch.save( + { + "dp_world_size": 1, + "mp_world_size": args.tp, + "optimizer": {}, + "global_steps": 1, + "skipped_steps": 1, + "iteration": 1, + "module": tp_state_dicts[i], + }, + f"{args.model_path}/0/mp_rank_{i:02d}_model_states.pt", + ) diff --git a/tools/ckpts/convert_hf_to_sequential.py b/tools/ckpts/convert_hf_to_sequential.py index c53f28391..55cfc6517 100644 --- a/tools/ckpts/convert_hf_to_sequential.py +++ b/tools/ckpts/convert_hf_to_sequential.py @@ -119,16 +119,27 @@ def shard_sequential_mp(num_mp_ranks, sequential): ranks = {x: dict() for x in range(num_mp_ranks)} for k, v in sequential.items(): if reduce( + np.logical_or, + [ + x in k + for x in [ + "dense_4h_to_h.bias", + "attention.dense.bias", + ] + ], + ): + # Divide by tp_size since they get added together + for x in range(num_mp_ranks): + ranks[x][k] = v / num_mp_ranks + elif reduce( np.logical_or, [ x in k for x in [ "layernorm", "rotary_emb", - "dense_4h_to_h.bias", "norm.weight", "norm.bias", - "attention.dense.bias", ] ], ): diff --git a/tools/ckpts/convert_neox_to_hf.py b/tools/ckpts/convert_neox_to_hf.py index 35812383e..f4e0ccf9f 100644 --- a/tools/ckpts/convert_neox_to_hf.py +++ b/tools/ckpts/convert_neox_to_hf.py @@ -580,30 +580,59 @@ def convert( # Load output embedding if not sequential: - loaded_tp_ranks = load_partitions( - input_checkpoint_path, - mp_partitions, - get_key(loaded_config, "num-layers") + 4, - sequential=sequential, - ) + if get_key(loaded_config, "no-weight-tying", False): + # if we have trained input + output embedding layers without tied weights + loaded_tp_ranks = load_partitions( + input_checkpoint_path, + mp_partitions, + get_key(loaded_config, "num-layers") + 4, + sequential=sequential, + ) + else: + # in this case, output embedding layer and input embedding layer are tied. + # load + save the input embed weights into the output embedding layer's place. + loaded_tp_ranks = load_partitions( + input_checkpoint_path, + mp_partitions, + layer_idx=0, + sequential=sequential, + ) # output embedding / LM head if architecture == "neox": # name of lm head / final linear proj varies lm_head = hf_model.embed_out else: lm_head = hf_model.lm_head - lm_head.load_state_dict( - { - "weight": torch.cat( - get_state( - loaded_tp_ranks, - "final_linear.weight", - layer_idx=get_key(loaded_config, "num-layers") + 4, - sequential=sequential, + + if get_key(loaded_config, "no-weight-tying", False): + # save the (untied) final linear into LM head for HF + lm_head.load_state_dict( + { + "weight": torch.cat( + get_state( + loaded_tp_ranks, + "final_linear.weight", + layer_idx=get_key(loaded_config, "num-layers") + 4, + sequential=sequential, + ), + dim=0, ), - dim=0, - ), - } - ) + } + ) + else: + # embedding layers are tied. transpose input layer and save + lm_head.load_state_dict( + { + "weight": torch.cat( + get_state( + loaded_tp_ranks, + "word_embeddings.weight", + layer_idx=0, + sequential=sequential, + ), + dim=0, + ), + } + ) del loaded_tp_ranks diff --git a/tools/datasets/README.md b/tools/datasets/README.md index f8215959c..af3009a23 100644 --- a/tools/datasets/README.md +++ b/tools/datasets/README.md @@ -93,6 +93,57 @@ output data: --dataset-impl {lazy,cached,mmap} Dataset implementation to use. Default: mmap +runtime: + --workers WORKERS Number of worker processes to launch + --log-interval LOG_INTERVAL + Interval between progress updates +``` +## `preprocess_data_with_chat_template.py` +Similar, but uses huggingface's [chat templates](https://huggingface.co/docs/transformers/main/en/chat_templating) to +tokenize the data to support multiturn and more complicated use cases. + +N.B. If using this, you **must** specify your data when training/finetuning with the following configs +```json +"train_data_paths": ["train_documents"], +"test_data_paths": ["test_documents"], +"valid_data_paths": ["test_documents"], +"label_data_paths": ["label_documents"] +``` + +the `"data_path"` option will not work with `"label_data_paths"`. + + +``` +usage: preprocess_data_with_chat_template.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] [--no-mask] + [--generation-role GENERATION_ROLE] [--only-last] [--num-docs NUM_DOCS] + --tokenizer-path TOKENIZER_PATH [--ftfy] --output-prefix OUTPUT_PREFIX + [--dataset-impl {lazy,cached,mmap}] [--workers WORKERS] + [--log-interval LOG_INTERVAL] + +options: + -h, --help show this help message and exit + +input data: + --input INPUT Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated list + --jsonl-keys JSONL_KEYS [JSONL_KEYS ...] + space separate listed of keys to extract from jsonl. Default: text + --no-mask If set, this will not mask any tokens in the input data. + --generation-role GENERATION_ROLE + The role of the model generating the chat, usually 'assistant'. Default: assistant + --only-last If set, this will mask everything except the last turn in the chat. + --num-docs NUM_DOCS Optional: Number of documents in the input data (if known) for an accurate progress bar. + +tokenizer: + --tokenizer-path TOKENIZER_PATH + Path to HF Tokenizer. + --ftfy Use ftfy to clean text + +output data: + --output-prefix OUTPUT_PREFIX + Path to binary output file without suffix + --dataset-impl {lazy,cached,mmap} + Dataset implementation to use. Default: mmap + runtime: --workers WORKERS Number of worker processes to launch --log-interval LOG_INTERVAL diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py new file mode 100644 index 000000000..4e101ea5a --- /dev/null +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -0,0 +1,349 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A script for processing a dataset such that chat templates are utilized in the creation of the data. +These are then used to perform instruction/chat model finetunes (for example, finetuning a model on only the assistant +portions of a chatml dataset). + +This follows the same output format as 'preprocess_data_with_mask.py' but using chat templates to generate the data. +This way we can support multiturn chat data in the finetuning process. instead of relying on a single turn of data. + +To run this script, first edit `tools/datasets/corpora.py` such that the command to call + `tools/datasets/preprocess_data_with_chat_template.py` is as follows: + +``` +cmd = f"python tools/datasets/preprocess_data_with_with_chat_template.py \ + --input {jsonl_filepath} \ + --output-prefix {parent_folder}/{self.name} \ + --tokenizer-path {hf-tokenizer} \ + --jsonl-keys {jsonl_keys} \ + --dataset-impl mmap \ + --workers {self.num_workers} " + +if self.only_last: + cmd += f"--only-last " + +if self.no_mask: + cmd += f"--no-mask " +``` + +Then, specify +``` +"train_data_paths": ["/path/to/dataset/name_text_document"], +"label_data_paths": ["/path/to/dataset/name_label_document"] +``` +in your YML config. This will then allow for finetuning on the data with loss masks set appropriately. + +""" + +import argparse +import multiprocessing +import os +import sys + +import lm_dataformat as lmd +import numpy as np + +sys.path.append( + os.path.abspath( + os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir) + ) +) + +import time +import tqdm +import jsonlines + +from megatron.data import indexed_dataset +from threading import Semaphore +from typing import List, Dict, Tuple +from transformers import AutoTokenizer, PreTrainedTokenizer + + +def build_chat( + chat: List[Dict[str, str]], + generation_role: str, + apply_mask: bool, + tokenizer: PreTrainedTokenizer, + only_last_turn: bool = False, +) -> Tuple[List[int], List[int]]: + """ + Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the + Chat Template from https://huggingface.co/docs/transformers/main/en/chat_templating + + :param chat: A list of dictionaries with "role" and "content" keys + :param generation_role: The role of the model generating the chat, usually "assistant" + :param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss + :param tokenizer: A HF tokenizer + :param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks + """ + tokens = [] + mask = [] + if apply_mask is False: + tokens = tokenizer.apply_chat_template(chat) + mask = tokens + return tokens, mask + for i, turn in enumerate(chat): + add_gen = ( + False if i == len(chat) - 1 else chat[i + 1]["role"] == generation_role + ) + chat_tokens = tokenizer.apply_chat_template( + chat[: i + 1], add_generation_prompt=add_gen + )[len(tokens) :] + + # remove previous stuff... + tokens.extend(chat_tokens) + if only_last_turn and (i != len(chat) - 1): + mask.extend([-100] * len(chat_tokens)) + elif apply_mask and (turn["role"] != generation_role): + mask.extend([-100] * len(chat_tokens)) + else: + mask.extend(chat_tokens) + if tokenizer.eos_token_id is not None: + mask.append(tokenizer.eos_token_id if mask[-1] != -100 else -100) + tokens.append(tokenizer.eos_token_id) + return tokens, mask + + +class Encoder(object): + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_path) + + def encode(self, text): + ids = {} + for key in self.args.jsonl_keys: + text_ids, label_ids = build_chat( + text[key], + self.args.generation_role, + not self.args.no_mask, + Encoder.tokenizer, + self.args.only_last, + ) + ids[key] = (text_ids, label_ids) + return ids, len(text) + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument( + "--input", + type=str, + required=True, + help="Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated " + "list", + ) + group.add_argument( + "--jsonl-keys", + nargs="+", + default=["conversation"], + help="space separate listed of keys to extract from jsonl. Default: text", + ) + group.add_argument( + "--no-mask", + help="If set, this will not mask any tokens in the input data.", + action="store_true", + ) + group.add_argument( + "--generation-role", + type=str, + default="assistant", + help="The role of the model generating the chat, usually 'assistant'. Default: assistant", + ) + group.add_argument( + "--only-last", + help="If set, this will mask everything except the last turn in the chat.", + action="store_true", + ) + group.add_argument( + "--num-docs", + default=None, + help="Optional: Number of documents in the input data (if known) for an accurate progress bar.", + type=int, + ) + group = parser.add_argument_group(title="tokenizer") + group.add_argument( + "--tokenizer-path", + type=str, + required=True, + help="Path to HF Tokenizer.", + ) + group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text") + group = parser.add_argument_group(title="output data") + group.add_argument( + "--output-prefix", + type=str, + required=True, + help="Path to binary output file without suffix", + ) + group.add_argument( + "--dataset-impl", + type=str, + default="mmap", + choices=["lazy", "cached", "mmap"], + help="Dataset implementation to use. Default: mmap", + ) + + group = parser.add_argument_group(title="runtime") + group.add_argument( + "--workers", type=int, default=1, help="Number of worker processes to launch" + ) + group.add_argument( + "--log-interval", + type=int, + default=100, + help="Interval between progress updates", + ) + args = parser.parse_args() + args.keep_empty = False + + # some default/dummy values for the tokenizer + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.model_parallel_size = 1 + + return args + + +def yield_from_files(fnames: list, semaphore): + """ + Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / + other compressed formats. Also filters out empty documents. + + :param fnames: list of filenames + """ + + def yielder(fname, semaphore): + with open(fname, encoding="utf-8") as f: + reader = jsonlines.Reader(f) + for f in reader: + semaphore.acquire() + yield f + + for fname in fnames: + semaphore.acquire() + + yield from yielder(fname, semaphore) + + +def main(): + args = get_args() + encoder = Encoder(args) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + print(f"Vocab size: {tokenizer.vocab_size}") + print(f"Output prefix: {args.output_prefix}") + + # build a semaphore object to stop `yield_from_files` from getting ahead of encoder.encode and + # hence building up memory + semaphore = Semaphore(10000 + args.workers) + + # use multiprocessing to iterate over input documents + fin = yield_from_files(args.input.split(","), semaphore) + + if args.workers > 1: + pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, fin, chunksize=25) + else: + encoder.initializer() + encoded_docs = (encoder.encode(doc) for doc in fin) + + # make a dataset builder for each key in args.jsonl_keys + # each key will output to a different file beginning with args.output_prefix + output_bin_files = {} + output_idx_files = {} + builders = {} + for key in args.jsonl_keys: + output_bin_files[key] = "{}_{}_{}.bin".format( + args.output_prefix, key, "document" + ) + output_idx_files[key] = "{}_{}_{}.idx".format( + args.output_prefix, key, "document" + ) + builders[key] = indexed_dataset.make_builder( + output_bin_files[key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[key]._dtype = np.int32 + if not args.no_mask: + assert ( + key + "_label" not in args.jsonl_keys + ), "label should not be included as it will be generated according to the mask." + key += "_label" + output_bin_files[key] = "{}_{}_{}.bin".format( + args.output_prefix, key, "document" + ) + output_idx_files[key] = "{}_{}_{}.idx".format( + args.output_prefix, key, "document" + ) + builders[key] = indexed_dataset.make_builder( + output_bin_files[key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[key]._dtype = np.int32 + + # actually do tokenization + proc_start = time.time() + total_bytes_processed = 0 + pbar = tqdm.tqdm() + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + + # release semaphore so `yield_from_files` can add another file to the buffer + semaphore.release() + + # add each tokenized document / sentence + for key, conv in doc.items(): + tokens = conv[0] + token_mask = conv[1] + builders[key].add_item(np.array(tokens, dtype=builders[key].dtype)) + builders[key + "_label"].add_item( + np.array(token_mask, dtype=builders[key + "_label"].dtype) + ) + # add indx... + builders[key].end_document() + builders[key + "_label"].end_document() + if i == 1: + print("key: ", key) + print("tokens: ", tokens) + print("token_mask: ", token_mask) + # log progress + if i % args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + pbar.set_description( + f"Processed {i}{'' if args.num_docs is None else '/' + str(args.num_docs)} documents ({i / elapsed} docs/s, {mbs} MB/s)." + ) + if i != 0: + pbar.update(args.log_interval) + + # save output file + update_keys = args.jsonl_keys + for key in update_keys: + builders[key].finalize(output_idx_files[key]) + builders[key + "_label"].finalize(output_idx_files[key + "_label"]) + + +if __name__ == "__main__": + main()