Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error During Model Saving QLORA + FSDP #2149

Open
6 of 8 tasks
ghsama opened this issue Dec 7, 2024 · 7 comments · Fixed by #2155 or huggingface/transformers#35212
Open
6 of 8 tasks

Error During Model Saving QLORA + FSDP #2149

ghsama opened this issue Dec 7, 2024 · 7 comments · Fixed by #2155 or huggingface/transformers#35212
Labels
bug Something isn't working waiting on upstream

Comments

@ghsama
Copy link

ghsama commented Dec 7, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

It is supposed to save the model without issue after finishing the training.

Current behaviour

it raises an error that it couldn t find a Paramater in a list, the issue is comming from the funciton _unflatten_param_groups in python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py

`

[rank0]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)                                                   [44/1835]
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train               
[rank0]:     return inner_training_loop(                                                                                             
[rank0]:            ^^^^^^^^^^^^^^^^^^^^                                                                                             
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/transformers/trainer.py", line 2589, in _inner_training_loop
[rank0]:     self._maybe_log_save_evaluate(                                                                                          
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/transformers/trainer.py", line 3054, in _maybe_log_save_eval
uate                                                                                                                                 
[rank0]:     self._save_checkpoint(model, trial)                                                                                     
[rank0]:   File "/workspace/axolotl/src/axolotl/core/trainer_builder.py", line 990, in _save_checkpoint                              
[rank0]:     saved = super()._save_checkpoint(model, trial, **kwargs)                                                                
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/transformers/trainer.py", line 3190, in _save_checkpoint    
[rank0]:     self._save_optimizer_and_scheduler(output_dir)                                                                          
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/transformers/trainer.py", line 3306, in _save_optimizer_and_
scheduler                                                                                                                            
[rank0]:     save_fsdp_optimizer(                                                                                                    
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 186, in save_fsdp_opti
mizer                                                                                                                                
[rank0]:     optim_state = FSDP.optim_state_dict(model, optimizer)                                                                   
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line
 1890, in optim_state_dict                                                                                                           
[rank0]:     return FullyShardedDataParallel._optim_state_dict_impl(                                                                 
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                 
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line
 1301, in _optim_state_dict_impl                                                                                                     
[rank0]:     return _optim_state_dict(                                                                                               
[rank0]:            ^^^^^^^^^^^^^^^^^^                                                                                               
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context  
[rank0]:     return func(*args, **kwargs)                                                                                            
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^                                                                                            
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 2019, in _opti
m_state_dict                                                                                                                         
[rank0]:     fsdp_osd["param_groups"] = _unflatten_param_groups(                                                                     
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^                                                                     
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1275, in _unfl
atten_param_groups                                                                                                                   
[rank0]:     nested_unflat_param_names = [                                                                                           
[rank0]:                                 ^                                                                                           
[rank0]:   File "/workspace/axolotl/new_env/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1276, in <list
comp>                                                                                                                                
[rank0]:     param_to_fqns[param] for param in param_group_params                                                                    
[rank0]:     ~~~~~~~~~~~~~^^^^^^^                                                                                                    
[rank0]: KeyError: Parameter containing:                                                                                             
[rank0]: tensor([[ 0.0166, -0.0110, -0.0100,  ...,  0.0162,  0.0114, -0.0138],                                                       
[rank0]:         [-0.0087, -0.0147,  0.0176,  ...,  0.0143,  0.0167,  0.0006],                                                       
[rank0]:         [-0.0201,  0.0154,  0.0206,  ..., -0.0014, -0.0056, -0.0142],                                                       
[rank0]:         ...,                                                                                                                
[rank0]:         [ 0.0045,  0.0175,  0.0115,  ...,  0.0125, -0.0143, -0.0030],                                                       
[rank0]:         [-0.0189,  0.0029,  0.0198,  ...,  0.0095,  0.0130, -0.0187],                                                       
[rank0]:         [ 0.0137, -0.0036,  0.0138,  ..., -0.0183, -0.0217, -0.0056]],                                                      
[rank0]:        device='cuda:0', requires_grad=True)
``` `



### Steps to reproduce

Run a training using the command :  
```shell
CUDA_LAUNCH_BLOCKING=1 accelerate launch -m axolotl.cli.train qlora-2048.yaml --dataset_processes=1 --max_steps=1 --batch_size=1 --micro_batch_size=1 --val_set_size=0.01 --sample_packing=False --eval_sample_packing=False --dataset_prepared_path=temp_debug/axolotl_outputs/data --output_dir=temp_debug/axolotl_outputs/model

