Skip to content

Commit

Permalink
sft rm
Browse files Browse the repository at this point in the history
  • Loading branch information
d.taranets committed Oct 14, 2024
1 parent 2b9f206 commit 4f85480
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 0 deletions.
13 changes: 13 additions & 0 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def train_rm_entrypoint(
pipelines.TrainRMStrategy().run(experiment_settings)


@app.command(name='train_multihead', help='Run RM pipeline')
def train_rm_entrypoint(
experiment_settings_path: Path = typer.Option(
...,
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
)
) -> None:
experiment_settings = pipeline_settings.RMTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainMultiheadStrategy().run(experiment_settings)


@app.command(name='train_classification', help='Train Classifier')
def classification_training(
experiment_settings_path: Path = typer.Option(
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/pipelines/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .rag import TrainRAGStrategy
from .rm import TrainRMStrategy
from .sft import TrainSFTStrategy
from .sft_rm import TrainMultiheadStrategy
144 changes: 144 additions & 0 deletions turbo_alignment/pipelines/train/sft_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Callable
import torch

import numpy as np
from torch import nn
from torch.utils.data import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments, AutoModel, AutoModelForCausalLM, GenerationMixin, AutoConfig
from transformers.data.data_collator import DataCollatorMixin

from turbo_alignment.cherry_picks.rm import MultiHeadCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.dataset.pair_preferences import (
PairPreferenceDataCollator,
PairPreferenceDataset,
)
from turbo_alignment.common.tf.loaders.model.model import load_model
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import MetricSettingsRegistry
from turbo_alignment.metrics.reward import compute_metrics as compute_rm_metrics
from turbo_alignment.pipelines.train.base import BaseTrainStrategy
from turbo_alignment.settings.datasets import DatasetStrategy
from turbo_alignment.settings.pipelines import RMTrainExperimentSettings
from turbo_alignment.trainers.sft_with_rm import SFTwithRMTrainer
from turbo_alignment.trainers.utils import concatenated_inputs

logger = get_project_logger()


class MultiheadModel(PreTrainedModel, GenerationMixin):
def __init__(self, config, model_settings, tokenizer):
super().__init__(config)

self.decoder = load_model(model_settings=model_settings, tokenizer=tokenizer)

self.lm_head = nn.Linear(self.decoder.norm.weight.shape[0], len(tokenizer), bias=False)
self.rm_head = nn.Linear(self.decoder.norm.weight.shape[0], 1, bias=False)

reward_token_ids = tokenizer.encode('<reward>', add_special_tokens=False)
if len(reward_token_ids) != 1:
raise ValueError('<reward> token us not found in the tokenizer')

self.reward_token_ids = reward_token_ids[0]

def forward(self, batch):

outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0]
outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0]


reward_token_pos_w = np.where(batch['inputs_w']['input_ids'][0].cpu() == self.reward_token_ids)[0]
reward_token_pos_l = np.where(batch['inputs_l']['input_ids'][0].cpu() == self.reward_token_ids)[0]

if len(reward_token_pos_w) != 1 or len(reward_token_pos_l) != 1:
raise ValueError('More than one <reward> token detected in replica')

outputs_w_1 = outputs_w[:reward_token_pos_w[0]]
outputs_w_2 = outputs_w[reward_token_pos_w[0]+1:]
outputs_w_cat = torch.cat((outputs_w_1, outputs_w_2), dim=0)

lm_logits = self.lm_head(outputs_w_cat)
rm_logits_w = self.rm_head(outputs_w[reward_token_pos_w[0]])
rm_logits_l = self.rm_head(outputs_l[reward_token_pos_l[0]])

return lm_logits, rm_logits_w, rm_logits_l, reward_token_pos_w


class TrainMultiheadStrategy(BaseTrainStrategy[RMTrainExperimentSettings]):
@staticmethod
def _get_data_collator(
experiment_settings: RMTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
**_kwargs,
) -> Callable:
return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=False)

