diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e9fcc7276d6c..574c5b554be2 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3634,6 +3634,15 @@ jobs: RUNNER: self-hosted-azure SCRIPT: | TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --disable-ckpt --strategy fsdp --devices 2 + + L2_VLM_HF_Transformer_PEFT_4bit: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT_4bit') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --disable-ckpt --use-4bit AFTER_SCRIPT: | rm -rf nemo_experiments @@ -4930,6 +4939,7 @@ jobs: - L2_HF_Transformer_SFT_2gpu - L2_VLM_HF_Transformer_PEFT - L2_VLM_HF_Transformer_PEFT_FSDP + - L2_VLM_HF_Transformer_PEFT_4bit - L2_HF_Transformer_SFT_2gpu_nemorun - L2_HF_Transformer_SFT_TE_Acceleration - L2_NeMo_2_SSM_Pretraining diff --git a/examples/vlm/hf/peft.py b/examples/vlm/hf/peft.py index 2400c333f398..01ba0fb7d5e7 100644 --- a/examples/vlm/hf/peft.py +++ b/examples/vlm/hf/peft.py @@ -85,6 +85,7 @@ def fmt(sample): parser.add_argument('--accelerator', default='gpu', choices=['gpu']) parser.add_argument('--max-steps', type=int, default=100) parser.add_argument('--wandb-project', type=str, default=None) + parser.add_argument('--use-4bit', help="Load model in 4bit", action="store_true") args = parser.parse_args() wandb = None @@ -103,7 +104,7 @@ def fmt(sample): processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model) llm.api.finetune( - model=vlm.HFAutoModelForImageTextToText(args.model), + model=vlm.HFAutoModelForImageTextToText(args.model, load_in_4bit=args.use_4bit), data=mk_hf_vlm_dataset(processor, args.mbs, args.gbs), trainer=nl.Trainer( devices=args.devices, @@ -124,5 +125,6 @@ def fmt(sample): peft=llm.peft.LoRA( target_modules=['*_proj'], dim=16, + lora_dtype=torch.bfloat16 if args.use_4bit else None, ), ) diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py index a0318c587e57..6c7e7e93ae8f 100644 --- a/nemo/collections/llm/peft/lora.py +++ b/nemo/collections/llm/peft/lora.py @@ -132,9 +132,12 @@ def _init_adapter( obj.dropout_position = dropout_position @staticmethod - def _forward(obj, x): + def _forward(obj, x, fwd=None): # pylint: disable=C0115,C0116 - res = F.linear(x, obj.weight, obj.bias) + if fwd is not None: + res = fwd(x) + else: + res = F.linear(x, obj.weight, obj.bias) if obj.dropout_position == 'pre': x = obj.dropout(x) lora_res = x @ obj.lora_a @@ -187,7 +190,11 @@ def patch_linear_module( assert isinstance(orig_linear, nn.Linear) LinearAdapter._init_adapter(orig_linear, dim, alpha, dropout, dropout_position, lora_A_init_method, lora_dtype) - orig_linear.forward = lambda x: LinearAdapter._forward(orig_linear, x) + fwd = None + # If the model uses quantized weights, we want to use orig_linear's forward + if orig_linear.weight.dtype == torch.uint8: + fwd = orig_linear.forward + orig_linear.forward = lambda x: LinearAdapter._forward(orig_linear, x, fwd) return orig_linear @@ -264,7 +271,7 @@ def transform(self, m: nn.Module, name=None, prefix=None): full_name = f"{prefix}.{name}" if prefix else name if name in self.target_modules or any(wildcard_match(pattern, full_name) for pattern in self.target_modules): if isinstance(m, nn.Linear): - if self._is_fsdp_v1: + if self._is_fsdp_v1 or m.weight.data.dtype == torch.uint8: lora_cls = patch_linear_module else: lora_cls = LinearAdapter diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt index aa33b3b55127..35a060164c5e 100644 --- a/requirements/requirements_multimodal.txt +++ b/requirements/requirements_multimodal.txt @@ -1,4 +1,5 @@ addict +bitsandbytes==0.45.0 clip decord; sys_platform == 'linux' diffusers>=0.19.3 diff --git a/tests/collections/vlm/hf/peft.py b/tests/collections/vlm/hf/peft.py index 96caebb5c243..bbe5462e431a 100644 --- a/tests/collections/vlm/hf/peft.py +++ b/tests/collections/vlm/hf/peft.py @@ -86,6 +86,7 @@ def fmt(sample): parser.add_argument('--max-steps', type=int, default=100) parser.add_argument('--wandb-project', type=str, default=None) parser.add_argument('--disable-ckpt', action='store_false') + parser.add_argument('--use-4bit', help="Load model in 4bit", action="store_true") args = parser.parse_args() wandb = None @@ -103,7 +104,7 @@ def fmt(sample): processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model) llm.api.finetune( - model=vlm.HFAutoModelForImageTextToText(args.model), + model=vlm.HFAutoModelForImageTextToText(args.model, load_in_4bit=args.use_4bit), data=mk_hf_vlm_dataset(processor, args.mbs, args.gbs), trainer=nl.Trainer( devices=args.devices, @@ -125,5 +126,6 @@ def fmt(sample): peft=llm.peft.LoRA( target_modules=['*_proj'], dim=16, + lora_dtype=torch.bfloat16 if args.use_4bit else None, ), )