Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLama mlp project layers missmatch with HF config during conversion #1319

Closed
Vmjkom opened this issue Nov 6, 2024 · 2 comments
Closed

LLama mlp project layers missmatch with HF config during conversion #1319

Vmjkom opened this issue Nov 6, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@Vmjkom
Copy link

Vmjkom commented Nov 6, 2024

Describe the bug
When i try to convert a neox trained LLAMA model (config below) with convert_neox_to_hf.py i get the error showcased in the screenshot.
So in my view, during training, the dimension of the mlp layers don't get configured correctly. I hadn't come across this issue at least before #1212.

To Reproduce
Train a model with the provided config and try to convert it to Huggingface format.

Proposed solution
I would look at #1276 and #1212 for possible issues regarding LLAMA and mlp which could let to the forementioned problem.
One could also revert back to the LLAMAParallelMLP class and mlp_type: "llama" parameter combination from before.

Screenshots
image

Environment (please complete the following information):

  • GPUs: 2x8 MI250X (amd)
  • Configs:
    Libraries:
    deepspeed @ git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742 flash-attn @ file:///opt/wheels/flash_attn-2.0.4-cp310-cp310-linux_x86_64.whl#sha256=0dc568c7b3516cc3f45f33858fe5ef048e5b7a82ba56c89189d5f6a97f4574f2 ftfy==6.2.3 lion-pytorch==0.1.4 lm-dataformat @ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836 lm_eval==0.4.1 mpi4py @ file:///opt/wheels/mpi4py-3.1.4-cp310-cp310-linux_x86_64.whl#sha256=6e012d8c61c0a0d8d6e93b4d98ba6946bb5a5c3d8280d1e0db93862ec19025c2 numpy==1.26.3 pybind11==2.13.6 pytorch-triton-rocm==2.2.0 regex==2024.5.15 sentencepiece==0.2.0 six==1.16.0 tiktoken==0.7.0 tokenizers==0.15.2 torch==2.2.2+rocm5.6 torchaudio==2.2.2+rocm5.6 torchdata==0.7.1 torchtext==0.17.2+cpu torchvision==0.17.2+rocm5.6 transformers==4.38.0 Python 3.10.13
{
  # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
  # across the node boundaries )
  "pipe_parallel_size": 0,
  "model_parallel_size": 1,

  "seed": 42,

  #Tokenizer
  "make_vocab_size_divisible_by": 1,
  "tokenizer_type": "GPT2BPETokenizer",
  "data_path": "/scratch/project_462000353/jburdge/data/fineweb-edu-100B/tokenized/gpt2_text_document",
  "vocab_file": "/scratch/project_462000353/tokenizers/gpt2/vocab.json",
  "merge_file": "/scratch/project_462000353/tokenizers/gpt2/merges.txt",

  # model settings
  "num_layers": 24,
  "hidden_size": 2048,
  "num_attention_heads": 32,
  "seq_length": 2048,
  "max_position_embeddings": 2048,
  "norm": "rmsnorm",
  "rms_norm_epsilon": 1.0e-05,
  "pos_emb": "rotary",
  "intermediate_size": 8192,
  "no_weight_tying": true,
  "gpt_j_residual": false,
  "output_layer_parallelism": "column",
  "num_kv_heads": 32,

  "scaled_upper_triang_masked_softmax_fusion": false,
  "bias_gelu_fusion": false,
  "use_bias_in_norms": false,
  "use_bias_in_attn_linear": false,
  "activation": "swiglu",
  "use_flashattn_swiglu": true,
  "mlp_multiple_of": 1,
  "use_bias_in_mlp": false,

  #flash_attention - value = num_layers
  "attention_config": [[["flash"], 24]],

  # init methods
  "init_method": "small_init",
  "init_method_std": 0.02,
  "output_layer_init_method": "wang_init",

  # optimizer settings
  "optimizer":
    {
      "type": "Adam",
      "params": { "lr": 3.0e-4, "betas": [0.9, 0.95], "eps": 1.0e-8 },
    },
  "min_lr": 3.0e-5,

  # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
  "zero_optimization":
    {
      "stage": 0,
      "allgather_partitions": True,
      "allgather_bucket_size": 50000000,
      "overlap_comm": True,
      "reduce_scatter": false,
      "reduce_bucket_size": 50000000,
      "contiguous_gradients": True,
    },

  # batch / data settings
  "train_micro_batch_size_per_gpu": 32,
  "gradient_accumulation_steps": 2,
  "data_impl": "mmap",

  # activation checkpointing
  "checkpoint_activations": true,
  "checkpoint_num_layers": 1,
  "partition_activations": false,
  "synchronize_each_layer": false,

  # regularization
  "gradient_clipping": 1.0,
  "weight_decay": 0.1,
  "hidden_dropout": 0.0,
  "attention_dropout": 0.0,

  # precision settings
  "precision": "bfloat16",
  "fp32_allreduce": true,

  # misc. training settings
  "train_iters": 10,
  "lr_decay_iters": 10,
  "distributed_backend": "nccl",
  "lr_decay_style": "cosine",
  "warmup": 0.01,

  #Evaluation
  "eval_interval": 10,
  "eval_iters": 5,

  #Dataloader workers
  "num_workers": 2,

  #Checkpoints
  "checkpoint_factor": 10,
  "keep_last_n_checkpoints": 1,
  "save": "/scratch/project_462000353/villekom/checkpoints/neox/debug/",
  #"load": "/scratch/project_462000353/villekom/checkpoints/neox/debug/",

  # logging
  "log_interval": 1,
  "steps_per_print": 1,
  "tensorboard_dir": "logs/tb/",
  "log_grad_pct_zeros": True,
  "log_grad_norm": True,
  "log_gradient_noise_scale": False, #Gradient Noise Scale logging does not work with zero stage 2+, as the gradients are distributed across ranks.

  #Deepspeed misc
  "wall_clock_breakdown": true,
  "tensorboard": { "enabled": false, "output_path": "logs/tb/" },
  "comms_logger":
    { "enabled": false, "verbose": false, "prof_all": true, "debug": False },
}

Additional context
Add any other context about the problem here.

@Vmjkom Vmjkom added the bug Something isn't working label Nov 6, 2024
@tiandeyu-cs
Copy link
Contributor

I also encountered this issue in the llama-type MLP, and I had to set the 'intermediate_size' to three times the intended value to deal with it.
I made a pull request (#1309) which fixed the llama configuations in the 'example' directories. I hope this helps.

@Quentin-Anthony
Copy link
Member

This should be resolved now that #1309 is merged. Reopen if this isn't the case for you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants