Skip to content

Commit

Permalink
Refactor LigerFusedLinearPreferenceBase (#381)
Browse files Browse the repository at this point in the history
## Summary
This PR refactors the `LigerFusedLinearPreferenceBase` class to contain
an abstractmethod corresponding to the calculation of the loss that
needs to be implemented by all sub-classes.

It also adds a new function to the class called `_compute_loss` which is
mostly the same as the `_compute_orpo_loss` function introduced in #362
but makes it generic to calculate the NLL/Cross Entropy Loss plus
accepts a custom loss function that implements a new alignment loss
function.

Most RLHF/RLAIF/Alignment algorithms state their final loss as `NLL +
Beta * (Alignment_Loss) `so adding the NLL logic inside the base class
reduces repeated code.

The _compute_loss function accepts

## Testing Done

On A100-80G-SXM


- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Co-authored-by: pramodith <[email protected]>
  • Loading branch information
pramodith and pramodith authored Nov 14, 2024
1 parent 6b2fd02 commit 2281b7e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 81 deletions.
103 changes: 101 additions & 2 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from abc import abstractmethod
from functools import partial

import torch
from torch.nn import functional as F


class LigerFusedLinearPreferenceBase(torch.autograd.Function):

@abstractmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
"""
Compute preference loss.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
"""
raise NotImplementedError("Preference loss function must be implemented.")

@staticmethod
def forward(
ctx,
Expand All @@ -11,6 +27,9 @@ def forward(
bias=None,
loss_fn=None,
chunk_size=1,
compute_nll_loss=True,
ignore_index=-100,
beta=0.1,
compiled=True,
):
"""
Expand All @@ -24,6 +43,9 @@ def forward(
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
compute_nll_loss (bool): Whether to compute NLL loss.
ignore_index (int): Index to ignore for loss computation.
beta (float): Weight for the odds ratio loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
Expand All @@ -36,21 +58,33 @@ def forward(
loss_acc = torch.zeros((), device=_input.device)

chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
loss_func_to_call = partial(
LigerFusedLinearPreferenceBase._compute_loss,
preference_loss_fn=loss_fn,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
full_target=target,
)

def accumulate_chunk(input_chunk, target_chunk):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
) = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 3), has_aux=True)(
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
)(
input_chunk, weight, target_chunk, bias
)
grad_bias.add_(chunk_grad_bias)
else:
(chunk_grad_input, chunk_grad_weight), (
chunk_loss,
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
) = torch.func.grad_and_value(loss_fn, argnums=(0, 1), has_aux=True)(
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(
input_chunk, weight, target_chunk
)
grad_weight.add_(chunk_grad_weight)
Expand Down Expand Up @@ -105,3 +139,68 @@ def backward(ctx, grad_output):
grad_bias = grad_bias * grad_output if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias, None, None, None

@staticmethod
def _compute_loss(
input_chunk,
weight,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
beta (float): Weight for the odds ratio loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
len_chosen_chunk = target_chunk.shape[0] // 2

logits_chunk = input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

alignment_loss = preference_loss_fn(
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
)
alignment_loss = alignment_loss / (full_target.shape[0] // 2)

loss = chosen_nll_loss - alignment_loss
return loss, (alignment_loss, chosen_logps, rejected_logps)
104 changes: 25 additions & 79 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import partial

import torch
import torch.nn.functional as F

Expand All @@ -8,79 +6,24 @@
)


def odds_ratio_loss(chosen_logps, rejected_logps, beta=0.1):
"""
Compute odds-ratio loss.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
return beta * ratio.sum()


def _compute_orpo_loss(
input_chunk,
weight,
target_chunk,
bias=None,
full_target=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
):
"""
Compute ORPO loss for a chunk of input and target.
Args:
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
beta (float): Weight for the odds ratio loss.
"""
len_chosen_chunk = target_chunk.shape[0] // 2

logits_chunk = input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
"""
Compute odds-ratio loss.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps))
- torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
return beta * ratio.sum()

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta)
or_loss = or_loss / (full_target.shape[0] // 2)

loss = chosen_nll_loss - or_loss
return loss, (or_loss, chosen_logps, rejected_logps)


class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def forward(
ctx,
Expand All @@ -98,15 +41,18 @@ def forward(
Handles both the forward and backward pass of the final linear layer with ORPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""
orpo_loss_fn = partial(
_compute_orpo_loss,
full_target=target,

return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
)
return LigerFusedLinearPreferenceBase.forward(
ctx, _input, weight, target, bias, loss_fn=orpo_loss_fn
compiled=compiled,
)

@staticmethod
Expand Down

0 comments on commit 2281b7e

Please sign in to comment.