Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Knowledge Distillation Base #417

Closed
wants to merge 12 commits into from

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Dec 2, 2024

Summary

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment.

Code Changes

  1. Refactor beta into two weights: weight_hard_loss and weight_soft_loss, as coefficients between hard_loss and soft_loss. @Tcc0403 also pointed out that we could use torch.lerp if applicable.

  2. Pass teacher_logits and student_logits directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the student_log_probs value calculated during ce_loss in distillation base.

    1. Remove the unnecessary get_batch_logps in test/utils.py.
  3. Modify chunking dimensions from B to B * T. Thanks to @hongpeng-guo's great advice.

    1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension.
  4. Normalize the distillation_loss using (full_target != ignore_index).sum().

TODO

  1. Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.

    The issue arises because we are not properly configuring the chunk_size for the B * T dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks.

    In contrast, this problem does not occur with the preference loss, as chunking is performed on the B dimension. This produces fewer than 10 chunks, which is efficient and works as expected.

    In conclusion, I set chunk_size to 1024 works pretty well in new benchmark results as shown in Add JSD Loss for Distillation #425

Knowledge Distillation

Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.

In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let z_t and z_s represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature T. When ground truth labels y are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.

The combined loss function is defined as:

$$\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),$$

Here, we directly pass in logits rather than logpbs. @Tcc0403

Shared DistillationBase

To support various distillation learning objectives, this PR aims to add a LigerFusedLinearDistillationBase which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.

Testing Done

I'll post JSD tests and benchmarks results in next PR: #425

  • Hardware Type: L40S
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@austin362667 austin362667 changed the title Feat/distill base Introduce Knowledge Distillation Base Dec 2, 2024
@austin362667 austin362667 force-pushed the feat/distill_base branch 2 times, most recently from 5257d26 to 3a9f125 Compare December 4, 2024 15:55
@austin362667 austin362667 mentioned this pull request Dec 4, 2024
3 tasks
hard_loss,
) = forward_output

soft_loss = self.distillation_loss(student_logits, teacher_logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the method use logprobs : def distillation_loss(self, student_logps, teacher_logps): but you use logits here.

I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.

Copy link
Collaborator Author

@austin362667 austin362667 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the method use logprobs : def distillation_loss(self, student_logps, teacher_logps): but you use logits here.

@winglian Nice catch! Thank you so much.

I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.

Sure, I think it's doable. And, I'm not quite sure I fully understand the need for logprobs implementation. Mind elaborate more on the vLLM use case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That makes a lot sense to me. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian curious if vllm/sglang support temperature scaled logprobs. This would be needed to enable https://github.com/huggingface/trl/blob/9c5388b69e0842f76edc46a2ff9d0b51e1db4337/trl/trainer/gkd_trainer.py#L174

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can address this ask in a subsequent PR
@ByronHsu what do you think?

@austin362667
Copy link
Collaborator Author

austin362667 commented Dec 6, 2024

Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.

The issue arises because we are not properly configuring the chunk_size for the B * T dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks.

In contrast, this problem does not occur with the preference loss, as chunking is performed on the B dimension. This produces fewer than 10 chunks, which is efficient and works as expected.

In conclusion, I set chunk_size to 1024 works pretty well in new benchmark results as shown in #425

austin362667 and others added 9 commits December 7, 2024 00:03
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Set default `chunk_size` to `1024`

Signed-off-by: Austin Liu <[email protected]>

Rebase

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Copy link
Collaborator

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@austin362667 FWIW, to run the Modal GPU CIs, this PR needs to be made from the main repo, i.e., linkedin/Liger-Kernel, instead of the forked repo.
A similar example is: I closed #399 and moved to #400 to enable the CI pipeline.

Copy link
Collaborator

@shivam15s shivam15s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@shivam15s shivam15s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you create another PR in linkedin? Some tests fail for me locally so I'd like to confirm before merging

@austin362667
Copy link
Collaborator Author

@shivam15s Certainly, right here #432 Thanks a lot

@austin362667
Copy link
Collaborator Author

Move discussion to #432

ByronHsu pushed a commit that referenced this pull request Dec 9, 2024
## Summary

Made #417 from the main
repo.

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR
is the s first split from
#408, focusing solely on
introducing the Knowledge Distillation base class. As a result, this PR
does not include any tests at the moment.

#### Code Changes

1. Refactor `beta` into two weights: `weight_hard_loss` and
`weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`.
@Tcc0403 also pointed out that we could use `torch.lerp` if applicable.

2. Pass `teacher_logits` and `student_logits` directly to the divergence
loss function. This avoids redundant computations of converting logits
to log probabilities and then reverting them to raw logits. However note
that we are not reusing the `student_log_probs` value calculated during
`ce_loss` in distillation base.

    1. Remove the unnecessary `get_batch_logps` in `test/utils.py`.

3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to
@hongpeng-guo's great advice.
1. Fix the loss calculation to use per-token values instead of averaging
across the sequence length dimension.

4. Normalize the `distillation_loss` using `(full_target !=
ignore_index).sum()`.

#### TODO  

1. [X] Although a slightly slowdown is reasonable, we need to
investigate why this PR's implementation is **significantly slower**
compared to the naive approach. Thanks to @Tcc0403 's clarification.
    
The issue arises because we are not properly configuring the
`chunk_size` for the `B * T` dimension, which is extremely large (a few
thousand). The previous default of 1 results in an excessive number of
chunks.

In contrast, this problem does not occur with the preference loss, as
chunking is performed on the `B` dimension. This produces fewer than 10
chunks, which is efficient and works as expected.

In conclusion, I set `chunk_size` to `1024` works pretty well in new
benchmark results as shown in
#425

2. [ ]
#417 (comment)

#### Knowledge Distillation

Knowledge Distillation (KD; [Hinton et al.
2015](https://arxiv.org/abs/1503.02531), [Gou et al.
2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to
build a smaller, cheaper model (“student model”) to speed up inference
by transferring skills from a pre-trained expensive model (“teacher
model”) into the student.

In knowledge distillation, a student model is trained to replicate the
outputs of a teacher model using a distillation loss. Neural networks
typically include a softmax layer; for instance, a large language model
produces a probability distribution over tokens. Let `z_t` and `z_s`
represent the logits before the softmax layer for the teacher and
student models, respectively. The distillation loss reduces the
discrepancy between the two softmax outputs at a high temperature `T`.
When ground truth labels `y` are available, this approach can be
combined with a supervised learning objective, such as cross-entropy, to
compare the student’s outputs with the ground truth.

The combined loss function is defined as:

```math
\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),
``` 

Here,  we directly pass in `logits` rather than `logpbs`. @Tcc0403 

#### Shared `DistillationBase`

To support various distillation learning objectives, this PR aims to add
a `LigerFusedLinearDistillationBase` which is basically same as propose
by @hongpeng-guo within this discussion
#371 (comment).
Thank you @hongpeng-guo for thinking through this.

## Testing Done

I'll post JSD tests and benchmarks results in next PR:
#425

- Hardware Type: L40S
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants