diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 41e7d82be..929bdff34 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -36,9 +36,10 @@ def chunk_forward( student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) # Teacher - teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() - if teacher_bias is not None: - teacher_logits_chunk += teacher_bias + with torch.no_grad(): + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk += teacher_bias # The hard/task loss ce_loss = 0.0 @@ -63,7 +64,7 @@ def forward( student_bias=None, teacher_bias=None, loss_fn=None, - chunk_size=1, + chunk_size=1024, ignore_index=-100, weight_hard_loss=0.5, weight_soft_loss=0.5, @@ -243,6 +244,7 @@ def _compute_loss( hard_loss /= full_target.shape[0] soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) - soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0]) + soft_loss /= full_target.shape[0] + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)