Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add softcapping to preference based fused linear #437

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
softcap=None
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
if softcap is not None:
logits_chunk = logits_chunk / softcap
logits_chunk = torch.tanh(logits_chunk)
logits_chunk = logits_chunk * softcap
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
Expand Down Expand Up @@ -81,6 +86,7 @@ def forward(
use_ref_model=False,
ref_weight=None,
ref_bias=None,
softcap=None,
**loss_kwargs,
):
"""
Expand All @@ -103,6 +109,7 @@ def forward(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
Expand Down Expand Up @@ -286,6 +293,7 @@ def _compute_loss(
use_ref_model=False,
ref_weight=None,
ref_bias=None,
softcap=None,
**loss_kwargs,
):
"""
Expand All @@ -304,6 +312,7 @@ def _compute_loss(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
Expand All @@ -319,6 +328,7 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
softcap=softcap
)
chosen_nll_loss = (
chosen_nll_loss
Expand Down Expand Up @@ -346,6 +356,7 @@ def _compute_loss(
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
softcap=softcap
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
214 changes: 214 additions & 0 deletions test/transformers/test_fused_linear_preference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from test.utils import assert_verbose_allclose, set_seed
from typing import Optional

import pytest
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)
from liger_kernel.utils import infer_device

device = infer_device()

# set random seed globally
set_seed()


class TorchLMHeadPreference(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based preference loss.

:param H: hidden size
:param V: vocab size
:param bias: whether to use bias
:param beta: weight for the odds ratio loss
:param softcap: scaler for softcapping logits
"""

def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
softcap: Optional[float] = None,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ignore_index = ignore_index
self.beta = beta
self.softcap = softcap

def forward(self, x, target):
logits = self.lin(x).to(torch.float32)
if self.softcap is not None and self.softcap != 0.0:
logits = self.softcap * torch.tanh(logits / self.softcap)

log_probs = F.log_softmax(logits, dim=-1)

len_chosen = target.shape[0] // 2
loss_mask = target != self.ignore_index
label = torch.where(loss_mask, target, 0)

per_token_logps = log_probs.gather(-1, label.unsqueeze(-1)).squeeze(-1)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen]
rejected_logps = average_log_prob[len_chosen:]

# Simple preference loss
preference_loss = -self.beta * (chosen_logps - rejected_logps).mean()

return preference_loss


class LigerLMHeadPreference(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
softcap: Optional[float] = None,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ignore_index = ignore_index
self.beta = beta
self.softcap = softcap

def forward(self, x, target):
def simple_preference_loss(chosen_logps, rejected_logps, target, beta=0.1):
return -beta * (chosen_logps - rejected_logps).mean()

loss, *_ = LigerFusedLinearPreferenceBase.apply(
x,
self.lin.weight,
target,
self.lin.bias,
simple_preference_loss,
chunk_size=1,
ignore_index=self.ignore_index,
beta=self.beta,
compute_nll_loss=False,
compiled=True,
softcap=self.softcap,
)
return loss


@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
(4, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-3, 5e-2),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"ignore_index, beta, softcap",
[
(-100, 0.1, None),
(42, 0.2, 30.0), # Pass non-default values to ensure all params work
],
)
def test_correctness(
B,
T,
H,
V,
scalar,
dtype,
bias,
ignore_index,
beta,
softcap,
atol,
rtol,
):
torch_lm_head = TorchLMHeadPreference(
H=H,
V=V,
bias=bias,
ignore_index=ignore_index,
beta=beta,
softcap=softcap,
dtype=dtype,
).to(device)

liger_lm_head = LigerLMHeadPreference(
H=H,
V=V,
bias=bias,
ignore_index=ignore_index,
beta=beta,
softcap=softcap,
dtype=dtype,
).to(device)

# init the linear layers with the same weights
torch_lm_head.lin.weight.data = liger_lm_head.lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)

if bias:
torch_lm_head.lin.bias.data = liger_lm_head.lin.bias.data = torch.rand(
V, device=device, dtype=dtype
)

# Create input tensors
_tensor = torch.randn(B * T * 2, H, device=device, dtype=dtype) * scalar # *2 for chosen/rejected pairs
_input1 = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

# Create target tensor
target = torch.randint(0, V, (B * T * 2,), device=device, dtype=torch.long)

# Assign some random elements as ignore_index
num_elements_to_assign = torch.randint(1, B * T, (1,)).item()
indices_to_assign = torch.randperm(B * T * 2)[:num_elements_to_assign]
target[indices_to_assign] = ignore_index

# Forward pass
output1 = torch_lm_head(_input1, target)
output2 = liger_lm_head(_input2, target)

# Check outputs match
assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)

# Backward pass
output1.backward()
output2.backward()

# Check gradients match
assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(
torch_lm_head.lin.weight.grad,
liger_lm_head.lin.weight.grad,
atol=atol,
rtol=rtol,
)

if bias:
assert_verbose_allclose(
torch_lm_head.lin.bias.grad,
liger_lm_head.lin.bias.grad,
atol=atol,
rtol=rtol,
)
Loading