Skip to content

Commit

Permalink
use torch no grad and change normalization term
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam15s committed Dec 7, 2024
1 parent 7acb5ca commit e381569
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def chunk_forward(
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)

# Teacher
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_logits_chunk += teacher_bias
with torch.no_grad():
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_logits_chunk += teacher_bias

# The hard/task loss
ce_loss = 0.0
Expand All @@ -63,7 +64,7 @@ def forward(
student_bias=None,
teacher_bias=None,
loss_fn=None,
chunk_size=1,
chunk_size=1024,
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
Expand Down Expand Up @@ -243,6 +244,7 @@ def _compute_loss(
hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= (full_target.shape[0] // student_input_chunk.shape[0])
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)

0 comments on commit e381569

Please sign in to comment.