Skip to content

Commit

Permalink
fix rm trainer with deepspeed stage3 and peft
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 30, 2024
1 parent 2b9f206 commit 65bf5e9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
13 changes: 13 additions & 0 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

import torch
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from transformers import PreTrainedModel, PreTrainedTokenizerBase
Expand All @@ -8,6 +10,7 @@
)
from turbo_alignment.settings.model import (
ModelForPeftSettings,
ModelType,
PreTrainedAdaptersModelSettings,
PreTrainedModelSettings,
)
Expand Down Expand Up @@ -77,4 +80,14 @@ def load_model(
# creating learnable adapters and freezing non-training parameters
model = _prepare_model_for_peft(model, model_settings.peft_settings)

# deepspeed stage3 is currently doens't work with seq_cls head and peft
if model_settings.model_type == ModelType.SEQ_CLS and is_deepspeed_zero3_enabled():
model.base_model.model.score = torch.nn.Linear(
in_features=model.base_model.model.score.original_module.in_features,
out_features=model.base_model.model.score.original_module.out_features,
bias=model.base_model.model.score.original_module.bias,
)
model.base_model.model.score.weight.requires_grad = True

return model

17 changes: 17 additions & 0 deletions turbo_alignment/trainers/rm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
from pathlib import Path
from typing import Any

import torch
from peft import PeftModel
from torch import nn
from transformers import PreTrainedModel
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import logging

from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer
Expand Down Expand Up @@ -64,3 +68,16 @@ def prediction_step(
labels = labels.long()

return loss, logits, labels

def _save_checkpoint(self, model, trial, metrics=None):
if isinstance(model, PeftModel) and self.accelerator.state.deepspeed_plugin.zero_stage == 3:
logger.info('Running custom _save_checkpoint')
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
run_dir = self._get_output_dir(trial=trial)
output_dir = Path(os.path.join(run_dir, checkpoint_folder))

(output_dir / 'cls_head').mkdir(parents=True, exist_ok=True)

torch.save(model.base_model.model.score.state_dict(), output_dir / 'cls_head' / 'cls_head.pt')

return super()._save_checkpoint(model=model, trial=trial, metrics=metrics) # pylint: disable=no-member

0 comments on commit 65bf5e9

Please sign in to comment.