diff --git a/README.md b/README.md index 80162cef4..09e623401 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,7 @@ Or use the 20B tokenizer (for which only a single Vocab file is needed): (alternatively, you can provide any tokenizer file that can be loaded by Hugging Face's tokenizers library with the `Tokenizer.from_pretrained()` command) -You can now pretokenize your data using `tools/preprocess_data.py`, the arguments for which are detailed below: +You can now pretokenize your data using `tools/datasets/preprocess_data.py`, the arguments for which are detailed below: ``` usage: preprocess_data.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] [--num-docs NUM_DOCS] --tokenizer-type {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer} [--vocab-file VOCAB_FILE] [--merge-file MERGE_FILE] [--append-eod] [--ftfy] --output-prefix OUTPUT_PREFIX @@ -327,7 +327,7 @@ runtime: For example: ```bash -python tools/preprocess_data.py \ +python tools/datasets/preprocess_data.py \ --input ./data/mydataset.jsonl.zst \ --output-prefix ./data/mydataset \ --vocab ./data/gpt2-vocab.json \ @@ -431,19 +431,19 @@ GPT-NeoX is optimized heavily for training only, and GPT-NeoX model checkpoints To convert a NeoX checkpoint (with pipeline-parallel-size>=1) to Hugging Face-loadable format, run: ```bash -python ./tools/convert_module_to_hf.py --input_dir /path/to/model/global_stepXXX --config_file your_config.yaml --output_dir hf_model/save/location +python ./tools/ckpts/convert_module_to_hf.py --input_dir /path/to/model/global_stepXXX --config_file your_config.yaml --output_dir hf_model/save/location ``` To convert a sequential model to Hugging Face format, run: ```bash -python ./tools/convert_sequential_to_hf.py --input_dir /path/to/model/global_stepXXX --config_file your_config.yaml --output_dir hf_model/save/location +python ./tools/ckpts/convert_sequential_to_hf.py --input_dir /path/to/model/global_stepXXX --config_file your_config.yaml --output_dir hf_model/save/location ``` (Note: this script should be used for v2.0 checkpoints saved on a v2.0 commit prior to https://github.com/EleutherAI/gpt-neox/pull/866 and which used `pipe-parallel-size=1`. Using `pipe-parallel-size=0` will also save models in this format.) Then to upload a model to [the Hugging Face Hub](https://huggingface.co/), run: ```bash huggingface-cli login -python ./tools/upload.py +python ./tools/ckpts/upload.py ``` and input the requested information, including HF hub user token. diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 8084a6728..6ba7a58bf 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = fd35b00 + Default = a0cf0e8 current git hash of repository diff --git a/prepare_data.py b/prepare_data.py index cda75deff..86d8b5f5a 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tools.corpora import prepare_dataset, DATA_DOWNLOADERS +from tools.datasets.corpora import prepare_dataset, DATA_DOWNLOADERS import argparse TOKENIZER_CHOICES = [ diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 000000000..ccfb1c5f8 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,15 @@ +# GPT-NeoX Auxiliary Tools + +This directory contains a number of auxiliary tools that are useful for working with GPT-NeoX but not part of the main training code. + +## Bash + +This directory contains some simple, frequently used bash commands to make working on multiple machines easier. + +## Checkpoints + +This directory contains tools for manipulating and converting checkpoints including changing the parallelism settings of a pretrained model, converting between GPT-NeoX and the transformers library, and updating checkpoints trained with Version 1.x of this library to be compatible with Version 2.x. + +## Datasets + +This directory contains tools for downloading and preprocessing datasets to the format expected by the GPT-NeoX library. diff --git a/tools/bash/README.md b/tools/bash/README.md new file mode 100644 index 000000000..5b7ed897c --- /dev/null +++ b/tools/bash/README.md @@ -0,0 +1,8 @@ +# Bash Scripts +Useful for running distributed per-node scripts on e.g. Kubernetes + +* `kill.sh` kills all python processes +* `killall.sh` uses pdsh to kill all `train.py` processes on the nodes listed in `/job/hosts/` +* `sync_cmd.sh` uses pdsh to run a command on all the nodes listed in `/job/hosts/` +* `sync.sh` uses pdcp to copy every file in a provided path to all of the nodes listed in `/job/hosts/` +* `syncdir.sh` uses pdcp to copy every file in a provided path to all of the nodes listed in `/job/hosts/` diff --git a/tools/kill.sh b/tools/bash/kill.sh similarity index 100% rename from tools/kill.sh rename to tools/bash/kill.sh diff --git a/tools/killall.sh b/tools/bash/killall.sh similarity index 100% rename from tools/killall.sh rename to tools/bash/killall.sh diff --git a/tools/sync.sh b/tools/bash/sync.sh similarity index 100% rename from tools/sync.sh rename to tools/bash/sync.sh diff --git a/tools/sync_cmd.sh b/tools/bash/sync_cmd.sh similarity index 100% rename from tools/sync_cmd.sh rename to tools/bash/sync_cmd.sh diff --git a/tools/syncdir.sh b/tools/bash/syncdir.sh similarity index 96% rename from tools/syncdir.sh rename to tools/bash/syncdir.sh index f43edbd8c..210d142ae 100755 --- a/tools/syncdir.sh +++ b/tools/bash/syncdir.sh @@ -16,7 +16,7 @@ # Push files to all nodes # Usage -# sync.sh file [file2..] +# syncdir.sh file [file2..] echo Number of files to upload: $# diff --git a/tools/ckpts/README.md b/tools/ckpts/README.md new file mode 100644 index 000000000..778f51b0a --- /dev/null +++ b/tools/ckpts/README.md @@ -0,0 +1,133 @@ +# Checkpoint Scripts + + +## Utilities + +### `inspect_checkpoints.py` +Reports information about a saved checkpoint. +``` +usage: inspect_checkpoints.py [-h] [--attributes [ATTRIBUTES ...]] [--interactive] [--compare] [--diff] dir + +positional arguments: + dir The checkpoint dir to inspect. Must be either: - a directory containing pickle binaries saved with 'torch.save' ending in .pt or .ckpt - a single path to a .pt or .ckpt file - two comma separated directories - + in which case the script will *compare* the two checkpoints + +options: + -h, --help show this help message and exit + --attributes [ATTRIBUTES ...] + Name of one or several attributes to query. To access an attribute within a nested structure, use '/' as separator. + --interactive, -i Drops into interactive shell after printing the summary. + --compare, -c If true, script will compare two directories separated by commas + --diff, -d In compare mode, only print diffs +``` + +## HuggingFace Scripts + +### `convert_hf_to_sequential.py` +A script for converting publicly available Huggingface (HF) checkpoints NeoX format. + +Note that this script requires access to corresponding config files for equivalent NeoX models to those found in Hugging face. + +``` +Example usage: (Converts the 70M Pythia model to NeoX format) +================================================================ +OMPI_COMM_WORLD_RANK=0 CUDA_VISIBLE_DEVICES=0 python tools/ckpts/convert_hf_to_sequential.py \ + --hf-model-name pythia-70m-v0 \ + --revision 143000 \ + --output-dir checkpoints/neox_converted/pythia/70m \ + --cache-dir checkpoints/HF \ + --config configs/pythia/70M.yml configs/local_setup.yml \ + --test + + +For multi-gpu support we must initialize deepspeed: +NOTE: This requires manually changing the arguments below. +================================================================ +CUDA_VISIBLE_DEVICES=0,1,2,3 python ./deepy.py tools/ckpts/convert_hf_to_sequential.py \ + -d configs pythia/70M.yml local_setup.yml +``` +### `convert_module_to_hf.py` +Converts a NeoX model with pipeline parallelism greater than 1 to a HuggingFace transformers `GPTNeoXForCausalLM` model + +Note that this script does not support all NeoX features. +Please investigate carefully whether your model is compatible with all architectures supported by the GPTNeoXForCausalLM class in HF. + +(e.g. position embeddings such as AliBi may not be supported by Huggingface's GPT-NeoX architecture) + +``` +usage: convert_module_to_hf.py [-h] [--input_dir INPUT_DIR] [--config_file CONFIG_FILE] [--output_dir OUTPUT_DIR] [--upload] + +Merge MP partitions and convert to HF Model. + +options: + -h, --help show this help message and exit + --input_dir INPUT_DIR + Path to NeoX checkpoint, e.g. /path/to/model/global_step143000 + --config_file CONFIG_FILE + Path to config file for the input NeoX checkpoint. + --output_dir OUTPUT_DIR + Output dir, where to save the HF Model, tokenizer, and configs + --upload Set to true in order to upload to the HF Hub directly. +``` + +### `convert_sequential_to_hf.py` +Converts a NeoX model without pipeline parallelism to a HuggingFace transformers `GPTNeoXForCausalLM` model. + +``` +usage: convert_sequential_to_hf.py [-h] [--input_dir INPUT_DIR] [--config_file CONFIG_FILE] [--output_dir OUTPUT_DIR] [--upload] + +Merge MP partitions and convert to HF Model. + +options: + -h, --help show this help message and exit + --input_dir INPUT_DIR + Path to NeoX checkpoint, e.g. /path/to/model/global_step143000 + --config_file CONFIG_FILE + Path to config file for the input NeoX checkpoint. + --output_dir OUTPUT_DIR + Output dir, where to save the HF Model, tokenizer, and configs + --upload Set to true in order to upload to the HF Hub directly. +``` +### `upload.py` +Uploads a _converted_ checkpoint to the HuggingFace hub. + +``` +python upload.py +``` +## NeoX-20B Scripts + +### `merge20b.py` +Reduces model and pipeline parallelism of a 20B checkpoint to 1 and 1. + +``` +usage: merge20b.py [-h] [--input_dir INPUT_DIR] [--output_dir OUTPUT_DIR] + +Merge 20B checkpoint. + +options: + -h, --help show this help message and exit + --input_dir INPUT_DIR + Checkpoint dir, which should contain (e.g. a folder named "global_step150000") + --output_dir OUTPUT_DIR + Output dir, to save the 1-GPU weights configs +``` +## Llama Scripts + +### `convert_raw_llama_weights_to_neox.py` +Takes a Llama checkpoint and puts it into a NeoX-compatible format. + +``` +usage: convert_raw_llama_weights_to_neox.py [-h] [--input_dir INPUT_DIR] [--model_size {7B,13B,30B,65B,tokenizer_only}] [--output_dir OUTPUT_DIR] [--num_output_shards NUM_OUTPUT_SHARDS] [--pipeline_parallel] + +Convert raw LLaMA checkpoints to GPT-NeoX format. + +options: + -h, --help show this help message and exit + --input_dir INPUT_DIR + Location of LLaMA weights, which contains tokenizer.model and model folders + --model_size {7B,13B,30B,65B,tokenizer_only} + --output_dir OUTPUT_DIR + Location to write GPT-NeoX mode + --num_output_shards NUM_OUTPUT_SHARDS + --pipeline_parallel Only use if PP>1 +``` diff --git a/tools/convert_hf_to_sequential.py b/tools/ckpts/convert_hf_to_sequential.py similarity index 59% rename from tools/convert_hf_to_sequential.py rename to tools/ckpts/convert_hf_to_sequential.py index 4ed5c67f4..8a3902bce 100644 --- a/tools/convert_hf_to_sequential.py +++ b/tools/ckpts/convert_hf_to_sequential.py @@ -2,6 +2,7 @@ import os import copy import deepspeed + # import time import argparse @@ -13,13 +14,14 @@ from transformers import GPTNeoXForCausalLM, GPTNeoXConfig sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) from megatron.neox_arguments import NeoXArgs from megatron.training import get_model, get_optimizer, get_learning_rate_scheduler from megatron.initialize import initialize_megatron from megatron import mpu from megatron.checkpointing import load_checkpoint, save_checkpoint + # from megatron.utils import ( # Timers, # init_wandb, @@ -28,11 +30,11 @@ """ A script for converting publicly available Huggingface (HF) checkpoints NeoX format. -Note that this script requires access to correspoinding config files for equivalent NeoX models to those found ing Hugging face. +Note that this script requires access to corresponding config files for equivalent NeoX models to those found in Hugging face. Example usage: (Converts the 70M Pythia model to NeoX format) ================================================================ -OMPI_COMM_WORLD_RANK=0 CUDA_VISIBLE_DEVICES=0 python tools/convert_hf_to_sequential.py \ +OMPI_COMM_WORLD_RANK=0 CUDA_VISIBLE_DEVICES=0 python tools/ckpts/convert_hf_to_sequential.py \ --hf-model-name pythia-70m-v0 \ --revision 143000 \ --output-dir checkpoints/neox_converted/pythia/70m \ @@ -41,39 +43,39 @@ --test -For multi-gpu support we must initiliaze deepspeed: +For multi-gpu support we must initialize deepspeed: NOTE: This requires manually changing the arguments below. ================================================================ -CUDA_VISIBLE_DEVICES=0,1,2,3 python ./deepy.py tools/convert_hf_to_sequential.py \ +CUDA_VISIBLE_DEVICES=0,1,2,3 python ./deepy.py tools/ckpts/convert_hf_to_sequential.py \ -d configs pythia/70M.yml local_setup.yml """ MULTI_GPU_ARGS = " ".join( - [ - "--hf-model-name pythia-70m-v0", - "--revision 143000", - "--output-dir checkpoints/neox_converted/pythia/70m", - "--cache-dir checkpoints/HF", - "--config configs/pythia/70M.yml configs/local_setup.yml", - "--test", - ] - ) - + [ + "--hf-model-name pythia-70m-v0", + "--revision 143000", + "--output-dir checkpoints/neox_converted/pythia/70m", + "--cache-dir checkpoints/HF", + "--config configs/pythia/70M.yml configs/local_setup.yml", + "--test", + ] +) -def convert_hf_to_sequential(hf_model,seq_state_dict): +def convert_hf_to_sequential(hf_model, seq_state_dict): """Converts the weights of a HuggingFace model to neox 2.0 format. - + :param hf_model: the huggingface model :param seq_state_dict: the state dict of the equivalent neox model - + returns the updated sequential state dict """ num_layers = hf_model.config.num_hidden_layers # Embedding is layer idx 0 - seq_state_dict['sequential.0.word_embeddings.weight'] =\ - hf_model.gpt_neox.embed_in.state_dict()['weight'] - + seq_state_dict[ + "sequential.0.word_embeddings.weight" + ] = hf_model.gpt_neox.embed_in.state_dict()["weight"] + for layer_hf in range(num_layers): # offset by 2 layer_seq = layer_hf + 2 @@ -81,29 +83,30 @@ def convert_hf_to_sequential(hf_model,seq_state_dict): # get layer from hf model hf_layer = hf_model.gpt_neox.layers[layer_hf] hf_layer_sd = hf_layer.state_dict() - - + for key in hf_model.gpt_neox.layers[0].state_dict().keys(): - - if key in ["attention.bias","attention.masked_bias"]: + + if key in ["attention.bias", "attention.masked_bias"]: continue seq_state_dict[f"sequential.{layer_seq}.{key}"] = hf_layer_sd[key] - + # Load final layer norm layer_seq = num_layers + 3 - seq_state_dict[f"sequential.{layer_seq}.norm.weight"] = \ - hf_model.gpt_neox.final_layer_norm.state_dict()['weight'] - seq_state_dict[f"sequential.{layer_seq}.norm.bias"] = \ - hf_model.gpt_neox.final_layer_norm.state_dict()['bias'] - + seq_state_dict[ + f"sequential.{layer_seq}.norm.weight" + ] = hf_model.gpt_neox.final_layer_norm.state_dict()["weight"] + seq_state_dict[ + f"sequential.{layer_seq}.norm.bias" + ] = hf_model.gpt_neox.final_layer_norm.state_dict()["bias"] + # output embedding / LM head layer_seq += 1 - seq_state_dict[f"sequential.{layer_seq}.final_linear.weight"] = \ - hf_model.embed_out.state_dict()['weight'] - + seq_state_dict[ + f"sequential.{layer_seq}.final_linear.weight" + ] = hf_model.embed_out.state_dict()["weight"] -def shard_sequential_mp(num_mp_ranks,sequential): +def shard_sequential_mp(num_mp_ranks, sequential): """Shards the sequential model into model parallel ranks. :param num_mp_ranks: the number of model parallel ranks @@ -111,114 +114,146 @@ def shard_sequential_mp(num_mp_ranks,sequential): returns a dict of state dicts for each mp rank """ - ranks = {x:dict() for x in range(num_mp_ranks)} - for k,v in sequential.items(): - if reduce(np.logical_or,[x in k for x in ['layernorm', - 'rotary_emb', - 'dense_4h_to_h.bias', - 'norm.weight', - 'norm.bias', - 'attention.dense.bias']]): + ranks = {x: dict() for x in range(num_mp_ranks)} + for k, v in sequential.items(): + if reduce( + np.logical_or, + [ + x in k + for x in [ + "layernorm", + "rotary_emb", + "dense_4h_to_h.bias", + "norm.weight", + "norm.bias", + "attention.dense.bias", + ] + ], + ): # no splitting for x in range(num_mp_ranks): ranks[x][k] = v else: if len(v.shape) == 1: size_per_rank = v.shape[0] / num_mp_ranks - if size_per_rank % 128 != 0.: + if size_per_rank % 128 != 0.0: padded_size = (128 - (size_per_rank % 128)) + size_per_rank size_diff = int((padded_size * 4) - v.shape[max_]) zero_pad = torch.zeros((size_diff)) - v = torch.cat([v,zero_pad],dim=max_) + v = torch.cat([v, zero_pad], dim=max_) else: padded_size = size_per_rank - assert size_per_rank % 1. == 0. - assert padded_size % 1. == 0. + assert size_per_rank % 1.0 == 0.0 + assert padded_size % 1.0 == 0.0 padded_size = int(padded_size) size_per_rank = int(size_per_rank) for x in range(num_mp_ranks): if size_per_rank != padded_size: - #need to pad - ranks[x][k] = v[padded_size * x : padded_size * (x+1)] + # need to pad + ranks[x][k] = v[padded_size * x : padded_size * (x + 1)] else: - ranks[x][k] = v[size_per_rank * x : size_per_rank * (x+1)] - + ranks[x][k] = v[size_per_rank * x : size_per_rank * (x + 1)] + elif len(v.shape) == 2: - if reduce(np.logical_or,[x in k for x in [ "attention.dense.weight", - "mlp.dense_4h_to_h.weight", ]]):\ - # column parallel + if reduce( + np.logical_or, + [ + x in k + for x in [ + "attention.dense.weight", + "mlp.dense_4h_to_h.weight", + ] + ], + ): # column parallel max_, min_ = 1, 0 - elif reduce(np.logical_or,[x in k for x in [ "mlp.dense_h_to_4h.weight", - "mlp.dense_h_to_4h.bias", - "attention.query_key_value.weight", - "attention.query_key_value.bias", - "word_embeddings.weight", - "final_linear.weight" ]]): + elif reduce( + np.logical_or, + [ + x in k + for x in [ + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "attention.query_key_value.weight", + "attention.query_key_value.bias", + "word_embeddings.weight", + "final_linear.weight", + ] + ], + ): # row parallel max_, min_ = 0, 1 else: raise Exception("Unknown weight to shard: {}".format(k)) - size_per_rank = v.shape[max_] / num_mp_ranks - if size_per_rank % 128 != 0.: + if size_per_rank % 128 != 0.0: padded_size = (128 - (size_per_rank % 128)) + size_per_rank size_diff = int((padded_size * num_mp_ranks) - v.shape[max_]) - assert size_diff > 0, \ - "[ERROR] size diff is negative: {} for size_per_rank: {}, k:{}, shape:{}, padded_size:{}".format( - size_diff,size_per_rank,k,v.shape,padded_size) + assert ( + size_diff > 0 + ), "[ERROR] size diff is negative: {} for size_per_rank: {}, k:{}, shape:{}, padded_size:{}".format( + size_diff, size_per_rank, k, v.shape, padded_size + ) - zero_pad = torch.zeros((size_diff,v.shape[min_])) if max_ == 0 \ - else torch.zeros((v.shape[min_],size_diff)) + zero_pad = ( + torch.zeros((size_diff, v.shape[min_])) + if max_ == 0 + else torch.zeros((v.shape[min_], size_diff)) + ) - v = torch.cat([v,zero_pad],dim=max_) + v = torch.cat([v, zero_pad], dim=max_) else: padded_size = size_per_rank - assert size_per_rank % 1. == 0. - assert padded_size % 1. == 0. + assert size_per_rank % 1.0 == 0.0 + assert padded_size % 1.0 == 0.0 padded_size = int(padded_size) size_per_rank = int(size_per_rank) for x in range(num_mp_ranks): if size_per_rank != padded_size: - #need to pad - ranks[x][k] = v[padded_size * x : padded_size * (x+1),:] if max_ == 0 \ - else v[:,padded_size * x : padded_size * (x+1)] + # need to pad + ranks[x][k] = ( + v[padded_size * x : padded_size * (x + 1), :] + if max_ == 0 + else v[:, padded_size * x : padded_size * (x + 1)] + ) else: - ranks[x][k] = v[size_per_rank * x : size_per_rank * (x+1),...] if max_ == 0 \ - else v[:,size_per_rank * x : size_per_rank * (x+1)] - - else: + ranks[x][k] = ( + v[size_per_rank * x : size_per_rank * (x + 1), ...] + if max_ == 0 + else v[:, size_per_rank * x : size_per_rank * (x + 1)] + ) + + else: raise NotImplementedError() return ranks +def replace_sharded_seq(mp_checkpoints, mp_sharded_seq): + """replaces the values within checkpointed configs with those + from the sharded sequential object.""" -def replace_sharded_seq(mp_checkpoints,mp_sharded_seq): - """replaces the values within checkpointed configs with those - from the sharded sequential object.""" - for mp_idx, shard in mp_sharded_seq.items(): - mp_key = f'mp_rank_{mp_idx:02}_model_states.pt' - + mp_key = f"mp_rank_{mp_idx:02}_model_states.pt" + # use for loop instead of direct assignment # to check for compatibility - for k,v in mp_checkpoints[mp_key]['module'].items(): + for k, v in mp_checkpoints[mp_key]["module"].items(): try: - mp_checkpoints[mp_key]['module'][k] = shard[k] + mp_checkpoints[mp_key]["module"][k] = shard[k] except KeyError: print("ERROR key:{} not found in shard.".format(k)) -def shard_pp(sequential,mp_rank,num_layers): +def shard_pp(sequential, mp_rank, num_layers): """Shards the model into layers. :param sequential: the state dict of the sequential model at mp=1 @@ -227,29 +262,32 @@ def shard_pp(sequential,mp_rank,num_layers): returns a dict of state dicts for each layer """ suffix = f"-model_{mp_rank:02}-model_states.pt" - + layers_seq = dict() layers_seq[f"layer_00" + suffix] = { - "word_embeddings.weight" : sequential[f"sequential.0.word_embeddings.weight"] + "word_embeddings.weight": sequential[f"sequential.0.word_embeddings.weight"] } - layers_seq[f"layer_{num_layers+3:02}" + suffix] = { - "norm.weight" : sequential[f"sequential.{num_layers+3}.norm.weight"], - "norm.bias" : sequential[f"sequential.{num_layers+3}.norm.bias"], + layers_seq[f"layer_{num_layers+3:02}" + suffix] = { + "norm.weight": sequential[f"sequential.{num_layers+3}.norm.weight"], + "norm.bias": sequential[f"sequential.{num_layers+3}.norm.bias"], } - + layers_seq[f"layer_{num_layers+4:02}" + suffix] = { - "final_linear.weight" : sequential[f"sequential.{num_layers+4}.final_linear.weight"] + "final_linear.weight": sequential[ + f"sequential.{num_layers+4}.final_linear.weight" + ] } - - for layer in range(2,num_layers+2): + + for layer in range(2, num_layers + 2): layer_keys = [x for x in sequential if ".{}.".format(layer) in x] - layers_seq[f"layer_{layer:02}" + suffix] = \ - {k.split('.{}.'.format(layer))[1] : sequential[k] for k in layer_keys} + layers_seq[f"layer_{layer:02}" + suffix] = { + k.split(".{}.".format(layer))[1]: sequential[k] for k in layer_keys + } return layers_seq -def shard_pp_mp(num_mp_ranks,sequential,num_layers): +def shard_pp_mp(num_mp_ranks, sequential, num_layers): """Shards the model into layers and model parallel ranks. :param num_mp_ranks: the number of model parallel ranks @@ -258,20 +296,18 @@ def shard_pp_mp(num_mp_ranks,sequential,num_layers): returns a dict of state dicts for each layer for each model parallel rank """ - mp_sharded = shard_sequential_mp(num_mp_ranks=num_mp_ranks, - sequential=sequential) - + mp_sharded = shard_sequential_mp(num_mp_ranks=num_mp_ranks, sequential=sequential) + layers_pp_mp = {} for mp_rank, d in mp_sharded.items(): layers_pp_mp.update( - shard_pp(sequential=d, - mp_rank=mp_rank, - num_layers=num_layers) + shard_pp(sequential=d, mp_rank=mp_rank, num_layers=num_layers) ) return layers_pp_mp + def convert(hf_model, ckpt_dir, output_dir): - """Converts a huggingface model to a NeoX checkpoint for different + """Converts a huggingface model to a NeoX checkpoint for different model parallel and pipeline parallel settings degrees. :param hf_model: the huggingface model @@ -280,70 +316,76 @@ def convert(hf_model, ckpt_dir, output_dir): returns None """ - os.listdir(ckpt_dir) ckpts, layers = {}, {} for x in os.listdir(ckpt_dir): if x.startswith("mp_rank"): - ckpts[x] = torch.load(os.path.join(ckpt_dir,x)) + ckpts[x] = torch.load(os.path.join(ckpt_dir, x)) elif x.startswith("layer"): - layers[x] = torch.load(os.path.join(ckpt_dir,x)) + layers[x] = torch.load(os.path.join(ckpt_dir, x)) assert len(layers) + len(ckpts) > 0, "No checkpoints found in {}".format(ckpt_dir) os.makedirs(output_dir, exist_ok=True) seq_state_dict = dict() - convert_hf_to_sequential(hf_model,seq_state_dict) + convert_hf_to_sequential(hf_model, seq_state_dict) if len(ckpts) == 1 and len(layers) == 0: # pp=0, mp=1 key = list(ckpts.keys())[0] - ckpts[key]['module'] = seq_state_dict + ckpts[key]["module"] = seq_state_dict to_save = ckpts elif len(ckpts) > 1 and len(layers) == 0: # pp=0, mp>1 - sharded_seq = shard_sequential_mp(num_mp_ranks=len(ckpts),sequential=seq_state_dict) - replace_sharded_seq(mp_checkpoints=ckpts,mp_sharded_seq=sharded_seq) + sharded_seq = shard_sequential_mp( + num_mp_ranks=len(ckpts), sequential=seq_state_dict + ) + replace_sharded_seq(mp_checkpoints=ckpts, mp_sharded_seq=sharded_seq) to_save = ckpts elif len(ckpts) == 1 and len(layers) > 1: # pp>0, mp==1 - to_save = shard_pp(sequential=seq_state_dict, - mp_rank=0, - num_layers=hf_model.config.num_hidden_layers) + to_save = shard_pp( + sequential=seq_state_dict, + mp_rank=0, + num_layers=hf_model.config.num_hidden_layers, + ) elif len(ckpts) > 1 and len(layers) > 1: # pp>0, mp>1 - to_save = shard_pp_mp(num_mp_ranks=len(ckpts), - sequential=seq_state_dict, - num_layers=hf_model.config.num_hidden_layers) + to_save = shard_pp_mp( + num_mp_ranks=len(ckpts), + sequential=seq_state_dict, + num_layers=hf_model.config.num_hidden_layers, + ) else: - raise NotImplementedError("Not implemented for len(ckpts)={} and len(layers)={}".format( - len(ckpts),len(layers))) + raise NotImplementedError( + "Not implemented for len(ckpts)={} and len(layers)={}".format( + len(ckpts), len(layers) + ) + ) - for k,v in to_save.items(): - print("saving {}...".format(os.path.join(output_dir,k))) - torch.save(v,os.path.join(ckpt_dir,k)) + for k, v in to_save.items(): + print("saving {}...".format(os.path.join(output_dir, k))) + torch.save(v, os.path.join(ckpt_dir, k)) # copy the checkpoint to the output_dir print("rm {}/*".format(output_dir)) os.system("rm {}/*".format(output_dir)) - os.makedirs(output_dir,exist_ok=True) - print("cp {} {}".format(os.path.join(ckpt_dir,'*'),output_dir)) - os.system("cp {} {}".format(os.path.join(ckpt_dir,'*'),output_dir)) - + os.makedirs(output_dir, exist_ok=True) + print("cp {} {}".format(os.path.join(ckpt_dir, "*"), output_dir)) + os.system("cp {} {}".format(os.path.join(ckpt_dir, "*"), output_dir)) # set latest file within the output_dir - latest_file = os.path.join("/".join(output_dir.split("/")[:-1]),'latest') - os.system('rm '+latest_file) - with open(latest_file,'w') as f: + latest_file = os.path.join("/".join(output_dir.split("/")[:-1]), "latest") + os.system("rm " + latest_file) + with open(latest_file, "w") as f: f.write(output_dir.split("/")[-1]) - def consume_neox_args2(args_parsed, overwrite_values=None): """ Deepspeed launcher needs to pass the arguments for `pretrain_gpt2.py` across to all machines. @@ -364,6 +406,7 @@ def consume_neox_args2(args_parsed, overwrite_values=None): megatron_config.update(overwrite_values) return NeoXArgs.from_dict(args_dict=megatron_config) + def get_non_existing_dir(tmp_dir): while os.path.exists(tmp_dir): tmp_dir = os.path.join(tmp_dir, "tmp_dir") @@ -380,71 +423,79 @@ def get_non_existing_dir(tmp_dir): default=143000, help="Revision or step of the Pythia model to convert.", ) - parser.add_argument( + parser.add_argument( "--output-dir", type=str, help="Path to save the converted GPT-NeoX model checkpoint.", ) parser.add_argument( - "--config", + "--config", nargs="*", default=[], - help="Path to the config file for the equivalent NeoX model." + help="Path to the config file for the equivalent NeoX model.", ) parser.add_argument( "--test", action="store_true", - help="If set, will run a test to ensure the conversion was successful." + help="If set, will run a test to ensure the conversion was successful.", ) parser.add_argument( "--download-only", action="store_true", - help="If set, script will only download the model and not convert it." + help="If set, script will only download the model and not convert it.", ) parser.add_argument( "--ckpt-tmp-dir", default="/tmp/ckpt_tmp_dir", - help="Directory to store cached hugging face checkpoints. [WARNING: MUST BE VISIBLE TO ALL RANKS]" + help="Directory to store cached hugging face checkpoints. [WARNING: MUST BE VISIBLE TO ALL RANKS]", ) parser.add_argument( "--hf-model-name", type=str, - help="Name of the hugging face model to download from EleutherAI/{hf-model-name}.}" + help="Name of the hugging face model to download from EleutherAI/{hf-model-name}.}", ) parser.add_argument( "--cache-dir", default="/gpfs/alpine/csc499/proj-shared/hf_checkpoints", - help="Directory to store cached hugging face checkpoints." + help="Directory to store cached hugging face checkpoints.", ) try: - if int(os.environ['WORLD_SIZE']) > 1: + if int(os.environ["WORLD_SIZE"]) > 1: args = parser.parse_args(MULTI_GPU_ARGS.split(" ")) else: args = parser.parse_args() except KeyError: args = parser.parse_args() - tmp_cache_dir = get_non_existing_dir(args.ckpt_tmp_dir) if args.download_only: hf_model = GPTNeoXForCausalLM.from_pretrained( f"EleutherAI/{args.hf_model_name}", revision=f"step{args.revision}", - cache_dir=os.path.join(args.cache_dir,f"{args.hf_model_name}/step{args.revision}") + cache_dir=os.path.join( + args.cache_dir, f"{args.hf_model_name}/step{args.revision}" + ), ).half() exit(0) else: print("======================================================================") - print("Warning the following script will delete files withing {}".format(args.output_dir)) - print("Warning the following script will delete this directory {}".format(tmp_cache_dir)) + print( + "Warning the following script will delete files within {}".format( + args.output_dir + ) + ) + print( + "Warning the following script will delete this directory {}".format( + tmp_cache_dir + ) + ) print("======================================================================") # time.sleep(5) - - if int(os.environ.get('OMPI_COMM_WORLD_SIZE',1)) > 1: + if int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) > 1: neox_args = consume_neox_args2(args2) else: neox_args = NeoXArgs.from_ymls(args.config) @@ -452,7 +503,6 @@ def get_non_existing_dir(tmp_dir): neox_args.build_tokenizer() neox_args.initialize_tensorboard_writer() - # setup logging and timers # init_wandb(neox_args=neox_args) # timers = Timers( @@ -477,8 +527,8 @@ def get_non_existing_dir(tmp_dir): mpu=mpu if not neox_args.is_pipe_parallel else None, ) - if os.environ['OMPI_COMM_WORLD_RANK'] == '0': - os.makedirs(f'{tmp_cache_dir}',exist_ok=True) + if os.environ["OMPI_COMM_WORLD_RANK"] == "0": + os.makedirs(f"{tmp_cache_dir}", exist_ok=True) torch.distributed.barrier() neox_args.save = tmp_cache_dir @@ -490,38 +540,42 @@ def get_non_existing_dir(tmp_dir): optimizer=optimizer, lr_scheduler=lr_scheduler, ) - print(os.listdir(f'{tmp_cache_dir}')) - ckpt_dir = os.path.join(tmp_cache_dir,'global_step0') - + print(os.listdir(f"{tmp_cache_dir}")) + ckpt_dir = os.path.join(tmp_cache_dir, "global_step0") if torch.distributed.get_rank() == 0: config = GPTNeoXConfig.from_pretrained( f"EleutherAI/{args.hf_model_name}", revision=f"step{args.revision}", - cache_dir=os.path.join(args.cache_dir,f"{args.hf_model_name}/step{args.revision}")) + cache_dir=os.path.join( + args.cache_dir, f"{args.hf_model_name}/step{args.revision}" + ), + ) # does not change the weights, but is needed to align logits - config.update({'hidden_act':'gelu_fast'}) + config.update({"hidden_act": "gelu_fast"}) hf_model = GPTNeoXForCausalLM.from_pretrained( f"EleutherAI/{args.hf_model_name}", revision=f"step{args.revision}", config=config, - cache_dir=os.path.join(args.cache_dir,f"{args.hf_model_name}/step{args.revision}") + cache_dir=os.path.join( + args.cache_dir, f"{args.hf_model_name}/step{args.revision}" + ), ).half() print("==========================================") print("Loaded Hugging Face model successfully!") print("==========================================") convert(hf_model, ckpt_dir=ckpt_dir, output_dir=args.output_dir) - if os.environ['OMPI_COMM_WORLD_RANK'] == '0': + if os.environ["OMPI_COMM_WORLD_RANK"] == "0": # cleanup temp dir os.system(f"rm -r {tmp_cache_dir}") torch.distributed.barrier() - #verify the conversion can be loaded + # verify the conversion can be loaded neox_args.load = "/".join(args.output_dir.split("/")[:-1]) print(neox_args.load) - neox_args.finetune=True + neox_args.finetune = True load_checkpoint( neox_args=neox_args, model=model, @@ -533,11 +587,9 @@ def get_non_existing_dir(tmp_dir): print("Converted checkpoint successfully loaded!") print("==========================================") - if args.test and torch.distributed.get_world_size() == 1: # only implemented for world size 1 - with torch.no_grad(): # torch.backends.cudnn.benchmark = False # torch.use_deterministic_algorithms(True) #setting the CUBLAS_WORKSPACE_CONFIG=:4096:8 environment variable is required for this to work (tested for A6000) @@ -547,27 +599,46 @@ def get_non_existing_dir(tmp_dir): b = 10 seq_len = 32 inputs = torch.randint(0, 50304, (b, seq_len), dtype=torch.long).cuda() - mask = (torch.triu(torch.ones(seq_len, seq_len)) != 1).transpose(0, 1).cuda() - pos_ids = torch.arange(0,seq_len).unsqueeze(0).cuda() + mask = ( + (torch.triu(torch.ones(seq_len, seq_len)) != 1).transpose(0, 1).cuda() + ) + pos_ids = torch.arange(0, seq_len).unsqueeze(0).cuda() torch.manual_seed(0) - outputs_neox = model.cuda()((inputs,pos_ids,mask.unsqueeze(0).unsqueeze(0)), neox_args=neox_args) + outputs_neox = model.cuda()( + (inputs, pos_ids, mask.unsqueeze(0).unsqueeze(0)), neox_args=neox_args + ) torch.manual_seed(0) outputs = hf_model.cuda()(input_ids=inputs) print("HF logits .sum(): ", outputs.logits.to(torch.float32).sum()) print("NeoX logits .sum(): ", outputs_neox.to(torch.float32).sum()) - - print("\nLogit comparison summary for {} sequences of length {}:".format(b,seq_len)) + + print( + "\nLogit comparison summary for {} sequences of length {}:".format( + b, seq_len + ) + ) print("=============================================================") for i in range(b): - abs_diff = (outputs.logits[i,...].to(torch.float32) - outputs_neox[i,...].to(torch.float32)).abs() - print("[Random sequence {}] (hflogits - neoxlogits).abs() -- mean: {:.5f}\tmax: {:.5f}\tmin: {:.5f}\tmedian: {:.5f}".format( - i,abs_diff.mean(),abs_diff.max(),abs_diff.min(),abs_diff.median())) + abs_diff = ( + outputs.logits[i, ...].to(torch.float32) + - outputs_neox[i, ...].to(torch.float32) + ).abs() + print( + "[Random sequence {}] (hflogits - neoxlogits).abs() -- mean: {:.5f}\tmax: {:.5f}\tmin: {:.5f}\tmedian: {:.5f}".format( + i, + abs_diff.mean(), + abs_diff.max(), + abs_diff.min(), + abs_diff.median(), + ) + ) elif args.test: - print("[INFO] Checkpoint conversion logit test not implemented for distributed world_size > 1. Current world_size: {}".format(torch.distributed.get_world_size())) - - - + print( + "[INFO] Checkpoint conversion logit test not implemented for distributed world_size > 1. Current world_size: {}".format( + torch.distributed.get_world_size() + ) + ) diff --git a/tools/convert_module_to_hf.py b/tools/ckpts/convert_module_to_hf.py similarity index 98% rename from tools/convert_module_to_hf.py rename to tools/ckpts/convert_module_to_hf.py index 2cbf390b9..f3f43c308 100644 --- a/tools/convert_module_to_hf.py +++ b/tools/ckpts/convert_module_to_hf.py @@ -25,7 +25,7 @@ sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) from megatron.tokenizer import build_tokenizer @@ -153,15 +153,19 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): try: # this conditional is quite messy because there were a number of ways to specify bf16 or fp16 training # in DeeperSpeed v1.0 . - if (fp16.get("fp16", None) or fp16["enabled"]) and not (fp16.get("type", None) == "bfloat16"): + if (fp16.get("fp16", None) or fp16["enabled"]) and not ( + fp16.get("type", None) == "bfloat16" + ): hf_model.half() print("Saving weights in fp16 precision...") elif fp16.get("type", None) == "bfloat16": hf_model.to(dtype=torch.bfloat16) print("Saving weights in bf16 precision...") except: - print("Model not trained in fp16 / bf16 mixed precision, saving weights in fp32...") - + print( + "Model not trained in fp16 / bf16 mixed precision, saving weights in fp32..." + ) + mp_partitions = get_key(loaded_config, "model-parallel-size") ### Embedding layer ### diff --git a/tools/convert_raw_llama_weights_to_neox.py b/tools/ckpts/convert_raw_llama_weights_to_neox.py similarity index 100% rename from tools/convert_raw_llama_weights_to_neox.py rename to tools/ckpts/convert_raw_llama_weights_to_neox.py diff --git a/tools/convert_sequential_to_hf.py b/tools/ckpts/convert_sequential_to_hf.py similarity index 98% rename from tools/convert_sequential_to_hf.py rename to tools/ckpts/convert_sequential_to_hf.py index cb2dc276c..f0a505ac3 100644 --- a/tools/convert_sequential_to_hf.py +++ b/tools/ckpts/convert_sequential_to_hf.py @@ -25,7 +25,7 @@ from typing import List sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) from megatron.tokenizer import build_tokenizer @@ -153,9 +153,7 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): hf_config = create_config(loaded_config) - hf_model = GPTNeoXForCausalLM( - hf_config - ) + hf_model = GPTNeoXForCausalLM(hf_config) # save model in FP16 if Deepspeed fp16 was used in config, else 32 bit fp16 = get_key(loaded_config, "fp16") @@ -175,7 +173,9 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): hf_model.to(dtype=torch.bfloat16) print("Saving weights in bf16 precision...") except: - print("Model not trained in fp16 / bf16 mixed precision, saving weights in fp32...") + print( + "Model not trained in fp16 / bf16 mixed precision, saving weights in fp32..." + ) mp_partitions = get_key(loaded_config, "model-parallel-size") diff --git a/tools/inspect_checkpoints.py b/tools/ckpts/inspect_checkpoints.py similarity index 100% rename from tools/inspect_checkpoints.py rename to tools/ckpts/inspect_checkpoints.py diff --git a/tools/merge20b.py b/tools/ckpts/merge20b.py similarity index 100% rename from tools/merge20b.py rename to tools/ckpts/merge20b.py diff --git a/tools/upload.py b/tools/ckpts/upload.py similarity index 100% rename from tools/upload.py rename to tools/ckpts/upload.py diff --git a/tools/datasets/README.md b/tools/datasets/README.md new file mode 100644 index 000000000..0f4c382e4 --- /dev/null +++ b/tools/datasets/README.md @@ -0,0 +1,107 @@ +# Data Scripts + +## `preprocess_data.py` +Takes a raw dataset, splits it up, tokenizes it, and saves it as numpy files that can be memmapped and used efficiently by the training code. + +``` +usage: preprocess_data.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] [--num-docs NUM_DOCS] + --tokenizer-type + {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer,TiktokenTokenizer,SPMTokenizer} + [--vocab-file VOCAB_FILE] [--merge-file MERGE_FILE] [--append-eod] [--ftfy] --output-prefix + OUTPUT_PREFIX [--dataset-impl {lazy,cached,mmap}] [--workers WORKERS] + [--log-interval LOG_INTERVAL] + +options: + -h, --help show this help message and exit + +input data: + --input INPUT Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma + separated list + --jsonl-keys JSONL_KEYS [JSONL_KEYS ...] + space separate listed of keys to extract from jsonl. Defa + --num-docs NUM_DOCS Optional: Number of documents in the input data (if known) for an accurate progress bar. + +tokenizer: + --tokenizer-type {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer,TiktokenTokenizer,SPMTokenizer} + What type of tokenizer to use. + --vocab-file VOCAB_FILE + Path to the vocab file + --merge-file MERGE_FILE + Path to the BPE merge file (if necessary). + --append-eod Append an token to the end of a document. + --ftfy Use ftfy to clean text + +output data: + --output-prefix OUTPUT_PREFIX + Path to binary output file without suffix + --dataset-impl {lazy,cached,mmap} + Dataset implementation to use. Default: mmap + +runtime: + --workers WORKERS Number of worker processes to launch + --log-interval LOG_INTERVAL + Interval between progress updates +``` +## `preprocess_data_with_mask.py` +Does the same but also creates `label` tensors if the dataset has labels. + +``` +usage: preprocess_data_with_mask.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] + [--mask-before-token MASK_BEFORE_TOKEN] [--num-docs NUM_DOCS] --tokenizer-type + {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer} + [--vocab-file VOCAB_FILE] [--merge-file MERGE_FILE] [--append-eod] [--ftfy] + --output-prefix OUTPUT_PREFIX [--dataset-impl {lazy,cached,mmap}] + [--workers WORKERS] [--log-interval LOG_INTERVAL] + +options: + -h, --help show this help message and exit + +input data: + --input INPUT Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma + separated list + --jsonl-keys JSONL_KEYS [JSONL_KEYS ...] + space separate listed of keys to extract from jsonl. Defa + --mask-before-token MASK_BEFORE_TOKEN + apply loss masks before certain token(s). If multi-token pattern, separate by commas without + space, e.g. --mask-before-token 0,1,1270 to use the token pattern [0,1,1270]. + --num-docs NUM_DOCS Optional: Number of documents in the input data (if known) for an accurate progress bar. + +tokenizer: + --tokenizer-type {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer} + What type of tokenizer to use. + --vocab-file VOCAB_FILE + Path to the vocab file + --merge-file MERGE_FILE + Path to the BPE merge file (if necessary). + --append-eod Append an token to the end of a document. + --ftfy Use ftfy to clean text + +output data: + --output-prefix OUTPUT_PREFIX + Path to binary output file without suffix + --dataset-impl {lazy,cached,mmap} + Dataset implementation to use. Default: mmap + +runtime: + --workers WORKERS Number of worker processes to launch + --log-interval LOG_INTERVAL + Interval between progress updates +``` +## `multinode_prepare_data.sh` +Does the same but distributed over multiple nodes. + +``` +# USAGE: +# This script allows you to prepare your dataset using multiple nodes by chunking the individual files and distributed the chunks +# over the processes. +# This bash script takes a single text file as input argument. +# The text file contains a valid filepath in each line, leading to a jsonl-file. +# Furthermore an environment variable for the rank and the world size needs to be set. +# These default to the SLURM and OMPI variables in this order of priority, but they can be set manually as well +# using the variables $RANK and $WORLD_SIZE, which will overwrite the cluster-specific variables. +# You can also add all arguments of the prepare_data.py script to this script and it will simply pass them through. +``` + + +## `corpora.py` +Has information for common datasets. Primarily meant for use in top-level `prepare_data.py` script. diff --git a/tools/corpora.py b/tools/datasets/corpora.py similarity index 100% rename from tools/corpora.py rename to tools/datasets/corpora.py diff --git a/tools/merge_datasets.py b/tools/datasets/merge_datasets.py similarity index 98% rename from tools/merge_datasets.py rename to tools/datasets/merge_datasets.py index c5d1e6255..4239c5eb5 100644 --- a/tools/merge_datasets.py +++ b/tools/datasets/merge_datasets.py @@ -4,7 +4,7 @@ import argparse sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) from megatron.data import indexed_dataset diff --git a/tools/multinode_prepare_data.sh b/tools/datasets/multinode_prepare_data.sh similarity index 95% rename from tools/multinode_prepare_data.sh rename to tools/datasets/multinode_prepare_data.sh index 1ff7244ae..87cb8ef31 100644 --- a/tools/multinode_prepare_data.sh +++ b/tools/datasets/multinode_prepare_data.sh @@ -53,7 +53,7 @@ fi echo "processing $chunk_file with rank $rank at world size $world_size" echo "using the following args: $py_args" # Call the Python script with the list of file paths in the chunk -python tools/preprocess_data.py --input $(tr '\n' ',' < "$chunk_file" | sed 's/,$/\n/') $py_args +python tools/datasets/preprocess_data.py --input $(tr '\n' ',' < "$chunk_file" | sed 's/,$/\n/') $py_args # Clean up rm "$chunk_file" diff --git a/tools/preprocess_data.py b/tools/datasets/preprocess_data.py similarity index 99% rename from tools/preprocess_data.py rename to tools/datasets/preprocess_data.py index 862620eb8..e780bec34 100644 --- a/tools/preprocess_data.py +++ b/tools/datasets/preprocess_data.py @@ -26,7 +26,7 @@ import numpy as np sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) import time import tqdm diff --git a/tools/preprocess_data_with_mask.py b/tools/datasets/preprocess_data_with_mask.py similarity index 92% rename from tools/preprocess_data_with_mask.py rename to tools/datasets/preprocess_data_with_mask.py index 636e852ce..093d94b2f 100644 --- a/tools/preprocess_data_with_mask.py +++ b/tools/datasets/preprocess_data_with_mask.py @@ -16,14 +16,14 @@ # limitations under the License. """ -A script for processing a dataset such that corresponding labels are also produced. These are then used to perform masked finetuning -(for example, finetuning a model to only output the text following some delimiter in the finetuning dataset such as "Answer: " +A script for processing a dataset such that corresponding labels are also produced. These are then used to perform masked finetuning +(for example, finetuning a model to only output the text following some delimiter in the finetuning dataset such as "Answer: " rather than generating the entire "Question: ... Answer: " turns of conversation. -To run this script, first edit `tools/corpora.py` such that the command to call `tools/preprocess_data.py` is as follows: +To run this script, first edit `tools/datasets/corpora.py` such that the command to call `tools/datasets/preprocess_data.py` is as follows: ``` -cmd = f"python tools/preprocess_data_with_mask.py \ +cmd = f"python tools/datasets/preprocess_data_with_mask.py \ --input {jsonl_filepath} \ --output-prefix {parent_folder}/{self.name} \ --vocab {self.vocab_file} \ @@ -33,22 +33,22 @@ --append-eod \ --mask-before-token X,Y,Z \ --workers {self.num_workers} " - + if self.num_docs is not None: cmd += f"--num-docs {self.num_docs} " if self.ftfy: cmd += f"--ftfy " ``` -where --mask-before-token must be the (comma-separated) list of tokens produced by encoding your delimiter string. +where --mask-before-token must be the (comma-separated) list of tokens produced by encoding your delimiter string. Up to and including the first occurrence of this token sequence in a document, all tokens will have their loss mask zeroed out when the label dataset is provided to NeoX. -Then, specify +Then, specify ``` "train_data_paths": ["/path/to/dataset/name_text_document"], "label_data_paths": ["/path/to/dataset/name_label_document"] ``` -in your YML config. This will then allow for finetuning on the data with loss masks set appropriately. +in your YML config. This will then allow for finetuning on the data with loss masks set appropriately. (However, be warned that NeoX packs documents to fill context windows, which may degrade performance in some finetuning situations where instead padding out to the context length may be preferred.) """ @@ -62,7 +62,7 @@ import numpy as np sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) import time import tqdm @@ -141,7 +141,7 @@ def encode(self, text): ids = {} for key in self.args.jsonl_keys: doc_ids = [] - text_ids = Encoder.tokenizer.tokenize(text['text']) + text_ids = Encoder.tokenizer.tokenize(text["text"]) if len(text_ids) > 0: doc_ids.append(text_ids) if self.args.append_eod: @@ -293,13 +293,15 @@ def main(): encoder.initializer() encoded_docs = (encoder.encode(doc) for doc in fin) - if args.mask_before_token is not None: - token_mask = [int(re.sub(r'[^0-9]', '', r)) for r in args.mask_before_token.split(",") if re.sub(r'[^0-9]', '', r)] + token_mask = [ + int(re.sub(r"[^0-9]", "", r)) + for r in args.mask_before_token.split(",") + if re.sub(r"[^0-9]", "", r) + ] else: token_mask = [] - # make a dataset builder for each key in args.jsonl_keys # each key will output to a different file beginning with args.output_prefix output_bin_files = {} @@ -318,7 +320,9 @@ def main(): vocab_size=tokenizer.vocab_size, ) if token_mask: - assert "label" not in args.jsonl_keys, "label should not be included as it will be generated according to the mask." + assert ( + "label" not in args.jsonl_keys + ), "label should not be included as it will be generated according to the mask." key = "label" output_bin_files[key] = "{}_{}_{}.bin".format( args.output_prefix, key, "document" @@ -335,7 +339,6 @@ def main(): for l in int32_labels: builders[l]._dtype = np.int32 - # actually do tokenization proc_start = time.time() total_bytes_processed = 0 @@ -350,15 +353,16 @@ def main(): for key, sentences in doc.items(): for sentence in sentences: builders[key].add_item(np.array(sentence, dtype=builders[key].dtype)) - if token_mask: + if token_mask: masked_sentence = mask(sentence, token_mask) - builders["label"].add_item(np.array(masked_sentence, dtype=builders["text"].dtype)) + builders["label"].add_item( + np.array(masked_sentence, dtype=builders["text"].dtype) + ) # separate with eos token builders[key].end_document() if token_mask: builders["label"].end_document() - # log progress if i % args.log_interval == 0: current = time.time() @@ -378,4 +382,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/tools/merge_mp_partitions.py b/tools/merge_mp_partitions.py deleted file mode 100644 index 6509718ff..000000000 --- a/tools/merge_mp_partitions.py +++ /dev/null @@ -1,293 +0,0 @@ -# Copyright (c) 2021, EleutherAI -# This file is based on code by the authors denoted below and has been modified from its original version. -# -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Merge model parallel partitions.""" - -import os -import sys - -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) -) - -import torch - -from megatron import mpu -from megatron.checkpointing import ensure_directory_exists -from megatron.checkpointing import get_checkpoint_name -from megatron.checkpointing import get_checkpoint_tracker_filename -from megatron.global_vars import rebuild_tokenizer -from megatron.global_vars import _parse_args - - -def split_into_partitions(tensor, num_partitions, partition_dim, stride): - - per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions) - per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) - - partitions_list = torch.split( - tensor, per_partition_per_stride_size, dim=partition_dim - ) - - partitions = [] - for i in range(num_partitions): - partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim) - partitions.append(partition) - - return partitions - - -def merge_partitions(merged, partitions, partition_dim, stride): - - # Number and size of each partition. - num_partitions = len(partitions) - per_partition_size = None - for partition in partitions: - if per_partition_size is None: - per_partition_size = partition.size(partition_dim) - else: - assert per_partition_size == partition.size(partition_dim) - - def concat_partitions(partitions_): - with torch.no_grad(): - if (per_partition_size * num_partitions) == merged.size(partition_dim): - torch.cat(partitions_, dim=partition_dim, out=merged) - else: - print( - " ***WARNING*** sizes do not match. Will cut " - "the merged partitions by {} along dimension {} " - "to reduce the size from {} to {} ...".format( - (per_partition_size * num_partitions) - - merged.size(partition_dim), - partition_dim, - per_partition_size * num_partitions, - merged.size(partition_dim), - ) - ) - merged_ = torch.cat(partitions_, dim=partition_dim) - merged_split = torch.split( - merged_, merged.size(partition_dim), dim=partition_dim - ) - merged_ = merged_split[0] - assert merged_.size(partition_dim) == merged.size(partition_dim) - merged.data.copy_(merged_.data) - - # If stride is 1, then do simple concatenation. - if stride == 1: - concat_partitions(partitions) - return - - # For none unity strides, first split based on stride and then group. - per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) - # Chunk and build a list. - chunks = None - for i, partition in enumerate(partitions): - chunk = torch.split(partition, per_partition_per_stride_size, dim=partition_dim) - - if chunks is None: - chunks = [0] * (num_partitions * len(chunk)) - chunks[i::num_partitions] = chunk - - # Concatinate. - concat_partitions(chunks) - - return - - -def get_model(model_type): - - if model_type == "GPT2": - from pretrain_gpt2 import model_provider - else: - raise Exception("unrecognized model type: {}".format(model_type)) - - model = model_provider() - model = model.half() - - return model - - -def get_parallel_checkpoint_name(path): - - tracker_filename = get_checkpoint_tracker_filename(path) - iteration = 0 - with open(tracker_filename, "r") as f: - metastring = f.read().strip() - iteration = int(metastring) - assert iteration > 0 - checkpoint_name = get_checkpoint_name(path, iteration) - - return checkpoint_name, iteration - - -def test_split_merge(): - - print("testing split and merge ...") - - # [QKV.ROW-COL] - tensor = torch.FloatTensor( - [ - [1.11, 1.12, 1.13, 1.14, 1.15], - [1.21, 1.22, 1.23, 1.24, 1.25], - [1.31, 1.32, 1.33, 1.34, 1.35], - [1.41, 1.42, 1.43, 1.44, 1.45], - [2.11, 2.12, 2.13, 2.14, 2.15], - [2.21, 2.22, 2.23, 2.24, 2.25], - [2.31, 2.32, 2.33, 2.34, 2.35], - [2.41, 2.42, 2.43, 2.44, 2.45], - [3.11, 3.12, 3.13, 3.14, 3.15], - [3.21, 3.22, 3.23, 3.24, 3.25], - [3.31, 3.32, 3.33, 3.34, 3.35], - [3.41, 3.42, 3.43, 3.44, 3.45], - ] - ) - - num_partitions = 2 - partition_dim = 0 - stride = 3 - partitions = split_into_partitions(tensor, num_partitions, partition_dim, stride) - - merged = torch.zeros_like(tensor) - merge_partitions(merged, partitions, partition_dim, stride) - - max_error = (merged - tensor).abs().max() - print(" > max error (should be zero): {}".format(max_error)) - - -def get_mp_merge_args(parser): - """Provide extra arguments required for merging.""" - group = parser.add_argument_group(title="mp merge") - - group.add_argument( - "--model-type", - type=str, - required=True, - choices=["BERT", "GPT2", "RACE", "MNLI", "QQP"], - help="Type of the model.", - ) - - return parser - - -def main(): - - # Args - args = _parse_args(extra_args_provider=get_mp_merge_args) - model_type = args.model_type - orig_model_parallel_size = args.model_parallel_size - args.model_parallel_size = 1 - tokenizer = rebuild_tokenizer(args) - - print("\n merging model parallel partitions ...") - print(" > number of partitions: {}".format(orig_model_parallel_size)) - print(" > checkpoint path: {}".format(args.load)) - print(" > model parameters:") - print(" number of tokens ................ {} ".format(tokenizer.vocab_size)) - print(" number of layers ................ {}".format(args.num_layers)) - print(" hidden size ..................... {}".format(args.hidden_size)) - print(" number of attention heads ....... {}".format(args.num_attention_heads)) - print( - " maximum position embeddings ..... {}".format(args.max_position_embeddings) - ) - - # Full model. - print("> building the full model ...") - mpu.initialize.set_model_parallel_world_size(1) - mpu.initialize.set_model_parallel_rank(0) - merged_model = get_model(model_type) - - # Build and load partitions. - partitions = [] - iteration = 0 - args.model_parallel_size = orig_model_parallel_size - tokenizer = rebuild_tokenizer(args) - mpu.initialize.set_model_parallel_world_size(args.model_parallel_size) - for rank in range(args.model_parallel_size): - mpu.initialize.set_model_parallel_rank(rank) - checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) - print("> loading {} ...".format(checkpoint_name)) - model_ = get_model(model_type) - sd = torch.load(checkpoint_name, map_location="cpu") - model_.load_state_dict(sd["model"]) - partitions.append(model_) - - # Parameter generators so we can loop through them semiltaneouly. - merged_params_gen = merged_model.named_parameters() - partitions_params_gen = [partition.named_parameters() for partition in partitions] - while True: - try: - - # Get the params and check names. - name, merged_param = next(merged_params_gen) - print(" > working on {} ...".format(name)) - print( - " merged type: {}, size: {}".format( - merged_param.dtype, list(merged_param.size()) - ) - ) - partitions_param = [] - for rank, partition_params_gen in enumerate(partitions_params_gen): - partition_name, partition_param = next(partition_params_gen) - assert partition_name == name - partitions_param.append(partition_param) - print( - " partition {} type: {}, size: {}".format( - rank, partition_param.dtype, list(partition_param.size()) - ) - ) - - # For the non-parallel parameters, simply copy the rank 0 values. - if not hasattr(merged_param, "model_parallel"): - print(" none-parallel parameter, simple copy from rank 0") - with torch.no_grad(): - merged_param.data.copy_(partitions_param[0].data) - # For parallel parameters, merge the values - else: - print( - " parallel parameter merge with stride {} along " - "dimension {}".format( - merged_param.stride, merged_param.partition_dim - ) - ) - merge_partitions( - merged_param, - partitions_param, - merged_param.partition_dim, - merged_param.stride, - ) - - except StopIteration: - break - - # Save the model. - args.model_parallel_size = 1 - mpu.initialize.set_model_parallel_rank(0) - sd = {} - sd["model"] = merged_model.state_dict() - sd["iteration"] = iteration - merged_path = os.path.join(args.load, "merged") - checkpoint_name = get_checkpoint_name(merged_path, iteration) - ensure_directory_exists(checkpoint_name) - print("> saving merged model to {}".format(checkpoint_name)) - torch.save(sd, checkpoint_name) - - print("done :-)") - - -if __name__ == "__main__": - - main()