Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Distillation with a Chunked, Fused Linear JS-divergence Loss #408

Closed
wants to merge 17 commits into from

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Nov 27, 2024

Summary

Knowledge Distillation

Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.

In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let z_t and z_s represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature T. When ground truth labels y are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.

The combined loss function is defined as:

$$\mathcal{L} = \mathcal{L}_{\text{distill}}(\text{softmax}(\mathbf{z_t}, T), \text{softmax}(\mathbf{z_s}, T)) + \lambda \mathcal{L}_{CE}(\mathbf{y}, \mathbf{z_s}),$$

Here, lambda is a hyperparameter that balances the distillation loss and the supervised objective.

Shared DistillationBase

To support various distillation learning objectives, this PR aims to add a LigerFusedLinearDistillationBase which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.

Jensen-Shannon Divergence Loss

In addition to adding the base class, this PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in the distillation setting. This component can be replaced with other losses (e.g., KL divergence) as distillation_loss_fn.

JSD is defined as the average of the KL divergences between each distribution and the mean distribution:

$$\text{JSD}(P || Q) = \frac{1}{2} \text{KL}(P || M) + \frac{1}{2} \text{KL}(Q || M), \quad \text{where } M = \frac{1}{2}(P + Q)$$

Here, Pand Q are the two probability distributions, and M is their average.

TODO

[EDIT] I found the speed in Triton-based, fused-linear JSD loss is also much slower than the naive torch JSD impl in forward pass.

  • Investigate why the speed of chunked implementation is so slow compared to the naive approach.
  • Integrate temperature scaling.

Testing Done

Yes.

> modal run dev.modal.tests
========== 656 passed, 215 skipped, 58 warnings in 233.39s (0:03:53) ===========

jsd_loss_memory

jsd_loss_speed

  • Hardware Type: A100 40G
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Austin Liu <[email protected]>

Add Testing Naive Distillation Base

Signed-off-by: Austin Liu <[email protected]>

Add Chunked JSD Tests and Benchmarks

Signed-off-by: Austin Liu <[email protected]>

Fix call

Signed-off-by: Austin Liu <[email protected]>

Fix Test Usage

Signed-off-by: Austin Liu <[email protected]>

Remove beta

Signed-off-by: Austin Liu <[email protected]>

Fix test params

Signed-off-by: Austin Liu <[email protected]>

Fix call

Signed-off-by: Austin Liu <[email protected]>

Fix ignore_index

Signed-off-by: Austin Liu <[email protected]>

Fix weights dimension

Signed-off-by: Austin Liu <[email protected]>

Fix assign dimension

Signed-off-by: Austin Liu <[email protected]>

Fix assign dimension

Signed-off-by: Austin Liu <[email protected]>

Fix teacher bias

Signed-off-by: Austin Liu <[email protected]>

Reshape input

Signed-off-by: Austin Liu <[email protected]>

Fix mean

Signed-off-by: Austin Liu <[email protected]>

Remove alpha

Signed-off-by: Austin Liu <[email protected]>

Fix t

Signed-off-by: Austin Liu <[email protected]>

Fix t

Signed-off-by: Austin Liu <[email protected]>

Fix t scaling

Signed-off-by: Austin Liu <[email protected]>

Remove teacher tests

Signed-off-by: Austin Liu <[email protected]>

Fix t scaling

Signed-off-by: Austin Liu <[email protected]>

Fix beta

Signed-off-by: Austin Liu <[email protected]>

Fix beta

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Format

Signed-off-by: Austin Liu <[email protected]>

Fix

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Fix tol

Signed-off-by: Austin Liu <[email protected]>

Fix tol

Signed-off-by: Austin Liu <[email protected]>

Fix tol

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
@austin362667 austin362667 changed the title Add Support for Knowledge Distillation with a chunked, fused linear JS-divergence Loss Add Distillation with a Chunked, Fused Linear JS-divergence Loss Nov 27, 2024
@austin362667 austin362667 changed the title Add Distillation with a Chunked, Fused Linear JS-divergence Loss Introduce Distillation with a Chunked, Fused Linear JS-divergence Loss Nov 28, 2024
@austin362667 austin362667 marked this pull request as ready for review November 28, 2024 03:22
if valid_mask.any():
student_average_log_prob[valid_mask] = (
student_per_token_logps * loss_mask
).sum(-1)[valid_mask] / loss_mask_sum[valid_mask]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite understand what loss_mask_sum and valid_mask do. Is it just a way to avoid ZeroDivisionError?

