Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nn.module support for chunked loss function #402

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
42 changes: 41 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
Expand Down Expand Up @@ -46,10 +47,10 @@ def forward(
target,
bias,
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

Expand All @@ -59,3 +60,42 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
"""
Fused linear layer with CPO loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.alpha,
self.compute_nll_loss,
self.compiled,
)
39 changes: 38 additions & 1 deletion src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
Expand Down Expand Up @@ -43,9 +44,9 @@ def forward(
target=target,
bias=bias,
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

Expand All @@ -55,3 +56,39 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
"""
Fused linear layer with DPO loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
)
9 changes: 9 additions & 0 deletions src/liger_kernel/chunked_loss/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction

liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def forward(
bias=None,
loss_fn=None,
chunk_size=1,
compute_nll_loss=True,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
compiled=True,
**loss_kwargs,
):
Expand Down
40 changes: 38 additions & 2 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=False,
compiled=True,
):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
Expand All @@ -49,9 +49,9 @@ def forward(
target=target,
bias=bias,
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

Expand All @@ -61,3 +61,39 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearORPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
)
43 changes: 43 additions & 0 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
Expand Down Expand Up @@ -62,3 +63,45 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None, None


class LigerFusedLinearSimPOLoss(torch.nn.Module):
"""
Fused linear layer with SimPO loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearSimPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.alpha,
self.compute_nll_loss,
self.compiled,
self.gamma,
)
Loading
Loading