diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c95aa40ed..8412f20a4 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -1,7 +1,23 @@ +from abc import abstractmethod +from functools import partial + import torch +from torch.nn import functional as F class LigerFusedLinearPreferenceBase(torch.autograd.Function): + + @abstractmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute preference loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + """ + raise NotImplementedError("Preference loss function must be implemented.") + @staticmethod def forward( ctx, @@ -11,6 +27,9 @@ def forward( bias=None, loss_fn=None, chunk_size=1, + compute_nll_loss=True, + ignore_index=-100, + beta=0.1, compiled=True, ): """ @@ -24,6 +43,9 @@ def forward( bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). loss_fn (callable): Loss function to compute the loss on a chunk of input/target. chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + compute_nll_loss (bool): Whether to compute NLL loss. + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight for the odds ratio loss. compiled (bool): Whether to use torch compile for chunk accumulation. """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU @@ -36,13 +58,23 @@ def forward( loss_acc = torch.zeros((), device=_input.device) chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + loss_func_to_call = partial( + LigerFusedLinearPreferenceBase._compute_loss, + preference_loss_fn=loss_fn, + ignore_index=ignore_index, + beta=beta, + compute_nll_loss=compute_nll_loss, + full_target=target, + ) def accumulate_chunk(input_chunk, target_chunk): if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( chunk_loss, (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), - ) = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 3), has_aux=True)( + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1, 3), has_aux=True + )( input_chunk, weight, target_chunk, bias ) grad_bias.add_(chunk_grad_bias) @@ -50,7 +82,9 @@ def accumulate_chunk(input_chunk, target_chunk): (chunk_grad_input, chunk_grad_weight), ( chunk_loss, (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), - ) = torch.func.grad_and_value(loss_fn, argnums=(0, 1), has_aux=True)( + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1), has_aux=True + )( input_chunk, weight, target_chunk ) grad_weight.add_(chunk_grad_weight) @@ -105,3 +139,68 @@ def backward(ctx, grad_output): grad_bias = grad_bias * grad_output if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias, None, None, None + + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight for the odds ratio loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + len_chosen_chunk = target_chunk.shape[0] // 2 + + logits_chunk = input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + chosen_nll_loss = 0.0 + if compute_nll_loss: + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( + -1 + ) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + + alignment_loss = preference_loss_fn( + chosen_logps, rejected_logps, beta=beta, **loss_kwargs + ) + alignment_loss = alignment_loss / (full_target.shape[0] // 2) + + loss = chosen_nll_loss - alignment_loss + return loss, (alignment_loss, chosen_logps, rejected_logps) diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 1cd6fe21e..0ff146d5d 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -1,5 +1,3 @@ -from functools import partial - import torch import torch.nn.functional as F @@ -8,79 +6,24 @@ ) -def odds_ratio_loss(chosen_logps, rejected_logps, beta=0.1): - """ - Compute odds-ratio loss. - Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). - beta (float): Weight for the odds ratio loss. - """ - log_odds = (chosen_logps - rejected_logps) - ( - torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) - ) - ratio = F.logsigmoid(log_odds) - return beta * ratio.sum() - - -def _compute_orpo_loss( - input_chunk, - weight, - target_chunk, - bias=None, - full_target=None, - ignore_index=-100, - beta=0.1, - compute_nll_loss=True, -): - """ - Compute ORPO loss for a chunk of input and target. - Args: - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - beta (float): Weight for the odds ratio loss. - """ - len_chosen_chunk = target_chunk.shape[0] // 2 - - logits_chunk = input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) +class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): - chosen_nll_loss = 0.0 - if compute_nll_loss: - chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", - ignore_index=ignore_index, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute odds-ratio loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + """ + log_odds = (chosen_logps - rejected_logps) - ( + torch.log1p(-torch.exp(chosen_logps)) + - torch.log1p(-torch.exp(rejected_logps)) ) + ratio = F.logsigmoid(log_odds) + return beta * ratio.sum() - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - - or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta) - or_loss = or_loss / (full_target.shape[0] // 2) - - loss = chosen_nll_loss - or_loss - return loss, (or_loss, chosen_logps, rejected_logps) - - -class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def forward( ctx, @@ -98,15 +41,18 @@ def forward( Handles both the forward and backward pass of the final linear layer with ORPO loss. Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. """ - orpo_loss_fn = partial( - _compute_orpo_loss, - full_target=target, + + return LigerFusedLinearPreferenceBase.forward( + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, - compute_nll_loss=compute_nll_loss, - ) - return LigerFusedLinearPreferenceBase.forward( - ctx, _input, weight, target, bias, loss_fn=orpo_loss_fn + compiled=compiled, ) @staticmethod