Skip to content

Commit

Permalink
Merge pull request #65 from turbo-llm/nca_pair_loss
Browse files Browse the repository at this point in the history
Nca pair loss
  • Loading branch information
oltsy authored Dec 13, 2024
2 parents 74ac8a9 + cc063a9 commit 12b29d4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
6 changes: 6 additions & 0 deletions turbo_alignment/settings/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DPOLossesType(str, Enum):
APO_DOWN = 'apo_down'
ASFT = 'asft'
DPOP = 'dpop'
NCA_PAIR = 'nca_pair'


class DPOLossSettings(ExtraFieldsNotAllowedBaseModel):
Expand Down Expand Up @@ -93,6 +94,10 @@ class DPOPLossSettings(DPOLossSettings):
lam: float = 0.1


class NCAPairLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.NCA_PAIR]


class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel):
sync_ref_model: bool = False
alpha: float = 1.0
Expand All @@ -114,6 +119,7 @@ class DPOTrainerSettings(TrainerSettings):
| APOZeroLossSettings
| APODownLossSettings
| DPOPLossSettings
| NCAPairLossSettings
)
sync_ref_settings: SyncRefModelSettings
use_ref_model: bool = True
Expand Down
31 changes: 31 additions & 0 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
HingeLossSettings,
IPOLossSettings,
KTOLossSettings,
NCAPairLossSettings,
ORPOLossSettings,
SigmoidLossSettings,
SigmoidLossWithMarginSettings,
Expand Down Expand Up @@ -456,6 +457,35 @@ def compute_loss(
return loss, chosen_rewards, rejected_rewards


@DPOLossRegistry.register(DPOLossesType.NCA_PAIR)
class NCAPairLoss(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
self.beta = beta
super().__init__(*args, **kwargs)

def compute_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

loss = (
-F.logsigmoid(chosen_logratios * self.beta)
- 0.5 * F.logsigmoid(-chosen_logratios * self.beta)
- 0.5 * F.logsigmoid(-rejected_logratios * self.beta)
)

return loss, chosen_rewards, rejected_rewards


@dataclass
class DPOTrainingArguments(TrainingArguments):
loss_settings: (
Expand All @@ -473,6 +503,7 @@ class DPOTrainingArguments(TrainingArguments):
| APOZeroLossSettings
| APODownLossSettings
| DPOPLossSettings
| NCAPairLossSettings
) = field(
default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
) # type: ignore[call-overload]
Expand Down

0 comments on commit 12b29d4

Please sign in to comment.