Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No Significant Improvement Observed in Model Training Speed #409

Open
lianghsun opened this issue Nov 27, 2024 · 4 comments
Open

No Significant Improvement Observed in Model Training Speed #409

lianghsun opened this issue Nov 27, 2024 · 4 comments

Comments

@lianghsun
Copy link

🐛 Describe the bug

I am training the meta-llama/Llama-3.2-1B model using LLaMA-Factory with the following YAML configuration:

### model
model_name_or_path: meta-llama/Llama-3.2-1B

### method
stage: pt
do_train: true
do_eval: true
finetuning_type: full
deepspeed: /path/to/ds_z3_config.json
use_liger_kernel: true
enable_liger_kernel: true

### dataset
dataset: /path/to/dir/
eval_dataset: /path/to/dir/
template: llama3
cutoff_len: 4000
max_samples: 30000000000
overwrite_cache: true
preprocessing_num_workers: 64
preprocessing_batch_size: 60000
tokenized_path: /path/to/dir/

### output
output_dir: /path/to/dir/
logging_steps: 1
save_steps: 5
plot_loss: true
overwrite_output_dir: true
save_total_limit: 8

### train
per_device_train_batch_size: 94
gradient_accumulation_steps: 32
learning_rate: 5.0e-5
num_train_epochs: 10
lr_scheduler_type: cosine
optim: adamw_torch_fused
warmup_ratio: 0.01
weight_decay: 0.1
bf16: true
ddp_timeout: 1080
ddp_find_unused_parameters: false
max_grad_norm: 1.0
seed: 42
dataloader_num_workers: 64
packing: true
flash_attn: auto

### eval
per_device_eval_batch_size: 4
eval_strategy: steps
eval_steps: 10

However, I have noticed that enabling or disabling liger_kernel does not lead to any noticeable reduction in training time. The runtime metrics remain nearly identical in both cases. Are there specific parameter settings in my YAML configuration that might be preventing liger_kernel from functioning optimally? Thanks :(

Reproduce

  1. Use the YAML configuration provided above.
  2. Train the meta-llama/Llama-3.2-1B model with and without liger_kernel.
llamafactory-cli train /path/to/above/yaml
  1. Compare training times and throughput metrics.

Versions

Environment Report:

Operating System: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
Python version: 3.12.7
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
Triton version: 3.1.0
Transformers version: 4.46.1

@ByronHsu
Copy link
Collaborator

ByronHsu commented Nov 27, 2024

It is likely because the model is too small, it is not fully utilizing the GPU, which causes the effect of liger kernel to be non significant.

  1. Can you compare the memory before and after liger?
  2. After liger, the memory should reduce a lot because of the large vocab size. Due to that, you can disable grad ckpt to avoid forward re-computation. You might have to adjust per_device_train_batch_size and gradient_accumulation_steps a bit to fit into the GPU
  3. I would suggest switching DDP to FSDP so you can have even more memory reduction.

@lianghsun
Copy link
Author

It is likely because the model is too small, it is not fully utilizing the GPU, which causes the effect of liger kernel to be non significant.

  1. Can you compare the memory before and after liger?
  2. After liger, the memory should reduce a lot because of the large vocab size. Due to that, you can disable grad ckpt to avoid forward re-computation. You might have to adjust per_device_train_batch_size and gradient_accumulation_steps a bit to fit into the GPU
  3. I would suggest switching DDP to FSDP so you can have even more memory reduction.

Thanks @ByronHsu for your response.

Regarding the first point, I have configured the dataset in the YAML file to fully utilize the memory of the H100 GPU. However, I observed no significant difference in memory usage with or without Liger; the memory consumption remained unchanged.

I will experiment with the second point as you suggested. Could you please explain why disabling gradient checkpointing is recommended? If it's related to Liger's underlying mechanism, I will review the relevant literature for a deeper understanding.

I appreciate your third suggestion and will implement it in future trials.

Thank you.

@tyler-romero
Copy link
Collaborator

Gradient checkpointing isn't related to Liger Kernel per-se, but it's a technique that trades off training speed for a reduction in memory consumption. If Liger is enabled (and reducing the memory consumption of your model's training step), then it could enable turning off grad checkpointing and thus speed up training.

@tyler-romero
Copy link
Collaborator

tyler-romero commented Nov 28, 2024

Another suggestion is to hook up a profiler to see what time is being spent on over the course of a few training steps.

https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-tracing-functionality

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants