-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Knowledge Distillation Base #417
Closed
Closed
Changes from 9 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
0769e97
Add liger and naive distill base
austin362667 a81c959
Format
austin362667 e13994a
Refactor beta
austin362667 720b5cb
Remove imports
austin362667 17c5b33
Fix distill base `chunk_size` scaling
austin362667 e3dada0
Fix chunk division
austin362667 5662554
Remove chunk arg
austin362667 7acb5ca
Fix `distillation_loss` arg typo
austin362667 e381569
use torch no grad and change normalization term
shivam15s 8aa842a
rearrange fns for readability
shivam15s 3561525
add no grad in tests
shivam15s 076c220
Merge branch 'main' into feat/distill_base
shivam15s File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
250 changes: 250 additions & 0 deletions
250
src/liger_kernel/chunked_loss/fused_linear_distillation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
from abc import abstractmethod | ||
from functools import partial | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
|
||
|
||
class LigerFusedLinearDistillationBase(torch.autograd.Function): | ||
|
||
@abstractmethod | ||
def distillation_loss_fn(student_logits, teacher_logits, temperature): | ||
""" | ||
Compute distillation loss. | ||
Args: | ||
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size). | ||
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). | ||
""" | ||
raise NotImplementedError("Distillation loss function must be implemented.") | ||
|
||
@staticmethod | ||
def chunk_forward( | ||
student_input_chunk, | ||
student_weight, | ||
teacher_input_chunk, | ||
teacher_weight, | ||
target_chunk, | ||
student_bias=None, | ||
teacher_bias=None, | ||
ignore_index=-100, | ||
compute_ce_loss=True, | ||
): | ||
# Student | ||
student_logits_chunk = student_input_chunk @ student_weight.t() | ||
if student_bias is not None: | ||
student_logits_chunk += student_bias | ||
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) | ||
|
||
# Teacher | ||
with torch.no_grad(): | ||
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() | ||
if teacher_bias is not None: | ||
teacher_logits_chunk += teacher_bias | ||
|
||
# The hard/task loss | ||
ce_loss = 0.0 | ||
if compute_ce_loss: | ||
ce_loss = F.nll_loss( | ||
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), | ||
target_chunk.view(-1), | ||
reduction="sum", | ||
ignore_index=ignore_index, | ||
) | ||
|
||
return student_logits_chunk, teacher_logits_chunk, ce_loss | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
student_input, | ||
student_weight, | ||
teacher_input, | ||
teacher_weight, | ||
target, | ||
student_bias=None, | ||
teacher_bias=None, | ||
loss_fn=None, | ||
chunk_size=1024, | ||
ignore_index=-100, | ||
weight_hard_loss=0.5, | ||
weight_soft_loss=0.5, | ||
compute_ce_loss=True, | ||
temperature=1.0, | ||
compiled=True, | ||
**loss_kwargs, | ||
): | ||
""" | ||
Base class for fused linear layer with distillation loss. | ||
Only need to compute gradients for student model. | ||
|
||
Args: | ||
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size). | ||
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size). | ||
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size). | ||
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size). | ||
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len). | ||
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). | ||
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). | ||
loss_fn (callable): Loss function to compute the loss on a chunk of input/target. | ||
chunk_size (int): Size of a chunk. | ||
compute_ce_loss (bool): Whether to compute CE loss. | ||
ignore_index (int): Index to ignore for loss computation. | ||
weight_hard_loss (float): Weight for hard/task loss. | ||
weight_soft_loss (float): Weight for soft/distillation loss. | ||
compiled (bool): Whether to use torch compile for chunk accumulation. | ||
loss_kwargs (dict): Other possible arguments that a loss function might need | ||
""" | ||
CHUNK_SIZE = chunk_size | ||
grad_weight = torch.zeros_like(student_weight) | ||
grad_inputs = [] | ||
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None | ||
loss_acc = torch.zeros((), device=student_input.device) | ||
|
||
loss_func_to_call = partial( | ||
LigerFusedLinearDistillationBase._compute_loss, | ||
distillation_loss_fn=loss_fn, | ||
full_target=target, | ||
ignore_index=ignore_index, | ||
weight_hard_loss=weight_hard_loss, | ||
weight_soft_loss=weight_soft_loss, | ||
compute_ce_loss=compute_ce_loss, | ||
temperature=temperature, | ||
**loss_kwargs, | ||
) | ||
|
||
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): | ||
if student_bias is not None: | ||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( | ||
chunk_loss, | ||
( | ||
chunk_soft_loss, | ||
chunk_hard_loss, | ||
chunk_student_logits, | ||
chunk_teacher_logits, | ||
), | ||
) = torch.func.grad_and_value( | ||
loss_func_to_call, argnums=(0, 1, 5), has_aux=True | ||
)( | ||
student_input_chunk, | ||
student_weight, | ||
teacher_input_chunk, | ||
teacher_weight, | ||
target_chunk, | ||
student_bias, | ||
teacher_bias, | ||
) | ||
grad_bias.add_(chunk_grad_bias) | ||
else: | ||
(chunk_grad_input, chunk_grad_weight), ( | ||
chunk_loss, | ||
( | ||
chunk_soft_loss, | ||
chunk_hard_loss, | ||
chunk_student_logits, | ||
chunk_teacher_logits, | ||
), | ||
) = torch.func.grad_and_value( | ||
loss_func_to_call, argnums=(0, 1), has_aux=True | ||
)( | ||
student_input_chunk, | ||
student_weight, | ||
teacher_input_chunk, | ||
teacher_weight, | ||
target_chunk, | ||
student_bias, | ||
teacher_bias, | ||
) | ||
grad_weight.add_(chunk_grad_weight) | ||
loss_acc.add_(chunk_loss) | ||
return chunk_grad_input | ||
|
||
if compiled: | ||
accumulate_chunk = torch.compile(accumulate_chunk) | ||
|
||
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) | ||
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) | ||
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) | ||
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) | ||
|
||
for student_input_chunk, teacher_input_chunk, target_chunk in zip( | ||
_student_input_chunks, _teacher_input_chunks, _target_chunks | ||
): | ||
grad_input = accumulate_chunk( | ||
student_input_chunk, teacher_input_chunk, target_chunk | ||
) | ||
grad_inputs.append(grad_input) | ||
|
||
ctx.save_for_backward( | ||
torch.cat(grad_inputs, dim=0), | ||
grad_weight, | ||
grad_bias, | ||
) | ||
return loss_acc | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
grad_input, grad_weight, grad_bias = ctx.saved_tensors | ||
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): | ||
grad_input = grad_input * grad_output | ||
grad_weight = grad_weight * grad_output | ||
grad_bias = grad_bias * grad_output if grad_bias is not None else None | ||
|
||
return grad_input, grad_weight, None, grad_bias | ||
|
||
@staticmethod | ||
def _compute_loss( | ||
student_input_chunk, | ||
student_weight, | ||
teacher_input_chunk, | ||
teacher_weight, | ||
target_chunk, | ||
student_bias=None, | ||
teacher_bias=None, | ||
distillation_loss_fn=None, | ||
full_target=None, | ||
ignore_index=-100, | ||
temperature=1.0, | ||
weight_hard_loss=0.5, | ||
weight_soft_loss=0.5, | ||
compute_ce_loss=True, | ||
**loss_kwargs, | ||
): | ||
""" | ||
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. | ||
Args: | ||
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. | ||
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). | ||
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). | ||
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). | ||
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). | ||
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). | ||
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). | ||
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). | ||
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). | ||
ignore_index (int): Index to ignore for loss computation. | ||
weight_hard_loss (float): Weight for hard loss. | ||
weight_soft_loss (float): Weight for soft loss. | ||
compute_ce_loss (bool): Whether to compute CE loss. | ||
loss_kwargs (dict): Additional arguments for the loss function. | ||
""" | ||
student_logits_chunk, teacher_logits_chunk, hard_loss = ( | ||
LigerFusedLinearDistillationBase.chunk_forward( | ||
student_input_chunk, | ||
student_weight, | ||
teacher_input_chunk, | ||
teacher_weight, | ||
target_chunk, | ||
student_bias=student_bias, | ||
teacher_bias=teacher_bias, | ||
ignore_index=ignore_index, | ||
compute_ce_loss=compute_ce_loss, | ||
) | ||
) | ||
|
||
hard_loss /= full_target.shape[0] | ||
|
||
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) | ||
soft_loss /= full_target.shape[0] | ||
|
||
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss | ||
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the method use logprobs :
def distillation_loss(self, student_logps, teacher_logps):
but you use logits here.I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@winglian Nice catch! Thank you so much.
Sure, I think it's doable. And, I'm not quite sure I fully understand the need for
logprobs
implementation. Mind elaborate more on the vLLM use case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. That makes a lot sense to me. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@winglian curious if vllm/sglang support temperature scaled logprobs. This would be needed to enable https://github.com/huggingface/trl/blob/9c5388b69e0842f76edc46a2ff9d0b51e1db4337/trl/trainer/gkd_trainer.py#L174
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can address this ask in a subsequent PR
@ByronHsu what do you think?