diff --git a/benchmark/scripts/benchmark_distill_jsd_loss.py.py b/benchmark/scripts/benchmark_distill_jsd_loss.py.py new file mode 100644 index 000000000..ab1605f57 --- /dev/null +++ b/benchmark/scripts/benchmark_distill_jsd_loss.py.py @@ -0,0 +1,236 @@ +import os +import sys + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchJSDLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + from test.chunked_loss.test_jsd_loss import HFJSDLoss + + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=bias, dtype=dtype + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.jsd_loss = HFJSDLoss( + ignore_index=ignore_index, beta=beta + ).get_batch_loss_metrics + + def forward(self, student, teacher, target): + return self.jsd_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + ) + + +class LigerJSDLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=bias, dtype=dtype + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.beta = beta + self.ignore_index = ignore_index + self.temperature = temperature + self.jsd_loss = LigerFusedLinearJSDFunction.apply + + def forward(self, student, teacher, target): + return self.jsd_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + self.beta, + ) + + +def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_jsd_loss = TorchJSDLoss( + H=H, V=V, dtype=dtype, ignore_index=ignore_index, bias=bias, beta=beta + ).to(device) + liger_jsd_loss = LigerJSDLoss( + H=H, V=V, dtype=dtype, ignore_index=ignore_index, bias=bias, beta=beta + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_jsd_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_jsd_loss(student_input2, teacher_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_jsd_loss = TorchJSDLoss( + H=H, V=V, dtype=dtype, ignore_index=ignore_index, bias=bias, beta=beta + ).to(device) + liger_jsd_loss = LigerJSDLoss( + H=H, V=V, dtype=dtype, ignore_index=ignore_index, bias=bias, beta=beta + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_jsd_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_jsd_loss(student_input2, teacher_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input1, student_input2], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "distill_jsd_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 4)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": False, + "beta": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_jsd_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + + run_benchmarks( + bench_test_fn=bench_memory_jsd_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 238bdded9..5dcd2fbf3 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,4 +1,5 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py index 5a51d3f72..f0bd8a8dd 100644 --- a/src/liger_kernel/chunked_loss/functional.py +++ b/src/liger_kernel/chunked_loss/functional.py @@ -1,5 +1,6 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction @@ -7,3 +8,5 @@ liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply + +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py new file mode 100644 index 000000000..0fd23c484 --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -0,0 +1,278 @@ +from abc import abstractmethod +from functools import partial + +import torch +from torch.nn import functional as F + + +class LigerFusedLinearDistillationBase(torch.autograd.Function): + + @abstractmethod + def distillation_loss_fn(student_logps, teacher_logps): + """ + Compute distillation loss. + Args: + student_logps (torch.Tensor): Avg log probabilities of student inputs. Shape: (batch_size, hidden_size,). + teacher_logps (torch.Tensor): Avg log probabilities of teacher inputs. Shape: (batch_size, hidden_size,). + """ + raise NotImplementedError("Distillation loss function must be implemented.") + + @staticmethod + def chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + ignore_index=-100, + compute_ce_loss=True, + ): + # Student + student_logits_chunk = student_input_chunk @ student_weight.t() + if student_bias is not None: + student_logits_chunk = student_logits_chunk + student_bias + student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) + + ce_loss = 0.0 + if compute_ce_loss: + # The hard/task loss + ce_loss = F.cross_entropy( + student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + student_average_log_prob = torch.zeros_like(loss_mask, dtype=torch.float) + student_per_token_logps = student_log_probs_chunk.gather( + -1, label_chunk.unsqueeze(-1) + ).squeeze(-1) + + loss_mask_sum = loss_mask.sum(-1) + valid_mask = loss_mask_sum > 0 + + if valid_mask.any(): + student_average_log_prob[valid_mask] = ( + student_per_token_logps * loss_mask + ).sum(-1)[valid_mask] / loss_mask_sum[valid_mask] + + # Teacher + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk = teacher_logits_chunk + teacher_bias + teacher_log_probs_chunk = F.log_softmax(teacher_logits_chunk.float(), dim=-1) + + teacher_average_log_prob = torch.zeros_like(loss_mask, dtype=torch.float) + teacher_per_token_logps = teacher_log_probs_chunk.gather( + -1, label_chunk.unsqueeze(-1) + ).squeeze(-1) + + if valid_mask.any(): + teacher_average_log_prob[valid_mask] = ( + teacher_per_token_logps * loss_mask + ).sum(-1)[valid_mask] / loss_mask_sum[valid_mask] + + return student_average_log_prob, teacher_average_log_prob, ce_loss + + @staticmethod + def forward( + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias=None, + teacher_bias=None, + loss_fn=None, + chunk_size=1, + ignore_index=-100, + beta=0.5, + compute_ce_loss=True, + temperature=1.0, + compiled=True, + **loss_kwargs, + ): + """ + Base class for fused linear layer with distillation loss. + Only need to compute gradients for student model. + + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size, seq_len, hidden_size). + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size). + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size, seq_len, hidden_size). + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target truth label tensor. Shape: (batch_size, seq_len). + student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + compute_ce_loss (bool): Whether to compute NLL loss. + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight between soft and hard loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = chunk_size + + grad_weight = torch.zeros_like(student_weight) + grad_inputs = [] + grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None + loss_acc = torch.zeros((), device=student_input.device) + + loss_func_to_call = partial( + LigerFusedLinearDistillationBase._compute_loss, + distillation_loss_fn=loss_fn, + ignore_index=ignore_index, + beta=beta, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + full_target=target, + **loss_kwargs, + ) + + def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): + if student_bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( + chunk_loss, + ( + chunk_distillation_loss, + chunk_ce_loss, + chunk_student_logps, + chunk_teacher_logps, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1, 5), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), ( + chunk_loss, + ( + chunk_distillation_loss, + chunk_ce_loss, + chunk_student_logps, + chunk_teacher_logps, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + + student_chunks = max(1, student_input.shape[0] // (2 * CHUNK_SIZE)) + teacher_chunks = max(1, teacher_input.shape[0] // (2 * CHUNK_SIZE)) + target_chunks = max(1, target.shape[0] // (2 * CHUNK_SIZE)) + + _student_input_chunks = torch.chunk(student_input, chunks=student_chunks, dim=0) + _teacher_input_chunks = torch.chunk(teacher_input, chunks=teacher_chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=target_chunks, dim=0) + + for student_input_chunk, teacher_input_chunk, target_chunk in zip( + _student_input_chunks, _teacher_input_chunks, _target_chunks + ): + grad_input = accumulate_chunk( + student_input_chunk, teacher_input_chunk, target_chunk + ) + grad_inputs.append(grad_input) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + grad_bias = grad_bias * grad_output if grad_bias is not None else None + + return grad_input, grad_weight, None, grad_bias, None, None, None + + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + temperature=1.0, + beta=0.5, + compute_ce_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, sequence_length, hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, sequence_length, hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size, sequence_length). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight between soft and hard loss. + compute_ce_loss (bool): Whether to compute CE loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + student_logps, teacher_logps, ce_loss = ( + LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + ) + + ce_loss = ce_loss / (full_target != ignore_index).sum() + + distillation_loss = distillation_loss_fn( + student_logps, teacher_logps, temperature + ) + distillation_loss = distillation_loss / (full_target.shape[0]) + + loss = beta * ce_loss + (1 - beta) * distillation_loss + return loss, (distillation_loss, ce_loss, student_logps, teacher_logps) diff --git a/src/liger_kernel/chunked_loss/jsd_loss.py b/src/liger_kernel/chunked_loss/jsd_loss.py new file mode 100644 index 000000000..eaa5594d9 --- /dev/null +++ b/src/liger_kernel/chunked_loss/jsd_loss.py @@ -0,0 +1,153 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_distillation import ( + LigerFusedLinearDistillationBase, +) + + +class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase): + @staticmethod + def distillation_loss_fn(student_logps, teacher_logps, temperature): + """ + Compute Jensen-Shannon Divergence loss between student and teacher distributions. + Args: + student_logps (torch.Tensor): Log probabilities from student model (Raw logits after log_softmax) + teacher_logps (torch.Tensor): Log probabilities from teacher model (Raw logits after log_softmax) + temperature (float): Temperature for softening probability distributions + Returns: + torch.Tensor: Jensen-Shannon Divergence loss + """ + # TODO: should incorporate with (high) temperature scaling on raw logits + + # For instance, + # Scale logits by temperature + # student_logits = student_logits / temperature + # teacher_logits = teacher_logits / temperature + # Convert to probabilities + # student_probs = F.softmax(student_logits, dim=-1) + # teacher_probs = F.softmax(teacher_logits, dim=-1) + + log_mean_probs = torch.log( + (torch.exp(student_logps) + torch.exp(teacher_logps)) / 2 + ) + + student_kl = F.kl_div( + log_mean_probs, student_logps, reduction="batchmean", log_target=True + ) + + teacher_kl = F.kl_div( + log_mean_probs, teacher_logps, reduction="batchmean", log_target=True + ) + + # JSD is the average of the KL divergences + jsd_loss = (student_kl + teacher_kl) / 2 + return jsd_loss + + @staticmethod + def forward( + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + ): + """ + Fused linear layer with JSD distillation loss. + Args: + student_input (torch.Tensor): Student input tensor. Shape: (BT, H_s) + student_weight (torch.Tensor): Student weight tensor. Shape: (V, H_s) + teacher_input (torch.Tensor): Teacher input tensor. Shape: (BT, H_t) + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (V, H_t) + true_labels (torch.LongTensor): Target tensor. Shape: (BT,) + beta (float): Weight for distillation loss + ignore_index (int): Index to ignore in loss computation + temperature (float): Temperature for softening distributions + compiled (bool): Whether to use torch compile + Returns: + torch.Tensor: Computed loss + """ + return LigerFusedLinearDistillationBase.forward( + ctx=ctx, + student_input=student_input, + student_weight=student_weight, + teacher_input=teacher_input, + teacher_weight=teacher_weight, + target=true_labels, + loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn, + beta=beta, + ignore_index=ignore_index, + temperature=temperature, + compiled=compiled, + ) + + @staticmethod + def backward(ctx, grad_output): + grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4] + + return (*grads, None, None, None, None, None) + + +class LigerFusedLinearJSDLoss(torch.nn.Module): + """ + Fused linear layer with JSD distillation loss. + """ + + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = False, + ): + """ + Args: + beta (float): Weight for distillation loss + ignore_index (int): Index to ignore in the loss + temperature (float): Temperature for softening distributions + compiled (bool): Whether to use torch compile + """ + super().__init__() + assert temperature != 0, "Temperature cannot be 0." + self.beta = beta + self.ignore_index = ignore_index + self.temperature = temperature + self.compiled = compiled + + def forward( + self, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + ) -> torch.Tensor: + """ + Compute the JSD distillation loss. + + Args: + student_input (torch.Tensor): Student input tensor + student_weight (torch.Tensor): Student weight tensor + teacher_input (torch.Tensor): Teacher input tensor + teacher_weight (torch.Tensor): Teacher weight tensor + true_labels (torch.LongTensor): Target labels tensor + + Returns: + torch.Tensor: Computed loss + """ + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + true_labels, + self.beta, + self.ignore_index, + self.temperature, + self.compiled, + ) diff --git a/test/chunked_loss/test_jsd_loss.py b/test/chunked_loss/test_jsd_loss.py new file mode 100644 index 000000000..4c4f4cedb --- /dev/null +++ b/test/chunked_loss/test_jsd_loss.py @@ -0,0 +1,295 @@ +from test.utils import HFDistillationLoss, assert_verbose_allclose, set_seed + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_jsd +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.utils import infer_device + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFJSDLoss(HFDistillationLoss): + """ + Naive implementation of a distillation loss using Jensen-Shannon Divergence (JSD). + """ + + def __init__( + self, temperature: float = 1.0, ignore_index: int = -100, beta: float = 0.5 + ): + super().__init__(ignore_index=ignore_index, beta=beta) + self.temperature = temperature + + def distillation_loss( + self, + student_logps: torch.FloatTensor, + teacher_logps: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Compute Jensen-Shannon Divergence loss between student and teacher distributions. + Args: + student_logps (torch.Tensor): Log probabilities from student model (Raw logits after log_softmax) + teacher_logps (torch.Tensor): Log probabilities from teacher model (Raw logits after log_softmax) + temperature (float): Temperature for softening probability distributions + Returns: + torch.Tensor: Jensen-Shannon Divergence loss + """ + # TODO: should incorporate with (high) temperature scaling on raw logits + + # For instance, + # Scale logits by temperature + # student_logits = student_logits / self.temperature + # teacher_logits = teacher_logits / self.temperature + # Convert to probabilities + # student_probs = F.softmax(student_logits, dim=-1) + # teacher_probs = F.softmax(teacher_logits, dim=-1) + + log_mean_probs = torch.log( + (torch.exp(student_logps) + torch.exp(teacher_logps)) / 2 + ) + + student_kl = F.kl_div( + log_mean_probs, student_logps, reduction="batchmean", log_target=True + ) + + teacher_kl = F.kl_div( + log_mean_probs, teacher_logps, reduction="batchmean", log_target=True + ) + + # JSD is the average of the KL divergences + jsd_loss = (student_kl + teacher_kl) / 2 + return jsd_loss + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # Create smaller student weight + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = HFJSDLoss( + ignore_index=ignore_index, beta=beta + ).get_batch_loss_metrics + self.temperature = temperature + + def forward(self, student_input, teacher_input, target): + + jsd_loss = self.jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + ) + return jsd_loss + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # Create smaller student weight + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.chunked_jsd = LigerFusedLinearJSDLoss(beta=beta, ignore_index=ignore_index) + self.temperature = temperature + + def forward(self, student_input, teacher_input, target): + return self.chunked_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + ) + + +############################################################################# +# Test the correctness of the fused linear JSD +############################################################################# + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-4, 5e-3), + ], +) +@pytest.mark.parametrize( + "temperature, beta", + [ + (1.0, 0.5), + (2.0, 0.1), + (1.0, 0.0), + (1.0, 1.0), + ], +) +def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ) + + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + with torch.autograd.detect_anomaly(): + output1 = torch_lm_head_jsd(student_input1, teacher_input, target) + output2 = liger_lm_head_jsd(student_input2, teacher_input, target) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose( + student_input1.grad, student_input2.grad, atol=atol, rtol=rtol + ) + + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (9, 7, 41, 41), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-4, 5e-3), + ], +) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", [(1.0, 0.5, -100), (2.0, 0.1, 42)] +) +def test_correctness_functional( + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol +): + _weight = torch.rand(V, H // 2, device=device, dtype=dtype) + student_weight1 = _weight.detach().clone().requires_grad_(True) + student_weight2 = _weight.detach().clone().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + label[indices_to_assign] = ignore_index + + output1 = liger_fused_linear_jsd( + student_input1, + student_weight1, + teacher_input, + teacher_weight, + label, + beta, + ignore_index, + temperature, + ) + output2 = LigerFusedLinearJSDFunction.apply( + student_input2, + student_weight2, + teacher_input, + teacher_weight, + label, + beta, + ignore_index, + temperature, + ) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose( + student_input1.grad, student_input2.grad, atol=atol, rtol=rtol + ) + + assert_verbose_allclose( + student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol + ) diff --git a/test/utils.py b/test/utils.py index e8383d659..91c2954c3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -372,8 +372,6 @@ def get_batch_logps( logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - is_encoder_decoder: Whether the model is an encoder-decoder model. - Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ @@ -509,3 +507,144 @@ def get_batch_loss_metrics( # full loss loss = policy_nll_loss * self.alpha - losses.mean() return loss + + +# TODO: Should we name it to HFDistillationLos or NavieDistillationLos? Since there's no Hugging Face impl of distill loss +class HFDistillationLoss: + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + ): + self.beta = beta + self.ignore_index = ignore_index + + @abstractmethod + def distillation_loss(self, student_logps, teacher_logps): + """Abstract method for computing distillation loss.""" + pass + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute log probabilities for a batch.""" + loss_mask = labels != self.ignore_index + safe_labels = torch.where(loss_mask, labels, 0) + + log_probs = logits.log_softmax(dim=-1) + per_token_logps = torch.gather( + log_probs, dim=-1, index=safe_labels.unsqueeze(-1) + ).squeeze(-1) + + per_token_logps = per_token_logps * loss_mask + + if average_log_prob: + return per_token_logps.sum(-1) / loss_mask.sum(-1) + else: + return per_token_logps.sum(-1) + + def concatenated_forward( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + average_log_prob: bool = True, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + """Compute forward pass for both student and teacher models.""" + + student_batch_seq_len_size, student_hidden_size = student_input.shape + student_input_reshaped = student_input.view(-1, student_hidden_size) + teacher_batch_seq_len_size, teacher_hidden_size = teacher_input.shape + teacher_input_reshaped = teacher_input.view(-1, teacher_hidden_size) + + student_outputs = student_input_reshaped @ student_weight.t() + if student_bias is not None: + student_outputs = student_outputs + student_bias + + teacher_outputs = teacher_input_reshaped @ teacher_weight.t() + if teacher_bias is not None: + teacher_outputs = teacher_outputs + teacher_bias + + student_logits = student_outputs.view(student_batch_seq_len_size, -1).float() + teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float() + + if torch.all(target == self.ignore_index): + return torch.tensor(0.0) + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + ce_loss = cross_entropy_loss( + student_logits.view(-1, student_logits.shape[-1]), + labels.view(-1), + ) + + student_logps = self.get_batch_logps( + student_logits, target, average_log_prob=average_log_prob + ) + teacher_logps = self.get_batch_logps( + teacher_logits, target, average_log_prob=average_log_prob + ) + + return ( + student_logps, + teacher_logps, + student_logits, + teacher_logits, + ce_loss, + ) + + def get_batch_loss_metrics( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + average_log_prob: bool = True, + ): + """Compute the distillation loss and other metrics for the given batch.""" + forward_output = self.concatenated_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias, + teacher_bias, + average_log_prob, + ) + ( + student_logps, + teacher_logps, + student_logits, + teacher_logits, + student_ce_loss, + ) = forward_output + + distill_loss = self.distillation_loss(student_logps, teacher_logps) + loss = student_ce_loss * (self.beta) + distill_loss.mean() * (1 - self.beta) + return loss