Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Nov 11, 2024
1 parent 7bb08fd commit 10738ec
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 33 deletions.
6 changes: 6 additions & 0 deletions turbo_alignment/dataset/pair_preferences/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,10 @@ def __call__(self, examples: list[dict[str, dict[str, Any]]]) -> dict[str, Any]:
if 'precomputed_margin' in examples[0] and examples[0]['precomputed_margin'] is not None:
batch['precomputed_margin'] = torch.tensor([ex['precomputed_margin'] for ex in examples])

if 'ref_chosen_logps' in examples[0] and examples[0]['ref_chosen_logps'] is not None:
batch['ref_chosen_logps'] = torch.tensor([ex['ref_chosen_logps'] for ex in examples])

if 'ref_rejected_logps' in examples[0] and examples[0]['ref_rejected_logps'] is not None:
batch['ref_rejected_logps'] = torch.tensor([ex['ref_rejected_logps'] for ex in examples])

return batch
2 changes: 2 additions & 0 deletions turbo_alignment/dataset/pair_preferences/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ class PairPreferenceRecord(DatasetRecord):
answer_w: ChatMessage
answer_l: ChatMessage
precomputed_margin: float | None = None
ref_chosen_logps: float | None = None
ref_rejected_logps: float | None = None
4 changes: 3 additions & 1 deletion turbo_alignment/dataset/pair_preferences/pair_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str,
if not (chosen_record and rejected_record):
continue

ignore_keys = ['precomputed_margin']
ignore_keys = ['precomputed_margin', 'ref_chosen_logps', 'ref_rejected_logps']
if not self._add_labels:
ignore_keys.append('labels')

Expand All @@ -85,6 +85,8 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str,
'inputs_w': chosen_tokens,
'inputs_l': rejected_tokens,
'precomputed_margin': record.precomputed_margin,
'ref_chosen_logps': record.ref_chosen_logps,
'ref_rejected_logps': record.ref_rejected_logps,
}
)

Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/metrics/distinctness.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict

from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from turbo_alignment.metrics.metric import Metric
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/metrics/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum

from pydantic import field_validator

from turbo_alignment.common.registry import Registrable
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/pipelines/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .kto import TrainKTOStrategy
from .multimodal import TrainMultimodalStrategy
from .rag import TrainRAGStrategy
from .ref_model import ref_model_DPOStrategy
from .rm import TrainRMStrategy
from .sft import TrainSFTStrategy
from .ref_model import ref_model_DPOStrategy
11 changes: 3 additions & 8 deletions turbo_alignment/pipelines/train/ref_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Callable

from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset, Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from torch.utils.data import ConcatDataset, Dataset

from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -61,7 +60,6 @@ def _get_trainer(
data_collator: Callable,
output_path='logits_outputs.jsonl',
):

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
ref_model = load_model(experiment_settings.model_settings, tokenizer)
Expand Down Expand Up @@ -109,8 +107,6 @@ def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCol
'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0])
)



def run(self, experiment_settings: DPOTrainExperimentSettings) -> None:
training_args = self._get_training_args(experiment_settings)

Expand Down Expand Up @@ -154,16 +150,15 @@ def run(self, experiment_settings: DPOTrainExperimentSettings) -> None:
train_dataset,
val_dataset,
data_collator,
output_path='logits_outputs.jsonl'
output_path='logits_outputs.jsonl',
)

if self.trainer.accelerator.is_main_process:
self._dataset_and_collator_sanity_check(train_dataset, data_collator)

self._add_trainer_callbacks(experiment_settings)


print('👀'*15, 'ref_model')
print('👀' * 15, 'ref_model')

self.trainer.train()

Expand Down
33 changes: 21 additions & 12 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from turbo_alignment.settings.pipelines.train.dpo import (
APODownLossSettings,
APOZeroLossSettings,
ASFTLossSettings,
CPOLossSettings,
DPOLossesType,
ASFTLossSettings,
HingeLossSettings,
IPOLossSettings,
KTOLossSettings,
Expand Down Expand Up @@ -575,12 +575,18 @@ def _get_batch_logps(
return (per_token_logps * loss_mask).sum(-1)

def concatenated_forward(
self, model: nn.Module, batch: dict[str, Any]
self,
model: nn.Module,
batch: dict[str, Any],
get_from_dataset=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)

precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None)

if get_from_dataset:
return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps')

all_logits = model(
concatenated_batch['input_ids'],
attention_mask=concatenated_batch['attention_mask'],
Expand All @@ -601,16 +607,19 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins

def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)
if self.ref_model is None:
chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True)
else:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)

return chosen_logps, rejected_logps

Expand Down
16 changes: 5 additions & 11 deletions turbo_alignment/trainers/ref_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import gc
from pathlib import Path
from typing import Any, Callable, Literal

import gc

import torch
from torch import nn
from torch.utils.data import Dataset
Expand All @@ -16,15 +15,11 @@
from turbo_alignment.common.data.io import write_jsonl
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.constants import DISABLE_LOSS_LABEL
from turbo_alignment.trainers.utils import (
concatenated_inputs,
prepare_model,
)
from turbo_alignment.trainers.utils import concatenated_inputs, prepare_model

logger = get_project_logger()



class ref_model_DPOTrainer(Trainer):
"""
Inspired by https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py
Expand Down Expand Up @@ -62,7 +57,6 @@ def __init__(
)
self.model = prepare_model(ref_model, self.accelerator, self.is_deepspeed_enabled)


def _get_batch_logps(
self,
logits: torch.Tensor,
Expand Down Expand Up @@ -130,7 +124,6 @@ def get_batch_metrics(
batch: dict[str, Any],
train_eval: Literal['train', 'eval'] = 'train',
) -> tuple[torch.Tensor, dict[str, float]]:

reference_chosen_logps, reference_rejected_logps = self._get_logps(self.model, batch)
return reference_chosen_logps, reference_rejected_logps

Expand All @@ -141,7 +134,9 @@ def compute_loss(
return_outputs=False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
reference_chosen_logps, reference_rejected_logps = self.get_batch_metrics(self.model, inputs, train_eval='train')
reference_chosen_logps, reference_rejected_logps = self.get_batch_metrics(
self.model, inputs, train_eval='train'
)

if self.accelerator.is_main_process:
data = {
Expand All @@ -163,7 +158,6 @@ def prediction_step(
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:

if prediction_loss_only:
return torch.tensor([0.0]).detach(), None, None

Expand Down

0 comments on commit 10738ec

Please sign in to comment.