Skip to content

Commit

Permalink
refactor so we can easily add new loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 29, 2024
1 parent cbac5e1 commit bea553b
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,10 @@
KD trainer
"""

from typing import Optional

import torch

from axolotl.core.trainers.base import AxolotlTrainer


def kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch: Optional[int] = None,
kd_temperature: float = 1.0,
):
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding

# Determine the teacher sequence length
teacher_seq_len = target_token_ids.shape[1]

# Slice student logits to match the teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]

# Gather student logits for teacher's top-K tokens
# shape -> [B, teacher_seq_len, K]
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)

# Apply KD temperature to student’s logits:
# z_s(T) = z_s / T
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature

# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, seq_len, K]

# Convert teacher_mask to boolean for indexing
valid_mask = target_mask.bool()

# Prune tensors to only keep valid tokens
# This will result in 1D arrays of only valid positions
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]

# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
teacher_probs = target_logprobs.exp()

# Compute forward KL:
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()

# 9) Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)

# Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None:
# If you know how many items should be considered in the batch
kd_loss = kd_loss / num_items_in_batch
else:
# Otherwise, just average over all valid tokens
kd_loss = kd_loss / kd_loss_per_token.size(0)

return kd_loss
from axolotl.core.trainers.kd.topk_logprob.forward_kl import loss as topk_kd_loss


class AxolotlKDTrainer(AxolotlTrainer):
Expand Down Expand Up @@ -176,7 +110,7 @@ def compute_loss(
shift_target_token_ids = target_token_ids[..., 1:, :].contiguous()
shift_target_mask = target_mask[..., 1:, :].contiguous()

loss_kd = kd_loss_function(
loss_kd = topk_kd_loss(
shift_logits,
shift_target_token_ids,
shift_target_logprobs,
Expand Down
Empty file.
72 changes: 72 additions & 0 deletions src/axolotl/core/trainers/kd/topk_logprob/forward_kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
loss for top_k KL divergence
"""
from typing import Optional

import torch


def loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch: Optional[int] = None,
kd_temperature: float = 1.0,
):
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding

# Determine the teacher sequence length
# _, teacher_seq_len, top_k = target_token_ids.shape
teacher_seq_len = target_token_ids.shape[1]

# Slice student logits to match the teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]

# Gather student logits for teacher's top-K tokens
# shape -> [B, teacher_seq_len, K]
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)

# Apply KD temperature to student’s logits:
# z_s(T) = z_s / T
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature

# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, seq_len, K]

# Convert teacher_mask to boolean for indexing
valid_mask = target_mask.bool()

# Prune tensors to only keep valid tokens
# This will result in 1D arrays of only valid positions
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]

# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
teacher_probs = target_logprobs.exp()

# Compute forward KL:
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()

# 9) Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)

# Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None:
# If you know how many items should be considered in the batch
kd_loss = kd_loss / num_items_in_batch
else:
# Otherwise, just average over all valid tokens
kd_loss = kd_loss / kd_loss_per_token.size(0)

return kd_loss

0 comments on commit bea553b

Please sign in to comment.