Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed zero1 zero2 zero 3 out of memory when big model #2217

Open
6 of 8 tasks
sankexin opened this issue Dec 23, 2024 · 3 comments
Open
6 of 8 tasks

deepspeed zero1 zero2 zero 3 out of memory when big model #2217

sankexin opened this issue Dec 23, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@sankexin
Copy link

sankexin commented Dec 23, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

train normally

Current behaviour

Traceback (most recent call last):
File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/axolotl/src/axolotl/cli/train.py", line 38, in
fire.Fire(do_cli)
File "/usr/local/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/usr/local/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/usr/local/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/home/axolotl/src/axolotl/cli/train.py", line 34, in do_cli
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
File "/home/axolotl/src/axolotl/train.py", line 119, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2164, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2522, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 3656, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch)
File "/home/axolotl/src/axolotl/core/trainer_builder.py", line 250, in compute_loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
File "/home/axolotl/src/axolotl/monkeypatch/medusa_utils.py", line 227, in compute_loss
logits = model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1855, in forward
loss = self.module(*inputs, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 1091, in forward
return self.base_model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 160, in forward
return self.model.forward(*args, **kwargs)
File "/home/axolotl/src/axolotl/monkeypatch/medusa_utils.py", line 184, in forward
medusa_logits.append(self.medusa_headi)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/peft/utils/other.py", line 250, in forward
return self.modules_to_save[self.active_adapter](*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: HIP out of memory. Tried to allocate 8.70 GiB. GPU 0 has a total capacty of 63.98 GiB of which 0 bytes is free. Of the allocated memory 52.85 GiB is allocated by PyTorch, and 8.45 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_HIP_ALLOC_CONF

Steps to reproduce

zero1.json zero2.json zero3.json

accelerate launch -m axolotl.cli.train examples/medusa/qwen_lora_stage1.yml
https://github.com/ctlllll/axolotl.git

I have test : #1129
'''
load_in_8bit: false
load_in_4bit: false
...

"bf16": true
'''

but it didn`t work! any help ?

Config yaml

'''
base_model: Qwen2-72B-Instruct
trust_remote_code: true
base_model_config: Qwen2-72B-Instruct
model_type: AutoModelForCausalLM
model_name: Qwen2ForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: data/qwen1.5_7b_medusa_training_data_sharegpt.jsonl
    type: sharegpt
    conversation: qwen-7b-chat
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qwen_lora_stage1

adapter: lora
lora_model_dir:

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
  - v_proj
  - q_proj
  - k_proj
  - lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:


sequence_len: 1024
sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 8
micro_batch_size: 30
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 20
eval_steps: 40
save_steps:
save_total_limit: 1
debug: true
deepspeed: deepspeed/zero2.json
weight_decay: 0.0
fsdp:
  #- full_shard
  #- auto_wrap
fsdp_config:                                                                                                                                                                                                                                        
  #fsdp_limit_all_gathers: true                                                                                                                                                                                                                     
  #fsdp_sync_module_states: true                                                                                                                                                                                                                    
  #fsdp_offload_params: true                                                                                                                                                                                                                        
  #fsdp_use_orig_params: false                                                                                                                                                                                                                      
  #fsdp_cpu_ram_efficient_loading: true                                                                                                                                                                                                             
  #fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP                                                                                                                                                                                                    
  #fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer                                                                                                                                                                                            
  #fsdp_state_dict_type: FULL_STATE_DICT                                                                                                                                                                                                            
  #fsdp_sharding_strategy: FULL_SHARD                                                                                                                                                                                                               
#special_tokens:                                                                                                                                                                                                                                    
#bos_token: "<|endoftext|>"                                                                                                                                                                                                                         
#eos_token: "<|im_end|>"                                                                                                                                                                                                                            
#unk_token: "<|endoftext|>"                                                                                                                                                                                                                         
#pad_token: "<|endoftext|>"                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                    
medusa_num_heads: 4                                                                                                                                                                                                                                 
medusa_num_layers: 1                                                                                                                                                                                                                                
medusa_heads_coefficient: 0.2                                                                                                                                                                                                                       
medusa_decay_coefficient: 0.8                                                                                                                                                                                                                       
medusa_logging: false                                                                                                                                                                                                                               
medusa_scheduler: constant                                                                                                                                                                                                                          
medusa_lr_multiplier: 4.0                                                                                                                                                                                                                           
#medusa_self_distillation: true                                                                                                                                                                                                                     
medusa_only_heads: true                                                                                                                                                                                                                             
ddp_find_unused_parameters: true
'''

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

https://github.com/ctlllll/axolotl.git

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@sankexin sankexin added the bug Something isn't working label Dec 23, 2024
@winglian
Copy link
Collaborator

How many gpus are you using? Even with multiple gpus, you're going to need zero3 with offload. None of the vanilla zero1-3 is going to work for a 72b parameter model. Additionally, consider using an 8bit LoRA instead of a half precision. A 72b model needs 144b GBs of VRAM just for the model weights at half precision.

@sankexin
Copy link
Author

sankexin commented Dec 24, 2024

How many gpus are you using? Even with multiple gpus, you're going to need zero3 with offload. None of the vanilla zero1-3 is going to work for a 72b parameter model. Additionally, consider using an 8bit LoRA instead of a half precision. A 72b model needs 144b GBs of VRAM just for the model weights at half precision.

zero3 offload is "out of memory too", >4 gpus, 4*64 GBs of VRAM , and 8 gpus is all the same bug, I have test 8bit LoRA , something like the deepspeed didnt work ,but check deepspeed init is normal, deepspeed code in transformers trainer is normal, and so I dont know that`s why.

@winglian
Copy link
Collaborator

Right, but you have to offload everything with zero3, which isn't the default, which is why I'm asking. Also, you're going to need close to 4x gpus, maybe 8x if you're not doing a quantized lora.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants