diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index 6473f26f0..c12cee77a 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -7,7 +7,6 @@ import torch from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.integrations.kd.kernels.kd import kd_loss_triton def kd_loss_function( @@ -93,59 +92,59 @@ def _set_signature_columns_if_needed(self): if columns_to_add: self._signature_columns += columns_to_add - def compute_loss_w_triton( - self, model, inputs, return_outputs=False, num_items_in_batch=None - ): - target_logprobs = inputs.pop("target_logprobs") - target_token_ids = inputs.pop("target_token_ids") - target_mask = inputs.pop("target_mask") - - if self.model_accepts_loss_kwargs: - loss_kwargs = {} - if num_items_in_batch is not None: - loss_kwargs["num_items_in_batch"] = num_items_in_batch - inputs = {**inputs, **loss_kwargs} - outputs = model(**inputs) - - student_logits = outputs["logits"] - # Slice or gather student logits to match teacher seq len - # e.g.: - teacher_seq_len = target_token_ids.shape[1] - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, seq_len, vocab_size] - - # GATHER top-K from student - student_logits_topk = torch.gather( - student_logits_for_kd, - dim=-1, - index=target_token_ids, # same shape [B, seq_len, K] - ) - - # Now call the Triton-based KD loss - kd_sum = kd_loss_triton( - student_logits_topk, - target_logprobs, # teacher logprobs [B, seq_len, K] - target_mask, # mask [B, seq_len, K] - ) - - # Normalize however you want - if num_items_in_batch is not None: - loss_kd = kd_sum / num_items_in_batch - else: - # or do e.g. average over valid tokens - # quick example: - total_valid = target_mask.sum() - loss_kd = kd_sum / (total_valid + 1e-8) - - # optionally combine with CE loss - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - else: - loss = loss_kd - - return (loss, outputs) if return_outputs else loss + # def compute_loss_w_triton( + # self, model, inputs, return_outputs=False, num_items_in_batch=None + # ): + # target_logprobs = inputs.pop("target_logprobs") + # target_token_ids = inputs.pop("target_token_ids") + # target_mask = inputs.pop("target_mask") + # + # if self.model_accepts_loss_kwargs: + # loss_kwargs = {} + # if num_items_in_batch is not None: + # loss_kwargs["num_items_in_batch"] = num_items_in_batch + # inputs = {**inputs, **loss_kwargs} + # outputs = model(**inputs) + # + # student_logits = outputs["logits"] + # # Slice or gather student logits to match teacher seq len + # # e.g.: + # teacher_seq_len = target_token_ids.shape[1] + # student_logits_for_kd = student_logits[ + # :, :teacher_seq_len, : + # ] # [B, seq_len, vocab_size] + # + # # GATHER top-K from student + # student_logits_topk = torch.gather( + # student_logits_for_kd, + # dim=-1, + # index=target_token_ids, # same shape [B, seq_len, K] + # ) + # + # # Now call the Triton-based KD loss + # kd_sum = kd_loss_triton( + # student_logits_topk, + # target_logprobs, # teacher logprobs [B, seq_len, K] + # target_mask, # mask [B, seq_len, K] + # ) + # + # # Normalize however you want + # if num_items_in_batch is not None: + # loss_kd = kd_sum / num_items_in_batch + # else: + # # or do e.g. average over valid tokens + # # quick example: + # total_valid = target_mask.sum() + # loss_kd = kd_sum / (total_valid + 1e-8) + # + # # optionally combine with CE loss + # if self.args.kd_ce_alpha > 0: + # kd_alpha = self.args.kd_alpha + # loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd + # else: + # loss = loss_kd + # + # return (loss, outputs) if return_outputs else loss def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None @@ -156,10 +155,6 @@ def compute_loss( Subclass and override for custom behavior. """ - # return self.compute_loss_w_triton( - # model, inputs, return_outputs, num_items_in_batch - # ) - target_logprobs = inputs.pop("target_logprobs") target_token_ids = inputs.pop("target_token_ids") target_mask = inputs.pop("target_mask") diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index 26cf76aac..28a3c8e67 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -1,3 +1,7 @@ +""" +Triton kernel for optimized kl divergence loss +""" + import torch import triton import triton.language as tl @@ -37,10 +41,10 @@ def kd_forward_kernel( mask_ptr: tl.tensor, # partial_kd: [B*seq_len] flattened buffer to store partial sums partial_kd_ptr: tl.tensor, - B: tl.int32, + B: tl.int32, # pylint: disable=invalid-name seq_len: tl.int32, - K: tl.int32, - BLOCK_SIZE: tl.constexpr, + K: tl.int32, # pylint: disable=invalid-name + BLOCK_SIZE: tl.constexpr, # pylint: disable=invalid-name ): """ For each position in [0..B*seq_len), we: @@ -82,11 +86,7 @@ def kd_forward_kernel( # load student logits, masked out-of-bounds with a large negative # so they don't affect the max - student_val = tl.where( - mask_pos, - tl.load(student_logits_ptr + offset_k), - -1e30 - ) + student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30) # update running max max_val = tl.where(student_val > max_val, student_val, max_val) @@ -96,11 +96,7 @@ def kd_forward_kernel( exp_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for k in range(K): offset_k = b_idx * (seq_len * K) + s_idx * K + k - student_val = tl.where( - mask_pos, - tl.load(student_logits_ptr + offset_k), - -1e30 - ) + student_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30) # exponent exponent = tl.exp(student_val - max_val) exp_sum += exponent @@ -119,20 +115,12 @@ def kd_forward_kernel( for k in range(K): offset_k = b_idx * (seq_len * K) + s_idx * K + k # teacher logprobs - t_log = tl.where( - mask_pos, - tl.load(teacher_logprobs_ptr + offset_k), - -1e30 - ) + t_log = tl.where(mask_pos, tl.load(teacher_logprobs_ptr + offset_k), -1e30) # teacher prob t_prob = tl.exp(t_log) # student logit - s_val = tl.where( - mask_pos, - tl.load(student_logits_ptr + offset_k), - -1e30 - ) + s_val = tl.where(mask_pos, tl.load(student_logits_ptr + offset_k), -1e30) # student logprob s_logprob = s_val - logsumexp_val @@ -142,7 +130,7 @@ def kd_forward_kernel( # also read mask to disable invalid tokens if mask is not purely sequence-based valid_k = tl.load(mask_ptr + offset_k) # if mask is bool => use 'valid_k != 0', if it's 0/1 => same - is_valid = (valid_k > 0) + is_valid = valid_k > 0 # zero out if either this index is out-of-bounds or mask is invalid kl_val = tl.where(mask_pos & is_valid, kl_val, 0.0) @@ -158,17 +146,17 @@ def kd_forward_kernel( def kd_forward_pass_triton( - student_logits, # [B, seq_len, K] (already gathered) + student_logits, # [B, seq_len, K] (already gathered) teacher_logprobs, # [B, seq_len, K] - mask, # [B, seq_len, K] bool or 0/1 - BLOCK_SIZE=1024, + mask, # [B, seq_len, K] bool or 0/1 + BLOCK_SIZE=1024, # pylint: disable=invalid-name ): """ Returns total KL (float). We do the sum on the Python side. NOTE: No normalization is done here. You might divide by `num_items_in_batch` or # valid tokens afterward. """ - B, seq_len, K = student_logits.shape + B, seq_len, K = student_logits.shape # pylint: disable=invalid-name # Flatten student_logits_flat = student_logits.reshape(-1) teacher_logprobs_flat = teacher_logprobs.reshape(-1) @@ -188,14 +176,17 @@ def kd_forward_pass_triton( teacher_logprobs_flat, mask_flat, partial_kd, - B, seq_len, K, - BLOCK_SIZE=BLOCK_SIZE + B, + seq_len, + K, + BLOCK_SIZE=BLOCK_SIZE, ) # Sum on CPU or GPU kd_sum = partial_kd.sum() return kd_sum + class _KLDivergenceTritonFn(torch.autograd.Function): @staticmethod def forward(ctx, student_logits, teacher_logprobs, mask): @@ -211,7 +202,6 @@ def forward(ctx, student_logits, teacher_logprobs, mask): ctx.save_for_backward(student_logits, teacher_logprobs, mask) return kd_loss - @staticmethod def backward(ctx, grad_output): # We'll do naive PyTorch re-computation for gradient wrt student_logits @@ -244,7 +234,7 @@ def kd_loss_triton( student_logits, # [B, teacher_seq_len, vocab_size], but typically we gather for top-K teacher_logprobs, mask, - num_items_in_batch=None, + num_items_in_batch=None, # pylint: disable=unused-argument ): """ Wrapper that calls our Triton-based forward+backward for KD. @@ -253,5 +243,7 @@ def kd_loss_triton( called gather on student_logits -> shape [B, seq_len, K]. """ return _KLDivergenceTritonFn.apply( - student_logits, teacher_logprobs, mask, # num_items_in_batch + student_logits, + teacher_logprobs, + mask, # num_items_in_batch )