From 323e99f1e9db3be295ad6bdafb62323ca2e9986a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Tue, 3 Sep 2024 15:34:30 +0000 Subject: [PATCH] fix linters and tests --- .../dataset/pair_preferences/pair_preference.py | 14 ++------------ turbo_alignment/pipelines/sampling/rm.py | 3 +-- turbo_alignment/settings/pipelines/train/dpo.py | 8 +++++--- turbo_alignment/trainers/dpo.py | 7 +++---- turbo_alignment/trainers/rm.py | 2 -- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/turbo_alignment/dataset/pair_preferences/pair_preference.py b/turbo_alignment/dataset/pair_preferences/pair_preference.py index e01a5fe..fc6ff70 100755 --- a/turbo_alignment/dataset/pair_preferences/pair_preference.py +++ b/turbo_alignment/dataset/pair_preferences/pair_preference.py @@ -49,7 +49,6 @@ def __init__( def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, Any] | None]: chosen_chat_records: list[ChatDatasetRecord] = [] rejected_chat_records: list[ChatDatasetRecord] = [] - best_decode_records: list[ChatDatasetRecord] = [] for record in records: context = [ @@ -60,24 +59,15 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, chosen = ChatMessage(role=record.answer_w.role, content=record.answer_w.content) rejected = ChatMessage(role=record.answer_l.role, content=record.answer_l.content) - if record.best_decode is not None: - best_decode = ChatMessage(role=record.best_decode.role, content=record.best_decode.content) - best_decode_records.append(ChatDatasetRecord(id=record.id, messages=context + [best_decode])) - chosen_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [chosen])) rejected_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [rejected])) tokenized_chosen_records = self._chat_dataset.convert_records(chosen_chat_records) tokenized_rejected_records = self._chat_dataset.convert_records(rejected_chat_records) - if len(best_decode_records) != 0: - tokenized_best_decode_records = self._chat_dataset.convert_records(best_decode_records) - else: - tokenized_best_decode_records = [None] * len(tokenized_chosen_records) - output: list[dict[str, Any] | None] = [] - for record, chosen_record, rejected_record, best_decode_record in zip( - records, tokenized_chosen_records, tokenized_rejected_records, tokenized_best_decode_records + for record, chosen_record, rejected_record in zip( + records, tokenized_chosen_records, tokenized_rejected_records ): if not (chosen_record and rejected_record): continue diff --git a/turbo_alignment/pipelines/sampling/rm.py b/turbo_alignment/pipelines/sampling/rm.py index 49c7783..86b9f8c 100755 --- a/turbo_alignment/pipelines/sampling/rm.py +++ b/turbo_alignment/pipelines/sampling/rm.py @@ -4,7 +4,6 @@ from accelerate import Accelerator from accelerate.utils import set_seed from accelerate.utils.operations import gather_object - from transformers import PreTrainedModel, PreTrainedTokenizerBase from turbo_alignment.common.tf.loaders.model.model import load_model @@ -64,7 +63,7 @@ def sample(self, experiment_settings: SamplingSettingsWithRMT) -> list[SamplingD model = load_model(experiment_settings.rm, tokenizer) if accelerator is not None: model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) - + model.eval() dataset = SamplingRMDataset( diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index b29e277..9464031 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -14,7 +14,7 @@ class DPOLossesType(str, Enum): SIGMOID = 'sigmoid' - SIGMOD_WITH_MARGIN = 'sigmoid_with_margin' + SIGMOID_WITH_MARGIN = 'sigmoid_with_margin' HINGE = 'hinge' IPO = 'ipo' KTO = 'kto' @@ -37,8 +37,10 @@ class SigmoidLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.SIGMOID] label_smoothing: float = 0 -class SigmoidLossWithMarginSettings(SigmoidLossSettings): - loss_type: Literal[DPOLossesType.SIGMOD_WITH_MARGIN] + +class SigmoidLossWithMarginSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SIGMOID_WITH_MARGIN] + class HingeLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.HINGE] diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index f9b654a..0ce83db 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -31,6 +31,7 @@ KTOLossSettings, ORPOLossSettings, SigmoidLossSettings, + SigmoidLossWithMarginSettings, SimPOLossSettings, SlicHfLossSettings, SyncRefModelSettings, @@ -294,14 +295,13 @@ def compute_loss( rejected_rewards = self.beta * policy_rejected_logps.detach() return losses, chosen_rewards, rejected_rewards - + @DPOLossRegistry.register(DPOLossesType.SIGMOID_WITH_MARGIN) class SigmoidLossWithMargin(DPOLossRegistry): def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: self.beta = beta - super().__ini__(*args, **kwargs) - + super().__init__(*args, **kwargs) def compute_loss( self, @@ -309,7 +309,6 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps diff --git a/turbo_alignment/trainers/rm.py b/turbo_alignment/trainers/rm.py index 2ecf755..d78a7e9 100755 --- a/turbo_alignment/trainers/rm.py +++ b/turbo_alignment/trainers/rm.py @@ -9,12 +9,10 @@ from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer from turbo_alignment.trainers.utils import concatenated_inputs - logger = logging.get_logger(__name__) class RMTrainer(MultiGPUCherryPicksTrainer): - def concatenated_forward(self, model: nn.Module, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) all_rewards = model(