diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 46ef2cba9..9954960cd 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -228,7 +228,7 @@ def _compute_loss( compute_ce_loss (bool): Whether to compute CE loss. loss_kwargs (dict): Additional arguments for the loss function. """ - student_logits, teacher_logits, hard_loss = ( + student_logits_chunk, teacher_logits_chunk, hard_loss = ( LigerFusedLinearDistillationBase.chunk_forward( student_input_chunk, student_weight, @@ -242,12 +242,9 @@ def _compute_loss( ) ) - hard_loss /= (full_target != ignore_index).sum() - - soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature) - soft_loss /= max( - 1, (full_target[: full_target.shape[0]] != ignore_index).sum() // chunk_size - ) + hard_loss /= full_target.shape[0] + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0]) loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss - return loss, (soft_loss, hard_loss, student_logits, teacher_logits) + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)