Skip to content

Commit

Permalink
Fix chunk division
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Dec 4, 2024
1 parent a538bf0 commit 805e50f
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _compute_loss(
compute_ce_loss (bool): Whether to compute CE loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
student_logits, teacher_logits, hard_loss = (
student_logits_chunk, teacher_logits_chunk, hard_loss = (
LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
Expand All @@ -242,12 +242,9 @@ def _compute_loss(
)
)

hard_loss /= (full_target != ignore_index).sum()

soft_loss = distillation_loss_fn(student_logits, teacher_logits, temperature)
soft_loss /= max(
1, (full_target[: full_target.shape[0]] != ignore_index).sum() // chunk_size
)
hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0])
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits, teacher_logits)
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)

0 comments on commit 805e50f

Please sign in to comment.