Skip to content

Commit

Permalink
tested: neox models with TP = 1, PipelineModule, work
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Jan 18, 2024
1 parent 87da3a9 commit 7a9aed4
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions tools/ckpts/convert_neox_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from tqdm import tqdm

import torch
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM
from transformers import GPTNeoXConfig, AutoModelForCausalLM

from typing import List
from typing import List, Literal

sys.path.append(
os.path.abspath(
Expand Down Expand Up @@ -106,16 +106,16 @@
def load_partitions(input_checkpoint_path: str, mp_partitions: int, layer_idx: int, sequential: bool) -> List[torch.Tensor]:
"""Returns a list containing all states from a model (across MP partitions)"""

if sequential is True:
filename_format = f"mp_rank_{i:02}_model_states.pt"
else if sequential is False:
filename_format = f"layer_{layer_idx:02}-model_{i:02}-model_states.pt",
if sequential:
filename_format = f"mp_rank_{{i:02}}_model_states.pt"
else:
filename_format = f"layer_{layer_idx:02}-model_{{i:02}}-model_states.pt"

loaded_tp_ranks = [
torch.load(
os.path.join(
input_checkpoint_path,
filename_format,
filename_format.format(i=i),
),
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
Expand Down Expand Up @@ -228,23 +228,25 @@ def convert(
output_checkpoint_path,
sequential: bool = True,
precision: Literal["auto", "fp16", "bf16", "fp32"] = "auto",
architecture: Literal["neox", "llama", "mistral"],
architecture: Literal["neox", "llama", "mistral"] = "neox",
):
"""convert a NeoX checkpoint to a HF model format.
should perform model-parallel merging correctly
but only supports features allowed by HF GPT-NeoX implementation (e.g. rotary embeddings)
"""

ARCH = MODEL_KEYS[architecture]


hf_config = create_config(loaded_config)

hf_model = AutoModelForCausalLM.from_config(hf_config)

if architecture == "neox":
hf_transformer = hf_model.gpt_neox
else:
hf_transformer = hf_model.model

hf_config = create_config(loaded_config)

hf_model = AutoModelForCausalLM(hf_config)

if precision == "auto":
print("Auto-detecting precision to save model into...")
# save model in FP16 if Deepspeed fp16 was used in config, else 32 bit
Expand Down Expand Up @@ -284,18 +286,18 @@ def convert(
# for the pipeline-parallel case (pipeline-parallel-size >= 1),
# we must load the correct layer's states at each step.
# (this does mean that less memory is required for PP conversion.)
loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=0)
loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=0, sequential=sequential)

### Embedding layer ###
# Embedding is layer idx 0
if architecture == "neox":
if architecture == "neox":
embed_in = hf_transformer.embed_in
else:
embed_in = hf_transformer.embed_tokens
embed_in.load_state_dict( # TODO: embed_in is not always model's name for embedding
{
"weight": torch.cat(
get_state(loaded_tp_ranks, "word_embeddings.weight", 0), dim=0
get_state(loaded_tp_ranks, "word_embeddings.weight", layer_idx=0, sequential=sequential), dim=0
)
}
)
Expand All @@ -312,7 +314,7 @@ def convert(
if not sequential:
# in the non-sequential case, must load from each layer individually.
# use layer index + 2 bc of embed layer and a dummy _pre_transformer_block, which are "layers 0 and 1"
loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=layer_i + 2)
loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, layer_idx=layer_i + 2, sequential=sequential)

# + 2 bc of embed layer and a dummy _pre_transformer_block
state_dict = {}
Expand Down Expand Up @@ -354,9 +356,9 @@ def convert(
# load state_dict into layer
hf_layer.load_state_dict(state_dict)

if sequential:
if not sequential:
loaded_tp_ranks = load_partitions(
input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3
input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3, sequential=sequential
)
# Load final layer norm
if architecture == "neox":
Expand Down Expand Up @@ -384,9 +386,9 @@ def convert(
)

# Load output embedding
if sequential:
if not sequential:
loaded_tp_ranks = load_partitions(
input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4
input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4, sequential=sequential
)
# output embedding / LM head
if architecture == "neox": # name of lm head / final linear proj varies
Expand Down Expand Up @@ -435,7 +437,7 @@ def main(input_args=None, overwrite_values=None):
)
parser.add_argument(
"--precision",
type=Literal["auto", "fp16", "bf16", "fp32"]
type=str,
default="auto",
help="What precision to save the model into. Defaults to auto, which auto-detects which 16-bit dtype to save into, or falls back to fp32."
)
Expand All @@ -446,12 +448,16 @@ def main(input_args=None, overwrite_values=None):
)
parser.add_argument(
"--architecture",
type=Literal["neox", "llama", "mistral"],
type=str,
default="neox",
help="What HF model class type to export into."
)
args = parser.parse_args(input_args)

# validate arguments
assert args.precision in ["auto", "fp16", "bf16", "fp32"], f"expected --precision to be one of 'auto', 'fp16', 'bf16', 'fp32' but got '{args.precision}' !"
assert args.architecture in ["neox", "llama", "mistral"], f"expected --architecture to be one of 'neox', 'mistral', 'llama', but got '{args.architecture}' !"

with open(args.config_file) as f:
loaded_config = yaml.full_load(f)
if overwrite_values:
Expand Down Expand Up @@ -491,14 +497,6 @@ def main(input_args=None, overwrite_values=None):
tokenizer.save_pretrained(args.output_dir)
print("tokenizer saved!")

print(
tokenizer.decode(
hf_model.generate(
tokenizer.encode("Hello, I am testing ", return_tensors="pt")
)[0]
)
)


if __name__ == "__main__":

Expand Down

0 comments on commit 7a9aed4

Please sign in to comment.