Skip to content

Commit

Permalink
[GKD] interpolate in prob. space (#2204)
Browse files Browse the repository at this point in the history
* interpolate in prob. space

* better var names

* use logsumexp

* set beta dtype

* beta tensor
  • Loading branch information
kashif authored Oct 9, 2024
1 parent ed9ea74 commit 7e5924d
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,18 @@ def generalized_jsd_loss(
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

# Compute the interpolated log probabilities
interpolated_log_probs = beta * student_log_probs + (1 - beta) * teacher_log_probs
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
dim=0,
)

# Compute KL divergences using F.kl_div
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F.kl_div(interpolated_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(interpolated_log_probs, student_log_probs, reduction="none", log_target=True)
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)

# Compute the Generalized Jensen-Shannon Divergence
jsd = beta * kl_teacher + (1 - beta) * kl_student
Expand Down

0 comments on commit 7e5924d

Please sign in to comment.