From 0769e97416dd8ba19823bd5a732439c05822eaf6 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 2 Dec 2024 15:46:19 +0800 Subject: [PATCH 01/11] Add liger and naive distill base Signed-off-by: Austin Liu --- src/liger_kernel/chunked_loss/__init__.py | 1 + src/liger_kernel/chunked_loss/functional.py | 2 + .../chunked_loss/fused_linear_distillation.py | 250 ++++++++++++++++++ test/utils.py | 110 ++++++++ 4 files changed, 363 insertions(+) create mode 100644 src/liger_kernel/chunked_loss/fused_linear_distillation.py diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 238bdded9..5dcd2fbf3 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,4 +1,5 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py index 5a51d3f72..6bb742291 100644 --- a/src/liger_kernel/chunked_loss/functional.py +++ b/src/liger_kernel/chunked_loss/functional.py @@ -1,9 +1,11 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py new file mode 100644 index 000000000..87a0896df --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -0,0 +1,250 @@ +from abc import abstractmethod +from functools import partial + +import torch +from torch.nn import functional as F + + +class LigerFusedLinearDistillationBase(torch.autograd.Function): + + @abstractmethod + def distillation_loss_fn(student_logits, teacher_logits, temperature): + """ + Compute distillation loss. + Args: + student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size). + teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + """ + raise NotImplementedError("Distillation loss function must be implemented.") + + @staticmethod + def chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + ignore_index=-100, + compute_ce_loss=True, + ): + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + # Student + student_per_token_logits_chunk = student_input_chunk @ student_weight.t() + if student_bias is not None: + student_per_token_logits_chunk += student_bias + student_per_token_log_probs_chunk = F.log_softmax(student_per_token_logits_chunk.float(), dim=-1) + + # Teacher + teacher_per_token_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_per_token_logits_chunk += teacher_bias + + # The hard/task loss + ce_loss = 0.0 + if compute_ce_loss: + ce_loss = F.cross_entropy( + student_per_token_log_probs_chunk.view(-1, student_per_token_log_probs_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + return student_per_token_logits_chunk, teacher_per_token_logits_chunk, ce_loss + + @staticmethod + def forward( + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias=None, + teacher_bias=None, + loss_fn=None, + chunk_size=1, + ignore_index=-100, + beta=0.5, + compute_ce_loss=True, + temperature=1.0, + compiled=True, + **loss_kwargs, + ): + """ + Base class for fused linear layer with distillation loss. + Only need to compute gradients for student model. + + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size). + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size). + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size). + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len). + student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + compute_ce_loss (bool): Whether to compute CE loss. + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight between soft and hard loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + CHUNK_SIZE = chunk_size + + grad_weight = torch.zeros_like(student_weight) + grad_inputs = [] + grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None + loss_acc = torch.zeros((), device=student_input.device) + + loss_func_to_call = partial( + LigerFusedLinearDistillationBase._compute_loss, + distillation_loss_fn=loss_fn, + ignore_index=ignore_index, + beta=beta, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + full_target=target, + **loss_kwargs, + ) + + def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): + if student_bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1, 5), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + + num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) + _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) + _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + + for student_input_chunk, teacher_input_chunk, target_chunk in zip( + _student_input_chunks, _teacher_input_chunks, _target_chunks + ): + grad_input = accumulate_chunk( + student_input_chunk, teacher_input_chunk, target_chunk + ) + grad_inputs.append(grad_input) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + grad_bias = grad_bias * grad_output if grad_bias is not None else None + + return grad_input, grad_weight, None, grad_bias, None, None, None + + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + temperature=1.0, + beta=0.5, + compute_ce_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight between soft and hard loss. + compute_ce_loss (bool): Whether to compute CE loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + student_logits, teacher_logits, hard_loss = ( + LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + ) + + hard_loss = hard_loss / (full_target != ignore_index).sum() + + soft_loss = distillation_loss_fn( + student_logits, teacher_logits, temperature + ) + soft_loss = soft_loss / (full_target != ignore_index).sum() + + loss = beta * hard_loss + (1 - beta) * soft_loss + return loss, (soft_loss, hard_loss, student_logits, teacher_logits) diff --git a/test/utils.py b/test/utils.py index 584b6b9d6..de94039d5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -520,3 +520,113 @@ def get_batch_loss_metrics( policy_nll_loss, ) return loss, (*return_vars, *aggregated_aux_outputs) + + +class HFDistillationLoss: + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1, + ): + self.beta = beta + self.ignore_index = ignore_index + self.temperature = temperature + + @abstractmethod + def distillation_loss(self, student_logps, teacher_logps): + """Abstract method for computing distillation loss.""" + pass + + def concatenated_forward( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + average_log_prob: bool = True, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + """Compute forward pass for both student and teacher models.""" + + student_batch_seq_len_size, student_hidden_size = student_input.shape + student_input_reshaped = student_input.view(-1, student_hidden_size) + teacher_batch_seq_len_size, teacher_hidden_size = teacher_input.shape + teacher_input_reshaped = teacher_input.view(-1, teacher_hidden_size) + + student_outputs = student_input_reshaped @ student_weight.t() + if student_bias is not None: + student_outputs = student_outputs + student_bias + + teacher_outputs = teacher_input_reshaped @ teacher_weight.t() + if teacher_bias is not None: + teacher_outputs = teacher_outputs + teacher_bias + + student_logits = student_outputs.view(student_batch_seq_len_size, -1).float() + teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float() + + if torch.all(target == self.ignore_index): + return torch.tensor(0.0) + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + ce_loss = cross_entropy_loss( + student_logits.view(-1, student_logits.shape[-1]), + labels.view(-1), + ) + + return ( + student_logits, + teacher_logits, + ce_loss, + ) + + def get_batch_loss_metrics( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + average_log_prob: bool = True, + ): + """Compute the distillation loss metrics for the given batch.""" + forward_output = self.concatenated_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias, + teacher_bias, + average_log_prob, + ) + ( + student_logits, + teacher_logits, + hard_loss, + ) = forward_output + + soft_loss = self.distillation_loss(student_logits, teacher_logits) + # full loss + loss = hard_loss * (self.beta) + soft_loss.mean() * (1 - self.beta) + return loss From a81c95988569695781a66717bbdfb24ebfce1d98 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 2 Dec 2024 16:28:39 +0800 Subject: [PATCH 02/11] Format Signed-off-by: Austin Liu --- .../chunked_loss/fused_linear_distillation.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 87a0896df..937b1e0ec 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -29,14 +29,13 @@ def chunk_forward( ignore_index=-100, compute_ce_loss=True, ): - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - # Student student_per_token_logits_chunk = student_input_chunk @ student_weight.t() if student_bias is not None: student_per_token_logits_chunk += student_bias - student_per_token_log_probs_chunk = F.log_softmax(student_per_token_logits_chunk.float(), dim=-1) + student_per_token_log_probs_chunk = F.log_softmax( + student_per_token_logits_chunk.float(), dim=-1 + ) # Teacher teacher_per_token_logits_chunk = teacher_input_chunk @ teacher_weight.t() @@ -47,7 +46,9 @@ def chunk_forward( ce_loss = 0.0 if compute_ce_loss: ce_loss = F.cross_entropy( - student_per_token_log_probs_chunk.view(-1, student_per_token_log_probs_chunk.shape[-1]), + student_per_token_log_probs_chunk.view( + -1, student_per_token_log_probs_chunk.shape[-1] + ), target_chunk.view(-1), reduction="sum", ignore_index=ignore_index, @@ -241,9 +242,7 @@ def _compute_loss( hard_loss = hard_loss / (full_target != ignore_index).sum() - soft_loss = distillation_loss_fn( - student_logits, teacher_logits, temperature - ) + soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature) soft_loss = soft_loss / (full_target != ignore_index).sum() loss = beta * hard_loss + (1 - beta) * soft_loss From e13994a9972ef1a481601c5895c93a7a1ba4c4b2 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 2 Dec 2024 17:11:28 +0800 Subject: [PATCH 03/11] Refactor beta Signed-off-by: Austin Liu --- .../chunked_loss/fused_linear_distillation.py | 19 ++++++++++++------- test/utils.py | 11 +++++------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 937b1e0ec..cd6a84c7e 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -69,7 +69,8 @@ def forward( loss_fn=None, chunk_size=1, ignore_index=-100, - beta=0.5, + weight_hard_loss=0.5, + weight_soft_loss=0.5, compute_ce_loss=True, temperature=1.0, compiled=True, @@ -91,7 +92,8 @@ def forward( chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). compute_ce_loss (bool): Whether to compute CE loss. ignore_index (int): Index to ignore for loss computation. - beta (float): Weight between soft and hard loss. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. compiled (bool): Whether to use torch compile for chunk accumulation. loss_kwargs (dict): Other possible arguments that a loss function might need """ @@ -106,7 +108,8 @@ def forward( LigerFusedLinearDistillationBase._compute_loss, distillation_loss_fn=loss_fn, ignore_index=ignore_index, - beta=beta, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, compute_ce_loss=compute_ce_loss, temperature=temperature, full_target=target, @@ -190,7 +193,7 @@ def backward(ctx, grad_output): grad_weight = grad_weight * grad_output grad_bias = grad_bias * grad_output if grad_bias is not None else None - return grad_input, grad_weight, None, grad_bias, None, None, None + return grad_input, grad_weight, None, grad_bias @staticmethod def _compute_loss( @@ -205,7 +208,8 @@ def _compute_loss( full_target=None, ignore_index=-100, temperature=1.0, - beta=0.5, + weight_hard_loss=0.5, + weight_soft_loss=0.5, compute_ce_loss=True, **loss_kwargs, ): @@ -222,7 +226,8 @@ def _compute_loss( teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). ignore_index (int): Index to ignore for loss computation. - beta (float): Weight between soft and hard loss. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. compute_ce_loss (bool): Whether to compute CE loss. loss_kwargs (dict): Additional arguments for the loss function. """ @@ -245,5 +250,5 @@ def _compute_loss( soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature) soft_loss = soft_loss / (full_target != ignore_index).sum() - loss = beta * hard_loss + (1 - beta) * soft_loss + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss return loss, (soft_loss, hard_loss, student_logits, teacher_logits) diff --git a/test/utils.py b/test/utils.py index de94039d5..fc5ba4944 100644 --- a/test/utils.py +++ b/test/utils.py @@ -525,11 +525,13 @@ def get_batch_loss_metrics( class HFDistillationLoss: def __init__( self, - beta: float = 0.5, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, ignore_index: int = -100, temperature: float = 1, ): - self.beta = beta + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss self.ignore_index = ignore_index self.temperature = temperature @@ -547,7 +549,6 @@ def concatenated_forward( target: torch.LongTensor, student_bias: torch.FloatTensor = None, teacher_bias: torch.FloatTensor = None, - average_log_prob: bool = True, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, @@ -607,7 +608,6 @@ def get_batch_loss_metrics( target: torch.LongTensor, student_bias: torch.FloatTensor = None, teacher_bias: torch.FloatTensor = None, - average_log_prob: bool = True, ): """Compute the distillation loss metrics for the given batch.""" forward_output = self.concatenated_forward( @@ -618,7 +618,6 @@ def get_batch_loss_metrics( target, student_bias, teacher_bias, - average_log_prob, ) ( student_logits, @@ -628,5 +627,5 @@ def get_batch_loss_metrics( soft_loss = self.distillation_loss(student_logits, teacher_logits) # full loss - loss = hard_loss * (self.beta) + soft_loss.mean() * (1 - self.beta) + loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() return loss From 720b5cbeac32811f5f43087707e0fefc4c49a10a Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 2 Dec 2024 22:40:55 +0800 Subject: [PATCH 04/11] Remove imports Signed-off-by: Austin Liu --- src/liger_kernel/chunked_loss/__init__.py | 1 - src/liger_kernel/chunked_loss/functional.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 5dcd2fbf3..238bdded9 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,5 +1,4 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 -from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py index 6bb742291..5a51d3f72 100644 --- a/src/liger_kernel/chunked_loss/functional.py +++ b/src/liger_kernel/chunked_loss/functional.py @@ -1,11 +1,9 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction -from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply -liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply From 17c5b333747173a88f9618c5de3487e4c7b10241 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 4 Dec 2024 22:56:27 +0800 Subject: [PATCH 05/11] Fix distill base `chunk_size` scaling Signed-off-by: Austin Liu Set default `chunk_size` to `1024` Signed-off-by: Austin Liu Rebase Signed-off-by: Austin Liu --- .../chunked_loss/fused_linear_distillation.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index cd6a84c7e..46ef2cba9 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -30,31 +30,27 @@ def chunk_forward( compute_ce_loss=True, ): # Student - student_per_token_logits_chunk = student_input_chunk @ student_weight.t() + student_logits_chunk = student_input_chunk @ student_weight.t() if student_bias is not None: - student_per_token_logits_chunk += student_bias - student_per_token_log_probs_chunk = F.log_softmax( - student_per_token_logits_chunk.float(), dim=-1 - ) + student_logits_chunk += student_bias + student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) # Teacher - teacher_per_token_logits_chunk = teacher_input_chunk @ teacher_weight.t() + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() if teacher_bias is not None: - teacher_per_token_logits_chunk += teacher_bias + teacher_logits_chunk += teacher_bias # The hard/task loss ce_loss = 0.0 if compute_ce_loss: - ce_loss = F.cross_entropy( - student_per_token_log_probs_chunk.view( - -1, student_per_token_log_probs_chunk.shape[-1] - ), + ce_loss = F.nll_loss( + student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), target_chunk.view(-1), reduction="sum", ignore_index=ignore_index, ) - return student_per_token_logits_chunk, teacher_per_token_logits_chunk, ce_loss + return student_logits_chunk, teacher_logits_chunk, ce_loss @staticmethod def forward( @@ -81,24 +77,23 @@ def forward( Only need to compute gradients for student model. Args: - student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size). - student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size). - teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size). - teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size). + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size). + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size). + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size). target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len). student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + chunk_size (int): Size of a chunk. compute_ce_loss (bool): Whether to compute CE loss. ignore_index (int): Index to ignore for loss computation. - weight_hard_loss (float): Weight for hard loss. - weight_soft_loss (float): Weight for soft loss. + weight_hard_loss (float): Weight for hard/task loss. + weight_soft_loss (float): Weight for soft/distillation loss. compiled (bool): Whether to use torch compile for chunk accumulation. loss_kwargs (dict): Other possible arguments that a loss function might need """ CHUNK_SIZE = chunk_size - grad_weight = torch.zeros_like(student_weight) grad_inputs = [] grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None @@ -107,12 +102,13 @@ def forward( loss_func_to_call = partial( LigerFusedLinearDistillationBase._compute_loss, distillation_loss_fn=loss_fn, + full_target=target, + chunk_size=chunk_size, ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, compute_ce_loss=compute_ce_loss, temperature=temperature, - full_target=target, **loss_kwargs, ) @@ -206,6 +202,7 @@ def _compute_loss( teacher_bias=None, distillation_loss_fn=None, full_target=None, + chunk_size=1, ignore_index=-100, temperature=1.0, weight_hard_loss=0.5, @@ -217,10 +214,10 @@ def _compute_loss( Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. Args: distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, hidden_size). - student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, hidden_size). - teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). @@ -245,10 +242,12 @@ def _compute_loss( ) ) - hard_loss = hard_loss / (full_target != ignore_index).sum() + hard_loss /= (full_target != ignore_index).sum() soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature) - soft_loss = soft_loss / (full_target != ignore_index).sum() + soft_loss /= max( + 1, (full_target[: full_target.shape[0]] != ignore_index).sum() // chunk_size + ) loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss return loss, (soft_loss, hard_loss, student_logits, teacher_logits) From e3dada062f5d68be278158f4cef621ea1767d690 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 4 Dec 2024 23:41:39 +0800 Subject: [PATCH 06/11] Fix chunk division Signed-off-by: Austin Liu --- .../chunked_loss/fused_linear_distillation.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 46ef2cba9..9954960cd 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -228,7 +228,7 @@ def _compute_loss( compute_ce_loss (bool): Whether to compute CE loss. loss_kwargs (dict): Additional arguments for the loss function. """ - student_logits, teacher_logits, hard_loss = ( + student_logits_chunk, teacher_logits_chunk, hard_loss = ( LigerFusedLinearDistillationBase.chunk_forward( student_input_chunk, student_weight, @@ -242,12 +242,9 @@ def _compute_loss( ) ) - hard_loss /= (full_target != ignore_index).sum() - - soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature) - soft_loss /= max( - 1, (full_target[: full_target.shape[0]] != ignore_index).sum() // chunk_size - ) + hard_loss /= full_target.shape[0] + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0]) loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss - return loss, (soft_loss, hard_loss, student_logits, teacher_logits) + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) From 5662554942946ef43b1c1470cb72d30d4fea4fcd Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 4 Dec 2024 23:42:52 +0800 Subject: [PATCH 07/11] Remove chunk arg Signed-off-by: Austin Liu --- src/liger_kernel/chunked_loss/fused_linear_distillation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 9954960cd..41e7d82be 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -103,7 +103,6 @@ def forward( LigerFusedLinearDistillationBase._compute_loss, distillation_loss_fn=loss_fn, full_target=target, - chunk_size=chunk_size, ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, @@ -202,7 +201,6 @@ def _compute_loss( teacher_bias=None, distillation_loss_fn=None, full_target=None, - chunk_size=1, ignore_index=-100, temperature=1.0, weight_hard_loss=0.5, From 7acb5ca8156b74ba7b68a856de75d7136dd6b755 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 5 Dec 2024 23:08:00 +0800 Subject: [PATCH 08/11] Fix `distillation_loss` arg typo Signed-off-by: Austin Liu --- test/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.py b/test/utils.py index fc5ba4944..69f8a8c40 100644 --- a/test/utils.py +++ b/test/utils.py @@ -536,7 +536,7 @@ def __init__( self.temperature = temperature @abstractmethod - def distillation_loss(self, student_logps, teacher_logps): + def distillation_loss(self, student_logits, teacher_logits): """Abstract method for computing distillation loss.""" pass From e381569d6c88283490887d76de70a824bb463b0c Mon Sep 17 00:00:00 2001 From: shivam15s Date: Sat, 7 Dec 2024 00:01:20 +0000 Subject: [PATCH 09/11] use torch no grad and change normalization term --- .../chunked_loss/fused_linear_distillation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 41e7d82be..929bdff34 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -36,9 +36,10 @@ def chunk_forward( student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) # Teacher - teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() - if teacher_bias is not None: - teacher_logits_chunk += teacher_bias + with torch.no_grad(): + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk += teacher_bias # The hard/task loss ce_loss = 0.0 @@ -63,7 +64,7 @@ def forward( student_bias=None, teacher_bias=None, loss_fn=None, - chunk_size=1, + chunk_size=1024, ignore_index=-100, weight_hard_loss=0.5, weight_soft_loss=0.5, @@ -243,6 +244,7 @@ def _compute_loss( hard_loss /= full_target.shape[0] soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) - soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0]) + soft_loss /= full_target.shape[0] + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) From 8aa842a442ff0e013df09eb562257f4343dc50db Mon Sep 17 00:00:00 2001 From: shivam15s Date: Sat, 7 Dec 2024 00:33:25 +0000 Subject: [PATCH 10/11] rearrange fns for readability --- .../chunked_loss/fused_linear_distillation.py | 116 +++++----- .../chunked_loss/fused_linear_preference.py | 202 +++++++++--------- 2 files changed, 159 insertions(+), 159 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 929bdff34..11ae767f6 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -53,6 +53,64 @@ def chunk_forward( return student_logits_chunk, teacher_logits_chunk, ce_loss + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + temperature=1.0, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + compute_ce_loss (bool): Whether to compute CE loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + student_logits_chunk, teacher_logits_chunk, hard_loss = ( + LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + ) + + hard_loss /= full_target.shape[0] + + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss /= full_target.shape[0] + + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) + @staticmethod def forward( ctx, @@ -190,61 +248,3 @@ def backward(ctx, grad_output): grad_bias = grad_bias * grad_output if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias - - @staticmethod - def _compute_loss( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=None, - teacher_bias=None, - distillation_loss_fn=None, - full_target=None, - ignore_index=-100, - temperature=1.0, - weight_hard_loss=0.5, - weight_soft_loss=0.5, - compute_ce_loss=True, - **loss_kwargs, - ): - """ - Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. - Args: - distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). - student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). - teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). - teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). - student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). - ignore_index (int): Index to ignore for loss computation. - weight_hard_loss (float): Weight for hard loss. - weight_soft_loss (float): Weight for soft loss. - compute_ce_loss (bool): Whether to compute CE loss. - loss_kwargs (dict): Additional arguments for the loss function. - """ - student_logits_chunk, teacher_logits_chunk, hard_loss = ( - LigerFusedLinearDistillationBase.chunk_forward( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=student_bias, - teacher_bias=teacher_bias, - ignore_index=ignore_index, - compute_ce_loss=compute_ce_loss, - ) - ) - - hard_loss /= full_target.shape[0] - - soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) - soft_loss /= full_target.shape[0] - - loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss - return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c31cbba8b..26ae38a3d 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -64,6 +64,103 @@ def chunk_forward( chosen_nll_loss, ) + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + use_ref_model=False, + ref_weight=None, + ref_bias=None, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + + if use_ref_model: + with torch.no_grad(): + ( + ref_chosen_logps, + ref_rejected_logps, + ref_chosen_logits, + ref_rejected_logits, + ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + + preference_loss_outputs = preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs + ) + if isinstance(preference_loss_outputs, tuple): + preference_loss, *aux_outputs = preference_loss_outputs + else: + preference_loss, aux_outputs = preference_loss_outputs, [] + + loss = alpha * chosen_nll_loss - preference_loss + return_vars = ( + chosen_logps, + rejected_logps, + chosen_logits_mean, + rejected_logits_mean, + chosen_nll_loss, + ) + return loss, (*return_vars, *aux_outputs) + @staticmethod def forward( ctx, @@ -134,7 +231,7 @@ def forward( **loss_kwargs, ) - def accumulate_helper(input_chunk, target_chunk): + def accumulate_core(input_chunk, target_chunk): if bias is not None: return torch.func.grad_and_value( loss_func_to_call, argnums=(0, 1, 3), has_aux=True @@ -156,7 +253,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: (chunk_grad_input, chunk_grad_weight), ( @@ -169,7 +266,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_weight.add_(chunk_grad_weight) loss_acc.add_(chunk_loss) @@ -199,7 +296,7 @@ def accumulate_chunk(input_chunk, target_chunk): return chunk_grad_input if compiled: - accumulate_helper = torch.compile(accumulate_helper) + accumulate_core = torch.compile(accumulate_core) len_chosen = target.shape[0] // 2 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) @@ -270,100 +367,3 @@ def backward(ctx, *grad_output): grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias, None, None, None - - @staticmethod - def _compute_loss( - input_chunk, - weight, - target_chunk, - bias=None, - preference_loss_fn=None, - full_target=None, - ignore_index=-100, - alpha=1.0, - beta=0.1, - compute_nll_loss=True, - use_ref_model=False, - ref_weight=None, - ref_bias=None, - **loss_kwargs, - ): - """ - Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. - Args: - preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - alpha (float): Weight for the NLL loss. - beta (float): Weight for the odds ratio loss. - compute_nll_loss (bool): Whether to compute NLL loss. - use_ref_model (bool): Whether to use a reference model for the alignment loss. - ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). - ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). - loss_kwargs (dict): Additional arguments for the loss function. - """ - ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - weight, - target_chunk, - bias=bias, - ignore_index=ignore_index, - compute_nll_loss=compute_nll_loss, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - - if use_ref_model: - with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model - ) - loss_kwargs["ref_chosen_logps"] = ref_chosen_logps - loss_kwargs["ref_rejected_logps"] = ref_rejected_logps - - preference_loss_outputs = preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs - ) - if isinstance(preference_loss_outputs, tuple): - preference_loss, *aux_outputs = preference_loss_outputs - else: - preference_loss, aux_outputs = preference_loss_outputs, [] - - loss = alpha * chosen_nll_loss - preference_loss - return_vars = ( - chosen_logps, - rejected_logps, - chosen_logits_mean, - rejected_logits_mean, - chosen_nll_loss, - ) - return loss, (*return_vars, *aux_outputs) From 356152524501da66863a5ba09b87daaa0232e8e0 Mon Sep 17 00:00:00 2001 From: shivam15s Date: Sat, 7 Dec 2024 01:21:02 +0000 Subject: [PATCH 11/11] add no grad in tests --- test/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/utils.py b/test/utils.py index 69f8a8c40..4c2fd5195 100644 --- a/test/utils.py +++ b/test/utils.py @@ -567,9 +567,10 @@ def concatenated_forward( if student_bias is not None: student_outputs = student_outputs + student_bias - teacher_outputs = teacher_input_reshaped @ teacher_weight.t() - if teacher_bias is not None: - teacher_outputs = teacher_outputs + teacher_bias + with torch.no_grad(): + teacher_outputs = teacher_input_reshaped @ teacher_weight.t() + if teacher_bias is not None: + teacher_outputs = teacher_outputs + teacher_bias student_logits = student_outputs.view(student_batch_seq_len_size, -1).float() teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float()