From 2a0775891e562d3443a992e822cd7caa43eb56e0 Mon Sep 17 00:00:00 2001 From: sanzhou Date: Mon, 9 Dec 2024 14:56:32 +0800 Subject: [PATCH] support qwen2 hf<->mcore ckpt converter --- examples/qwen/README.md | 71 +++++ tools/checkpoint/loader_mcore.py | 371 +++++++++++++++++-------- tools/checkpoint/loader_qwen2_hf.py | 405 ++++++++++++++++++++++++++++ tools/checkpoint/saver_mcore.py | 67 ++++- tools/checkpoint/saver_qwen2_hf.py | 333 +++++++++++++++++++++++ tools/checkpoint/schema_mcore.py | 3 + 6 files changed, 1127 insertions(+), 123 deletions(-) create mode 100644 examples/qwen/README.md create mode 100644 tools/checkpoint/loader_qwen2_hf.py create mode 100644 tools/checkpoint/saver_qwen2_hf.py diff --git a/examples/qwen/README.md b/examples/qwen/README.md new file mode 100644 index 0000000000..209ec196b4 --- /dev/null +++ b/examples/qwen/README.md @@ -0,0 +1,71 @@ +# QWen2 ckpt converter + +## Download QWen2 Checkpoints +Download QWen2 HF format checkpoint from [HF-hub](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) + +Or you can simply run this following script to download QWen2 into a specific folder. +```python +from huggingface_hub import snapshot_download +SAVED_DIR = "" # Specify the saved directory +# Download HF checkpoints +snapshot_download(repo_id="Qwen/Qwen2.5-3B-Instruct", ignore_patterns=["*.pt"], local_dir=SAVED_DIR, local_dir_use_symlinks=False) +``` + +## Convert QWen2 checkpoints from HF to MCore +Since MCore 0.7, we support using distributed checkpointing to load and save checkpoints with different parallel mappings. +To convert HF model to distributed checkpoints, use following instructions: + +``` +TOKENIZER_MODEL=/workspace/checkpoints/qwen2/tokenizer.model +MEGATRON_PATH="/workspace/megatron-lm" +export PYTHONPATH=$MEGATRON_PATH:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +TARGET_TP_SIZE=1 +TARGET_EP_SIZE=1 +TARGET_PP_SIZE=1 + +HF_FORMAT_DIR=/workspace/checkpoints/qwen2-hf +MCORE_FORMAT_DIR=/workspace/checkpoints/qwen-mcore-TP${TARGET_TP_SIZE}PP${TARGET_PP_SIZE}EP${TARGET_EP_SIZE} + +TARGET_EP_SIZE=${TARGET_EP_SIZE:-1} +TARGET_PP_SIZE=${TARGET_PP_SIZE:-1} +TARGET_CKPT_FORMAT=${TARGET_CKPT_FORMAT:-"torch_dist"} + +torchrun --nproc-per-node=1 --nnodes=1 checkpoint/convert.py \ +--model-type GPT \ +--loader qwen2_hf \ +--saver mcore \ +--target-tensor-parallel-size ${TARGET_TP_SIZE} \ +--target-pipeline-parallel-size ${TARGET_PP_SIZE} \ +--target-expert-parallel-size ${TARGET_EP_SIZE} \ +--load-dir ${HF_FORMAT_DIR} \ +--save-dir ${MCORE_FORMAT_DIR} \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--target-ckpt-format ${TARGET_CKPT_FORMAT}``` +``` + +## Convert QWen2 checkpoints from MCore to HF +Since MCore 0.7, we support using distributed checkpointing to load and save checkpoints with different parallel mappings. +To convert HF model to distributed checkpoints, use following instructions: + +``` +MEGATRON_PATH="/workspace/megatron-lm" +export PYTHONPATH=$MEGATRON_PATH:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +MCORE_FORMAT_DIR=/workspace/checkpoints/qwen-mcore-TP1PP1EP1 +HF_FORMAT_DIR=/workspace/checkpoints/qwen2-hf + +torchrun --nproc-per-node=1 --nnodes=1 checkpoint/convert.py \ +--model-type GPT \ +--loader mcore \ +--saver qwen2_hf \ +--load-dir ${MCORE_FORMAT_DIR} \ +--save-dir ${HF_FORMAT_DIR} +``` +NOTE: for qwen2moe, need to set gate=True for shared_experts in gpt_layer_specs.py + +## Acknowledgements +Contributors outside NVIDIA for the huggingface converter and example of QWen models in Megatron-Core: +- QWen Team diff --git a/tools/checkpoint/loader_mcore.py b/tools/checkpoint/loader_mcore.py index 9185969b33..06c7cc5655 100644 --- a/tools/checkpoint/loader_mcore.py +++ b/tools/checkpoint/loader_mcore.py @@ -3,8 +3,9 @@ import json import os import sys -import torch import types +import torch +import packaging from schema_mcore import get_model_schema from utils import print_memory_usage @@ -35,14 +36,16 @@ def _load_checkpoint(queue, args): # Search in directory above this sys.path.append(os.path.abspath( os.path.join(os.path.dirname(__file__), + os.path.pardir, os.path.pardir))) if args.megatron_path is not None: sys.path.insert(0, args.megatron_path) try: from megatron.training.arguments import parse_args, validate_args - from megatron.training.global_vars import set_args, set_global_variables + from megatron.training.global_vars import set_global_variables from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint + from megatron.core.parallel_state import initialize_model_parallel from megatron.legacy.model import module from megatron.core import mpu from megatron.core.enums import ModelType @@ -59,6 +62,7 @@ def _load_checkpoint(queue, args): '--no-bias-dropout-fusion', '--no-async-tensor-model-parallel-allreduce', '--use-cpu-initialization', + '--auto-detect-ckpt-format', '--micro-batch-size', '1', '--no-load-optim', '--no-load-rng', @@ -73,16 +77,60 @@ def _load_checkpoint(queue, args): ] margs = parse_args() + + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(0) + device_id = torch.device(f'cuda:0') + else: + device_id = None + margs, checkpoint_args = load_args_from_checkpoint(margs) + # for now, if load dist ckpt, we load it as tp1pp1ep1vp1 for convenience + if checkpoint_args.use_dist_ckpt: + # Call the init process + init_process_group_kwargs = { + 'backend': 'nccl', + 'world_size': 1, + 'rank': 0, + } + + if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"): + init_process_group_kwargs['device_id'] = device_id + margs.tensor_model_parallel_size = 1 + margs.pipeline_model_parallel_size = 1 + margs.expert_model_parallel_size = 1 + margs.virtual_pipeline_model_parallel_size = 1 + torch.distributed.init_process_group(**init_process_group_kwargs) + initialize_model_parallel() + print(f"real initializing distributed") + else: + print(f"fake initializing distributed") + margs.tensor_model_parallel_size = checkpoint_args.tensor_model_parallel_size + margs.pipeline_model_parallel_size = checkpoint_args.pipeline_model_parallel_size + margs.expert_model_parallel_size = checkpoint_args.expert_model_parallel_size + margs.virtual_pipeline_model_parallel_size = checkpoint_args.virtual_pipeline_model_parallel_size + margs.sequence_parallel = checkpoint_args.sequence_parallel + margs.ckpt_format = checkpoint_args.ckpt_format + # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes - margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size + margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size * margs.expert_model_parallel_size # Explicitly copy data types from checkpoint. margs.fp16 = checkpoint_args.fp16 margs.bf16 = checkpoint_args.bf16 + margs.use_legacy_models = False + margs.transformer_impl = args.loader_transformer_impl + margs.norm_epsilon = checkpoint_args.norm_epsilon + margs.rotary_base = checkpoint_args.rotary_base + if checkpoint_args.num_experts: + margs.moe_shared_expert_intermediate_size = checkpoint_args.moe_shared_expert_intermediate_size + margs.num_experts = checkpoint_args.num_experts + margs.moe_router_topk = checkpoint_args.moe_router_topk + # Expert parallelism requires sequence parallelism. if margs.expert_model_parallel_size > 1: margs.sequence_parallel = True @@ -90,9 +138,6 @@ def _load_checkpoint(queue, args): # Validate margs. margs = validate_args(margs) - margs.use_legacy_models = False - margs.transformer_impl = args.loader_transformer_impl - def check_for_arg(arg_name, default=None): if getattr(margs, arg_name, None) is None: if default is not None: @@ -105,6 +150,7 @@ def check_for_arg(arg_name, default=None): check_for_arg('tensor_model_parallel_size') check_for_arg('pipeline_model_parallel_size') + check_for_arg('expert_model_parallel_size') check_for_arg('num_layers') check_for_arg('hidden_size') check_for_arg('seq_length') @@ -117,7 +163,9 @@ def check_for_arg(arg_name, default=None): check_for_arg('disable_bias_linear', False) check_for_arg('params_dtype') check_for_arg('swiglu', False) - + if checkpoint_args.num_experts: + check_for_arg('num_experts') + print(f"checkpoint_args {checkpoint_args}") # Determine how to make our models if args.model_type == 'GPT': from pretrain_gpt import model_provider @@ -133,67 +181,69 @@ def check_for_arg(arg_name, default=None): consumed_train_samples = None consumed_valid_samples = None - def get_models(count, dtype): + def get_models(tp_size, ep_size, dtype): nonlocal consumed_train_samples nonlocal consumed_valid_samples model_array_len = margs.virtual_pipeline_model_parallel_size if model_array_len is None: model_array_len = 1 - models = [[] for _ in range(model_array_len)] + models = [[[] for _ in range(ep_size)] for _ in range(model_array_len)] pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - for rank in range(count): - mpu.set_tensor_model_parallel_rank(rank) - if margs.virtual_pipeline_model_parallel_size is not None: - model_ = [] - for i in range(margs.virtual_pipeline_model_parallel_size): - mpu.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. + for ep_rank in range(ep_size): + mpu.set_expert_model_parallel_rank(ep_rank) + for tp_rank in range(tp_size): + mpu.set_tensor_model_parallel_rank(tp_rank) + if margs.virtual_pipeline_model_parallel_size is not None: + model_ = [] + for i in range(margs.virtual_pipeline_model_parallel_size): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + this_model = model_provider( + pre_process=pre_process, + post_process=post_process + ).to(dtype) + model_.append(this_model) + else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - this_model = model_provider( - pre_process=pre_process, - post_process=post_process - ).to(dtype) - model_.append(this_model) - else: - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - model_rank = 0 - model_ = [model_provider(pre_process, post_process).to(dtype)] - margs.consumed_train_samples = 0 - margs.consumed_valid_samples = 0 - margs.exit_on_missing_checkpoint = True - load_checkpoint(model_, None, None) - - if consumed_train_samples is not None: - assert(margs.consumed_train_samples == consumed_train_samples) - else: - consumed_train_samples = margs.consumed_train_samples - if consumed_valid_samples is not None: - assert(margs.consumed_valid_samples == consumed_valid_samples) - else: - consumed_valid_samples = margs.consumed_valid_samples - for vp_rank in range(model_array_len): - models[vp_rank].append(model_[vp_rank]) + model_ = [model_provider(pre_process, post_process).to(dtype)] + margs.consumed_train_samples = 0 + margs.consumed_valid_samples = 0 + margs.exit_on_missing_checkpoint = True + load_checkpoint(model_, None, None, strict=False) + + if consumed_train_samples is not None: + assert(margs.consumed_train_samples == consumed_train_samples) + else: + consumed_train_samples = margs.consumed_train_samples + if consumed_valid_samples is not None: + assert(margs.consumed_valid_samples == consumed_valid_samples) + else: + consumed_valid_samples = margs.consumed_valid_samples + for vp_rank in range(model_array_len): + models[vp_rank][ep_rank].append(model_[vp_rank]) - # Print memory usage. - print_memory_usage("loader", rank, count) + # Print memory usage. + print_memory_usage("loader", tp_rank, tp_size) return models set_global_variables(margs, build_tokenizer=False) mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) - mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) mpu.set_expert_model_parallel_world_size(margs.expert_model_parallel_size) + mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) fused_kernels.load(margs) - + print(f"loader's margs {margs}") # Get true (non-padded) vocab size if args.true_vocab_size is not None: true_vocab_size = args.true_vocab_size elif args.vocab_file is not None: - vocab = json.load(open(args.vocab_file)) + with open(args.vocab_file) as vocab_file_handler: + vocab = json.load(vocab_file_handler) true_vocab_size = len(vocab) if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size: print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.") @@ -205,6 +255,7 @@ def get_models(count, dtype): # short aliases tp_size = margs.tensor_model_parallel_size pp_size = margs.pipeline_model_parallel_size + ep_size = margs.expert_model_parallel_size vp_size = margs.virtual_pipeline_model_parallel_size if vp_size is None: vp_size = 1 @@ -221,6 +272,7 @@ def get_models(count, dtype): md.model_type = args.model_type md.num_layers = margs.num_layers md.hidden_size = margs.hidden_size + md.ffn_hidden_size = margs.ffn_hidden_size md.seq_length = margs.seq_length md.num_attention_heads = margs.num_attention_heads md.max_position_embeddings = margs.max_position_embeddings @@ -229,6 +281,7 @@ def get_models(count, dtype): md.params_dtype = margs.params_dtype md.bert_binary_head = margs.bert_binary_head md.output_layer = margs.untie_embeddings_and_output_weights + md.untie_embeddings_and_output_weights = margs.untie_embeddings_and_output_weights md.position_embedding_type = margs.position_embedding_type md.linear_bias = margs.add_bias_linear md.qkv_bias = margs.add_qkv_bias @@ -236,15 +289,27 @@ def get_models(count, dtype): md.swiglu = margs.swiglu md.previous_tensor_parallel_size = margs.tensor_model_parallel_size md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size + md.previous_expert_parallel_size = margs.expert_model_parallel_size md.true_vocab_size = true_vocab_size md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by md.checkpoint_args = checkpoint_args md.use_legacy_models = margs.use_legacy_models - - # Get first pipe stage. + md.num_query_groups = margs.num_query_groups + md.group_query_attention = margs.group_query_attention + md.norm_epsilon = margs.norm_epsilon + md.rotary_base = margs.rotary_base + md.padded_vocab_size = margs.padded_vocab_size + md.num_experts = margs.num_experts + md.moe_router_topk = margs.moe_router_topk + md.moe_shared_expert_intermediate_size = margs.moe_shared_expert_intermediate_size + + # Get first pipe stage mpu.set_pipeline_model_parallel_rank(0) - all_models = [get_models(tp_size, md.params_dtype)] + # all_models: pp_rank, vp_rank, ep_rank, tp_rank + all_models = [get_models(tp_size, ep_size, md.params_dtype)] models = all_models[0][0] + if ep_size == 1: + assert len(models) == 1 md.consumed_train_samples = consumed_train_samples md.consumed_valid_samples = consumed_valid_samples @@ -264,7 +329,7 @@ def queue_put(name, msg): ) # Send embeddings. - embeddings = [ schema.get("embeddings", model) for model in models ] + embeddings = [ schema.get("embeddings", model) for model in models[0] ] message = { "word embeddings": torch.cat([ e["word"] for e in embeddings ], dim=0) } @@ -274,7 +339,140 @@ def queue_put(name, msg): assert embeddings[0]["pos"] is None queue_put("embeddings", message) - # Send layers. + def set_common_message(message): + # Get non-parallel tensors from tp_rank 0 + layer = schema.get_layer(models[0][0], layer_num) + message["input norm weight"] = layer["self_attn_norm_weight"] + message["post norm weight"] = layer["mlp_norm_weight"] + if norm_has_bias: + message["input norm bias"] = layer["self_attn_norm_bias"] + message["post norm bias"] = layer["mlp_norm_bias"] + if md.linear_bias: + message["dense bias"] = layer["self_attn_proj_bias"] + + # Grab attention parallel tensors for this layer + qkv_weight = [] + qkv_bias = [] + dense_weight = [] + for tp_rank, model in enumerate(models[0]): + layer = schema.get_layer(model, layer_num) + qkv_weight.append(layer["self_attn_qkv_weight"]) + dense_weight.append(layer["self_attn_proj_weight"]) + if md.qkv_bias: + qkv_bias.append(layer["self_attn_qkv_bias"]) + + # simple concat of the rest + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["dense weight"] = torch.cat(dense_weight, dim=1) + if md.qkv_bias: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + + def set_message_for_dense_model(message): + # Get non-parallel tensors from tp_rank 0 + layer = schema.get_layer(models[0][0], layer_num) + if md.linear_bias: + message["mlp l1 bias"] = layer["mlp_fc2_bias"] + + # Grab mlp parallel tensors for this layer + mlp_l0_weight = [] + mlp_l0_bias = [] + mlp_l1_weight = [] + for tp_rank, model in enumerate(models[0]): + layer = schema.get_layer(model, layer_num) + mlp_l0_weight.append(layer["mlp_fc1_weight"]) + mlp_l1_weight.append(layer["mlp_fc2_weight"]) + if md.linear_bias: + mlp_l0_bias.append(layer["mlp_fc1_bias"]) + + # Handle gated linear units + if md.swiglu: + # concat all the first halves ('W's) and all the second halves ('V's) + for tp_rank in range(tp_size): + mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) + message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) + message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + else: + message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) + + # simple concat of the rest + message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) + if md.linear_bias: + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) + message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0) + message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0) + else: + message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) + + def set_message_for_moe_model(args, message): + use_shared_expert = args.moe_shared_expert_intermediate_size is not None + + # Get non-parallel tensors from tp_rank 0 + layer = schema.get_layer(models[0][0], layer_num) + router_weight = layer["router_weight"] + if use_shared_expert: + shared_mlp_gate_weight = layer["shared_mlp_gate_weight"] + + # Grab all parallel tensors for this layer + shared_expert_mlp_l0_weight = [] + shared_expert_mlp_l1_weight = [] + mlp_l0_weight_list = [[] for _ in range(margs.num_experts)] + mlp_l1_weight_list = [[] for _ in range(margs.num_experts)] + + # Routed Experts modules + num_experts_per_rank = margs.num_experts // ep_size + for ep_rank, tp_models in enumerate(models): + for tp_rank, model in enumerate(tp_models): + layer = schema.get_layer(model, layer_num) + for local_expert_idx in range(num_experts_per_rank): + expert_idx = int(ep_rank * num_experts_per_rank + local_expert_idx) + mlp_l0_weight_list[expert_idx].append(layer[f"mlp_fc1_weight.{local_expert_idx}"]) + mlp_l1_weight_list[expert_idx].append(layer[f"mlp_fc2_weight.{local_expert_idx}"]) + + mlp_l0_weight_w_list = [[] for _ in range(margs.num_experts)] + mlp_l0_weight_v_list = [[] for _ in range(margs.num_experts)] + # Concat along the tensor parallel dimension + for expert_idx in range(margs.num_experts): + mlp_l0_weight = mlp_l0_weight_list[expert_idx] + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) + mlp_l0_weight_w_list[expert_idx] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) + mlp_l0_weight_v_list[expert_idx] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + else: + mlp_l0_weight_list[expert_idx] = torch.cat(mlp_l0_weight, dim=0) + mlp_l1_weight_list[expert_idx] = torch.cat(mlp_l1_weight_list[expert_idx], dim=1) + + if md.swiglu: + # Stack along the expert parallel dimension + message["mlp l0 weight W"] = torch.stack(mlp_l0_weight_w_list) + message["mlp l0 weight V"] = torch.stack(mlp_l0_weight_v_list) + else: + message["mlp l0 weight"] = torch.stack(mlp_l0_weight_list) + message["mlp l1 weight"] = torch.stack(mlp_l1_weight_list) + + # Share Experts modules + if use_shared_expert: + for tp_rank, model in enumerate(models[0]): + layer = schema.get_layer(model, layer_num) + shared_expert_mlp_l0_weight.append(layer["shared_mlp_fc1_weight"]) + shared_expert_mlp_l1_weight.append(layer["shared_mlp_fc2_weight"]) + + if md.swiglu: + for tp_rank in range(tp_size): + shared_expert_mlp_l0_weight[tp_rank] = torch.chunk(shared_expert_mlp_l0_weight[tp_rank], 2, dim=0) + message["shared mlp l0 weight W"] = torch.cat([w[0] for w in shared_expert_mlp_l0_weight], dim=0) + message["shared mlp l0 weight V"] = torch.cat([w[1] for w in shared_expert_mlp_l0_weight], dim=0) + else: + message["shared mlp l0 weight"] = torch.cat(shared_expert_mlp_l0_weight, dim=0) + message["shared mlp l1 weight"] = torch.cat(shared_expert_mlp_l1_weight, dim=1) + message["shared gate weight"] = shared_mlp_gate_weight + + # Do nothing to router + message["router weight"] = router_weight + + total_layer_num = 0 for vp_rank in range(vp_size): mpu.set_virtual_pipeline_model_parallel_rank(vp_rank) @@ -282,71 +480,22 @@ def queue_put(name, msg): if pp_rank > 0: mpu.set_pipeline_model_parallel_rank(pp_rank) if vp_rank == 0: - all_models.append(get_models(tp_size, md.params_dtype)) + all_models.append(get_models(tp_size, ep_size, md.params_dtype)) models = all_models[pp_rank][vp_rank] - for layer_num in range(schema.get_num_layers(models[0])): + for layer_num in range(schema.get_num_layers(models[0][0])): message = {} - - # Get non-parallel tensors from tp_rank 0 - layer = schema.get_layer(models[0], layer_num) - message["input norm weight"] = layer["self_attn_norm_weight"] - message["post norm weight"] = layer["mlp_norm_weight"] - if norm_has_bias: - message["input norm bias"] = layer["self_attn_norm_bias"] - message["post norm bias"] = layer["mlp_norm_bias"] - if md.linear_bias: - message["dense bias"] = layer["self_attn_proj_bias"] - message["mlp l1 bias"] = layer["mlp_fc2_bias"] - - # Grab all parallel tensors for this layer - qkv_weight = [] - qkv_bias = [] - dense_weight = [] - mlp_l0_weight = [] - mlp_l0_bias = [] - mlp_l1_weight = [] - for tp_rank, model in enumerate(models): - layer = schema.get_layer(model, layer_num) - qkv_weight.append(layer["self_attn_qkv_weight"]) - dense_weight.append(layer["self_attn_proj_weight"]) - mlp_l0_weight.append(layer["mlp_fc1_weight"]) - mlp_l1_weight.append(layer["mlp_fc2_weight"]) - if md.qkv_bias: - qkv_bias.append(layer["self_attn_qkv_bias"]) - if md.linear_bias: - mlp_l0_bias.append(layer["mlp_fc1_bias"]) - - # Handle gated linear units - if md.swiglu: - # concat all the first halves ('W's) and all the second halves ('V's) - for tp_rank in range(tp_size): - mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) - message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) - message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + set_common_message(message) + if margs.num_experts: + set_message_for_moe_model(margs, message) else: - message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) - - # simple concat of the rest - message["qkv weight"] = torch.cat(qkv_weight, dim=0) - message["dense weight"] = torch.cat(dense_weight, dim=1) - message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) - if md.qkv_bias: - message["qkv bias"] = torch.cat(qkv_bias, dim=0) - if md.linear_bias: - if md.swiglu: - for tp_rank in range(tp_size): - mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) - message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0) - message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0) - else: - message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) + set_message_for_dense_model(message) queue_put(f"transformer layer {total_layer_num}", message) total_layer_num = total_layer_num + 1 # Send final norm from tp_rank 0. - final_norm = schema.get("final_norm", models[0]) + final_norm = schema.get("final_norm", models[0][0]) message = { "weight": final_norm["weight"], } @@ -356,7 +505,7 @@ def queue_put(name, msg): # Send output layer. if md.output_layer: - output_layer_ranks = [ schema.get("output_layer", m) for m in models ] + output_layer_ranks = [ schema.get("output_layer", m) for m in models[0] ] message = { "weight": torch.cat([r["weight"] for r in output_layer_ranks], dim=0), } @@ -366,7 +515,7 @@ def queue_put(name, msg): if md.model_type == 'BERT': # Pooler. - pooler = schema.get("pooler", models[0]) + pooler = schema.get("pooler", models[0][0]) message = { "weight": pooler["weight"], "bias": pooler["bias"], @@ -374,7 +523,7 @@ def queue_put(name, msg): queue_put("pooler", message) # LM head. - lm_head = schema.get("lm_head", models[0]) + lm_head = schema.get("lm_head", models[0][0]) message = { "dense weight": lm_head["dense_weight"], "dense bias": lm_head["dense_bias"], @@ -386,7 +535,7 @@ def queue_put(name, msg): # Binary head. if md.bert_binary_head: - binary_head = schema.get("binary_head", models[0]) + binary_head = schema.get("binary_head", models[0][0]) message = { "weight": binary_head["weight"], "bias": binary_head["bias"], diff --git a/tools/checkpoint/loader_qwen2_hf.py b/tools/checkpoint/loader_qwen2_hf.py new file mode 100644 index 0000000000..6c7da00e99 --- /dev/null +++ b/tools/checkpoint/loader_qwen2_hf.py @@ -0,0 +1,405 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Support model: dense, moe""" + +import os +import sys +import torch +import transformers +from tqdm import tqdm +import types + + +def add_arguments(parser): + group = parser.add_argument_group(title='Qwen2 HF loader.') + + group.add_argument('--true-vocab-size', type=int, default=None, + help='original size of vocab, if specified will trim padding from embedding table.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file. If specified will use this to get vocab size and ' + 'trim padding from the embedding table.') + group.add_argument('--tokenizer-model', required=True, + help='Sentencepiece tokenizer model.') + group.add_argument('--megatron-path', type=str, default=None, + help='Base directory of deepspeed repository') + + +def load_args_from_checkpoint(args): + # Read Qwen2 args. + from transformers import AutoConfig + qwen2_config = AutoConfig.from_pretrained(args.load) + + # Update Megatron args. + args.untie_embeddings_and_output_weights = not qwen2_config.tie_word_embeddings + args.kv_channels = getattr(qwen2_config, "head_dim", None) + args.seq_length = 4096 + args.global_batch_size = 1024 + args.iteration = 1 # '0', 'release' don't work + args.add_position_embedding = False + args.use_rotary_position_embeddings = True + args.swiglu = True + args.bf16 = True + args.add_bias_linear = False + args.normalization = "RMSNorm" + args.tokenizer_type = "HuggingFaceTokenizer" + args.disable_bias_linear = True + args.add_qkv_bias = True + + args.max_position_embeddings = qwen2_config.max_position_embeddings + args.hidden_size = qwen2_config.hidden_size + args.num_attention_heads = qwen2_config.num_attention_heads + args.num_layers = qwen2_config.num_hidden_layers + args.norm_epsilon = qwen2_config.rms_norm_eps + args.rotary_base = qwen2_config.rope_theta + args.vocab_size = qwen2_config.vocab_size + args.padded_vocab_size = qwen2_config.vocab_size + args.qwen2 = qwen2_config + if "num_experts" in qwen2_config.__dict__.keys(): + args.num_experts = qwen2_config.num_experts + args.moe_router_topk = qwen2_config.num_experts_per_tok + args.ffn_hidden_size = qwen2_config.moe_intermediate_size + if qwen2_config.shared_expert_intermediate_size: + args.moe_shared_expert_intermediate_size = qwen2_config.shared_expert_intermediate_size + else: + args.ffn_hidden_size = qwen2_config.intermediate_size + args.sequence_parallel = True + + if qwen2_config.num_key_value_heads: + args.group_query_attention = True + args.num_query_groups = qwen2_config.num_key_value_heads + +def verify_transformers_version(): + version_parts = transformers.__version__.split('.')[:2] + major, minor = map(int, version_parts) + assert major >= 4 and minor >= 37 + +def set_preprocess_state(args, model, hf_model): + '''Set embedding params.''' + model.embedding.word_embeddings.weight.data.copy_( + hf_model.model.embed_tokens.weight) + +def set_postprocess_state(args, model, hf_model): + '''Set output layer & norm params.''' + model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight) + if args.untie_embeddings_and_output_weights: + model.output_layer.weight.data.copy_(hf_model.lm_head.weight) + +def set_attn_state(args, layer, hf_layer): + '''Set self-attention params.''' + + # Get attention layer & state. + attn = layer.self_attention + hf_attn = hf_layer.self_attn + + # Reshape loaded weights. + tp = args.tensor_model_parallel_size + num_heads = args.num_attention_heads // tp + num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) // tp + num_querys_per_group = num_heads // num_query_groups + dim = args.kv_channels + assert num_heads % num_querys_per_group == 0 + + # Copy weights (re-order dimensions for Megatron). + attn.linear_qkv.weight.data.copy_(torch.cat([ + hf_attn.q_proj.weight.reshape((num_query_groups, num_querys_per_group*dim, -1)), + hf_attn.k_proj.weight.reshape((num_query_groups, dim, -1)), + hf_attn.v_proj.weight.reshape((num_query_groups, dim, -1)), + ], dim=1).reshape((-1, args.hidden_size))) + attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) + + # Copy bias + attn.linear_qkv.bias.data.copy_(torch.cat([ + hf_attn.q_proj.bias.reshape((num_query_groups, num_querys_per_group*dim, -1)), + hf_attn.k_proj.bias.reshape((num_query_groups, dim, -1)), + hf_attn.v_proj.bias.reshape((num_query_groups, dim, -1)), + ], dim=1).reshape(num_query_groups*(num_querys_per_group+2)*dim)) + +def set_mlp_state(args, layer, hf_layer): + '''Set MLP params.''' + layer.mlp.linear_fc1.weight.data.copy_( + torch.cat([ + hf_layer.mlp.gate_proj.weight, + hf_layer.mlp.up_proj.weight + ], dim=0) + ) + layer.mlp.linear_fc2.weight.data.copy_( + hf_layer.mlp.down_proj.weight + ) + +def set_moe_mlp_state(args, layer, hf_layer): + '''Set MOE MLP params.''' + layer.mlp.router.weight.data.copy_(hf_layer.mlp.gate.weight) + layer.mlp.shared_experts.gate_weight.data.copy_(hf_layer.mlp.shared_expert_gate.weight) + mcore_experts = layer.mlp.experts.local_experts + hf_experts = hf_layer.mlp.experts + for expert_idx in range(args.num_experts): + mcore_experts[expert_idx].linear_fc1.weight.data.copy_( + torch.cat([ + hf_experts[expert_idx].gate_proj.weight, + hf_experts[expert_idx].up_proj.weight + ], dim=0) + ) + mcore_experts[expert_idx].linear_fc2.weight.data.copy_( + hf_experts[expert_idx].down_proj.weight + ) + + # shared exp + layer.mlp.shared_experts.linear_fc1.weight.data.copy_( + torch.cat([ + hf_layer.mlp.shared_expert.gate_proj.weight, + hf_layer.mlp.shared_expert.up_proj.weight + ], dim=0) + ) + layer.mlp.shared_experts.linear_fc2.weight.data.copy_( + hf_layer.mlp.shared_expert.down_proj.weight + ) + +def set_layer_state(args, model, hf_model, layer_idx): + '''Set transformer layer params.''' + + layer = model.decoder.layers[layer_idx] + hf_layer = hf_model.model.layers[layer_idx] + + set_attn_state(args, layer, hf_layer) + if args.num_experts: + set_moe_mlp_state(args, layer, hf_layer) + layer.pre_mlp_layernorm.weight.data.copy_(hf_layer.post_attention_layernorm.weight) + else: + set_mlp_state(args, layer, hf_layer) + layer.mlp.linear_fc1.layer_norm_weight.data.copy_(hf_layer.post_attention_layernorm.weight) + + layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight) + +def load_checkpoint_to_model(args, model_type): + '''Set model params.''' + + from pretrain_gpt import model_provider + from transformers import AutoModelForCausalLM, AutoModel, AutoConfig + + # Load Huggingface model. + hf_model = AutoModelForCausalLM.from_pretrained(args.load, torch_dtype=torch.bfloat16, device_map="cpu") + + # Init Megatron model. + model = model_provider(True, True).to(args.params_dtype) + # Set model state. + set_preprocess_state(args, model, hf_model) + set_postprocess_state(args, model, hf_model) + for layer_idx in tqdm(range(args.num_layers), "set layer states"): + set_layer_state(args, model, hf_model, layer_idx) + return model + +def get_mlp_message_for_dense_model(message, layer, md): + if md.swiglu: + chunked_mlp_l0_weight = torch.chunk(layer.mlp.linear_fc1.weight.data, 2, dim=0) + message["mlp l0 weight W"] = chunked_mlp_l0_weight[0] + message["mlp l0 weight V"] = chunked_mlp_l0_weight[1] + else: + message["mlp l0 weight"] = layer.mlp.linear_fc1.weight.data + message["mlp l1 weight"] = layer.mlp.linear_fc2.weight.data + +def get_mlp_message_for_moe_model(message, layer, md): + experts = layer.mlp.experts.local_experts + message["router weight"] = layer.mlp.router.weight.data + if md.swiglu: + chunked_mlp_l0_weight = [torch.chunk(local_expert.linear_fc1.weight.data, 2, dim=0) for local_expert in experts] + message["mlp l0 weight W"] = torch.stack([local_weight[0] for local_weight in chunked_mlp_l0_weight], dim=0) + message["mlp l0 weight V"] = torch.stack([local_weight[1] for local_weight in chunked_mlp_l0_weight], dim=0) + else: + message["mlp l0 weight"] = torch.stack([local_expert.linear_fc1.weight.data for local_expert in experts]) + message["mlp l1 weight"] = torch.stack([local_expert.linear_fc2.weight.data for local_expert in experts], dim=0) + + # shared exp + if md.swiglu: + chunked_mlp_l0_weight = torch.chunk(layer.mlp.shared_experts.linear_fc1.weight.data, 2, dim=0) + message["shared mlp l0 weight W"] = chunked_mlp_l0_weight[0] + message["shared mlp l0 weight V"] = chunked_mlp_l0_weight[1] + else: + message["shared mlp l0 weight"] = layer.mlp.shared_experts.linear_fc1.weight.data + message["shared mlp l1 weight"] = layer.mlp.shared_experts.linear_fc2.weight.data + message["shared mlp gate weight"] = layer.mlp.shared_experts.gate_weight.data + +def _load_checkpoint(queue, args): + + # Llama-2 requires HF transformers >=4.31.0. + verify_transformers_version() + + # Search in directory above this. + sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir, + os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + try: + from megatron.training.arguments import parse_args, validate_args + from megatron.training.global_vars import set_args, set_global_variables + from megatron.legacy.model import module + from megatron.core import mpu + from megatron.core.enums import ModelType + from megatron.legacy import fused_kernels + except ModuleNotFoundError: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + queue.put("exit") + exit(1) + + # We want all arguments to come from us. + sys.argv = ['script.py', + '--use-mcore-models', + '--disable-bias-linear', + '--no-masked-softmax-fusion', + '--no-bias-gelu-fusion', + '--no-bias-dropout-fusion', + '--no-async-tensor-model-parallel-allreduce', + '--use-cpu-initialization', + '--micro-batch-size', '1', + '--no-load-optim', + '--no-load-rng', + '--no-save-optim', + '--no-save-rng', + '--no-initialization', + '--mock-data', # To pass the "blend data checks" in arguments.py + '--transformer-impl', 'transformer_engine', + '--load', args.load_dir + ] + + margs = parse_args() + margs.tokenizer_model = args.tokenizer_model + load_args_from_checkpoint(margs) + + # Arguments do sanity checks on the world size, but we don't care, + # so trick it into thinking we are plenty of processes. + margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size + + margs = validate_args(margs) + + def check_for_arg(arg_name, default=None): + if getattr(margs, arg_name, None) is None: + if default is not None: + setattr(margs, arg_name, default) + else: + print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") + print(f"Arguments: {margs}") + queue.put("exit") + exit(1) + + check_for_arg('tensor_model_parallel_size') + check_for_arg('pipeline_model_parallel_size') + check_for_arg('num_layers') + check_for_arg('hidden_size') + check_for_arg('seq_length') + check_for_arg('num_attention_heads') + check_for_arg('max_position_embeddings') + check_for_arg('position_embedding_type') + check_for_arg('tokenizer_type') + check_for_arg('iteration') + check_for_arg('disable_bias_linear') + check_for_arg('add_qkv_bias') + check_for_arg('params_dtype') + check_for_arg('swiglu') + + # Determine how to make our models. + assert args.model_type == 'GPT', 'Qwen2 is a GPT-moe model.' + margs.model_type = ModelType.encoder_or_decoder + + # Suppress warning about torch.distributed not being initialized. + module.MegatronModule.embedding_warning_printed = True + + set_global_variables(margs, build_tokenizer=False) + mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) + mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) + mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) + mpu.set_expert_model_parallel_world_size(margs.expert_model_parallel_size) + fused_kernels.load(margs) + + # Metadata. + md = types.SimpleNamespace() + md.model_type = args.model_type + md.num_layers = margs.num_layers + md.hidden_size = margs.hidden_size + md.seq_length = margs.seq_length + md.num_attention_heads = margs.num_attention_heads + md.max_position_embeddings = margs.max_position_embeddings + md.tokenizer_type = margs.tokenizer_type + md.iteration = margs.iteration + md.params_dtype = margs.params_dtype + md.bert_binary_head = margs.bert_binary_head + md.output_layer = margs.untie_embeddings_and_output_weights + md.position_embedding_type = margs.position_embedding_type + md.linear_bias = margs.add_bias_linear + md.norm_has_bias = False + md.swiglu = margs.swiglu + md.previous_tensor_parallel_size = margs.tensor_model_parallel_size + md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size + md.true_vocab_size = margs.vocab_size # skips padding in saver + md.make_vocab_size_divisible_by = None + md.checkpoint_args = margs + md.consumed_train_samples = 0 + md.consumed_valid_samples = 0 + md.qkv_bias = margs.add_qkv_bias + if margs.num_experts: + md.num_experts = margs.num_experts + md.moe_shared_expert_intermediate_size = margs.moe_shared_expert_intermediate_size + md.moe_shared_experts_gate = True + + # Get first pipe stage. + mpu.set_tensor_model_parallel_rank(0) + mpu.set_pipeline_model_parallel_rank(0) + mpu.set_expert_model_parallel_rank(0) + model = load_checkpoint_to_model(margs, args.model_type) + + queue.put(md) + + def queue_put(name, msg): + print(f"sending {name}") + msg["name"] = name + queue.put(msg) + + # Send embeddings. + message = { + "word embeddings": model.embedding.word_embeddings.weight.data + } + if md.position_embedding_type == 'learned_absolute': + message["position embeddings"] = model.embedding.position_embeddings.weight.data + else: + assert not hasattr(model.embedding, 'position_embeddings') + + queue_put("embeddings", message) + + for layer_idx in range(margs.num_layers): + message = {} + + # Get non-parallel tensors from tp_rank 0. + layer = model.decoder.layers[layer_idx] + message["input norm weight"] = layer.self_attention.linear_qkv.layer_norm_weight.data + + # Simple concat of the rest. + message["qkv weight"] = layer.self_attention.linear_qkv.weight.data + message["qkv bias"] = layer.self_attention.linear_qkv.bias.data + message["dense weight"] = layer.self_attention.linear_proj.weight.data + + # Grab all parallel tensors for this layer. + if margs.num_experts: + message["post norm weight"] = layer.pre_mlp_layernorm.weight.data + get_mlp_message_for_moe_model(message, layer, md) + else: + message["post norm weight"] = layer.mlp.linear_fc1.layer_norm_weight.data + get_mlp_message_for_dense_model(message, layer, md) + + queue_put(f"transformer layer {layer_idx}", message) + + queue_put("final norm", { + "weight": model.decoder.final_layernorm.weight.data, + }) + + if md.output_layer: + queue_put("output layer", { + "weight": model.output_layer.weight.data + }) + queue.put("done") + +def load_checkpoint(queue, args): + try: + _load_checkpoint(queue, args) + except: + queue.put("exit") + raise diff --git a/tools/checkpoint/saver_mcore.py b/tools/checkpoint/saver_mcore.py index 2caf26a9a0..fb220fb77b 100644 --- a/tools/checkpoint/saver_mcore.py +++ b/tools/checkpoint/saver_mcore.py @@ -26,6 +26,9 @@ def add_arguments(parser): help='Which Transformer implementation to use.') group.add_argument('--target-expert-parallel-size', type=int, default=1, help='Target expert model parallel size, default to 1') + parser.add_argument('--target-ckpt-format', default='torch', + choices=['torch', 'torch_dist', 'zarr'], + help='Checkpoint format to use.') def save_checkpoint(queue, args): @@ -47,6 +50,7 @@ def save_checkpoint(queue, args): from megatron.training.arguments import (parse_args, validate_args) from megatron.training.checkpointing import save_checkpoint from megatron.training.global_vars import set_global_variables, get_args + from megatron.core.parallel_state import initialize_model_parallel from megatron.core.enums import ModelType from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding from megatron.legacy import fused_kernels @@ -98,6 +102,10 @@ def check_message(msg): "Default to 1.") args.target_pipeline_parallel_size = 1 + if args.target_ckpt_format == "torch_dist": + assert args.target_tensor_parallel_size == 1, "Please setting --target-tensor-parallel-size to 1 to use dist ckpt" + assert args.target_pipeline_parallel_size == 1, "Please setting --target-pipeline-parallel-size to 1 to use dist ckpt" + assert args.target_expert_parallel_size == 1, "Please setting --target-expert-parallel-size to 1 to use dist ckpt" # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes @@ -133,7 +141,7 @@ def check_message(msg): '--no-initialization', '--save-interval', '1', '--save', args.save_dir, - '--ckpt-format', 'torch', # only 'torch' supported for conversion + '--ckpt-format', str(args.target_ckpt_format), '--no-one-logger', ] @@ -226,14 +234,25 @@ def check_message(msg): margs.model_type = ModelType.encoder_or_decoder else: raise Exception(f'unrecognized model type: {args.model_type}') - + print(f"saver's margs {margs}") # fake initializing distributed - mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) - mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) - mpu.set_expert_model_parallel_world_size(args.target_expert_parallel_size) - mpu.set_tensor_model_parallel_rank(0) - mpu.set_pipeline_model_parallel_rank(0) - mpu.set_expert_model_parallel_rank(0) + if args.target_ckpt_format == "torch_dist": + torch.distributed.init_process_group( + backend=margs.distributed_backend, + world_size=margs.world_size, + rank=margs.rank, + ) + initialize_model_parallel() + print(f"real initializing distributed") + else: + print(f"fake initializing distributed") + # fake initializing distributed + mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) + mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) + mpu.set_expert_model_parallel_world_size(args.target_expert_parallel_size) + mpu.set_tensor_model_parallel_rank(0) + mpu.set_pipeline_model_parallel_rank(0) + mpu.set_expert_model_parallel_rank(0) fused_kernels.load(margs) # Embeddings @@ -366,7 +385,23 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1): # Split up the parallel tensors qkv_weight = chunk_weight(msg.pop("qkv weight"), "column", args.target_tensor_parallel_size) dense_weight = chunk_weight(msg.pop("dense weight"), "row", args.target_tensor_parallel_size) - mlp_l1_weight = chunk_weight(msg.pop("mlp l1 weight"), "row", args.target_tensor_parallel_size, args.target_expert_parallel_size) + mlp_l1_weight = chunk_weight(msg.pop("mlp l1 weight"), "row", args.target_tensor_parallel_size, + args.target_expert_parallel_size) + + if margs.moe_shared_expert_intermediate_size: + if md.swiglu: + shared_mlp_l0_weight_W = chunk_weight(msg.pop("shared mlp l0 weight W"), "column", + args.target_tensor_parallel_size) + shared_mlp_l0_weight_V = chunk_weight(msg.pop("shared mlp l0 weight V"), "column", + args.target_tensor_parallel_size) + shared_mlp_l0_weight = torch.cat((shared_mlp_l0_weight_W, shared_mlp_l0_weight_V), dim=-2) + else: + shared_mlp_l0_weight = chunk_weight(msg.pop("shared mlp l0 weight"), "column", + args.target_tensor_parallel_size) + shared_mlp_l1_weight = chunk_weight(msg.pop("shared mlp l1 weight"), "row", + args.target_tensor_parallel_size) + if hasattr(md, "moe_shared_experts_gate") and md.moe_shared_experts_gate: + shared_experts_gate = msg.pop("shared mlp gate weight") if margs.num_experts: router = msg.pop("router weight") @@ -400,11 +435,15 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1): "self_attn_proj_weight" : dense_weight[tp_rank], "mlp_norm_weight" : post_norm_weight } - if margs.num_experts: + if margs.moe_shared_expert_intermediate_size: params_dict.update({ - "mlp_fc1_weight" : mlp_l0_weight[ep_rank][tp_rank], - "mlp_fc2_weight" : mlp_l1_weight[ep_rank][tp_rank] + "shared_mlp_fc1_weight" : shared_mlp_l0_weight[tp_rank], + "shared_mlp_fc2_weight" : shared_mlp_l1_weight[tp_rank] }) + if margs.num_experts: + num_local_experts = margs.num_experts // args.target_expert_parallel_size + params_dict.update(**{f"mlp_fc1_weight.{expert_idx}" : mlp_l0_weight[ep_rank][tp_rank][expert_idx] for expert_idx in range(num_local_experts) }, + **{f"mlp_fc2_weight.{expert_idx}" : mlp_l1_weight[ep_rank][tp_rank][expert_idx] for expert_idx in range(num_local_experts) }) else: params_dict.update({ "mlp_fc1_weight" : mlp_l0_weight[tp_rank], @@ -436,6 +475,10 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1): params_dict.update({ "router_weight": router }) + if hasattr(md, "moe_shared_experts_gate") and md.moe_shared_experts_gate: + params_dict.update({ + "shared_mlp_gate_weight": shared_experts_gate + }) model = get_local_model(pp_rank, ep_rank, tp_rank) schema.set_layer(model, layer_id, params_dict) diff --git a/tools/checkpoint/saver_qwen2_hf.py b/tools/checkpoint/saver_qwen2_hf.py new file mode 100644 index 0000000000..fccbe86322 --- /dev/null +++ b/tools/checkpoint/saver_qwen2_hf.py @@ -0,0 +1,333 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import argparse +import json +import pprint +import re +from pathlib import Path +from typing import Optional + +import torch +from safetensors.torch import save_file as safe_save_file + +# The regex to extract layer names. +LAYER_RE = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + +LOCAL_HF_FILES_PATH = Path(__file__).absolute().parent.parent / "local-hf-files" + +try: + from transformers.modeling_utils import shard_checkpoint + + + def save_state_dict(path, state_dict, max_shard_size): + shards_dict, shards_index = shard_checkpoint( + state_dict, max_shard_size, weights_name='model.safetensors') + + # Save index. + if shards_index: + # Only save non-empty shards index. + safe_index_filename = path / "model.safetensors.index.json" + with open(safe_index_filename, 'w', encoding='utf-8') as f_safe_index: + content = json.dumps(shards_index, indent=2, sort_keys=True) + "\n" + f_safe_index.write(content) + + # Save shards. + for shard_file, shard in shards_dict.items(): + shard_filename = path / shard_file + print(f'Saving to shard checkpoint {shard_filename} ...') + safe_save_file(shard, shard_filename, metadata={"format": "pt"}) + +except ImportError: + print('WARNING: Cannot import `transformers.modeling_utils.shard_checkpoint`, ' + 'use `huggingface_hub.split_torch_state_dict_into_shards` instead.') + + from huggingface_hub import split_torch_state_dict_into_shards + + + def save_state_dict(path, state_dict, max_shard_size): + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern="model{suffix}.safetensors", max_shard_size=max_shard_size) + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + safe_save_file( + shard, + path / filename, + metadata={"format": "pt"}, + ) + if state_dict_split.is_sharded: + shards_index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Only save non-empty shards index. + safe_index_filename = path / "model.safetensors.index.json" + with open(safe_index_filename, 'w', encoding='utf-8') as f_safe_index: + content = json.dumps(shards_index, indent=2, sort_keys=True) + "\n" + f_safe_index.write(content) + + +def save_hf_checkpoint( + path: Path, + state_dict: dict, + max_shard_size: Optional[str], +): + path.mkdir(exist_ok=True, parents=True) + + if max_shard_size is None: + ckpt_filename = path / 'pytorch_model.bin' + print(f'Saving to no-shard checkpoint ...') + torch.save(state_dict, ckpt_filename) + else: + save_state_dict(path, state_dict, max_shard_size) + print(f'Successful saved checkpoint to {path}') + + +def add_arguments(parser): + group = parser.add_argument_group(title='Hf qwen saver') + parser.add_argument( + "--shard", + type=str, + default=None, + help='Sharded size of converted HF checkpoint, e.g. "2GB", "8GB", default to None (no shards)', + ) + return group + + +def split_qkv( + param, num_heads, hidden_size, num_key_value_heads +): + input_shape = param.size() + channels = hidden_size // num_heads + saved_shape = [num_key_value_heads, (num_heads // num_key_value_heads + 2) * channels] + list(input_shape[1:]) + qkv_weight = param.view(*saved_shape) + query, key, value = qkv_weight.split( + [num_heads // num_key_value_heads * channels, channels, channels], dim=1) + + query, key, value = query.contiguous().view([-1] + list(input_shape[1:])), \ + key.contiguous().view([-1] + list(input_shape[1:])), \ + value.contiguous().view([-1] + list(input_shape[1:])) + + return query, key, value + + +def construct_qwen2moe_config( + megatron_cfg: argparse.Namespace, + num_query_groups: int = None, +): + assert getattr(megatron_cfg, 'num_experts', 1) > 1, 'Not a MoE model' + + try: + from transformers.models.qwen2_moe import Qwen2MoeConfig, Qwen2MoeForCausalLM + except ImportError: + raise('Cannot import Qwen2MoeForCausalLM from transformers.') + + print("Converting from megatron to qwen2-moe ...") + + if megatron_cfg.moe_shared_expert_intermediate_size is not None: + moe_shared_expert_intermediate_size = megatron_cfg.moe_shared_expert_intermediate_size + else: + moe_shared_expert_intermediate_size = 0 + + # Spell out all parameters. + qwen2_moe_cfg = Qwen2MoeConfig( + bos_token_id=151643, + eos_token_id=151643, + hidden_size=megatron_cfg.hidden_size, + intermediate_size=megatron_cfg.ffn_hidden_size * megatron_cfg.moe_router_topk, + max_position_embeddings=megatron_cfg.max_position_embeddings, + num_attention_heads=megatron_cfg.num_attention_heads, + num_key_value_heads=megatron_cfg.num_attention_heads, + num_hidden_layers=getattr(megatron_cfg, "num_layers_without_padding", megatron_cfg.num_layers), + rms_norm_eps=megatron_cfg.norm_epsilon, + rope_theta=megatron_cfg.rotary_base, + torch_dtype="bfloat16", + vocab_size=megatron_cfg.padded_vocab_size, + tie_word_embeddings=not megatron_cfg.untie_embeddings_and_output_weights, + decoder_sparse_step=1, + moe_intermediate_size=megatron_cfg.ffn_hidden_size, + shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + num_experts_per_tok=megatron_cfg.moe_router_topk, + num_experts=megatron_cfg.num_experts, + output_router_logits=False, + router_aux_loss_coef=0.001 + ) + if num_query_groups is not None: + qwen2_moe_cfg.num_key_value_heads = num_query_groups + + if getattr(megatron_cfg, 'group_query_attention', False): + # Set from megatron config. + qwen2_moe_cfg.num_key_value_heads = megatron_cfg.num_query_groups + + qwen2_moe_cfg.architectures = ["Qwen2MoeForCausalLM"] + print('Qwen2-MoE config:', qwen2_moe_cfg) + return qwen2_moe_cfg + +def construct_qwen2_config( + megatron_cfg: argparse.Namespace, + num_query_groups: int = None, +): + try: + from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM + except ImportError as e: + print('Cannot import Qwen2Model, please check your transformers install.') + exit(1) + + print("Converting from megatron to qwen2 ...") + + config_dict = dict( + bos_token_id=151643, + eos_token_id=151643, + hidden_size=megatron_cfg.hidden_size, + intermediate_size=megatron_cfg.ffn_hidden_size, + max_position_embeddings=megatron_cfg.max_position_embeddings, + num_attention_heads=megatron_cfg.num_attention_heads, + num_key_value_heads=megatron_cfg.num_attention_heads, + num_hidden_layers=megatron_cfg.num_layers, + rms_norm_eps=megatron_cfg.norm_epsilon, + rope_theta=megatron_cfg.rotary_base, + torch_dtype='bfloat16', + vocab_size=megatron_cfg.padded_vocab_size, + tie_word_embeddings=not megatron_cfg.untie_embeddings_and_output_weights, + ) + qwen2_cfg = Qwen2Config(**config_dict) + + if num_query_groups is not None: + qwen2_cfg.num_key_value_heads = num_query_groups + if getattr(megatron_cfg, 'group_query_attention', False): + # Set from megatron config. + qwen2_cfg.num_key_value_heads = megatron_cfg.num_query_groups + + qwen2_cfg.architectures = ["Qwen2ForCausalLM"] + print('Qwen2 config:', qwen2_cfg) + return qwen2_cfg + +def set_dense_mlp(qwen2_hf, prefix, msg): + mlp_l0_weight_W = msg.pop("mlp l0 weight W") + mlp_l0_weight_V = msg.pop("mlp l0 weight V") + mlp_l1_weight = msg.pop("mlp l1 weight") + qwen2_hf[f"{prefix}.mlp.gate_proj.weight"] = mlp_l0_weight_W + qwen2_hf[f"{prefix}.mlp.up_proj.weight"] = mlp_l0_weight_V + qwen2_hf[f"{prefix}.mlp.down_proj.weight"] = mlp_l1_weight + + +def set_moe_mlp(qwen2_hf, prefix, msg, md): + shared_expert_mlp_l0_weight_W = msg.pop("shared mlp l0 weight W") + shared_expert_mlp_l0_weight_V = msg.pop("shared mlp l0 weight V") + shared_expert_mlp_l1_weight = msg.pop("shared mlp l1 weight") + shared_expert_gate_weight = msg.pop("shared gate weight") + qwen2_hf[f'{prefix}.mlp.shared_expert_gate.weight'] = shared_expert_gate_weight + qwen2_hf[f'{prefix}.mlp.shared_expert.gate_proj.weight'] = shared_expert_mlp_l0_weight_W + qwen2_hf[f'{prefix}.mlp.shared_expert.up_proj.weight'] = shared_expert_mlp_l0_weight_V + qwen2_hf[f'{prefix}.mlp.shared_expert.down_proj.weight'] = shared_expert_mlp_l1_weight + + router_weight = msg.pop("router weight") + qwen2_hf[f'{prefix}.mlp.gate.weight'] = router_weight + + mlp_l0_weight_W = msg.pop("mlp l0 weight W") + mlp_l0_weight_V = msg.pop("mlp l0 weight V") + mlp_l1_weight = msg.pop("mlp l1 weight") + + assert len(mlp_l0_weight_W) == md.num_experts + for expert_idx in range(md.num_experts): + qwen2_hf[prefix + f".mlp.experts.{expert_idx}.gate_proj.weight"] = mlp_l0_weight_W[expert_idx] + qwen2_hf[prefix + f".mlp.experts.{expert_idx}.up_proj.weight"] = mlp_l0_weight_V[expert_idx] + qwen2_hf[prefix + f".mlp.experts.{expert_idx}.down_proj.weight"] = mlp_l1_weight[expert_idx] + + +def save_checkpoint(queue, args): + import os + import sys + sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir, + os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + def queue_get(name=None): + val = queue.get() + if val == "exit": + print("Loader exited, exiting saver") + exit(1) + if name is not None and args.checking and val["name"] != name: + val_name = val["name"] + print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.') + exit(1) + if name is not None: + print(f"received {name}") + return val + + md = queue_get() + qwen2_hf = {} + # Embeddings + # ----------- + embeddings_msg = queue_get("embeddings") + # Truncate the embedding table to vocab_size rows. + word_embeddings = embeddings_msg['word embeddings'][: md.padded_vocab_size, :] + qwen2_hf["model.embed_tokens.weight"] = word_embeddings + + # Transformer layers. + # ------------------ + total_layer_num = 0 + for layer_idx in range(md.num_layers): + layer_name = f"model.layers.{layer_idx}" + + msg = queue_get(f"transformer layer {total_layer_num}") + + input_norm_weight = msg.pop("input norm weight") + post_norm_weight = msg.pop("post norm weight") + qwen2_hf[layer_name + ".input_layernorm.weight"] = input_norm_weight + qwen2_hf[layer_name + ".post_attention_layernorm.weight"] = post_norm_weight + + # attention + qkv_weight = msg.pop("qkv weight") + dense_weight = msg.pop("dense weight") + + hidden_size = md.hidden_size + heads = md.num_attention_heads + num_key_value_heads = md.num_query_groups + q, k, v = split_qkv( + qkv_weight, heads, hidden_size, num_key_value_heads + ) + qwen2_hf[layer_name + ".self_attn.q_proj.weight"] = q + qwen2_hf[layer_name + ".self_attn.k_proj.weight"] = k + qwen2_hf[layer_name + ".self_attn.v_proj.weight"] = v + # Transpose the bias. + if md.qkv_bias: + qkv_bias = msg.pop("qkv bias") + q_b, k_b, v_b = split_qkv( + qkv_bias, heads, hidden_size, num_key_value_heads + ) + qwen2_hf[layer_name + ".self_attn.q_proj.bias"] = q_b + qwen2_hf[layer_name + ".self_attn.k_proj.bias"] = k_b + qwen2_hf[layer_name + ".self_attn.v_proj.bias"] = v_b + qwen2_hf[layer_name + ".self_attn.o_proj.weight"] = dense_weight + + # mlp + if md.num_experts: + set_moe_mlp(qwen2_hf, layer_name, msg, md) + else: + set_dense_mlp(qwen2_hf, layer_name, msg) + + total_layer_num = total_layer_num + 1 + + msg = queue_get("final norm") + final_norm_weight = msg.pop("weight") + qwen2_hf["model.norm.weight"] = final_norm_weight + + if md.output_layer: + msg = queue_get("output layer") + # LM head + if md.untie_embeddings_and_output_weights: + qwen2_hf["lm_head.weight"] = msg.pop("weight") + else: + qwen2_hf["lm_head.weight"] = word_embeddings + + if md.num_experts: + qwen2_cfg = construct_qwen2moe_config(md, num_query_groups=md.num_query_groups) + else: + qwen2_cfg = construct_qwen2_config(md, num_query_groups=md.num_query_groups) + save_hf_checkpoint(Path(args.save_dir), qwen2_hf, args.shard) + qwen2_cfg.save_pretrained(Path(args.save_dir)) + print("Done!") diff --git a/tools/checkpoint/schema_mcore.py b/tools/checkpoint/schema_mcore.py index ef90ff0aa3..bcd7b5b47a 100644 --- a/tools/checkpoint/schema_mcore.py +++ b/tools/checkpoint/schema_mcore.py @@ -123,6 +123,9 @@ def __init__(self, model_type, num_experts, expert_model_parallel_size): **{f"mlp_fc1_weight.{expert_idx}" : f"mlp.experts.local_experts.{expert_idx}.linear_fc1.weight" for expert_idx in range(num_local_experts) }, **{f"mlp_fc2_weight.{expert_idx}" : f"mlp.experts.local_experts.{expert_idx}.linear_fc2.weight" for expert_idx in range(num_local_experts) }, + "shared_mlp_fc1_weight": "mlp.shared_experts.linear_fc1.weight", + "shared_mlp_fc2_weight": "mlp.shared_experts.linear_fc2.weight", + "shared_mlp_gate_weight": "mlp.shared_experts.gate_weight", })