Skip to content

Commit

Permalink
Fix distill base chunk_size scaling
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>

Set default `chunk_size` to `1024`

Signed-off-by: Austin Liu <[email protected]>

Rebase

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Dec 4, 2024
1 parent a83a8a5 commit a538bf0
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,27 @@ def chunk_forward(
compute_ce_loss=True,
):
# Student
student_per_token_logits_chunk = student_input_chunk @ student_weight.t()
student_logits_chunk = student_input_chunk @ student_weight.t()
if student_bias is not None:
student_per_token_logits_chunk += student_bias
student_per_token_log_probs_chunk = F.log_softmax(
student_per_token_logits_chunk.float(), dim=-1
)
student_logits_chunk += student_bias
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)

# Teacher
teacher_per_token_logits_chunk = teacher_input_chunk @ teacher_weight.t()
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_per_token_logits_chunk += teacher_bias
teacher_logits_chunk += teacher_bias

# The hard/task loss
ce_loss = 0.0
if compute_ce_loss:
ce_loss = F.cross_entropy(
student_per_token_log_probs_chunk.view(
-1, student_per_token_log_probs_chunk.shape[-1]
),
ce_loss = F.nll_loss(
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)

return student_per_token_logits_chunk, teacher_per_token_logits_chunk, ce_loss
return student_logits_chunk, teacher_logits_chunk, ce_loss

@staticmethod
def forward(
Expand All @@ -81,24 +77,23 @@ def forward(
Only need to compute gradients for student model.
Args:
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size).
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size).
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size).
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size).
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
chunk_size (int): Size of a chunk.
compute_ce_loss (bool): Whether to compute CE loss.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
CHUNK_SIZE = chunk_size

grad_weight = torch.zeros_like(student_weight)
grad_inputs = []
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
Expand All @@ -107,12 +102,13 @@ def forward(
loss_func_to_call = partial(
LigerFusedLinearDistillationBase._compute_loss,
distillation_loss_fn=loss_fn,
full_target=target,
chunk_size=chunk_size,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
full_target=target,
**loss_kwargs,
)

Expand Down Expand Up @@ -206,6 +202,7 @@ def _compute_loss(
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
chunk_size=1,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
Expand All @@ -217,10 +214,10 @@ def _compute_loss(
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
Expand All @@ -245,10 +242,12 @@ def _compute_loss(
)
)

hard_loss = hard_loss / (full_target != ignore_index).sum()
hard_loss /= (full_target != ignore_index).sum()

soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature)
soft_loss = soft_loss / (full_target != ignore_index).sum()
soft_loss /= max(
1, (full_target[: full_target.shape[0]] != ignore_index).sum() // chunk_size
)

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits, teacher_logits)

0 comments on commit a538bf0

Please sign in to comment.