Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support R-DPO #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions turbo_alignment/settings/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DPOLossesType(str, Enum):
SIMPO = 'simpo'
APO_ZERO = 'apo_zero'
APO_DOWN = 'apo_down'
RDPO = 'rdpo'


class DPOLossSettings(ExtraFieldsNotAllowedBaseModel):
Expand Down Expand Up @@ -85,6 +86,12 @@ class APODownLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.APO_DOWN]


class RDPOLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.RDPO]
beta: float = 0.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
beta: float = 0.1

alpha: float = 0.1


class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel):
sync_ref_model: bool = False
alpha: float = 1.0
Expand All @@ -104,6 +111,7 @@ class DPOTrainerSettings(TrainerSettings):
| SigmoidLossWithMarginSettings
| APOZeroLossSettings
| APODownLossSettings
| RDPOLossSettings
)
sync_ref_settings: SyncRefModelSettings
use_ref_model: bool = True
Expand Down
83 changes: 82 additions & 1 deletion turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
SimPOLossSettings,
SlicHfLossSettings,
SyncRefModelSettings,
RDPOLossSettings,
)
from turbo_alignment.trainers.utils import (
DPOLossRegistry,
Expand All @@ -61,6 +62,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int | None,
rejected_lens: int | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -95,6 +98,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
Expand Down Expand Up @@ -129,6 +134,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -157,6 +164,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -188,6 +197,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -220,6 +231,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -254,6 +267,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor | None,
reference_rejected_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps

Expand All @@ -271,6 +286,46 @@ def compute_loss(
)


@DPOLossRegistry.register(DPOLossesType.RDPO)
class RDPOLoss(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, alpha: float = 0, **kwargs) -> None:
self.beta = beta
self.alpha = alpha
super().__init__(*args, **kwargs)

def compute_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
policy_best_decode_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps

logits = pi_logratios - ref_logratios

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

unscaled = self.beta * logits

if chosen_lens is not None and rejected_lens is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

обязательно же не None должно быть

unscaled -= self.alpha * rejected_lens - self.alpha * chosen_lens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

мне визуально так приятней и проще, но и так как у тебя ок

Suggested change
unscaled -= self.alpha * rejected_lens - self.alpha * chosen_lens
unscaled += self.alpha * (chosen_lens - rejected_lens)


loss = -F.logsigmoid(unscaled)

return (
loss,
chosen_rewards,
rejected_rewards,
)


@DPOLossRegistry.register(DPOLossesType.ORPO)
class ORPOLoss(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
Expand All @@ -284,6 +339,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor | None,
reference_rejected_logps: torch.FloatTensor | None,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7))
Expand Down Expand Up @@ -312,6 +369,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
Expand Down Expand Up @@ -346,6 +405,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
Expand Down Expand Up @@ -378,6 +439,8 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
Expand Down Expand Up @@ -412,6 +475,7 @@ class DPOTrainingArguments(TrainingArguments):
| SigmoidLossWithMarginSettings
| APOZeroLossSettings
| APODownLossSettings
| RDPOLossSettings
) = field(
default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
) # type: ignore[call-overload]
Expand Down Expand Up @@ -507,13 +571,17 @@ def dpo_loss(
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
precomputed_margins: torch.Tensor | None,
chosen_lens: int| None,
rejected_lens: int| None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.dpo_loss_registry.compute_loss(
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
precomputed_margins=precomputed_margins,
chosen_lens=chosen_lens,
rejected_lens=rejected_lens,
)

def _get_batch_logps(
Expand Down Expand Up @@ -559,9 +627,16 @@ def concatenated_forward(
chosen_logps = all_logps[:chosen_idxs]
rejected_logps = all_logps[chosen_idxs : chosen_idxs + rejected_idx]

if self.loss_type == DPOLossesType.RDPO:
chosen_lens = len(torch.where(batch['inputs_w']['labels'] != -100)[1])
rejected_lens = len(torch.where(batch['inputs_l']['labels'] != -100)[1])
else:
chosen_lens = None
rejected_lens = None

chosen_logits = all_logits[:chosen_idxs]
rejected_logits = all_logits[chosen_idxs:]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins, chosen_lens, rejected_lens

def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
Expand Down Expand Up @@ -591,6 +666,8 @@ def get_batch_metrics(
policy_chosen_logits,
policy_rejected_logits,
precomputed_margins,
chosen_lens,
rejected_lens,
) = self.concatenated_forward(model, batch)

reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')])
Expand All @@ -604,6 +681,8 @@ def get_batch_metrics(
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
precomputed_margins=precomputed_margins,
chosen_lens=chosen_lens,
rejected_lens=rejected_lens,
)

prefix = 'eval_' if train_eval == 'eval' else ''
Expand Down Expand Up @@ -671,6 +750,8 @@ def get_batch_metrics(
reference_chosen_logps=sft_chosen_logps,
reference_rejected_logps=sft_rejected_logps,
precomputed_margins=precomputed_margins,
chosen_lens=chosen_lens,
rejected_lens=rejected_lens,
)

sft_prefix_name = prefix + 'rewards/sft_'
Expand Down
2 changes: 2 additions & 0 deletions turbo_alignment/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,7 @@ def compute_loss(
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
chosen_lens: int | None,
rejected_lens: int | None,
Comment on lines +100 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

не очень нравится идея добавлять штуку в общий абстрактный класс которая только в одном лоссе используется

хочется как-то внутри более изолировано и чисто внутри лосса это считать мб

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Согласен, что это не оч красиво. Делал по примеру precomputed_margins, ща мб попробую переписать без этого

) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...
Loading