From eef9f83d9b1de3f2a7b6c3ae8ce4b212d9d95466 Mon Sep 17 00:00:00 2001 From: Tianyu Zhao Date: Mon, 22 Jan 2024 14:54:08 +0900 Subject: [PATCH] Added qwen adaptation code, conversion code, and scripts. --- .../convert_hf_checkpoint_to_nemo_qwen.py | 260 ++++++++++++++ .../convert_nemo_checkpoint_to_hf_qwen.py | 222 ++++++++++++ .../conf/megatron_llama_70b_config.yaml | 1 + .../conf/megatron_llama_config.yaml | 1 + .../conf/megatron_qwen_config.yaml | 329 ++++++++++++++++++ .../nlp/language_modeling/qwen_14b.sh | 28 ++ .../nlp/language_modeling/test_qwen.sh | 72 ++++ .../tokenizers/huggingface/auto_tokenizer.py | 10 +- .../language_modeling/megatron/gpt_model.py | 2 + .../language_modeling/megatron_gpt_model.py | 1 + .../modules/common/megatron/language_model.py | 4 + .../modules/common/megatron/transformer.py | 14 +- nemo/requirements/requirements.txt | 2 + 13 files changed, 941 insertions(+), 5 deletions(-) create mode 100644 nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_qwen.py create mode 100644 nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_qwen.py create mode 100644 nemo/examples/nlp/language_modeling/conf/megatron_qwen_config.yaml create mode 100755 nemo/examples/nlp/language_modeling/qwen_14b.sh create mode 100755 nemo/examples/nlp/language_modeling/test_qwen.sh diff --git a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_qwen.py b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_qwen.py new file mode 100644 index 0000000..262017f --- /dev/null +++ b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_hf_checkpoint_to_nemo_qwen.py @@ -0,0 +1,260 @@ +import argparse +import json +from pathlib import Path +import numpy as np +import torch +import os +import glob +from transformers import AutoModelForCausalLM + + +def fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] + # for compatibility with later versions of NVIDIA Megatron-LM. + # The inverse operation is performed inside Megatron-LM to read checkpoints: + # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 + # If param is the weight tensor of the self-attention block, the returned tensor + # will have to be transposed one more time to be read by HuggingFace GPT2. + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def convert_checkpoint(args): + + with open(args.config_file, "r") as f: + config = json.load(f) + print(config) + br_key = "h." # Used to filter all transformer layers except layernorm + + translation = { + "model.language_model.embedding.word_embeddings.weight": (1, "transformer.wte.weight", 0, 0), # a['model']['language_model']['word_embeddings']['weight'] + "input_layernorm.weight": (0, "ln_1.weight", None, 0), + "self_attention.query_key_value.weight": (1, "attn.c_attn.weight", 0, 0), + "self_attention.query_key_value.bias": (1, "attn.c_attn.bias", 0, 0), + "self_attention.dense.weight": (1, "attn.c_proj.weight", 1, 0), + "post_attention_layernorm.weight": (0, "ln_2.weight", None, 0), + "self_attention.core_attention.rotary_emb.inv_freq": (0, "rotary_emb.inv_freq", None, 0), + "mlp.dense_h_to_4h.weight": (1, "mlp.w2.weight", 0, 0), + "mlp.dense_h_to_4h_2.weight": (1, "mlp.w1.weight", 0, 0), + "mlp.dense_4h_to_h.weight": (1, "mlp.c_proj.weight", 1, 0), + "model.language_model.encoder.final_layernorm.weight": (0, "transformer.ln_f.weight", None, 0), + "model.language_model.output_layer.weight": (1, "lm_head.weight", 0, 0), # this is shared + } + + reverse_translation = {} + for k, v in translation.items(): + split, br_k, dim, transpose = v + reverse_translation[br_k] = (split, k, dim, transpose) + + TP = args.tp_degree + PP = args.pp_degree + + hf_model = AutoModelForCausalLM.from_pretrained(args.path_to_checkpoint, trust_remote_code=True) + # hf_model.resize_token_embeddings(pad_to_multiple_of=128) + model_bedrock = hf_model.state_dict() + + for i in range(config["num_hidden_layers"]): + model_bedrock[f"transformer.h.{i}.rotary_emb.inv_freq"] = hf_model.transformer.rotary_emb.inv_freq + + print(list(model_bedrock.keys())) + + print("Loaded QWen model") + + + for p in range(PP): + for i in range(TP): + print(f"=== PP {p}, TP {i} ===") + nemo_model = {} + for k, v in model_bedrock.items(): + # print(f">>> {k}") + if "attention.masked_bias" in k: + # We don't want to copy attention mask bias, since its a constant of 1e4 + continue + if br_key in k: + parts = k.split(br_key)[1].split(".") + layer_number = parts[0] + if int(layer_number) >= (config["num_hidden_layers"]//PP)*(p+1) or int(layer_number) < (config["num_hidden_layers"]//PP)*p: + continue + k = ".".join(parts[1:]) + split, key, dim, tranpose = reverse_translation[k] + layer_number = layer_number if PP == 1 else str(int(layer_number) % (config["num_hidden_layers"]//PP)) + key = "model.language_model.encoder.layers." + layer_number + "." + key + nemo_model[key] = v + if tranpose: + nemo_model[key]= torch.transpose( + nemo_model[key], 0, 1 + ) + + if "query_key_value" in (key): + heads = config["num_attention_heads"] + hidden_size = config["hidden_size"] + hidden_size_per_head = config["hidden_size"] // heads + + def permute_rotary(w): + assert w.shape == (heads, hidden_size_per_head, hidden_size*3) + return ( + w.view(heads, hidden_size_per_head // 2, 2, hidden_size*3) + .transpose(1, 2) + .reshape(heads, hidden_size_per_head, hidden_size*3) + ) + + def permute(w, n_heads=heads, dim1=hidden_size, dim2=hidden_size*3): + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + if "weight" in key: + nemo_model[key] = permute_rotary( + permute(nemo_model[key]).view( + heads, hidden_size_per_head, hidden_size*3 + ) + ) + nemo_model[key] = nemo_model[key].view( + 3, + heads, + hidden_size_per_head, + hidden_size, + ).transpose(0, 1).contiguous().view( + heads * 3 * hidden_size_per_head, + hidden_size, + ) + nemo_model[key] = nemo_model[key].view( + TP, + heads * 3 * hidden_size_per_head // TP, + hidden_size, + ) + elif "bias" in key: + nemo_model[key] = nemo_model[key].view( + 3, + heads, + hidden_size_per_head, + ).transpose(0, 1).contiguous().view( + heads * 3 * hidden_size_per_head + ) + nemo_model[key] = nemo_model[key].view( + TP, + heads * 3 * hidden_size_per_head // TP, + ) + + if split: + if "query_key_value" in (key): + nemo_model[key] = nemo_model[key][i] + else: + tp_last_dim_size = nemo_model[key].shape[dim] // TP + if dim: # First or last dimension to shard + nemo_model[key] = nemo_model[key][ + ..., i * tp_last_dim_size : (i + 1) * tp_last_dim_size + ].clone() + else: + nemo_model[key] = nemo_model[key][ + i * tp_last_dim_size : (i + 1) * tp_last_dim_size, ... + ].clone() + + print(key, split, nemo_model[key].shape, v.shape) + else: + split, key, dim, transpose = reverse_translation[k] + if "wte" in k and p==0: + # Padding to make it divisble by TP degree + if v.shape[0] % TP > 0: + x = torch.nn.functional.pad( + v, (0, 0, 0, (TP - v.shape[0] % TP)) + ) + else: + x = v + last_dim_size = x.shape[0] + tp_last_dim_size = last_dim_size // TP + nemo_model[key] = x[ + i * tp_last_dim_size : (i + 1) * tp_last_dim_size, ... + ].clone() + print(key, split, nemo_model[key].shape, v.shape) + elif "transformer.ln_f" in k and p == (PP-1): + nemo_model[key] = v + print(key, split, nemo_model[key].shape, v.shape) + elif "lm_head" in k and p == (PP-1): + # Padding to make it divisble by TP degree + if v.shape[0] % TP > 0: + x = torch.nn.functional.pad( + v, (0, 0, 0, (TP - v.shape[0] % TP)) + ) + else: + x = v + if split: + tp_last_dim_size = x.shape[dim]//TP + if dim: + nemo_model[key] = x[..., i*tp_last_dim_size:(i+1)*tp_last_dim_size].clone() + else: + nemo_model[key] = x[i*tp_last_dim_size:(i+1)*tp_last_dim_size, ...].clone() + print(key, split, nemo_model[key].shape, v.shape) + + if args.save_bf16: + for _k in nemo_model: + nemo_model[_k] = nemo_model[_k].to(dtype=torch.bfloat16, device='cpu') + out_model = {"state_dict": nemo_model} + + output_folder = args.output_path + if TP > 1: + if PP>1: + output_folder = output_folder + f"/tp_rank_{i:02d}" + else: + output_folder = output_folder + f"/mp_rank_{i:02d}" + if PP > 1: + output_folder = output_folder + f"_pp_rank_{p:03d}" + if not os.path.exists(output_folder): + os.makedirs(output_folder) + torch.save(out_model, f"{output_folder}/model_optim_rng.ckpt") #, (not master_only), global_master=True) + + print("Done saving Megatron checkpoint") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_version", default=2.0) + parser.add_argument( + "--path_to_checkpoint", + type=str, + help="Path to the checkpoint file (.zip archive or direct .pt file)", + ) + parser.add_argument( + "--config_file", + default="", + type=str, + help="An optional config json file describing the pre-trained model.", + ) + parser.add_argument( + "--output_path", + default="", + type=str, + help="output path", + ) + parser.add_argument( + "--tp_degree", + default=1, + type=int, + help="Tensor parallelism", + ) + parser.add_argument( + "--pp_degree", + default=1, + type=int, + help="Pipeline parallelism", + ) + parser.add_argument( + "--save_bf16", + default=False, + type=bool, + help="Save weights in bf16.", + ) + args = parser.parse_args() + convert_checkpoint(args) diff --git a/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_qwen.py b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_qwen.py new file mode 100644 index 0000000..e68472e --- /dev/null +++ b/nemo/examples/nlp/language_modeling/checkpoint_conversion/convert_nemo_checkpoint_to_hf_qwen.py @@ -0,0 +1,222 @@ +import os +import argparse +import json +from pathlib import Path, PurePath +from os.path import join +from glob import glob +import re +import numpy as np +import torch +import torch_xla.utils.serialization as xser + + +def fix_query_key_value_ordering(param, num_heads, hidden_size_per_head): + input_shape = param.size() + saved_shape = (num_heads, 3, hidden_size_per_head) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def get_tp_pp_degree(path_to_checkpoints): + + dir_name = PurePath(path_to_checkpoints).name + TP = 1 + PP = 1 + for folder in os.listdir(path_to_checkpoints): + pp_search = re.search('pp_rank_[\d]*', folder) + if pp_search: + PP = max(PP, 1+int(pp_search[0].split('pp_rank_')[1])) + if PP>1: + tp_search = re.search('tp_rank_[\d]*', folder) + if tp_search: + TP = max(TP, 1+int(tp_search[0].split('tp_rank_')[1])) + else: + tp_search = re.search('mp_rank_[\d]*', folder) + if tp_search: + TP = max(TP, 1+int(tp_search[0].split('mp_rank_')[1])) + + return TP, PP + + +def _get_tp_str(tp: int): + tp_template = '00' + tp = str(tp) + leading_zeros = len(tp_template) - len(tp) + return ''.join(['0'] * leading_zeros + list(tp)) + + +def _get_pp_str(pp: int): + pp_template = '000' + pp = str(pp) + leading_zeros = len(pp_template) - len(pp) + return ''.join(['0'] * leading_zeros + list(pp)) + + +def get_checkpoints_for_pp(pp: int, path_to_checkpoints: str, PP: int=1, TP: int=1, is_xser: bool=False): + """ + Returns all checkpoints for specified PP rank + """ + if PP == 1 and TP == 1: + pp_str = "" + else: + pp_str = f'tp_rank_*_pp_rank_{_get_pp_str(pp)}' if PP > 1 else "mp_rank_*" + + template = join(path_to_checkpoints, pp_str, '*.ckpt') + + # take largest step saved model from the available checkpoints + max_step_recorded=max( + [int(re.match(r".*megatron_llama--step=(\d+).*ckpt$", i).group(1)) + for i in glob(template)]) + template = join(path_to_checkpoints, pp_str, f'*megatron_llama--step={max_step_recorded}*.ckpt') + + tp_paths = sorted(glob(template)) + return {i: xser.load(p)['state_dict'] if is_xser else torch.load(p)['state_dict'] for i, p in enumerate(tp_paths)} + + +def get_checkpoints_for_tp(tp: int, path_to_checkpoints: str, is_xser: bool=False): + """ + Returns all checkpoints for specified TP rank + """ + tp_str = _get_tp_str(tp) + template = join(path_to_checkpoints, f'tp_rank_{tp_str}_pp_rank_*', '*.ckpt') + + pp_paths = sorted(glob(template)) + return {i: xser.load(p)['state_dict'] if is_xser else torch.load(p)['state_dict'] for i, p in enumerate(pp_paths)} + + +def _get_nemo_key(k, nemo_key = 'model.language_model.'): + if "final_layernorm" in k: + nemo_key += 'encoder.' + return k.replace(nemo_key, '') + + +def convert_checkpoint(config_file, + path_to_checkpoints, + output_path, + checkpoint_version=2.0, + is_xser=False, + bf16=False): + + with open(config_file, "r") as f: + config = json.load(f) + + translation = { + "embedding.word_embeddings.weight": (1, "transformer.wte.weight", 0, 0), # a['model']['language_model']['word_embeddings']['weight'] + "input_layernorm.weight": (0, "ln_1.weight", None, 0), + "self_attention.query_key_value.weight": (1, "attn.c_attn.weight", 0, 0), + "self_attention.query_key_value.bias": (1, "attn.c_attn.bias", 0, 0), + "self_attention.dense.weight": (1, "attn.c_proj.weight", 1, 0), + "post_attention_layernorm.weight": (0, "ln_2.weight", None, 0), + "self_attention.core_attention.rotary_emb.inv_freq": (0, "rotary_emb.inv_freq", None, 0), + "mlp.dense_h_to_4h.weight": (1, "mlp.w2.weight", 0, 0), + "mlp.dense_h_to_4h_2.weight": (1, "mlp.w1.weight", 0, 0), + "mlp.dense_4h_to_h.weight": (1, "mlp.c_proj.weight", 1, 0), + "final_layernorm.weight": (0, "transformer.ln_f.weight", None, 0), + "output_layer.weight": (1, "lm_head.weight", 0, 0), # this is shared + } + + nemo_key = "model.language_model." + br_key = "transformer.h." + + TP, PP = get_tp_pp_degree(path_to_checkpoints) + print(f"TP: {TP}, PP: {PP}") + + heads = config["num_attention_heads"] + hidden_size_per_head = config["hidden_size"] // heads + + hf_model = {} + + layer_re = re.compile("model.language_model.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z_]+)") + + for pp in range(PP): + print(f"Loading PP={pp}") + tp_models = get_checkpoints_for_pp(pp, path_to_checkpoints, PP, TP, is_xser) + layer_keys = tp_models[0].keys() + for k in layer_keys: + print(f">> {k}") + if "position_embeddings" in k: + nemo_key = _get_nemo_key(k) + _, key, _, _ = translation[nemo_key] + hf_model[key] = tp_models[0][k] + continue + + if "word_embeddings" in k: + nemo_key = _get_nemo_key(k) + split, key, dim, transpose = translation[nemo_key] + hf_model[key] = torch.concat([tp_models[i][k] for i in range(len(tp_models))], dim=0) + continue + + if "output_layer" in k: + nemo_key = _get_nemo_key(k) + split, key, dim, transpose = translation[nemo_key] + hf_model[key] = torch.concat([tp_models[i][k] for i in range(len(tp_models))], dim=dim) + continue + + if "final_layernorm" in k: + nemo_key = _get_nemo_key(k) + split, key, dim, transpose = translation[nemo_key] + hf_model[key] = tp_models[0][k] + continue + + m = layer_re.match(k) + layer_idx = m.group(1) + op_name = m.group(2) + weight_or_bias = m.group(3) + nemo_key = f"{op_name}.{weight_or_bias}" + split, key, dim, transpose = translation[nemo_key] + ln_idx= int(layer_idx) + pp*(config["num_hidden_layers"]//PP) + hf_key = f"{br_key}{ln_idx}.{key}" + if split: + hf_model[hf_key] = torch.concat([tp_models[i][k] for i in range(len(tp_models))], dim=dim) + else: + hf_model[hf_key] = tp_models[0][k] + + if "query_key_value" in k: + hf_model[hf_key] = fix_query_key_value_ordering(hf_model[hf_key], heads, hidden_size_per_head) + + if transpose: + hf_model[hf_key] = torch.transpose(hf_model[hf_key], 0, 1) + + if args.bf16: + for k, v in hf_model.items(): + hf_model[k] = v.to(torch.bfloat16) + + path = Path(output_path) + path.mkdir(parents=True, exist_ok=True) + torch.save(hf_model, str(path)+"/pytorch_model.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_version", default=2.0) + parser.add_argument( + "--path_to_checkpoints", + type=str, + help="Path to the checkpoints from creating NeMo checkpoint files using `convert_hf_checkpoint_to_nemo.py`", + required=True + ) + parser.add_argument( + "--config_file", + type=str, + help="The config json file describing the pre-trained model.", + required=True + ) + parser.add_argument( + "--output_path", + default="", + type=str, + help="output path", + ) + parser.add_argument( + "--is_xser", + default=False, + type=bool + ) + parser.add_argument( + "--bf16", + action="store_true" + ) + args = parser.parse_args() + convert_checkpoint(args.config_file, args.path_to_checkpoints, args.output_path, args.checkpoint_version, args.is_xser, args.bf16) diff --git a/nemo/examples/nlp/language_modeling/conf/megatron_llama_70b_config.yaml b/nemo/examples/nlp/language_modeling/conf/megatron_llama_70b_config.yaml index 461672b..f55becb 100644 --- a/nemo/examples/nlp/language_modeling/conf/megatron_llama_70b_config.yaml +++ b/nemo/examples/nlp/language_modeling/conf/megatron_llama_70b_config.yaml @@ -79,6 +79,7 @@ model: activation: 'swiglu' # ['swiglu', 'gelu'] transformer_block_type: 'pre_ln' # ['pre_ln', 'post_ln', 'normformer', 'gpt_j'] https://github.com/EleutherAI/gpt-neox/blob/303d7be582ae1c969347c25c54f568cc122445fc/megatron/model/transformer.py#L804-L847 has_bias: False + has_bias_qkv: False num_kv_heads: 8 tokenizer: diff --git a/nemo/examples/nlp/language_modeling/conf/megatron_llama_config.yaml b/nemo/examples/nlp/language_modeling/conf/megatron_llama_config.yaml index e57d01f..cef2959 100644 --- a/nemo/examples/nlp/language_modeling/conf/megatron_llama_config.yaml +++ b/nemo/examples/nlp/language_modeling/conf/megatron_llama_config.yaml @@ -78,6 +78,7 @@ model: activation: 'swiglu' # ['swiglu', 'gelu'] transformer_block_type: 'pre_ln' # ['pre_ln', 'post_ln', 'normformer', 'gpt_j'] https://github.com/EleutherAI/gpt-neox/blob/303d7be582ae1c969347c25c54f568cc122445fc/megatron/model/transformer.py#L804-L847 has_bias: False + has_bias_qkv: False tokenizer: library: 'huggingface' diff --git a/nemo/examples/nlp/language_modeling/conf/megatron_qwen_config.yaml b/nemo/examples/nlp/language_modeling/conf/megatron_qwen_config.yaml new file mode 100644 index 0000000..2b4cac1 --- /dev/null +++ b/nemo/examples/nlp/language_modeling/conf/megatron_qwen_config.yaml @@ -0,0 +1,329 @@ +name: megatron_llama +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: tpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + replace_sampler_ddp: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + create_tensorboard_logger: True + explicit_log_dir: null + exp_dir: null + name: megatron_llama + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: step + save_top_k: 1 + mode: max + save_last: False + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_llama--{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +model: + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 4 # limited by GPU memory + global_batch_size: 8 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + # model architecture + encoder_seq_length: 512 + max_position_embeddings: ${.encoder_seq_length} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 3072 # Transformer FFN hidden size. For Llama it's 8/3*hidden_size + num_attention_heads: 12 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0 # Dropout probability for hidden state transformer. + attention_dropout: 0 # Dropout probability in the attention layer. + ffn_dropout: 0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'rmsnorm' # Type of normalization layers ['rmsnorm', 'layernorm'] + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 8 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + share_embeddings_and_output_weights: False # Untie embedding and output layer weights. + position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope] + rotary_percentage: 1 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + activation: 'swiglu' # ['swiglu', 'gelu'] + transformer_block_type: 'pre_ln' # ['pre_ln', 'post_ln', 'normformer', 'gpt_j'] https://github.com/EleutherAI/gpt-neox/blob/303d7be582ae1c969347c25c54f568cc122445fc/megatron/model/transformer.py#L804-L847 + has_bias: False + has_bias_qkv: True + + tokenizer: + library: 'huggingface' + type: '/root/scripts/data/llama7b-hf' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + use_fast: False + + # Mixed precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + # Fusion + grad_div_ar_fusion: False # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: False # Use a kernel that fuses the attention softmax with it's mask. + + + # Miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages + log_parameter_norm: True # Logs parameter norm across model parallel ranks + log_gradient_norm: True # Logs gradient norm across model parallel ranks + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: selective # 'selective' or 'full' + activations_checkpoint_method: uniform # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: 1 + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: True + + ## Zero Redundancy Optimizer + # Wraps your chosen optimizer with a Zero Redundancy Optimizer + # Partitions optimizer states across ranks reducing memory consumption + wrap_with_zero: False + + ## Transformer Engine + transformer_engine: False + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + convert_to_hf: False # convert model to Huggingface format + output_dir: null # output directory to save converted model + config_path: null # path to HF config.json file + + data: + # Path to data must be specified by the user. + # Supports List, String and Dictionary + # List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + # Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test} + # Or see example below: + # "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}" + data_prefix: ??? + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: mmap + splits_string: 980,10,10 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 1 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + validation_drop_last: True # Set to false if the last partial validation samples is to be consumed + no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token + pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size + + # Below configs only used with fine tuning + fine_tuning: False # Set to True to use fine-tuning dataloader instead of pretraining + + chat: False # whether use chatbot data or not + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: # Path to a list of JSONL files corresponding to the source data. + - /path/to/file.jsonl + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 1 + memmap_workers: 1 + pin_memory: False + max_seq_length: ${model.encoder_seq_length} + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + - 1.0 + label_key: 'output' + add_eos: True + add_sep: False + add_bos: False + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + + validation_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + concat_sampling_probabilities: # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + - 1.0 + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 4 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: False + max_seq_length: ${model.encoder_seq_length} + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'rouge', 'token_f1'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 4 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. + truncation_method: 'right' # Truncation from which position, Options: Options: ['left', 'right'] + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [0] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: adamw + lr: 2e-4 + weight_decay: 0.01 + capturable: False + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 + + +enable_recovery_time_instrumentation: False # default to not printing the detailing timing for recovery diff --git a/nemo/examples/nlp/language_modeling/qwen_14b.sh b/nemo/examples/nlp/language_modeling/qwen_14b.sh new file mode 100755 index 0000000..f761eb5 --- /dev/null +++ b/nemo/examples/nlp/language_modeling/qwen_14b.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +export SEQ_LENGTH=4096 +export HS=5120 +export TP=8 +export PP=4 +export N_LAYERS=40 +export N_AH=40 +export FFN_HS=13696 +export GBS=512 +export UBS=1 +export TRAIN_ITERS=50000 + +export VALIDATE_INTERVAL=250 +export SAVE_CHECKPOINT_INTERVAL=1000 + +export INIT_METHOD_STD=0.02 +export LAYERNORM_EPSILON=1e-8 +export WARMUP_STEPS=500 + +export LOAD_CHECKPOINT_FROM='/fsx/qwen-14b-tp8-pp4/tp_rank_07_pp_rank_003/model_optim_rng.ckpt' + +# This helps to build the helpers.cpp required (only once) +# cd /usr/local/lib/python3.8/site-packages/nemo/collections/nlp/data/language_modeling/megatron/ +# make +# cd /root/scripts/nemo +# + +./test_qwen.sh diff --git a/nemo/examples/nlp/language_modeling/test_qwen.sh b/nemo/examples/nlp/language_modeling/test_qwen.sh new file mode 100755 index 0000000..f95eb7c --- /dev/null +++ b/nemo/examples/nlp/language_modeling/test_qwen.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +source ./train_setup.sh + +: ${TOKENIZER_PATH=/fsx/qwen-14b-hf} +: ${DATASET_PATH=$HOME/examples_datasets/qwen/book.jsonl-processed_text_document} + +echo "SEQ_LEN=$SEQ_LENGTH, HS=$HS, FFN_HS=$FFN_HS TP=$TP PP=$PP N_LAYERS=$N_LAYERS N_AH=$N_AH GBS=$GBS UBS=$UBS TRAIN_ITERS=$TRAIN_ITERS" + + +$MAYBE_COMPILE torchrun $DISTRIBUTED_ARGS megatron_gpt_pretraining.py \ + --config-path=conf \ + --config-name=megatron_qwen_config \ + trainer.devices=$PROCESSES_PER_NODE \ + trainer.num_nodes=$NTASKS \ + trainer.max_epochs=null \ + trainer.max_steps=$TRAIN_ITERS\ + trainer.val_check_interval=$VALIDATE_INTERVAL \ + trainer.log_every_n_steps=1 \ + trainer.limit_val_batches=1 \ + trainer.limit_test_batches=1 \ + trainer.accumulate_grad_batches=1 \ + trainer.precision=32 \ + model.megatron_amp_O2=$megatron_amp_O2 \ + model.tokenizer.type=$TOKENIZER_PATH \ + model.micro_batch_size=$UBS \ + model.global_batch_size=$GBS \ + model.tensor_model_parallel_size=$TP \ + model.pipeline_model_parallel_size=$PP \ + model.max_position_embeddings=$SEQ_LENGTH \ + model.encoder_seq_length=$SEQ_LENGTH \ + model.hidden_size=$HS \ + model.ffn_hidden_size=$FFN_HS \ + model.num_layers=$N_LAYERS \ + model.num_attention_heads=$N_AH \ + model.init_method_std=$INIT_METHOD_STD \ + model.hidden_dropout=0 \ + model.layernorm_epsilon=$LAYERNORM_EPSILON \ + model.data.data_prefix=[1.0,$DATASET_PATH] \ + model.data.num_workers=1 \ + model.data.seq_length=$SEQ_LENGTH \ + model.optim.name=$OPTIM_NAME \ + model.optim.lr=3.0e-5 \ + model.optim.betas=[0.9,0.95] \ + model.optim.weight_decay=0.1 \ + model.optim.sched.name=CosineAnnealing \ + model.optim.sched.warmup_steps=$WARMUP_STEPS \ + model.optim.sched.constant_steps=0 \ + model.optim.sched.min_lr=3.0e-6 \ + model.optim.capturable=True \ + model.sequence_parallel=True \ + model.activations_checkpoint_granularity=full \ + model.activations_checkpoint_method=uniform \ + model.activations_checkpoint_num_layers=1 \ + model.make_vocab_size_divisible_by=32\ + +model.save_xser=True\ + exp_manager.create_tensorboard_logger=$CREATE_TB_LOGGER \ + exp_manager.resume_if_exists=True \ + exp_manager.resume_ignore_no_checkpoint=True \ + exp_manager.create_checkpoint_callback=$CHECKPOINT_CALLBACK \ + exp_manager.explicit_log_dir=$EXPLICIT_LOGDIR \ + +exp_manager.checkpoint_callback_params.train_time_interval=$SAVE_CHECKPOINT_INTERVAL \ + exp_manager.checkpoint_callback_params.save_top_k=3 \ + model.use_cpu_initialization=False \ + +model.load_xser=True \ + model.resume_from_checkpoint=$LOAD_CHECKPOINT_FROM \ + 2>&1 | tee -a $LOG_PATH/log + +# Note: to resume training using a checkpoint, please add the following configuration above, adjusting for your checkpoint path +# model.use_cpu_initialization=False \ +# +model.load_xser=True \ +# model.resume_from_checkpoint='/efs/checkpoint/megatron_gpt--step\=1085-consumed_samples\=69632.0-last.ckpt' \ diff --git a/nemo/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index 57e98ab..6cfcac2 100644 --- a/nemo/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -64,11 +64,16 @@ def __init__( # this logic deals with different huggingface tokenizers having different positional args if vocab_file is None: self.tokenizer = AUTOTOKENIZER.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name, use_fast=use_fast, + pretrained_model_name_or_path=pretrained_model_name, + use_fast=use_fast, + trust_remote_code=True, ) elif merges_file is None: self.tokenizer = AUTOTOKENIZER.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name, vocab_file=vocab_file, use_fast=use_fast, + pretrained_model_name_or_path=pretrained_model_name, + vocab_file=vocab_file, + use_fast=use_fast, + trust_remote_code=True, ) else: self.tokenizer = AUTOTOKENIZER.from_pretrained( @@ -76,6 +81,7 @@ def __init__( vocab_file=vocab_file, merges_file=merges_file, use_fast=use_fast, + trust_remote_code=True, ) except Exception as e: raise ValueError( diff --git a/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index 51d07b3..1c1cee7 100644 --- a/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -140,6 +140,7 @@ def __init__( normalization='layernorm', layernorm_epsilon=1e-5, bias=True, + bias_qkv=True, bias_activation_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, @@ -228,6 +229,7 @@ def __init__( rotary_percentage=rotary_percentage, share_embeddings_and_output_weights=share_embeddings_and_output_weights, bias=bias, + bias_qkv=bias_qkv, bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, diff --git a/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f2f0d9b..5542aba 100755 --- a/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -229,6 +229,7 @@ def model_provider_func(self, pre_process, post_process): rotary_percentage=self.cfg.get('rotary_percentage', 1.0), activation=self.cfg.get('activation', 'gelu'), bias=self.cfg.get('has_bias', True), + bias_qkv=self.cfg.get('has_bias_qkv', True), transformer_block_type=self.cfg.get('transformer_block_type','pre_ln'), masked_softmax_fusion=self.cfg.get('masked_softmax_fusion', True), gradient_accumulation_fusion=self.cfg.get('gradient_accumulation_fusion', False), diff --git a/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py index a223c3c..2f585f9 100644 --- a/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -81,6 +81,7 @@ def get_language_model( multi_query_attention=False, bias_dropout_add_fusion=True, bias=True, + bias_qkv=True, gradient_accumulation_fusion=False, persist_layer_norm=False, openai_gelu=False, @@ -148,6 +149,7 @@ def get_language_model( bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, bias=bias, + bias_qkv=bias_qkv, rotary_percentage=rotary_percentage, share_embeddings_and_output_weights=share_embeddings_and_output_weights, masked_softmax_fusion=masked_softmax_fusion, @@ -450,6 +452,7 @@ def __init__( bias_activation_fusion=True, bias_dropout_add_fusion=True, bias=True, + bias_qkv=True, masked_softmax_fusion=True, activation='gelu', headscale=False, @@ -562,6 +565,7 @@ def __init__( openai_gelu=openai_gelu, onnx_safe=onnx_safe, bias=bias, + bias_qkv=bias_qkv, bias_activation_fusion=bias_activation_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, diff --git a/nemo/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/nemo/collections/nlp/modules/common/megatron/transformer.py index fd4f100..18aedab 100644 --- a/nemo/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -710,6 +710,7 @@ def __init__( layer_type=None, megatron_legacy=False, bias=True, + bias_qkv=True, headscale=False, position_embedding_type='learned_absolute', multi_query_attention=False, @@ -765,7 +766,7 @@ def __init__( gather_output=False, init_method=init_method, use_cpu_initialization=use_cpu_initialization, - bias=bias, + bias=bias_qkv, sequence_parallel_enabled=sequence_parallel, no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, @@ -780,7 +781,7 @@ def __init__( projection_size, gather_output=False, init_method=init_method, - bias=bias, + bias=bias_qkv, sequence_parallel_enabled=sequence_parallel, no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, @@ -792,7 +793,7 @@ def __init__( 2 * projection_size, gather_output=False, init_method=init_method, - bias=bias, + bias=bias_qkv, sequence_parallel_enabled=sequence_parallel, no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, @@ -1339,6 +1340,7 @@ def __init__( activation='gelu', megatron_legacy=False, bias=True, + bias_qkv=True, chunk_size=64, normalization='layernorm', transformer_block_type='pre_ln', @@ -1367,6 +1369,7 @@ def __init__( self.layer_type = layer_type self.sequence_parallel = sequence_parallel self.bias = bias + self.bias_qkv = bias_qkv self.transformer_block_type = transformer_block_type self.position_embedding_type = position_embedding_type self.position_interpolation_factor = position_interpolation_factor @@ -1432,6 +1435,7 @@ def __init__( layer_type=layer_type, megatron_legacy=megatron_legacy, bias=bias, + bias_qkv=bias_qkv, headscale=headscale, activations_checkpoint_granularity=activations_checkpoint_granularity, position_embedding_type=position_embedding_type, @@ -1899,6 +1903,7 @@ def __init__( activation='gelu', megatron_legacy=False, bias=True, + bias_qkv=True, chunk_size=64, normalization='layernorm', transformer_block_type='pre_ln', @@ -1943,6 +1948,7 @@ def __init__( activation=activation, megatron_legacy=megatron_legacy, bias=bias, + bias_qkv=bias_qkv, chunk_size=chunk_size, normalization=normalization, transformer_block_type=transformer_block_type, @@ -2157,6 +2163,7 @@ def __init__( model_type=ModelType.encoder_or_decoder, megatron_legacy=False, bias=True, + bias_qkv=True, chunk_size=64, normalization='layernorm', transformer_block_type='pre_ln', @@ -2353,6 +2360,7 @@ def build_layer(layer_number): activation=activation, megatron_legacy=megatron_legacy, bias=bias, + bias_qkv=bias_qkv, chunk_size=chunk_size, normalization=normalization, transformer_block_type=transformer_block_type, diff --git a/nemo/requirements/requirements.txt b/nemo/requirements/requirements.txt index 0a6e3e9..8bb235f 100644 --- a/nemo/requirements/requirements.txt +++ b/nemo/requirements/requirements.txt @@ -8,8 +8,10 @@ scikit-learn setuptools==59.5.0 tensorboard text-unidecode +tiktoken torch tqdm>=4.41.0 +transformers>=4.32.0 wget wrapt pytorch-lightning==1.8.6