Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 28, 2024
1 parent b246edd commit cbac5e1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 91 deletions.
111 changes: 53 additions & 58 deletions src/axolotl/core/trainers/kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch

from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.kd.kernels.kd import kd_loss_triton


def kd_loss_function(
Expand Down Expand Up @@ -93,59 +92,59 @@ def _set_signature_columns_if_needed(self):
if columns_to_add:
self._signature_columns += columns_to_add

def compute_loss_w_triton(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")

if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)

student_logits = outputs["logits"]
# Slice or gather student logits to match teacher seq len
# e.g.:
teacher_seq_len = target_token_ids.shape[1]
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab_size]

# GATHER top-K from student
student_logits_topk = torch.gather(
student_logits_for_kd,
dim=-1,
index=target_token_ids, # same shape [B, seq_len, K]
)

# Now call the Triton-based KD loss
kd_sum = kd_loss_triton(
student_logits_topk,
target_logprobs, # teacher logprobs [B, seq_len, K]
target_mask, # mask [B, seq_len, K]
)

# Normalize however you want
if num_items_in_batch is not None:
loss_kd = kd_sum / num_items_in_batch
else:
# or do e.g. average over valid tokens
# quick example:
total_valid = target_mask.sum()
loss_kd = kd_sum / (total_valid + 1e-8)

# optionally combine with CE loss
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd

return (loss, outputs) if return_outputs else loss
# def compute_loss_w_triton(
# self, model, inputs, return_outputs=False, num_items_in_batch=None
# ):
# target_logprobs = inputs.pop("target_logprobs")
# target_token_ids = inputs.pop("target_token_ids")
# target_mask = inputs.pop("target_mask")
#
# if self.model_accepts_loss_kwargs:
# loss_kwargs = {}
# if num_items_in_batch is not None:
# loss_kwargs["num_items_in_batch"] = num_items_in_batch
# inputs = {**inputs, **loss_kwargs}
# outputs = model(**inputs)
#
# student_logits = outputs["logits"]
# # Slice or gather student logits to match teacher seq len
# # e.g.:
# teacher_seq_len = target_token_ids.shape[1]
# student_logits_for_kd = student_logits[
# :, :teacher_seq_len, :
# ] # [B, seq_len, vocab_size]
#
# # GATHER top-K from student
# student_logits_topk = torch.gather(
# student_logits_for_kd,
# dim=-1,
# index=target_token_ids, # same shape [B, seq_len, K]
# )
#
# # Now call the Triton-based KD loss
# kd_sum = kd_loss_triton(
# student_logits_topk,
# target_logprobs, # teacher logprobs [B, seq_len, K]
# target_mask, # mask [B, seq_len, K]
# )
#
# # Normalize however you want
# if num_items_in_batch is not None:
# loss_kd = kd_sum / num_items_in_batch
# else:
# # or do e.g. average over valid tokens
# # quick example:
# total_valid = target_mask.sum()
# loss_kd = kd_sum / (total_valid + 1e-8)
#
# # optionally combine with CE loss
# if self.args.kd_ce_alpha > 0:
# kd_alpha = self.args.kd_alpha
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
# else:
# loss = loss_kd
#
# return (loss, outputs) if return_outputs else loss

