-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support R-DPO #25
base: main
Are you sure you want to change the base?
support R-DPO #25
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -37,6 +37,7 @@ | |||||
SimPOLossSettings, | ||||||
SlicHfLossSettings, | ||||||
SyncRefModelSettings, | ||||||
RDPOLossSettings, | ||||||
) | ||||||
from turbo_alignment.trainers.utils import ( | ||||||
DPOLossRegistry, | ||||||
|
@@ -61,6 +62,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int | None, | ||||||
rejected_lens: int | None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
@@ -95,6 +98,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) | ||||||
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) | ||||||
|
@@ -129,6 +134,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
@@ -157,6 +164,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
@@ -188,6 +197,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
@@ -220,6 +231,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
@@ -254,6 +267,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor | None, | ||||||
reference_rejected_logps: torch.FloatTensor | None, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
|
||||||
|
@@ -271,6 +286,46 @@ def compute_loss( | |||||
) | ||||||
|
||||||
|
||||||
@DPOLossRegistry.register(DPOLossesType.RDPO) | ||||||
class RDPOLoss(DPOLossRegistry): | ||||||
def __init__(self, *args, beta: float = 0.1, alpha: float = 0, **kwargs) -> None: | ||||||
self.beta = beta | ||||||
self.alpha = alpha | ||||||
super().__init__(*args, **kwargs) | ||||||
|
||||||
def compute_loss( | ||||||
self, | ||||||
policy_chosen_logps: torch.FloatTensor, | ||||||
policy_rejected_logps: torch.FloatTensor, | ||||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
policy_best_decode_logps: torch.FloatTensor | None, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
||||||
logits = pi_logratios - ref_logratios | ||||||
|
||||||
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() | ||||||
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() | ||||||
|
||||||
unscaled = self.beta * logits | ||||||
|
||||||
if chosen_lens is not None and rejected_lens is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. обязательно же не None должно быть |
||||||
unscaled -= self.alpha * rejected_lens - self.alpha * chosen_lens | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. мне визуально так приятней и проще, но и так как у тебя ок
Suggested change
|
||||||
|
||||||
loss = -F.logsigmoid(unscaled) | ||||||
|
||||||
return ( | ||||||
loss, | ||||||
chosen_rewards, | ||||||
rejected_rewards, | ||||||
) | ||||||
|
||||||
|
||||||
@DPOLossRegistry.register(DPOLossesType.ORPO) | ||||||
class ORPOLoss(DPOLossRegistry): | ||||||
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: | ||||||
|
@@ -284,6 +339,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor | None, | ||||||
reference_rejected_logps: torch.FloatTensor | None, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - ( | ||||||
torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7)) | ||||||
|
@@ -312,6 +369,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps | ||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps | ||||||
|
@@ -346,6 +405,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
chosen_logratios = policy_chosen_logps - reference_chosen_logps | ||||||
rejected_logratios = policy_rejected_logps - reference_rejected_logps | ||||||
|
@@ -378,6 +439,8 @@ def compute_loss( | |||||
reference_chosen_logps: torch.FloatTensor, | ||||||
reference_rejected_logps: torch.FloatTensor, | ||||||
precomputed_margins: torch.FloatTensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
chosen_logratios = policy_chosen_logps - reference_chosen_logps | ||||||
rejected_logratios = policy_rejected_logps - reference_rejected_logps | ||||||
|
@@ -412,6 +475,7 @@ class DPOTrainingArguments(TrainingArguments): | |||||
| SigmoidLossWithMarginSettings | ||||||
| APOZeroLossSettings | ||||||
| APODownLossSettings | ||||||
| RDPOLossSettings | ||||||
) = field( | ||||||
default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) | ||||||
) # type: ignore[call-overload] | ||||||
|
@@ -507,13 +571,17 @@ def dpo_loss( | |||||
reference_chosen_logps: torch.Tensor, | ||||||
reference_rejected_logps: torch.Tensor, | ||||||
precomputed_margins: torch.Tensor | None, | ||||||
chosen_lens: int| None, | ||||||
rejected_lens: int| None, | ||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||
return self.dpo_loss_registry.compute_loss( | ||||||
policy_chosen_logps=policy_chosen_logps, | ||||||
policy_rejected_logps=policy_rejected_logps, | ||||||
reference_chosen_logps=reference_chosen_logps, | ||||||
reference_rejected_logps=reference_rejected_logps, | ||||||
precomputed_margins=precomputed_margins, | ||||||
chosen_lens=chosen_lens, | ||||||
rejected_lens=rejected_lens, | ||||||
) | ||||||
|
||||||
def _get_batch_logps( | ||||||
|
@@ -559,9 +627,16 @@ def concatenated_forward( | |||||
chosen_logps = all_logps[:chosen_idxs] | ||||||
rejected_logps = all_logps[chosen_idxs : chosen_idxs + rejected_idx] | ||||||
|
||||||
if self.loss_type == DPOLossesType.RDPO: | ||||||
chosen_lens = len(torch.where(batch['inputs_w']['labels'] != -100)[1]) | ||||||
rejected_lens = len(torch.where(batch['inputs_l']['labels'] != -100)[1]) | ||||||
else: | ||||||
chosen_lens = None | ||||||
rejected_lens = None | ||||||
|
||||||
chosen_logits = all_logits[:chosen_idxs] | ||||||
rejected_logits = all_logits[chosen_idxs:] | ||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins | ||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins, chosen_lens, rejected_lens | ||||||
|
||||||
def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: | ||||||
with torch.no_grad(): | ||||||
|
@@ -591,6 +666,8 @@ def get_batch_metrics( | |||||
policy_chosen_logits, | ||||||
policy_rejected_logits, | ||||||
precomputed_margins, | ||||||
chosen_lens, | ||||||
rejected_lens, | ||||||
) = self.concatenated_forward(model, batch) | ||||||
|
||||||
reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')]) | ||||||
|
@@ -604,6 +681,8 @@ def get_batch_metrics( | |||||
reference_chosen_logps=reference_chosen_logps, | ||||||
reference_rejected_logps=reference_rejected_logps, | ||||||
precomputed_margins=precomputed_margins, | ||||||
chosen_lens=chosen_lens, | ||||||
rejected_lens=rejected_lens, | ||||||
) | ||||||
|
||||||
prefix = 'eval_' if train_eval == 'eval' else '' | ||||||
|
@@ -671,6 +750,8 @@ def get_batch_metrics( | |||||
reference_chosen_logps=sft_chosen_logps, | ||||||
reference_rejected_logps=sft_rejected_logps, | ||||||
precomputed_margins=precomputed_margins, | ||||||
chosen_lens=chosen_lens, | ||||||
rejected_lens=rejected_lens, | ||||||
) | ||||||
|
||||||
sft_prefix_name = prefix + 'rewards/sft_' | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,5 +97,7 @@ def compute_loss( | |
reference_chosen_logps: torch.FloatTensor, | ||
reference_rejected_logps: torch.FloatTensor, | ||
precomputed_margins: torch.FloatTensor | None, | ||
chosen_lens: int | None, | ||
rejected_lens: int | None, | ||
Comment on lines
+100
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. не очень нравится идея добавлять штуку в общий абстрактный класс которая только в одном лоссе используется хочется как-то внутри более изолировано и чисто внутри лосса это считать мб There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Согласен, что это не оч красиво. Делал по примеру precomputed_margins, ща мб попробую переписать без этого |
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.