diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py new file mode 100644 index 000000000..11ae767f6 --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -0,0 +1,250 @@ +from abc import abstractmethod +from functools import partial + +import torch +from torch.nn import functional as F + + +class LigerFusedLinearDistillationBase(torch.autograd.Function): + + @abstractmethod + def distillation_loss_fn(student_logits, teacher_logits, temperature): + """ + Compute distillation loss. + Args: + student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size). + teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + """ + raise NotImplementedError("Distillation loss function must be implemented.") + + @staticmethod + def chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + ignore_index=-100, + compute_ce_loss=True, + ): + # Student + student_logits_chunk = student_input_chunk @ student_weight.t() + if student_bias is not None: + student_logits_chunk += student_bias + student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) + + # Teacher + with torch.no_grad(): + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk += teacher_bias + + # The hard/task loss + ce_loss = 0.0 + if compute_ce_loss: + ce_loss = F.nll_loss( + student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + return student_logits_chunk, teacher_logits_chunk, ce_loss + + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + temperature=1.0, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + compute_ce_loss (bool): Whether to compute CE loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + student_logits_chunk, teacher_logits_chunk, hard_loss = ( + LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + ) + + hard_loss /= full_target.shape[0] + + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss /= full_target.shape[0] + + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) + + @staticmethod + def forward( + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias=None, + teacher_bias=None, + loss_fn=None, + chunk_size=1024, + ignore_index=-100, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + temperature=1.0, + compiled=True, + **loss_kwargs, + ): + """ + Base class for fused linear layer with distillation loss. + Only need to compute gradients for student model. + + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size). + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size). + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size). + target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len). + student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk. + compute_ce_loss (bool): Whether to compute CE loss. + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard/task loss. + weight_soft_loss (float): Weight for soft/distillation loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + CHUNK_SIZE = chunk_size + grad_weight = torch.zeros_like(student_weight) + grad_inputs = [] + grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None + loss_acc = torch.zeros((), device=student_input.device) + + loss_func_to_call = partial( + LigerFusedLinearDistillationBase._compute_loss, + distillation_loss_fn=loss_fn, + full_target=target, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + **loss_kwargs, + ) + + def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): + if student_bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1, 5), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + + num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) + _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) + _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + + for student_input_chunk, teacher_input_chunk, target_chunk in zip( + _student_input_chunks, _teacher_input_chunks, _target_chunks + ): + grad_input = accumulate_chunk( + student_input_chunk, teacher_input_chunk, target_chunk + ) + grad_inputs.append(grad_input) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + grad_bias = grad_bias * grad_output if grad_bias is not None else None + + return grad_input, grad_weight, None, grad_bias diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c31cbba8b..26ae38a3d 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -64,6 +64,103 @@ def chunk_forward( chosen_nll_loss, ) + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + use_ref_model=False, + ref_weight=None, + ref_bias=None, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + + if use_ref_model: + with torch.no_grad(): + ( + ref_chosen_logps, + ref_rejected_logps, + ref_chosen_logits, + ref_rejected_logits, + ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + + preference_loss_outputs = preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs + ) + if isinstance(preference_loss_outputs, tuple): + preference_loss, *aux_outputs = preference_loss_outputs + else: + preference_loss, aux_outputs = preference_loss_outputs, [] + + loss = alpha * chosen_nll_loss - preference_loss + return_vars = ( + chosen_logps, + rejected_logps, + chosen_logits_mean, + rejected_logits_mean, + chosen_nll_loss, + ) + return loss, (*return_vars, *aux_outputs) + @staticmethod def forward( ctx, @@ -134,7 +231,7 @@ def forward( **loss_kwargs, ) - def accumulate_helper(input_chunk, target_chunk): + def accumulate_core(input_chunk, target_chunk): if bias is not None: return torch.func.grad_and_value( loss_func_to_call, argnums=(0, 1, 3), has_aux=True @@ -156,7 +253,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: (chunk_grad_input, chunk_grad_weight), ( @@ -169,7 +266,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_weight.add_(chunk_grad_weight) loss_acc.add_(chunk_loss) @@ -199,7 +296,7 @@ def accumulate_chunk(input_chunk, target_chunk): return chunk_grad_input if compiled: - accumulate_helper = torch.compile(accumulate_helper) + accumulate_core = torch.compile(accumulate_core) len_chosen = target.shape[0] // 2 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) @@ -270,100 +367,3 @@ def backward(ctx, *grad_output): grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias, None, None, None - - @staticmethod - def _compute_loss( - input_chunk, - weight, - target_chunk, - bias=None, - preference_loss_fn=None, - full_target=None, - ignore_index=-100, - alpha=1.0, - beta=0.1, - compute_nll_loss=True, - use_ref_model=False, - ref_weight=None, - ref_bias=None, - **loss_kwargs, - ): - """ - Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. - Args: - preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - alpha (float): Weight for the NLL loss. - beta (float): Weight for the odds ratio loss. - compute_nll_loss (bool): Whether to compute NLL loss. - use_ref_model (bool): Whether to use a reference model for the alignment loss. - ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). - ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). - loss_kwargs (dict): Additional arguments for the loss function. - """ - ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - weight, - target_chunk, - bias=bias, - ignore_index=ignore_index, - compute_nll_loss=compute_nll_loss, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - - if use_ref_model: - with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model - ) - loss_kwargs["ref_chosen_logps"] = ref_chosen_logps - loss_kwargs["ref_rejected_logps"] = ref_rejected_logps - - preference_loss_outputs = preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs - ) - if isinstance(preference_loss_outputs, tuple): - preference_loss, *aux_outputs = preference_loss_outputs - else: - preference_loss, aux_outputs = preference_loss_outputs, [] - - loss = alpha * chosen_nll_loss - preference_loss - return_vars = ( - chosen_logps, - rejected_logps, - chosen_logits_mean, - rejected_logits_mean, - chosen_nll_loss, - ) - return loss, (*return_vars, *aux_outputs) diff --git a/test/utils.py b/test/utils.py index 584b6b9d6..4c2fd5195 100644 --- a/test/utils.py +++ b/test/utils.py @@ -520,3 +520,113 @@ def get_batch_loss_metrics( policy_nll_loss, ) return loss, (*return_vars, *aggregated_aux_outputs) + + +class HFDistillationLoss: + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1, + ): + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + + @abstractmethod + def distillation_loss(self, student_logits, teacher_logits): + """Abstract method for computing distillation loss.""" + pass + + def concatenated_forward( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + """Compute forward pass for both student and teacher models.""" + + student_batch_seq_len_size, student_hidden_size = student_input.shape + student_input_reshaped = student_input.view(-1, student_hidden_size) + teacher_batch_seq_len_size, teacher_hidden_size = teacher_input.shape + teacher_input_reshaped = teacher_input.view(-1, teacher_hidden_size) + + student_outputs = student_input_reshaped @ student_weight.t() + if student_bias is not None: + student_outputs = student_outputs + student_bias + + with torch.no_grad(): + teacher_outputs = teacher_input_reshaped @ teacher_weight.t() + if teacher_bias is not None: + teacher_outputs = teacher_outputs + teacher_bias + + student_logits = student_outputs.view(student_batch_seq_len_size, -1).float() + teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float() + + if torch.all(target == self.ignore_index): + return torch.tensor(0.0) + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + ce_loss = cross_entropy_loss( + student_logits.view(-1, student_logits.shape[-1]), + labels.view(-1), + ) + + return ( + student_logits, + teacher_logits, + ce_loss, + ) + + def get_batch_loss_metrics( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + ): + """Compute the distillation loss metrics for the given batch.""" + forward_output = self.concatenated_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias, + teacher_bias, + ) + ( + student_logits, + teacher_logits, + hard_loss, + ) = forward_output + + soft_loss = self.distillation_loss(student_logits, teacher_logits) + # full loss + loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() + return loss