Skip to content

Commit

Permalink
Merge branch 'main' into tcc/weight-ce
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 authored Dec 22, 2024
2 parents 5535a60 + 15a2f58 commit b6253b0
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 34 deletions.
18 changes: 15 additions & 3 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
):
"""
Paper: https://arxiv.org/pdf/2401.08417
Expand All @@ -30,9 +32,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss

@staticmethod
Expand All @@ -45,6 +52,7 @@ def forward(
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=True,
compiled=True,
):
Expand All @@ -58,14 +66,15 @@ def forward(
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None
return *grads, None, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
Expand All @@ -78,6 +87,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
):
Expand All @@ -90,6 +100,7 @@ def __init__(
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

Expand All @@ -102,6 +113,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
)
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(
ref_bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compute_nll_loss=False,
compiled=True,
use_ref_model=True,
):
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = False,
):
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def _compute_loss(
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - preference_loss
loss = alpha * chosen_nll_loss + preference_loss
return_vars = (
chosen_logps,
rejected_logps,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
- torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)

chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
Expand Down
21 changes: 18 additions & 3 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
chosen_logps,
rejected_logps,
full_target,
beta=0.1,
gamma=0.5,
label_smoothing=0.0,
):
"""
Paper: https://arxiv.org/pdf/2405.14734
Expand All @@ -33,9 +38,14 @@ def preference_loss_fn(
full_target: Non chunked full target tensor
beta (float): beta weight
gamma (float): gemma margin term
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss

@staticmethod
Expand All @@ -48,6 +58,7 @@ def forward(
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=False,
compiled=True,
gamma=0.5,
Expand All @@ -63,14 +74,15 @@ def forward(
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compiled=compiled,
gamma=gamma,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None


class LigerFusedLinearSimPOLoss(torch.nn.Module):
Expand All @@ -83,6 +95,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
Expand All @@ -96,6 +109,7 @@ def __init__(
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma
Expand All @@ -109,6 +123,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
self.gamma,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/transformers/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def lce_forward_deprecated(
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
Args:
Expand Down
6 changes: 5 additions & 1 deletion src/liger_kernel/transformers/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class _FSDPForwardRedirection:
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
Expand Down Expand Up @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
outputs.last_hidden_state,
concatenated_batch["concatenated_labels"],
)
# if aux_loss_enabled, add the aux_loss to the orpo_loss
if self.aux_loss_enabled:
orpo_loss += self.aux_loss_coef * outputs.aux_loss

return orpo_loss, aux_outputs

def get_batch_loss_metrics(
Expand Down
33 changes: 27 additions & 6 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def alignment_loss(
if self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
else:
raise ValueError(
Expand All @@ -86,6 +86,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
loss_type: str = "sigmoid",
simpo_gamma: float = 0.5,
):
Expand All @@ -97,6 +98,7 @@ def __init__(
ignore_index=ignore_index,
beta=beta,
loss_type=loss_type,
label_smoothing=label_smoothing,
simpo_gamma=simpo_gamma,
).get_batch_loss_metrics

Expand All @@ -114,13 +116,17 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.cpo_loss = LigerFusedLinearCPOLoss(
ignore_index=ignore_index, beta=beta, alpha=alpha
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
label_smoothing=label_smoothing,
)

def forward(self, x, y):
Expand All @@ -145,8 +151,21 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]
)
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ignore_index,
beta,
alpha,
label_smoothing,
):
B = 2 * B # cpo loss requires B to be even

Expand All @@ -157,6 +176,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)
liger_lm_head_cpo = LigerLMHeadCPO(
H=H,
Expand All @@ -165,6 +185,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)

torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(
Expand Down
Loading

0 comments on commit b6253b0

Please sign in to comment.