Skip to content

Commit

Permalink
🤲 DPO with margin and pipeline fixes
Browse files Browse the repository at this point in the history
🤲 DPO with margin and pipeline fixes
  • Loading branch information
syrn1k authored Sep 5, 2024
2 parents e22595f + 323e99f commit 633f198
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 57 deletions.
8 changes: 7 additions & 1 deletion turbo_alignment/dataset/chat/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any

from pydantic import BaseModel
from pydantic import BaseModel, field_validator

from turbo_alignment.dataset.base.models import DatasetRecord

Expand All @@ -17,6 +17,12 @@ class ChatMessage(BaseModel):
content: str
disable_loss: bool = False

@field_validator('role', mode='before')
def set_bot_role(cls, values: str) -> str:
if values == 'assistant':
return ChatMessageRole.BOT
return values


class ChatDatasetRecord(DatasetRecord):
messages: list[ChatMessage]
Expand Down
19 changes: 12 additions & 7 deletions turbo_alignment/dataset/pair_preferences/collators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, Iterable

import torch
import transformers
Expand All @@ -19,7 +19,9 @@ def _get_batch(
) -> transformers.BatchEncoding:
features = [ex[key] for ex in examples]
labels = [v.tolist() for feature in features for k, v in feature.items() if k == 'labels']
no_labels_features = [{k: v for k, v in feature.items() if k != 'labels'} for feature in features]
no_labels_features = [
{k: v for k, v in feature.items() if k not in ['labels', 'precomputed_margin']} for feature in features
]

batch = tokenizer.pad(
no_labels_features,
Expand All @@ -38,13 +40,16 @@ def __call__(self, examples: list[dict[str, dict[str, Any]]]) -> dict[str, Any]:
max_length = 0
for ex in examples:
for t in ex:
if 'input_ids' in ex[t]:
max_length = max(max_length, len(ex[t]['input_ids']))
if isinstance(ex[t], Iterable):
if 'input_ids' in ex[t]:
max_length = max(max_length, len(ex[t]['input_ids']))

batch = {
batch: dict[str, Any] = {
'inputs_w': dict(self._get_batch(examples, self.tokenizer, 'inputs_w', max_length)),
'inputs_l': dict(self._get_batch(examples, self.tokenizer, 'inputs_l', max_length)),
}
if 'best_decode' in examples[0] and len(examples[0]['best_decode']) != 0:
batch['best_decode'] = dict(self._get_batch(examples, self.tokenizer, 'best_decode', max_length))

if 'precomputed_margin' in examples[0] and examples[0]['precomputed_margin'] is not None:
batch['precomputed_margin'] = torch.tensor([ex['precomputed_margin'] for ex in examples])

return batch
2 changes: 1 addition & 1 deletion turbo_alignment/dataset/pair_preferences/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ class PairPreferenceRecord(DatasetRecord):
context: list[ChatMessage]
answer_w: ChatMessage
answer_l: ChatMessage
best_decode: ChatMessage | None = None
precomputed_margin: float | None = None
22 changes: 4 additions & 18 deletions turbo_alignment/dataset/pair_preferences/pair_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(
def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, Any] | None]:
chosen_chat_records: list[ChatDatasetRecord] = []
rejected_chat_records: list[ChatDatasetRecord] = []
best_decode_records: list[ChatDatasetRecord] = []

for record in records:
context = [
Expand All @@ -60,45 +59,32 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str,
chosen = ChatMessage(role=record.answer_w.role, content=record.answer_w.content)
rejected = ChatMessage(role=record.answer_l.role, content=record.answer_l.content)

if record.best_decode is not None:
best_decode = ChatMessage(role=record.best_decode.role, content=record.best_decode.content)
best_decode_records.append(ChatDatasetRecord(id=record.id, messages=context + [best_decode]))

chosen_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [chosen]))
rejected_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [rejected]))

tokenized_chosen_records = self._chat_dataset.convert_records(chosen_chat_records)
tokenized_rejected_records = self._chat_dataset.convert_records(rejected_chat_records)

if len(best_decode_records) != 0:
tokenized_best_decode_records = self._chat_dataset.convert_records(best_decode_records)
else:
tokenized_best_decode_records = [None] * len(tokenized_chosen_records)