@staticmethod
def _get_cherry_pick_callback(
experiment_settings: RMTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
**_kwargs,
) -> MultiHeadCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[PairPreferenceDataset](PairPreferenceDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.TRAIN
)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in cherry_pick_settings.metric_settings
]

return MultiHeadCherryPickCallback(
cherry_pick_settings=cherry_pick_settings,
datasets=cherry_pick_datasets,
metrics=metrics,
)

@staticmethod
def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> TrainingArguments:
return TrainingArguments(
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
)

@staticmethod
def _load_model(
experiment_settings: RMTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
) -> nn.Module | PreTrainedModel:
config = AutoConfig.from_pretrained(experiment_settings.model_settings.model_path)
return MultiheadModel(config, experiment_settings.model_settings, tokenizer)


@staticmethod
def _get_trainer(
training_args: TrainingArguments,
experiment_settings: RMTrainExperimentSettings,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
train_dataset: Dataset,
val_dataset: Dataset,
data_collator: DataCollatorMixin,
**_kwargs,
):
return SFTwithRMTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
compute_metrics=compute_rm_metrics,
callbacks=[],
)

def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None:
logger.info(f'Train sample input_ids:\n{dataset[0]}')
logger.info(f'Train sample example:\n{self.tokenizer.decode(dataset[0]["inputs_w"]["input_ids"])}')
logger.info(f'Train sample example:\n{self.tokenizer.decode(dataset[0]["inputs_l"]["input_ids"])}')
84 changes: 84 additions & 0 deletions turbo_alignment/trainers/sft_with_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any
import os

import torch
from torch import nn
from pathlib import Path
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import logging

from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

logger = logging.get_logger(__name__)


class SFTwithRMTrainer(MultiGPUCherryPicksTrainer):
def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:

sft_logits, rewards_w, rewards_l, reward_token_pos_w = model.forward(inputs)

sft_logits = sft_logits.view(-1, sft_logits.size(-1))
sft_labels = inputs['inputs_w']['input_ids']

# print(f'INPUTS TEXT SHAPE: {sft_labels.shape}')
# print(f'INPUTS TEXT: {sft_labels}')

sft_labels_1 = sft_labels.view(-1)[:reward_token_pos_w[0]]
sft_labels_2 = sft_labels.view(-1)[reward_token_pos_w[0]+1:]
sft_labels_cat = torch.cat((sft_labels_1, sft_labels_2), dim=0)

loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy(sft_logits, sft_labels_cat)
if return_outputs:
return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l}
return loss

def prediction_step(
self,
model: PreTrainedModel | nn.Module,
inputs: dict[str, dict[str, torch.Tensor]],
prediction_loss_only: bool,
ignore_keys: list[str] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:

inputs = self._prepare_inputs(inputs)

if ignore_keys is None:
if hasattr(self.model, 'config'):
ignore_keys = getattr(self.model.config, 'keys_to_ignore_at_inference', [])
else:
ignore_keys = []

with torch.no_grad():
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)

if prediction_loss_only:
return (loss, None, None)

loss = loss.detach()
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
logits = nested_detach(logits)

logits = torch.stack(logits).T

labels = logits[:, 0] > logits[:, 1]

labels = labels.long()

return loss, logits, labels


def _save_checkpoint(self, model, trial, metrics=None):
logger.info('Running custom _save_checkpoint')
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
run_dir = self._get_output_dir(trial=trial)
output_dir = Path(os.path.join(run_dir, checkpoint_folder))

(output_dir / 'decoder').mkdir(parents=True, exist_ok=True)

torch.save(model.module.lm_head.state_dict(), output_dir / 'lm_head.pt')
torch.save(model.module.rm_head.state_dict(), output_dir / 'rm_head.pt')

model.module.decoder.save_pretrained(output_dir / 'decoder')
self.tokenizer.save_pretrained(output_dir / 'decoder')

0 comments on commit 4f85480

Please sign in to comment.