Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Knowledge Distillation Base #417

Closed
wants to merge 12 commits into from
250 changes: 250 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from abc import abstractmethod
from functools import partial

import torch
from torch.nn import functional as F


class LigerFusedLinearDistillationBase(torch.autograd.Function):

@abstractmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
"""
raise NotImplementedError("Distillation loss function must be implemented.")

@staticmethod
def chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
ignore_index=-100,
compute_ce_loss=True,
):
# Student
student_logits_chunk = student_input_chunk @ student_weight.t()
if student_bias is not None:
student_logits_chunk += student_bias
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)

# Teacher
with torch.no_grad():
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_logits_chunk += teacher_bias

# The hard/task loss
ce_loss = 0.0
if compute_ce_loss:
ce_loss = F.nll_loss(
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)

return student_logits_chunk, teacher_logits_chunk, ce_loss

@staticmethod
def forward(
ctx,
student_input,
student_weight,
teacher_input,
teacher_weight,
target,
student_bias=None,
teacher_bias=None,
loss_fn=None,
chunk_size=1024,
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
temperature=1.0,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with distillation loss.
Only need to compute gradients for student model.

Args:
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk.
compute_ce_loss (bool): Whether to compute CE loss.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
CHUNK_SIZE = chunk_size
grad_weight = torch.zeros_like(student_weight)
grad_inputs = []
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
loss_acc = torch.zeros((), device=student_input.device)

loss_func_to_call = partial(
LigerFusedLinearDistillationBase._compute_loss,
distillation_loss_fn=loss_fn,
full_target=target,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if student_bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_bias.add_(chunk_grad_bias)
else:
(chunk_grad_input, chunk_grad_weight), (
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)

for student_input_chunk, teacher_input_chunk, target_chunk in zip(
_student_input_chunks, _teacher_input_chunks, _target_chunks
):
grad_input = accumulate_chunk(
student_input_chunk, teacher_input_chunk, target_chunk
)
grad_inputs.append(grad_input)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
return loss_acc

@staticmethod
def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
grad_bias = grad_bias * grad_output if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias

@staticmethod
def _compute_loss(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
student_logits_chunk, teacher_logits_chunk, hard_loss = (
LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=student_bias,
teacher_bias=teacher_bias,
ignore_index=ignore_index,
compute_ce_loss=compute_ce_loss,
)
)

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
109 changes: 109 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,112 @@ def get_batch_loss_metrics(
policy_nll_loss,
)
return loss, (*return_vars, *aggregated_aux_outputs)


class HFDistillationLoss:
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1,
):
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature

@abstractmethod
def distillation_loss(self, student_logits, teacher_logits):
"""Abstract method for computing distillation loss."""
pass

def concatenated_forward(
self,
student_input: torch.FloatTensor,
student_weight: torch.FloatTensor,
teacher_input: torch.FloatTensor,
teacher_weight: torch.FloatTensor,
target: torch.LongTensor,
student_bias: torch.FloatTensor = None,
teacher_bias: torch.FloatTensor = None,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
]:
"""Compute forward pass for both student and teacher models."""

student_batch_seq_len_size, student_hidden_size = student_input.shape
student_input_reshaped = student_input.view(-1, student_hidden_size)
teacher_batch_seq_len_size, teacher_hidden_size = teacher_input.shape
teacher_input_reshaped = teacher_input.view(-1, teacher_hidden_size)

student_outputs = student_input_reshaped @ student_weight.t()
if student_bias is not None:
student_outputs = student_outputs + student_bias

teacher_outputs = teacher_input_reshaped @ teacher_weight.t()
if teacher_bias is not None:
teacher_outputs = teacher_outputs + teacher_bias

student_logits = student_outputs.view(student_batch_seq_len_size, -1).float()
teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float()

if torch.all(target == self.ignore_index):
return torch.tensor(0.0)

def cross_entropy_loss(logits, labels):
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = target
ce_loss = cross_entropy_loss(
student_logits.view(-1, student_logits.shape[-1]),
labels.view(-1),
)

return (
student_logits,
teacher_logits,
ce_loss,
)

def get_batch_loss_metrics(
self,
student_input: torch.FloatTensor,
student_weight: torch.FloatTensor,
teacher_input: torch.FloatTensor,
teacher_weight: torch.FloatTensor,
target: torch.LongTensor,
student_bias: torch.FloatTensor = None,
teacher_bias: torch.FloatTensor = None,
):
"""Compute the distillation loss metrics for the given batch."""
forward_output = self.concatenated_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
target,
student_bias,
teacher_bias,
)
(
student_logits,
teacher_logits,
hard_loss,
) = forward_output

soft_loss = self.distillation_loss(student_logits, teacher_logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the method use logprobs : def distillation_loss(self, student_logps, teacher_logps): but you use logits here.

I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.

Copy link
Collaborator Author

@austin362667 austin362667 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the method use logprobs : def distillation_loss(self, student_logps, teacher_logps): but you use logits here.

@winglian Nice catch! Thank you so much.

I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.

Sure, I think it's doable. And, I'm not quite sure I fully understand the need for logprobs implementation. Mind elaborate more on the vLLM use case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That makes a lot sense to me. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian curious if vllm/sglang support temperature scaled logprobs. This would be needed to enable https://github.com/huggingface/trl/blob/9c5388b69e0842f76edc46a2ff9d0b51e1db4337/trl/trainer/gkd_trainer.py#L174

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can address this ask in a subsequent PR
@ByronHsu what do you think?

# full loss
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean()
return loss
Loading