loss_fn=None,
chunk_size=1,
ignore_index=-100,
beta=0.5,
Copy link
Collaborator

@Tcc0403 Tcc0403 Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we need another variable name for the weight between soft and hard loss, since some loss functions have 'beta' parameter, such as generalized jsd we've implemented in #278.

Since lambda is a reserved keyword, maybe weight_hard_loss and weight_soft_loss?
If sum of both weights is 1, you can just pick one of them and also consider torch.lerp() for combining 2 losses

labels.view(-1),
)

student_logps = self.get_batch_logps(
Copy link
Collaborator

@Tcc0403 Tcc0403 Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to calculate the probability per token for knowledge distillation? I might be wrong but don't we just pass teacher_logits and student_logits directly to divergence loss function, such as kldiv (normally with reduction="batchmean") or jsd?

Copy link
Collaborator Author

@austin362667 austin362667 Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tcc0403 Thank you for the review!!

Actually, you're right—I’m aware of that. I was just trying to align the interface with the preference-based design and reuse the value of student_log_probs calculated during ce_loss in DistillBase. However, if it's not necessary to maintain the same interface, I prefer your suggestion.

As shown in the distillation calculation function in this PR, it essentially undoes the operations. This redundant computation could be avoided by directly passing the raw logits to the divergence function, instead of first converting them to log probabilities and then reversing them back to the original values.

label_chunk = torch.where(loss_mask, target_chunk, 0)

student_average_log_prob = torch.zeros_like(loss_mask, dtype=torch.float)
student_per_token_logps = student_log_probs_chunk.gather(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Dec 1, 2024

[EDIT] I found the speed in Triton-based, fused-linear JSD loss is also much slower than the naive torch JSD impl in forward pass.

If you're referring to current LigerFusedLinearJSD, there're some benchmark in #300. When comparing forward pass only, fljsd kernel is supposed to be slower since it does gradient calculations in forward pass as well, and it isn't purely written in triton so it might also suffer from kernel launching overhead. But it's true that it doesn't perform well in low BT scenario.

Comment on lines +15 to +16
student_logps (torch.Tensor): Avg log probabilities of student inputs. Shape: (batch_size, hidden_size,).
teacher_logps (torch.Tensor): Avg log probabilities of teacher inputs. Shape: (batch_size, hidden_size,).
Copy link
Collaborator

@hongpeng-guo hongpeng-guo Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the general distillation loss, the student and teacher logps should be per-token instead of being averaged in the sequence length dimension. I.e., both tensors should be of shape (bathc_size, sequence_size, vocab_size) or (flattended_batch_sequence_size, vocab_size).

Comment on lines +272 to +275
distillation_loss = distillation_loss_fn(
student_logps, teacher_logps, temperature
)
distillation_loss = distillation_loss / (full_target.shape[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we made the distillation loss per_token, we may normalize the distillation_loss with full_target != ignore_index).sum similar to the ce_loss.

Copy link
Collaborator

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for drafting the distillation base class, left some comments on the fused_linear_distillation.py, mainly discussing the loss should be computed per token or averaged from the sequence level first.

Anther major question that I am having is on the chunking dimensions. Current implementation of this PR is just chunking from the batch_size dimension, which is similar to the implementation of fused_linear_preference.py. However, I think it would be better if we can chunk from the flattened dim[0] of (B*T, vocab_size), which is also the way of chunking described in the paper for CE_loss.

For preference_base class, I think the chunking only happens on the batch_size dimension because the sequence dimension is reduced when calculating the average logps (link). . But for distillation, we may prefer to follow the patten of CE loss to chunk on the joint dimension of B*T, so that this kernel can work for very long sequence/ context scenario. Happy to help refine this base class @austin362667

cc @shivam15s what do you think on this?

@hongpeng-guo
Copy link
Collaborator

hongpeng-guo commented Dec 1, 2024

@austin362667 nit: A side note is to split this PR into two stacked PRs: first for the distillation base class and second for the JSDloss based from it. We can prioritize to polish and merge the first PR so that other distillation losses can be based on it and it's non-blocking 😄

@austin362667
Copy link
Collaborator Author

@hongpeng-guo Thanks for review~

For preference_base class, I think the chunking only happens on the batch_size dimension because the sequence dimension is reduced when calculating the average logps (link). . But for distillation, we may prefer to follow the patten of CE loss to chunk on the joint dimension of B*T, so that this kernel can work for very long sequence/ context scenario. Happy to help refine this base class @austin362667

That makes perfect sense to me; I'll proceed with this approach.

@austin362667 nit: A side note is to split this PR into two stacked PRs: first for the distillation base class and second for the JSDloss based from it. We can prioritize to polish and merge the first PR so that other distillation losses can be based on it and it's non-blocking 😄

Absolutely! I'll split this into two separate PRs.

@austin362667
Copy link
Collaborator Author

austin362667 commented Dec 2, 2024

Thanks all nice comments! @Tcc0403 and @hongpeng-guo
Move discussion to #417

ByronHsu pushed a commit that referenced this pull request Dec 9, 2024
## Summary

Made #417 from the main
repo.

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR
is the s first split from
#408, focusing solely on
introducing the Knowledge Distillation base class. As a result, this PR
does not include any tests at the moment.

#### Code Changes

1. Refactor `beta` into two weights: `weight_hard_loss` and
`weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`.
@Tcc0403 also pointed out that we could use `torch.lerp` if applicable.

2. Pass `teacher_logits` and `student_logits` directly to the divergence
loss function. This avoids redundant computations of converting logits
to log probabilities and then reverting them to raw logits. However note
that we are not reusing the `student_log_probs` value calculated during
`ce_loss` in distillation base.

    1. Remove the unnecessary `get_batch_logps` in `test/utils.py`.

3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to
@hongpeng-guo's great advice.
1. Fix the loss calculation to use per-token values instead of averaging
across the sequence length dimension.

4. Normalize the `distillation_loss` using `(full_target !=
ignore_index).sum()`.

#### TODO  

1. [X] Although a slightly slowdown is reasonable, we need to
investigate why this PR's implementation is **significantly slower**
compared to the naive approach. Thanks to @Tcc0403 's clarification.
    
The issue arises because we are not properly configuring the
`chunk_size` for the `B * T` dimension, which is extremely large (a few
thousand). The previous default of 1 results in an excessive number of
chunks.

In contrast, this problem does not occur with the preference loss, as
chunking is performed on the `B` dimension. This produces fewer than 10
chunks, which is efficient and works as expected.

In conclusion, I set `chunk_size` to `1024` works pretty well in new
benchmark results as shown in
#425

2. [ ]
#417 (comment)

#### Knowledge Distillation

Knowledge Distillation (KD; [Hinton et al.
2015](https://arxiv.org/abs/1503.02531), [Gou et al.
2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to
build a smaller, cheaper model (“student model”) to speed up inference
by transferring skills from a pre-trained expensive model (“teacher
model”) into the student.

In knowledge distillation, a student model is trained to replicate the
outputs of a teacher model using a distillation loss. Neural networks
typically include a softmax layer; for instance, a large language model
produces a probability distribution over tokens. Let `z_t` and `z_s`
represent the logits before the softmax layer for the teacher and
student models, respectively. The distillation loss reduces the
discrepancy between the two softmax outputs at a high temperature `T`.
When ground truth labels `y` are available, this approach can be
combined with a supervised learning objective, such as cross-entropy, to
compare the student’s outputs with the ground truth.

The combined loss function is defined as:

```math
\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),
``` 

Here,  we directly pass in `logits` rather than `logpbs`. @Tcc0403 

#### Shared `DistillationBase`

To support various distillation learning objectives, this PR aims to add
a `LigerFusedLinearDistillationBase` which is basically same as propose
by @hongpeng-guo within this discussion
#371 (comment).
Thank you @hongpeng-guo for thinking through this.

## Testing Done

I'll post JSD tests and benchmarks results in next PR:
#425

- Hardware Type: L40S
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants