diff --git a/tools/ckpts/convert_neox_to_hf.py b/tools/ckpts/convert_neox_to_hf.py index 4fdfbc5bb..eb3396c59 100644 --- a/tools/ckpts/convert_neox_to_hf.py +++ b/tools/ckpts/convert_neox_to_hf.py @@ -20,9 +20,9 @@ from tqdm import tqdm import torch -from transformers import GPTNeoXConfig, GPTNeoXForCausalLM +from transformers import GPTNeoXConfig, AutoModelForCausalLM -from typing import List +from typing import List, Literal sys.path.append( os.path.abspath( @@ -106,16 +106,16 @@ def load_partitions(input_checkpoint_path: str, mp_partitions: int, layer_idx: int, sequential: bool) -> List[torch.Tensor]: """Returns a list containing all states from a model (across MP partitions)""" - if sequential is True: - filename_format = f"mp_rank_{i:02}_model_states.pt" - else if sequential is False: - filename_format = f"layer_{layer_idx:02}-model_{i:02}-model_states.pt", + if sequential: + filename_format = f"mp_rank_{{i:02}}_model_states.pt" + else: + filename_format = f"layer_{layer_idx:02}-model_{{i:02}}-model_states.pt" loaded_tp_ranks = [ torch.load( os.path.join( input_checkpoint_path, - filename_format, + filename_format.format(i=i), ), map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) @@ -228,7 +228,7 @@ def convert( output_checkpoint_path, sequential: bool = True, precision: Literal["auto", "fp16", "bf16", "fp32"] = "auto", - architecture: Literal["neox", "llama", "mistral"], + architecture: Literal["neox", "llama", "mistral"] = "neox", ): """convert a NeoX checkpoint to a HF model format. should perform model-parallel merging correctly @@ -236,15 +236,17 @@ def convert( """ ARCH = MODEL_KEYS[architecture] + + + hf_config = create_config(loaded_config) + + hf_model = AutoModelForCausalLM.from_config(hf_config) + if architecture == "neox": hf_transformer = hf_model.gpt_neox else: hf_transformer = hf_model.model - hf_config = create_config(loaded_config) - - hf_model = AutoModelForCausalLM(hf_config) - if precision == "auto": print("Auto-detecting precision to save model into...") # save model in FP16 if Deepspeed fp16 was used in config, else 32 bit @@ -284,18 +286,18 @@ def convert( # for the pipeline-parallel case (pipeline-parallel-size >= 1), # we must load the correct layer's states at each step. # (this does mean that less memory is required for PP conversion.) - loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=0) + loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=0, sequential=sequential) ### Embedding layer ### # Embedding is layer idx 0 - if architecture == "neox": + if architecture == "neox": embed_in = hf_transformer.embed_in else: embed_in = hf_transformer.embed_tokens embed_in.load_state_dict( # TODO: embed_in is not always model's name for embedding { "weight": torch.cat( - get_state(loaded_tp_ranks, "word_embeddings.weight", 0), dim=0 + get_state(loaded_tp_ranks, "word_embeddings.weight", layer_idx=0, sequential=sequential), dim=0 ) } ) @@ -312,7 +314,7 @@ def convert( if not sequential: # in the non-sequential case, must load from each layer individually. # use layer index + 2 bc of embed layer and a dummy _pre_transformer_block, which are "layers 0 and 1" - loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=layer_i + 2) + loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=layer_i + 2, sequential=sequential) # + 2 bc of embed layer and a dummy _pre_transformer_block state_dict = {} @@ -354,9 +356,9 @@ def convert( # load state_dict into layer hf_layer.load_state_dict(state_dict) - if sequential: + if not sequential: loaded_tp_ranks = load_partitions( - input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3 + input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3, sequential=sequential ) # Load final layer norm if architecture == "neox": @@ -384,9 +386,9 @@ def convert( ) # Load output embedding - if sequential: + if not sequential: loaded_tp_ranks = load_partitions( - input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4 + input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4, sequential=sequential ) # output embedding / LM head if architecture == "neox": # name of lm head / final linear proj varies @@ -435,7 +437,7 @@ def main(input_args=None, overwrite_values=None): ) parser.add_argument( "--precision", - type=Literal["auto", "fp16", "bf16", "fp32"] + type=str, default="auto", help="What precision to save the model into. Defaults to auto, which auto-detects which 16-bit dtype to save into, or falls back to fp32." ) @@ -446,12 +448,16 @@ def main(input_args=None, overwrite_values=None): ) parser.add_argument( "--architecture", - type=Literal["neox", "llama", "mistral"], + type=str, default="neox", help="What HF model class type to export into." ) args = parser.parse_args(input_args) + # validate arguments + assert args.precision in ["auto", "fp16", "bf16", "fp32"], f"expected --precision to be one of 'auto', 'fp16', 'bf16', 'fp32' but got '{args.precision}' !" + assert args.architecture in ["neox", "llama", "mistral"], f"expected --architecture to be one of 'neox', 'mistral', 'llama', but got '{args.architecture}' !" + with open(args.config_file) as f: loaded_config = yaml.full_load(f) if overwrite_values: @@ -491,14 +497,6 @@ def main(input_args=None, overwrite_values=None): tokenizer.save_pretrained(args.output_dir) print("tokenizer saved!") - print( - tokenizer.decode( - hf_model.generate( - tokenizer.encode("Hello, I am testing ", return_tensors="pt") - )[0] - ) - ) - if __name__ == "__main__":