def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
Expand All @@ -156,10 +155,6 @@ def compute_loss(
Subclass and override for custom behavior.
"""

# return self.compute_loss_w_triton(
# model, inputs, return_outputs, num_items_in_batch
# )

target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
Expand Down
58 changes: 25 additions & 33 deletions src/axolotl/integrations/kd/kernels/kd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Triton kernel for optimized kl divergence loss
"""

import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -37,10 +41,10 @@ def kd_forward_kernel(
mask_ptr: tl.tensor,
# partial_kd: [B*seq_len] flattened buffer to store partial sums
partial_kd_ptr: tl.tensor,
B: tl.int32,
B: tl.int32, # pylint: disable=invalid-name
seq_len: tl.int32,
K: tl.int32,
BLOCK_SIZE: tl.constexpr,
K: tl.int32, # pylint: disable=invalid-name
BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name
):
"""
For each position in [0..B*seq_len), we:
Expand Down Expand Up @@ -82,11 +86,7 @@ def kd_forward_kernel(

# load student logits, masked out-of-bounds with a large negative
# so they don't affect the max
student_val = tl.where(
mask_pos,
tl.load(student_logits_ptr + offset_k),
-1e30
)
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
# update running max
max_val = tl.where(student_val > max_val, student_val, max_val)

Expand All @@ -96,11 +96,7 @@ def kd_forward_kernel(
exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for k in range(K):
offset_k = b_idx * (seq_len * K) + s_idx * K + k
student_val = tl.where(
mask_pos,
tl.load(student_logits_ptr + offset_k),
-1e30
)
student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
# exponent
exponent = tl.exp(student_val - max_val)
exp_sum += exponent
Expand All @@ -119,20 +115,12 @@ def kd_forward_kernel(
for k in range(K):
offset_k = b_idx * (seq_len * K) + s_idx * K + k
# teacher logprobs
t_log = tl.where(
mask_pos,
tl.load(teacher_logprobs_ptr + offset_k),
-1e30
)
t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30)
# teacher prob
t_prob = tl.exp(t_log)

# student logit
s_val = tl.where(
mask_pos,
tl.load(student_logits_ptr + offset_k),
-1e30
)
s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30)
# student logprob
s_logprob = s_val - logsumexp_val

Expand All @@ -142,7 +130,7 @@ def kd_forward_kernel(
# also read mask to disable invalid tokens if mask is not purely sequence-based
valid_k = tl.load(mask_ptr + offset_k)
# if mask is bool => use 'valid_k != 0', if it's 0/1 => same
is_valid = (valid_k > 0)
is_valid = valid_k > 0

# zero out if either this index is out-of-bounds or mask is invalid
kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0)
Expand All @@ -158,17 +146,17 @@ def kd_forward_kernel(


def kd_forward_pass_triton(
student_logits, # [B, seq_len, K] (already gathered)
student_logits, # [B, seq_len, K] (already gathered)
teacher_logprobs, # [B, seq_len, K]
mask, # [B, seq_len, K] bool or 0/1
BLOCK_SIZE=1024,
mask, # [B, seq_len, K] bool or 0/1
BLOCK_SIZE=1024, # pylint: disable=invalid-name
):
"""
Returns total KL (float). We do the sum on the Python side.
NOTE: No normalization is done here.
You might divide by `num_items_in_batch` or # valid tokens afterward.
"""
B, seq_len, K = student_logits.shape
B, seq_len, K = student_logits.shape # pylint: disable=invalid-name
# Flatten
student_logits_flat = student_logits.reshape(-1)
teacher_logprobs_flat = teacher_logprobs.reshape(-1)
Expand All @@ -188,14 +176,17 @@ def kd_forward_pass_triton(
teacher_logprobs_flat,
mask_flat,
partial_kd,
B, seq_len, K,
BLOCK_SIZE=BLOCK_SIZE
B,
seq_len,
K,
BLOCK_SIZE=BLOCK_SIZE,
)

# Sum on CPU or GPU
kd_sum = partial_kd.sum()
return kd_sum


class _KLDivergenceTritonFn(torch.autograd.Function):
@staticmethod
def forward(ctx, student_logits, teacher_logprobs, mask):
Expand All @@ -211,7 +202,6 @@ def forward(ctx, student_logits, teacher_logprobs, mask):
ctx.save_for_backward(student_logits, teacher_logprobs, mask)
return kd_loss


@staticmethod
def backward(ctx, grad_output):
# We'll do naive PyTorch re-computation for gradient wrt student_logits
Expand Down Expand Up @@ -244,7 +234,7 @@ def kd_loss_triton(
student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K
teacher_logprobs,
mask,
num_items_in_batch=None,
num_items_in_batch=None, # pylint: disable=unused-argument
):
"""
Wrapper that calls our Triton-based forward+backward for KD.
Expand All @@ -253,5 +243,7 @@ def kd_loss_triton(
called gather on student_logits -> shape [B, seq_len, K].
"""
return _KLDivergenceTritonFn.apply(
student_logits, teacher_logprobs, mask, # num_items_in_batch
student_logits,
teacher_logprobs,
mask, # num_items_in_batch
)

0 comments on commit cbac5e1

Please sign in to comment.