using the config yml :

base_model: unsloth/Llama-3.2-1B-Instruct #unsloth/Llama-3.1-Nemotron-70B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer  
load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: fozziethebeat/alpaca_messages_2k_test
    type: chat_template
    shards: 10


evaluation_strategy : steps #epoch
eval_steps: 50
#evals_per_epoch: 1

eval_table_size: 3
eval_max_new_tokens: 512


dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out/qlora-llama3.2

adapter: qlora
lora_model_dir:

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true

lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:

use_wandb: true
wandb_project: llama3.1-seq2048-fine-tune-cuad-qa-partial
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model: checkpoint

#gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
#evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1

debug: true
deepspeed:
weight_decay: 0.0
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD
  activation_checkpointing: true
  pin_memory: false
special_tokens:
  pad_token: <|end_of_text|>

Config yaml

base_model: unsloth/Llama-3.2-1B-Instruct #unsloth/Llama-3.1-Nemotron-70B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer  
load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: fozziethebeat/alpaca_messages_2k_test
    type: chat_template
    shards: 10


evaluation_strategy : steps #epoch
eval_steps: 50
#evals_per_epoch: 1

eval_table_size: 3
eval_max_new_tokens: 512


dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out/qlora-llama3.2

adapter: qlora
lora_model_dir:

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true

lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:

use_wandb: true
wandb_project: llama3.1-seq2048-fine-tune-cuad-qa-partial
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model: checkpoint

#gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
#evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1

debug: true
deepspeed:
weight_decay: 0.0
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD
  activation_checkpointing: true
  pin_memory: false
special_tokens:
  pad_token: <|end_of_text|>

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.11

axolotl branch-commit

main

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@ghsama ghsama added the bug Something isn't working label Dec 7, 2024
@winglian
Copy link
Collaborator

winglian commented Dec 7, 2024

I was able to reproduce this on my end. Upon further digging, there is already a pending fix upstream in accelerate. huggingface/accelerate#3213. In the meantime, you can use save_only_model: true as a workaround, except that you won't be able to resume from a checkpoint as it won't save the optimizer state with that.

@winglian
Copy link
Collaborator

winglian commented Dec 7, 2024

Also a similar issue mentioned here huggingface/peft#2205

@ghsama
Copy link
Author

ghsama commented Dec 8, 2024

i will test the brtanch transformers-version-flexibility in my case

@ghsama
Copy link
Author

ghsama commented Dec 8, 2024

using the branch version afterthe training i receive this error ; Using save_only_model: true and without it

[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      [34/1991]
[rank0]:   File "/workspace/axolotl/src/axolotl/train.py", line 192, in train                                                              
[rank0]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)                                                                  
[rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train                   
[rank0]:     return inner_training_loop(                                                                                                   
[rank0]:            ^^^^^^^^^^^^^^^^^^^^                                                                                                   
[rank0]:   File "<string>", line 508, in _fixed_inner_training_loop                                                                        
[rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/trainer_callback.py", line 472, in on_train_end    
[rank0]:     return self.call_event("on_train_end", args, state, control)                                                                  
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                  
[rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/trainer_callback.py", line 519, in call_event      
[rank0]:     result = getattr(callback, event)(                                                                                            
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                            
[rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/integrations/integration_utils.py", line 918, in on
_train_end                                                                                                                                 
[rank0]:     fake_trainer = Trainer(args=args, model=model, processing_class=tokenizer, eval_dataset=["fake"])                             
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                             
[rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func   
[rank0]:     return func(*args, **kwargs)                                                                                                  
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/trainer.py", line 554, in __init__
[rank0]:     raise ValueError(
[rank0]: ValueError: You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of the quantized mo
del to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft for more details

@winglian
Copy link
Collaborator

winglian commented Dec 8, 2024

using the branch version afterthe training i receive this error ; Using save_only_model: true and without it

Are you using lora/qlora? Looks like it's erroring that you are trying to fft a quantized model

@ghsama
Copy link
Author

ghsama commented Dec 8, 2024

using the branch version afterthe training i receive this error ; Using save_only_model: true and without it

Are you using lora/qlora? Looks like it's erroring that you are trying to fft a quantized model

I m using qlora, more speceficly qlora+fsdp

@winglian
Copy link
Collaborator

winglian commented Dec 8, 2024

sorry, it auto-closed due to the linked PR getting merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working waiting on upstream
Projects
None yet
2 participants