Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can this training framework be used for training gte-qwen2-7b-instruct model? #51

Open
Double-bear opened this issue Aug 28, 2024 · 1 comment

Comments

@Double-bear
Copy link

Hello. I wanted to use gritlm to a open-source embedding model —— gte-qwen2-7b-instruct, but I encountered some problems:

[rank1]: Traceback (most recent call last):

[rank1]:   File "/code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/run.py", line 438, in <module>

[rank1]:     main()

[rank1]:   File "/code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/run.py", line 420, in main

[rank1]:     trainer.train()

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1859, in train

[rank1]:     return inner_training_loop(

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2203, in _inner_training_loop

[rank1]:     tr_loss_step = self.training_step(model, inputs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3138, in training_step

[rank1]:     loss = self.compute_loss(model, inputs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3161, in compute_loss

[rank1]:     outputs = model(**inputs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl

[rank1]:     return self._call_impl(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl

[rank1]:     return forward_call(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1608, in forward

[rank1]:     else self._run_ddp_forward(*inputs, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1426, in _run_ddp_forward

[rank1]:     return self.module(*inputs, **kwargs)  # type: ignore[index]

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl

[rank1]:     return self._call_impl(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl

[rank1]:     return forward_call(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 825, in forward

[rank1]:     return model_forward(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 813, in __call__

[rank1]:     return convert_to_fp32(self.model_forward(*args, **kwargs))

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast

[rank1]:     return func(*args, **kwargs)

[rank1]:   File "/code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/model.py", line 204, in forward

[rank1]:     p_reps = self.encode(passage)

[rank1]:   File "/code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/model.py", line 145, in encode

[rank1]:     out = (getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model)(**kwargs)[0]

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl

[rank1]:     return self._call_impl(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl

[rank1]:     return forward_call(*args, **kwargs)

[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/gte-qwen2-7B-instruct/modeling_qwen.py", line 1081, in forward

[rank1]:     layer_outputs = decoder_layer(

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl

[rank1]:     return self._call_impl(*args, **kwargs)

[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl

[rank1]:     return forward_call(*args, **kwargs)

[rank1]:   File "/root/.cache/huggingface/modules/transformers_modules/gte-qwen2-7B-instruct/modeling_qwen.py", line 795, in forward

[rank1]:     hidden_states = residual + hidden_states

[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 56.00 MiB. GPU 1 has a total capacity of 79.33 GiB of which 35.81 MiB is free. Process 2768284 has 79.28 GiB memory in use. Of the allocated memory 77.51 GiB is allocated by PyTorch, and 761.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

My GPU is A800, 80G, and I used 8 * A800.
My submit script is following:

#!/bin/bash
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

export WANDB_DISABLED=true
export PYTHONPATH=/code/xx/LLM_mine/recall/reference/gritlm

echo "====MASTER_ADDR===="
echo $MASTER_ADDR
MASTER_ADDR=$(ping -c 1 ${MASTER_ADDR} |grep PING | grep -E -o '([0-9]{1,3}\.){3}[0-9]{1,3}')
echo $MASTER_ADDR
echo "==================="

# pip install /code/xx/LLM_mine/recall/reference/gritlm
pip install transformers==4.40.0 accelerate==0.29.1 datasets==2.18.0
pip install /code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/GradCache

torchrun --nnodes $WORLD_SIZE --nproc_per_node 8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
    /code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/run.py \
    --output_dir /code/xx/LLM_mine/recall/reference/gritlm/output \
    --model_name_or_path /code/xx/LLM_mine/recall/model/gte-qwen2-7B-instruct \
    --train_data /code/xx/LLM_mine/recall/reference/gritlm/gritlm/training/toy_data/cv_embedding.jsonl \
    --learning_rate 1e-5 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 2 \
    --dataloader_drop_last True \
    --normalized True \
    --temperature 0.02 \
    --query_max_len 256 \
    --passage_max_len 2048 \
    --train_group_size 2 \
    --mode embedding \
    --attn bbcc \
    --attn_implementation sdpa \
    --bf16 

How can I solve this problem?

@Muennighoff
Copy link
Collaborator

Yes it can be used to train that model; You're running OOM which you can solve via

  • decrease per_device_train_batch_size
  • use more gpus
  • decrease passage_max_len
  • use --gradient_checkpointing
  • use --split_emb

Consider taking a look at https://github.com/ContextualAI/gritlm/blob/main/scripts/training/train_gritlm_7b.sh & https://github.com/ContextualAI/gritlm/blob/main/scripts/training/train_gritlm_8x7b.sh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants