diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..987f0cdcf 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -9,7 +9,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0 + ): """ Paper: https://arxiv.org/pdf/2401.08417 @@ -30,9 +32,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). full_target (torch.Tensor): Non chunked full target tensor beta (float): Weight for the CPO loss + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + F.logsigmoid(logits) * (1 - label_smoothing) + + F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -45,6 +52,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=True, compiled=True, ): @@ -58,6 +66,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -65,7 +74,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -78,6 +87,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, ): @@ -90,6 +100,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled @@ -102,6 +113,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, ) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 7efa0603d..2dc9f1a6b 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 + chosen_logps, + rejected_logps, + full_target, + beta=0.1, + gamma=0.5, + label_smoothing=0.0, ): """ Paper: https://arxiv.org/pdf/2405.14734 @@ -33,9 +38,14 @@ def preference_loss_fn( full_target: Non chunked full target tensor beta (float): beta weight gamma (float): gemma margin term + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - gamma - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + F.logsigmoid(logits) * (1 - label_smoothing) + + F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -48,6 +58,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=False, compiled=True, gamma=0.5, @@ -63,6 +74,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compiled=compiled, gamma=gamma, ) @@ -70,7 +82,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearSimPOLoss(torch.nn.Module): @@ -83,6 +95,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, @@ -96,6 +109,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.gamma = gamma @@ -109,6 +123,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, self.gamma, diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index f0fef7734..a0c4050e5 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -86,6 +86,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, ): @@ -97,6 +98,7 @@ def __init__( ignore_index=ignore_index, beta=beta, loss_type=loss_type, + label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics @@ -114,13 +116,17 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -145,8 +151,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + alpha, + label_smoothing, ): B = 2 * B # cpo loss requires B to be even @@ -157,6 +176,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -165,6 +185,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 3d0937c27..eede598fe 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -25,6 +25,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, gamma: float = 0.5, ): super().__init__() @@ -32,7 +33,11 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.simpo_loss = LigerFusedLinearSimPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + gamma=gamma, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -57,8 +62,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + gamma, + label_smoothing, ): B = 2 * B # SimPO loss requires B to be even @@ -70,6 +88,7 @@ def test_correctness( ignore_index=ignore_index, beta=beta, loss_type="simpo", + label_smoothing=label_smoothing, simpo_gamma=gamma, ) liger_lm_head_simpo = LigerLMHeadSimPO( @@ -79,6 +98,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, gamma=gamma, )