Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

☘️ Fix RM trainer with PEFT + DeepSpeed Stage3 #31

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import torch
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

from turbo_alignment.common.tf.loaders.model.registry import (
PeftConfigRegistry,
TransformersAutoModelRegistry,
)
from turbo_alignment.settings.model import (
ModelForPeftSettings,
ModelType,
PreTrainedAdaptersModelSettings,
PreTrainedModelSettings,
)
Expand Down Expand Up @@ -77,4 +79,13 @@ def load_model(
# creating learnable adapters and freezing non-training parameters
model = _prepare_model_for_peft(model, model_settings.peft_settings)

# deepspeed stage3 is currently doens't work with seq_cls head and peft
if model_settings.model_type == ModelType.SEQ_CLS and is_deepspeed_zero3_enabled():
model.base_model.model.score = torch.nn.Linear(
in_features=model.base_model.model.score.original_module.in_features,
out_features=model.base_model.model.score.original_module.out_features,
bias=model.base_model.model.score.original_module.bias,
)
model.base_model.model.score.weight.requires_grad = True

return model
18 changes: 18 additions & 0 deletions turbo_alignment/trainers/rm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
from pathlib import Path
from typing import Any

import torch
from peft import PeftModel
from torch import nn
from transformers import PreTrainedModel
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import logging

from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer
Expand Down Expand Up @@ -64,3 +69,16 @@ def prediction_step(
labels = labels.long()

return loss, logits, labels

def _save_checkpoint(self, model, trial, metrics=None):
if isinstance(model, PeftModel) and is_deepspeed_zero3_enabled():
logger.info('Running custom _save_checkpoint')
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
run_dir = self._get_output_dir(trial=trial)
output_dir = Path(os.path.join(run_dir, checkpoint_folder))

(output_dir / 'cls_head').mkdir(parents=True, exist_ok=True)

torch.save(model.base_model.model.score.state_dict(), output_dir / 'cls_head' / 'cls_head.pt')

return super()._save_checkpoint(model=model, trial=trial, metrics=metrics) # pylint: disable=no-member
Loading