From 773232462534cc824fb5bd3c6f546688c0997521 Mon Sep 17 00:00:00 2001 From: "d.taranets" Date: Tue, 15 Oct 2024 15:55:58 +0300 Subject: [PATCH] lint fixes --- turbo_alignment/cherry_picks/rm.py | 61 +++++++++++++++++++++++ turbo_alignment/cli/train.py | 2 +- turbo_alignment/generators/multihead.py | 56 +++++++++++++++++++++ turbo_alignment/pipelines/train/base.py | 2 +- turbo_alignment/pipelines/train/sft_rm.py | 24 +++++---- turbo_alignment/trainers/sft_with_rm.py | 13 +++-- 6 files changed, 138 insertions(+), 20 deletions(-) create mode 100644 turbo_alignment/generators/multihead.py diff --git a/turbo_alignment/cherry_picks/rm.py b/turbo_alignment/cherry_picks/rm.py index 6cdcbcc..0b9c8d3 100755 --- a/turbo_alignment/cherry_picks/rm.py +++ b/turbo_alignment/cherry_picks/rm.py @@ -6,6 +6,7 @@ from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.pair_preferences import PairPreferenceDataset from turbo_alignment.generators.rm import RMPairGenerator +from turbo_alignment.generators.multihead import MultiHeadPairGenerator from turbo_alignment.metrics.metric import Metric from turbo_alignment.settings.cherry_pick import RMCherryPickSettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults @@ -69,3 +70,63 @@ def _get_dataset_metrics( ] return metric_outputs + + +class MultiHeadCherryPickCallback(CherryPickCallbackBase[PairPreferenceDataset]): + def __init__( + self, + cherry_pick_settings: RMCherryPickSettings, + datasets: Iterable[PairPreferenceDataset], + metrics: list[Metric], + ) -> None: + super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics) + + def _get_dataset_metrics( + self, + dataset: PairPreferenceDataset, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> list[MetricResults]: + accelerator: Accelerator = kwargs.get('accelerator', None) + + generator = MultiHeadPairGenerator( + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + ) + + generations = generator.generate_from_dataset(dataset) + generations_w = [gen.reward_w for gen in generations] + generations_l = [gen.reward_l for gen in generations] + + pair_scores = [1 if w > l else 0 for w, l in zip(generations_w, generations_l)] + + answers_w = [record.answer_w.content for record in generations] + answers_l = [record.answer_l.content for record in generations] + + metric_outputs = [ + MetricResults( + element_wise_scores=[ElementWiseScores(label=dataset.source.name + '@@' + 'chosen', values=answers_w)] + ), + MetricResults( + element_wise_scores=[ + ElementWiseScores(label=dataset.source.name + '@@' + 'rejected', values=answers_l) + ] + ), + MetricResults( + element_wise_scores=[ + ElementWiseScores(label=dataset.source.name + '@@' + 'chosen_reward', values=generations_w) + ] + ), + MetricResults( + element_wise_scores=[ + ElementWiseScores(label=dataset.source.name + '@@' + 'rejected_reward', values=generations_l) + ] + ), + MetricResults( + element_wise_scores=[ElementWiseScores(label=dataset.source.name + '@@' + 'score', values=pair_scores)] + ), + ] + + return metric_outputs diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index 2837715..dab9f9f 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -73,7 +73,7 @@ def train_rm_entrypoint( @app.command(name='train_multihead', help='Run RM pipeline') -def train_rm_entrypoint( +def train_multihead( experiment_settings_path: Path = typer.Option( ..., '--experiment_settings_path', diff --git a/turbo_alignment/generators/multihead.py b/turbo_alignment/generators/multihead.py new file mode 100644 index 0000000..3ff42ad --- /dev/null +++ b/turbo_alignment/generators/multihead.py @@ -0,0 +1,56 @@ +from typing import Any + +import torch +from transformers import DataCollatorWithPadding, PreTrainedTokenizerBase + +from turbo_alignment.dataset.pair_preferences import PairPreferenceRecord +from turbo_alignment.generators.base import BaseGenerator +from turbo_alignment.settings.generators.outputs.rm import RMPairInferenceOutput + + +class MultiHeadPairGenerator(BaseGenerator[PairPreferenceRecord, RMPairInferenceOutput]): + def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): + self._collator = DataCollatorWithPadding(tokenizer=tokenizer) + + super().__init__(tokenizer=tokenizer, **kwargs) + + def _generate_from_batch( + self, records: list[dict[str, Any]], original_records: list[PairPreferenceRecord], dataset_name: str + ) -> list[RMPairInferenceOutput]: + rewards_w, rewards_l = [], [] + for record in records: + input_ids_w = record['inputs_w']['input_ids'].to(self.device) + att_w = record['inputs_w']['attention_mask'].to(self.device) + inputs_w = {'input_ids': input_ids_w, 'attention_mask': att_w} + + input_ids_l = record['inputs_l']['input_ids'].to(self.device) + att_l = record['inputs_w']['attention_mask'].to(self.device) + inputs_l = {'input_ids': input_ids_l, 'attention_mask': att_l} + + batch = {'inputs_w': inputs_w, 'inputs_l': inputs_l} + + with torch.no_grad(): + _, reward_w, reward_l, _ = self._model.forward(batch).logits.cpu() + rewards_w.append(reward_w) + rewards_l.append(reward_l) + + # merged_inputs = [r['inputs_w'] for r in records] + [r['inputs_l'] for r in records] + # batch = self._collator(merged_inputs) + # input_ids = batch['input_ids'].to(self.device) + # attn_mask = batch['attention_mask'].to(self.device) + + # with torch.no_grad(): + # _, rewards_w, rewards_l, _ = self._model(input_ids=input_ids, attention_mask=attn_mask).logits.cpu() + + return [ + RMPairInferenceOutput( + id=record.id, + context=record.context, + answer_w=record.answer_w, + answer_l=record.answer_l, + reward_w=reward_w.item(), + reward_l=reward_l.item(), + dataset_name=dataset_name, + ) + for record, reward_w, reward_l in zip(original_records, rewards_w, rewards_l) + ] diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 0f98602..a7be9c3 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -185,7 +185,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: if self.trainer.accelerator.is_main_process: self._dataset_and_collator_sanity_check(train_dataset, data_collator) - self._add_trainer_callbacks(experiment_settings) + # self._add_trainer_callbacks(experiment_settings) os.makedirs(self.trainer.args.output_dir, exist_ok=True) self._save_experiment_config( diff --git a/turbo_alignment/pipelines/train/sft_rm.py b/turbo_alignment/pipelines/train/sft_rm.py index fc65b75..71aa8d5 100755 --- a/turbo_alignment/pipelines/train/sft_rm.py +++ b/turbo_alignment/pipelines/train/sft_rm.py @@ -4,7 +4,13 @@ import numpy as np from torch import nn from torch.utils.data import Dataset -from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments, AutoModel, AutoModelForCausalLM, GenerationMixin, AutoConfig +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + TrainingArguments, + GenerationMixin, + AutoConfig, +) from transformers.data.data_collator import DataCollatorMixin from turbo_alignment.cherry_picks.rm import MultiHeadCherryPickCallback @@ -23,7 +29,6 @@ from turbo_alignment.settings.datasets import DatasetStrategy from turbo_alignment.settings.pipelines import RMTrainExperimentSettings from turbo_alignment.trainers.sft_with_rm import SFTwithRMTrainer -from turbo_alignment.trainers.utils import concatenated_inputs logger = get_project_logger() @@ -39,24 +44,22 @@ def __init__(self, config, model_settings, tokenizer): reward_token_ids = tokenizer.encode('', add_special_tokens=False) if len(reward_token_ids) != 1: - raise ValueError(' token us not found in the tokenizer') + raise ValueError(' token is not found in the tokenizer') self.reward_token_ids = reward_token_ids[0] - - def forward(self, batch): + def forward(self, batch): outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0] outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0] - reward_token_pos_w = np.where(batch['inputs_w']['input_ids'][0].cpu() == self.reward_token_ids)[0] reward_token_pos_l = np.where(batch['inputs_l']['input_ids'][0].cpu() == self.reward_token_ids)[0] if len(reward_token_pos_w) != 1 or len(reward_token_pos_l) != 1: raise ValueError('More than one token detected in replica') - outputs_w_1 = outputs_w[:reward_token_pos_w[0]] - outputs_w_2 = outputs_w[reward_token_pos_w[0]+1:] + outputs_w_1 = outputs_w[: reward_token_pos_w[0]] + outputs_w_2 = outputs_w[reward_token_pos_w[0] + 1 :] outputs_w_cat = torch.cat((outputs_w_1, outputs_w_2), dim=0) lm_logits = self.lm_head(outputs_w_cat) @@ -64,7 +67,7 @@ def forward(self, batch): rm_logits_l = self.rm_head(outputs_l[reward_token_pos_l[0]]) return lm_logits, rm_logits_w, rm_logits_l, reward_token_pos_w - + class TrainMultiheadStrategy(BaseTrainStrategy[RMTrainExperimentSettings]): @staticmethod @@ -106,7 +109,7 @@ def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> Traini remove_unused_columns=False, **experiment_settings.trainer_settings.dict(), ) - + @staticmethod def _load_model( experiment_settings: RMTrainExperimentSettings, @@ -115,7 +118,6 @@ def _load_model( config = AutoConfig.from_pretrained(experiment_settings.model_settings.model_path) return MultiheadModel(config, experiment_settings.model_settings, tokenizer) - @staticmethod def _get_trainer( training_args: TrainingArguments, diff --git a/turbo_alignment/trainers/sft_with_rm.py b/turbo_alignment/trainers/sft_with_rm.py index c40c6f1..a492388 100644 --- a/turbo_alignment/trainers/sft_with_rm.py +++ b/turbo_alignment/trainers/sft_with_rm.py @@ -4,7 +4,7 @@ import torch from torch import nn from pathlib import Path -from transformers import PreTrainedModel, Trainer +from transformers import PreTrainedModel from transformers.trainer_pt_utils import nested_detach from transformers.utils import logging @@ -16,7 +16,6 @@ class SFTwithRMTrainer(MultiGPUCherryPicksTrainer): def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: - sft_logits, rewards_w, rewards_l, reward_token_pos_w = model.forward(inputs) sft_logits = sft_logits.view(-1, sft_logits.size(-1)) @@ -25,11 +24,13 @@ def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tenso # print(f'INPUTS TEXT SHAPE: {sft_labels.shape}') # print(f'INPUTS TEXT: {sft_labels}') - sft_labels_1 = sft_labels.view(-1)[:reward_token_pos_w[0]] - sft_labels_2 = sft_labels.view(-1)[reward_token_pos_w[0]+1:] + sft_labels_1 = sft_labels.view(-1)[: reward_token_pos_w[0]] + sft_labels_2 = sft_labels.view(-1)[reward_token_pos_w[0] + 1 :] sft_labels_cat = torch.cat((sft_labels_1, sft_labels_2), dim=0) - loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy(sft_logits, sft_labels_cat) + loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy( + sft_logits, sft_labels_cat + ) if return_outputs: return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l} return loss @@ -41,7 +42,6 @@ def prediction_step( prediction_loss_only: bool, ignore_keys: list[str] | None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - inputs = self._prepare_inputs(inputs) if ignore_keys is None: @@ -67,7 +67,6 @@ def prediction_step( labels = labels.long() return loss, logits, labels - def _save_checkpoint(self, model, trial, metrics=None): logger.info('Running custom _save_checkpoint')