diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 929bdff34..11ae767f6 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -53,6 +53,64 @@ def chunk_forward( return student_logits_chunk, teacher_logits_chunk, ce_loss + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + temperature=1.0, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + compute_ce_loss (bool): Whether to compute CE loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + student_logits_chunk, teacher_logits_chunk, hard_loss = ( + LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + ) + + hard_loss /= full_target.shape[0] + + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss /= full_target.shape[0] + + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) + @staticmethod def forward( ctx, @@ -190,61 +248,3 @@ def backward(ctx, grad_output): grad_bias = grad_bias * grad_output if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias - - @staticmethod - def _compute_loss( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=None, - teacher_bias=None, - distillation_loss_fn=None, - full_target=None, - ignore_index=-100, - temperature=1.0, - weight_hard_loss=0.5, - weight_soft_loss=0.5, - compute_ce_loss=True, - **loss_kwargs, - ): - """ - Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. - Args: - distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). - student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). - teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). - teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). - student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). - ignore_index (int): Index to ignore for loss computation. - weight_hard_loss (float): Weight for hard loss. - weight_soft_loss (float): Weight for soft loss. - compute_ce_loss (bool): Whether to compute CE loss. - loss_kwargs (dict): Additional arguments for the loss function. - """ - student_logits_chunk, teacher_logits_chunk, hard_loss = ( - LigerFusedLinearDistillationBase.chunk_forward( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=student_bias, - teacher_bias=teacher_bias, - ignore_index=ignore_index, - compute_ce_loss=compute_ce_loss, - ) - ) - - hard_loss /= full_target.shape[0] - - soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) - soft_loss /= full_target.shape[0] - - loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss - return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c31cbba8b..26ae38a3d 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -64,6 +64,103 @@ def chunk_forward( chosen_nll_loss, ) + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + use_ref_model=False, + ref_weight=None, + ref_bias=None, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + + if use_ref_model: + with torch.no_grad(): + ( + ref_chosen_logps, + ref_rejected_logps, + ref_chosen_logits, + ref_rejected_logits, + ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + + preference_loss_outputs = preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs + ) + if isinstance(preference_loss_outputs, tuple): + preference_loss, *aux_outputs = preference_loss_outputs + else: + preference_loss, aux_outputs = preference_loss_outputs, [] + + loss = alpha * chosen_nll_loss - preference_loss + return_vars = ( + chosen_logps, + rejected_logps, + chosen_logits_mean, + rejected_logits_mean, + chosen_nll_loss, + ) + return loss, (*return_vars, *aux_outputs) + @staticmethod def forward( ctx, @@ -134,7 +231,7 @@ def forward( **loss_kwargs, ) - def accumulate_helper(input_chunk, target_chunk): + def accumulate_core(input_chunk, target_chunk): if bias is not None: return torch.func.grad_and_value( loss_func_to_call, argnums=(0, 1, 3), has_aux=True @@ -156,7 +253,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: (chunk_grad_input, chunk_grad_weight), ( @@ -169,7 +266,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_weight.add_(chunk_grad_weight) loss_acc.add_(chunk_loss) @@ -199,7 +296,7 @@ def accumulate_chunk(input_chunk, target_chunk): return chunk_grad_input if compiled: - accumulate_helper = torch.compile(accumulate_helper) + accumulate_core = torch.compile(accumulate_core) len_chosen = target.shape[0] // 2 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) @@ -270,100 +367,3 @@ def backward(ctx, *grad_output): grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias, None, None, None - - @staticmethod - def _compute_loss( - input_chunk, - weight, - target_chunk, - bias=None, - preference_loss_fn=None, - full_target=None, - ignore_index=-100, - alpha=1.0, - beta=0.1, - compute_nll_loss=True, - use_ref_model=False, - ref_weight=None, - ref_bias=None, - **loss_kwargs, - ): - """ - Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. - Args: - preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - alpha (float): Weight for the NLL loss. - beta (float): Weight for the odds ratio loss. - compute_nll_loss (bool): Whether to compute NLL loss. - use_ref_model (bool): Whether to use a reference model for the alignment loss. - ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). - ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). - loss_kwargs (dict): Additional arguments for the loss function. - """ - ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - weight, - target_chunk, - bias=bias, - ignore_index=ignore_index, - compute_nll_loss=compute_nll_loss, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - - if use_ref_model: - with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model - ) - loss_kwargs["ref_chosen_logps"] = ref_chosen_logps - loss_kwargs["ref_rejected_logps"] = ref_rejected_logps - - preference_loss_outputs = preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs - ) - if isinstance(preference_loss_outputs, tuple): - preference_loss, *aux_outputs = preference_loss_outputs - else: - preference_loss, aux_outputs = preference_loss_outputs, [] - - loss = alpha * chosen_nll_loss - preference_loss - return_vars = ( - chosen_logps, - rejected_logps, - chosen_logits_mean, - rejected_logits_mean, - chosen_nll_loss, - ) - return loss, (*return_vars, *aux_outputs)