Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lora example not working with deepspeed zero3 #1481

Closed
6 of 8 tasks
xu3kev opened this issue Apr 5, 2024 · 6 comments
Closed
6 of 8 tasks

lora example not working with deepspeed zero3 #1481

xu3kev opened this issue Apr 5, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@xu3kev
Copy link

xu3kev commented Apr 5, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

should be able to do training as usual

Current behaviour

crash with the following error message

Traceback (most recent call last):
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/wl678/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/wl678/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/home/wl678/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/home/wl678/axolotl/src/axolotl/train.py", line 84, in train
    model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
  File "/home/wl678/axolotl/src/axolotl/utils/models.py", line 608, in load_model
    raise err
  File "/home/wl678/axolotl/src/axolotl/utils/models.py", line 527, in load_model
    model = LlamaForCausalLM.from_pretrained(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3562, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3989, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/transformers/modeling_utils.py", line 822, in _load_state_dict_into_meta_model
    value = type(value)(value.data.to("cpu"), **value.__dict__)
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 491, in __new__
    obj = torch.Tensor._make_subclass(cls, data, requires_grad)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients

Steps to reproduce

run the codellama-7b lora example with deepspeed zero3

accelerate launch -m axolotl.cli.train examples/code-llama/7b/lora.yml --deepspeed deepspeed_configs/zero3.json --eval_sample_packing false

Config yaml

base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: mhenrichsen/alpaca_2k_test
    type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main/c2b64e4dcff59cfbd754626e5172688433cc13e1

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@xu3kev xu3kev added the bug Something isn't working label Apr 5, 2024
@1716649290
Copy link

I met the same problem. Have you tried any good solutions since then?

@watsonchua
Copy link

I'm facing this issue now. It has been two months since this was reported. Did anybody find a solution?

@kalomaze
Copy link

Also running into this

@Nero10578
Copy link
Contributor

Ran into this and setting load_in_8bit to false made it work.

@cjakfskvnad
Copy link

same here.

@bursteratom
Copy link
Collaborator

Hi this is a known issue where fix is depending on transformers upstream. Currently Zero3 works with 4-bit qlora but not 8-bit lora. See #2068

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants