Skip to content

Commit

Permalink
Merge branch 'main' into add-softcapping-to-preference-based
Browse files Browse the repository at this point in the history
  • Loading branch information
ryankert01 authored Dec 21, 2024
2 parents cb215b3 + 15a2f58 commit 6f460ec
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 36 deletions.
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@

<details>
<summary>Latest News 🔥</summary>

- [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)

- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
Expand All @@ -72,7 +73,7 @@

**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.

We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).

## Supercharge Your Model with Liger Kernel

Expand All @@ -89,6 +90,21 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
## Optimize Post Training with Liger Kernel

<p align="center">
<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/post-training.png" width="50%" alt="Post Training">
</p>

We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.

```python
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
orpo_loss = LigerFusedLinearORPOLoss()
y = orpo_loss(lm_head.weight, x, target)
```


## Examples

| **Use Case** | **Description** |
Expand Down
2 changes: 1 addition & 1 deletion dev/modal/tests_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv")

app = modal.App("liger_tests", image=image)
app = modal.App("liger_tests_bwd", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
Expand Down
Binary file added docs/images/post-training.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 14 additions & 2 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
):
"""
Paper: https://arxiv.org/pdf/2401.08417
Expand All @@ -32,9 +34,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss

@staticmethod
Expand All @@ -47,6 +54,7 @@ def forward(
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=True,
compiled=True,
softcap=None,
Expand All @@ -61,6 +69,7 @@ def forward(
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
softcap=softcap,
Expand All @@ -82,6 +91,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
softcap: Optional[float] = None,
Expand All @@ -96,6 +106,7 @@ def __init__(
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.softcap = softcap
Expand All @@ -109,6 +120,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
self.softcap,
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(
ref_bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compute_nll_loss=False,
compiled=True,
use_ref_model=True,
softcap=None,
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = False,
softcap: Optional[float] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _compute_loss(
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - preference_loss
loss = alpha * chosen_nll_loss + preference_loss
return_vars = (
chosen_logps,
rejected_logps,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
- torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)

chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
Expand Down
19 changes: 17 additions & 2 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
chosen_logps,
rejected_logps,
full_target,
beta=0.1,
gamma=0.5,
label_smoothing=0.0,
):
"""
Paper: https://arxiv.org/pdf/2405.14734
Expand All @@ -35,9 +40,14 @@ def preference_loss_fn(
full_target: Non chunked full target tensor
beta (float): beta weight
gamma (float): gemma margin term
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss

@staticmethod
Expand All @@ -50,6 +60,7 @@ def forward(
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=False,
compiled=True,
gamma=0.5,
Expand All @@ -66,6 +77,7 @@ def forward(
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compiled=compiled,
gamma=gamma,
softcap=softcap,
Expand All @@ -87,6 +99,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
Expand All @@ -102,6 +115,7 @@ def __init__(
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma
Expand All @@ -116,6 +130,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
self.gamma,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/transformers/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def lce_forward_deprecated(
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
Args:
Expand Down
6 changes: 5 additions & 1 deletion src/liger_kernel/transformers/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class _FSDPForwardRedirection:
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
Expand Down Expand Up @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
outputs.last_hidden_state,
concatenated_batch["concatenated_labels"],
)
# if aux_loss_enabled, add the aux_loss to the orpo_loss
if self.aux_loss_enabled:
orpo_loss += self.aux_loss_coef * outputs.aux_loss

return orpo_loss, aux_outputs

def get_batch_loss_metrics(
Expand Down
33 changes: 27 additions & 6 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def alignment_loss(
if self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
else:
raise ValueError(
Expand All @@ -86,6 +86,7 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
loss_type: str = "sigmoid",
simpo_gamma: float = 0.5,
):
Expand All @@ -97,6 +98,7 @@ def __init__(
ignore_index=ignore_index,
beta=beta,
loss_type=loss_type,
label_smoothing=label_smoothing,
simpo_gamma=simpo_gamma,
).get_batch_loss_metrics

Expand All @@ -114,13 +116,17 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.cpo_loss = LigerFusedLinearCPOLoss(
ignore_index=ignore_index, beta=beta, alpha=alpha
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
label_smoothing=label_smoothing,
)

def forward(self, x, y):
Expand All @@ -145,8 +151,21 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]
)
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ignore_index,
beta,
alpha,
label_smoothing,
):
B = 2 * B # cpo loss requires B to be even

Expand All @@ -157,6 +176,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)
liger_lm_head_cpo = LigerLMHeadCPO(
H=H,
Expand All @@ -165,6 +185,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)

torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(
Expand Down
Loading

0 comments on commit 6f460ec

Please sign in to comment.