Skip to content

Commit

Permalink
CPO & SimPO add label_smoothing (#493)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Add label_smoothing support for CPO and SimPO so that they align with
the huggingface
[interface](https://github.com/huggingface/trl/blob/b668048fe1931c57796ad5ae3f10852337ce7565/trl/trainer/cpo_trainer.py#L645C1-L658C14).
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
- [x] Something wrong with the unit test. I'll have to fix it
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 authored Dec 20, 2024
1 parent 3205342 commit 79e2b02
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 10 deletions.
18 changes: 15 additions & 3 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
):
"""
Paper: https://arxiv.org/pdf/2401.08417
Expand All @@ -30,9 +32,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = (
F.logsigmoid(logits) * (1 - label_smoothing)
+ F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss

@staticmethod
Expand All @@ -45,6 +52,7 @@ def forward(
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=True,
compiled=True,
):
Expand All @@ -58,14 +66,15 @@ def forward(
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None
return *grads, None, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
Expand All @@ -78,6 +87,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
):
Expand All @@ -90,6 +100,7 @@ def __init__(
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

Expand All @@ -102,6 +113,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
)
21 changes: 18 additions & 3 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
chosen_logps,
rejected_logps,
full_target,
beta=0.1,
gamma=0.5,
label_smoothing=0.0,
):
"""
Paper: https://arxiv.org/pdf/2405.14734
Expand All @@ -33,9 +38,14 @@ def preference_loss_fn(
full_target: Non chunked full target tensor
beta (float): beta weight
gamma (float): gemma margin term
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = (
F.logsigmoid(logits) * (1 - label_smoothing)
+ F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss

@staticmethod
Expand All @@ -48,6 +58,7 @@ def forward(
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=False,
compiled=True,
gamma=0.5,
Expand All @@ -63,14 +74,15 @@ def forward(
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compiled=compiled,
gamma=gamma,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None


class LigerFusedLinearSimPOLoss(torch.nn.Module):
Expand All @@ -83,6 +95,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
Expand All @@ -96,6 +109,7 @@ def __init__(
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma
Expand All @@ -109,6 +123,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
self.gamma,
Expand Down
25 changes: 23 additions & 2 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
loss_type: str = "sigmoid",
simpo_gamma: float = 0.5,
):
Expand All @@ -97,6 +98,7 @@ def __init__(
ignore_index=ignore_index,
beta=beta,
loss_type=loss_type,
label_smoothing=label_smoothing,
simpo_gamma=simpo_gamma,
).get_batch_loss_metrics

Expand All @@ -114,13 +116,17 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.cpo_loss = LigerFusedLinearCPOLoss(
ignore_index=ignore_index, beta=beta, alpha=alpha
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
label_smoothing=label_smoothing,
)

def forward(self, x, y):
Expand All @@ -145,8 +151,21 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]
)
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ignore_index,
beta,
alpha,
label_smoothing,
):
B = 2 * B # cpo loss requires B to be even

Expand All @@ -157,6 +176,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)
liger_lm_head_cpo = LigerLMHeadCPO(
H=H,
Expand All @@ -165,6 +185,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)

torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(
Expand Down
24 changes: 22 additions & 2 deletions test/chunked_loss/test_simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
gamma: float = 0.5,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.simpo_loss = LigerFusedLinearSimPOLoss(
ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
gamma=gamma,
label_smoothing=label_smoothing,
)

def forward(self, x, y):
Expand All @@ -57,8 +62,21 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)]
)
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ignore_index,
beta,
gamma,
label_smoothing,
):
B = 2 * B # SimPO loss requires B to be even

Expand All @@ -70,6 +88,7 @@ def test_correctness(
ignore_index=ignore_index,
beta=beta,
loss_type="simpo",
label_smoothing=label_smoothing,
simpo_gamma=gamma,
)
liger_lm_head_simpo = LigerLMHeadSimPO(
Expand All @@ -79,6 +98,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
gamma=gamma,
)

Expand Down

0 comments on commit 79e2b02

Please sign in to comment.