output: list[dict[str, Any] | None] = []
for record, chosen_record, rejected_record, best_decode_record in zip(
records, tokenized_chosen_records, tokenized_rejected_records, tokenized_best_decode_records
for record, chosen_record, rejected_record in zip(
records, tokenized_chosen_records, tokenized_rejected_records
):
if not (chosen_record and rejected_record):
continue

ignore_keys = []
ignore_keys = ['precomputed_margin']
if not self._add_labels:
ignore_keys.append('labels')

chosen_tokens = {k: v.squeeze(0) for k, v in chosen_record.items() if k not in ignore_keys}
rejected_tokens = {k: v.squeeze(0) for k, v in rejected_record.items() if k not in ignore_keys}
best_decode_tokens = {}

if best_decode_record is not None:
best_decode_tokens = {k: v.squeeze(0) for k, v in best_decode_record.items() if k not in ignore_keys}

output.append(
{
'id': record.id,
'inputs_w': chosen_tokens,
'inputs_l': rejected_tokens,
'best_decode': best_decode_tokens,
'precomputed_margin': record.precomputed_margin,
}
)

Expand Down
9 changes: 7 additions & 2 deletions turbo_alignment/pipelines/sampling/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils.operations import gather_object
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.common.tf.loaders.model.model import load_model
Expand Down Expand Up @@ -36,7 +37,9 @@ def _get_rewards(
batch=experiment_settings.rm_batch_size,
micro_batch=experiment_settings.rm_batch_size,
)
outputs: list[RMSamplingInferenceOutput] = generator.generate_from_dataset(dataset)
outputs: list[RMSamplingInferenceOutput] = gather_object(generator.generate_from_dataset(dataset))[
: len(dataset)
]
return {out.id: out.rewards for out in outputs}

@staticmethod
Expand All @@ -59,7 +62,9 @@ def sample(self, experiment_settings: SamplingSettingsWithRMT) -> list[SamplingD

model = load_model(experiment_settings.rm, tokenizer)
if accelerator is not None:
model.to(accelerator.device)
model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)

model.eval()

dataset = SamplingRMDataset(
source=experiment_settings.dataset_settings.sources[0],
Expand Down
6 changes: 6 additions & 0 deletions turbo_alignment/settings/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class DPOLossesType(str, Enum):
SIGMOID = 'sigmoid'
SIGMOID_WITH_MARGIN = 'sigmoid_with_margin'
HINGE = 'hinge'
IPO = 'ipo'
KTO = 'kto'
Expand All @@ -37,6 +38,10 @@ class SigmoidLossSettings(DPOLossSettings):
label_smoothing: float = 0


class SigmoidLossWithMarginSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.SIGMOID_WITH_MARGIN]


class HingeLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.HINGE]
norm: bool = True
Expand Down Expand Up @@ -86,6 +91,7 @@ class DPOTrainerSettings(TrainerSettings):
| ORPOLossSettings
| SimPOLossSettings
| SlicHfLossSettings
| SigmoidLossWithMarginSettings
)
sync_ref_settings: SyncRefModelSettings
use_ref_model: bool = True
Expand Down
75 changes: 55 additions & 20 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
KTOLossSettings,
ORPOLossSettings,
SigmoidLossSettings,
SigmoidLossWithMarginSettings,
SimPOLossSettings,
SlicHfLossSettings,
SyncRefModelSettings,
Expand All @@ -57,7 +58,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -91,7 +92,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
Expand Down Expand Up @@ -125,7 +126,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -153,7 +154,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -184,7 +185,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -216,7 +217,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand All @@ -231,8 +232,8 @@ def compute_loss(
if self.norm:
loss = torch.relu(self.delta - self.beta * logits)

if policy_best_decode_logps is not None:
loss = loss - self.lam * policy_best_decode_logps
if precomputed_margins is not None:
loss = loss - self.lam * precomputed_margins

return loss, chosen_rewards, rejected_rewards

Expand All @@ -250,7 +251,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor | None,
reference_rejected_logps: torch.FloatTensor | None,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps

Expand Down Expand Up @@ -280,7 +281,7 @@ def compute_loss(
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor | None,
reference_rejected_logps: torch.FloatTensor | None,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7))
Expand All @@ -296,6 +297,40 @@ def compute_loss(
return losses, chosen_rewards, rejected_rewards


@DPOLossRegistry.register(DPOLossesType.SIGMOID_WITH_MARGIN)
class SigmoidLossWithMargin(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
self.beta = beta
super().__init__(*args, **kwargs)

def compute_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps

logits = pi_logratios - ref_logratios

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

if precomputed_margins is None:
raise ValueError('Precomputed margins should not be none when using SigmoidLossWithMargin')

loss = -F.logsigmoid(self.beta * logits - precomputed_margins)

return (
loss,
chosen_rewards,
rejected_rewards,
)


@dataclass
class DPOTrainingArguments(TrainingArguments):
loss_settings: (
Expand All @@ -308,6 +343,7 @@ class DPOTrainingArguments(TrainingArguments):
| ORPOLossSettings
| SimPOLossSettings
| SlicHfLossSettings
| SigmoidLossWithMarginSettings
) = field(
default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
) # type: ignore[call-overload]
Expand Down Expand Up @@ -402,14 +438,14 @@ def dpo_loss(
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
policy_best_decode_logps: torch.Tensor | None,
precomputed_margins: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.dpo_loss_registry.compute_loss(
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
policy_best_decode_logps=policy_best_decode_logps,
precomputed_margins=precomputed_margins,
)

def _get_batch_logps(
Expand Down Expand Up @@ -437,6 +473,9 @@ def concatenated_forward(
self, model: nn.Module, batch: dict[str, Any]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)

precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None)

all_logits = model(
concatenated_batch['input_ids'],
attention_mask=concatenated_batch['attention_mask'],
Expand All @@ -452,13 +491,9 @@ def concatenated_forward(
chosen_logps = all_logps[:chosen_idxs]
rejected_logps = all_logps[chosen_idxs : chosen_idxs + rejected_idx]

policy_best_decode_logps: torch.Tensor = all_logps[chosen_idxs + rejected_idx :]
if len(policy_best_decode_logps) == 0:
policy_best_decode_logps = None # type: ignore[assignment]

chosen_logits = all_logits[:chosen_idxs]
rejected_logits = all_logits[chosen_idxs:]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, policy_best_decode_logps
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins

def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
Expand Down Expand Up @@ -487,7 +522,7 @@ def get_batch_metrics(
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_best_decode_logps,
precomputed_margins,
) = self.concatenated_forward(model, batch)

reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')])
Expand All @@ -500,7 +535,7 @@ def get_batch_metrics(
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
policy_best_decode_logps=policy_best_decode_logps,
precomputed_margins=precomputed_margins,
)

prefix = 'eval_' if train_eval == 'eval' else ''
Expand Down Expand Up @@ -561,7 +596,7 @@ def get_batch_metrics(
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=sft_chosen_logps,
reference_rejected_logps=sft_rejected_logps,
policy_best_decode_logps=policy_best_decode_logps,
precomputed_margins=precomputed_margins,
)

sft_prefix_name = prefix + 'rewards/sft_'
Expand Down
Loading

0 comments on commit 633f198

Please sign in to comment.