Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371

Open
5 of 12 tasks
shivam15s opened this issue Nov 8, 2024 · 21 comments
Open
5 of 12 tasks

[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371

shivam15s opened this issue Nov 8, 2024 · 21 comments
Assignees

Comments

@shivam15s
Copy link
Collaborator

shivam15s commented Nov 8, 2024

🚀 The feature, motivation and pitch

We want to support various alignment and distillation loss functions.
Refer this PR on ORPO: #362

Progress

Alignment

Distillation

  • KL divergence
  • cosine_similarity
  • earth_mover_distance
  • JSD
  • KVD

Design

Approach Overview:

The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:

  1. Modular Optimization Process:
    • Every alignment algorithm’s optimization can be broken into three key steps:
      • Linear layer computation
      • Loss computation
      • Gradient calculation
  2. Fused Linear and Loss Computation:
    • Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
  3. Chunking & Forward Optimization:
    • Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
    • We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
  4. Torch Compile for Kernel Optimization:
    • Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.

By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.

Key Findings

By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger.
References:

  1. Torch compiled FLCE is 2x faster than the current FLCE #227
  2. https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py

Interface

Have a base class FlexChunkLoss that handles chunking, accumulation and compiling strategies.
A custom loss class wraps the FlexChunkLoss and implements the loss fn that operates on a given chunk.

class Mycustomloss(FlexChunkLoss):
  def loss_fn(...):
    ..do something here

Alternatives

No response

Additional context

No response

@shivam15s shivam15s changed the title Liger FlexChunkLoss: Supporting various alignment (DPO, ORPO, IRPO, CPO, etc) and distillation (KL divergence, cosine_similarity, earth_mover_distance, etc) loss functions [RFC] Liger FlexChunkLoss: Supporting various alignment (DPO, ORPO, IRPO, CPO, etc) and distillation (KL divergence, cosine_similarity, earth_mover_distance, etc) loss functions Nov 8, 2024
@austin362667
Copy link
Collaborator

take DPO

@hongpeng-guo
Copy link
Collaborator

I can take fused linear kl div. BTW, really nice illustration on the chunk linear op fusion from the paper. Very clear to new contributors 😄

@pramodith
Copy link
Collaborator

pramodith commented Nov 13, 2024

@shivam15s @ByronHsu I think we should also consider including some of the loss functions commonly used for training embedding models, especially the popular ones supported in Sentence transformers.

It's quite common for embedding models to require large batch sizes to be trained well. Coupled with the fact that their batch/input structure is kind of similar to RLHF where we have positive and negative pairs, I believe that this can prove to be useful. I'd recommend supporting CoSENTLoss, MatryokshaLoss and TripleLoss for starters https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cosentloss. Perhaps this can be its own roadmap separate to this one although the idea of chunking and fusing remains the same.

@ByronHsu
Copy link
Collaborator

@pramodith that is a good idea! do you know if the models in embedding also has large vocab and suffer from memory bottleneck?

@pramodith
Copy link
Collaborator

@ByronHsu most embedding models have a final Linear layer of shape (hidden_dim, hidden_dim), so vocab size doesn't really come into the picture for them so you're right to point it out, but it is common to have an effective batch size of 65k

@ByronHsu
Copy link
Collaborator

Then i think chunk loss is still helpful given the large batch size

@pramodith
Copy link
Collaborator

Then i think chunk loss is still helpful given the large batch size

Yes, I think so too. I can give this a try after we wrap up all the important RLHF and distillation losses. I'll also get Tom Aarsen's perspective since he's the lead of Sentence Transformers.

@ByronHsu ByronHsu changed the title [RFC] Liger FlexChunkLoss: Supporting various alignment (DPO, ORPO, IRPO, CPO, etc) and distillation (KL divergence, cosine_similarity, earth_mover_distance, etc) loss functions [RFC] Liger FlexChunkLoss: Supporting various alignment and distillation loss functions Nov 15, 2024
@ByronHsu ByronHsu pinned this issue Nov 15, 2024
@ByronHsu ByronHsu changed the title [RFC] Liger FlexChunkLoss: Supporting various alignment and distillation loss functions [RFC] Liger FlexChunkLoss: Alignment and Distillation loss Nov 15, 2024
ByronHsu pushed a commit that referenced this issue Nov 15, 2024
## Summary

Add support for a fused, torch-compiled, and chunked DPO ([Direct
Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss
kernel, as requested in
#371.
This implementation is largely based on the excellent work done on ORPO
(#362) by @shivam15s.

### DPO Loss Formulation

In a reference setting (not reference free):

$$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) -
\log(\pi_\theta(y_r|x))$$

$$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) -
\log(\pi_{\theta_{\text{ref}}}(y_c|x)) +
\log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$

Corresponds to:
```python
# Policy model log probabilities
policy_chosen_logps = log_probs(policy_chosen_logits)
policy_rejected_logps = log_probs(policy_rejected_logits)

# Reference model log probabilities
ref_chosen_logps = log_probs(ref_chosen_logits)
ref_rejected_logps = log_probs(ref_rejected_logits)

# Compute advantages
chosen_advantages = policy_chosen_logps - ref_chosen_logps
rejected_advantages = policy_rejected_logps - ref_rejected_logps

# DPO loss
logits_diff = (chosen_advantages - rejected_advantages) / beta
losses = -F.logsigmoid(logits_diff)
```

In this PR:

1. The above mathematical equation shows that to maximize the reward
difference, we get formula:
    $$r_θ(x_c) - r_θ(x_r)$$
2. This can be further optimized using just:
    $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$
3. So, the code implements:
    ```python
logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) -
π_θ(x_r))/β
losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff))
    ```
4. Sum up DPO and NLL:
    $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$

## Testing Done


![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731)

![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6)


- Hardware Type: **NVIDIA L40S (48G)**
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
@pramodith
Copy link
Collaborator

#take Simpo and Irpo since they are just extensions of CPO.

@vulkomilev
Copy link

I will #take KTO as the next

@vulkomilev
Copy link

A little update on kto I am working now on the tests

@ByronHsu
Copy link
Collaborator

ByronHsu commented Nov 22, 2024

@Chillee FYI We are working on a set of post-training losses based on your compiled chunked loss implementation for CE. Thanks for the reference!

@vulkomilev
Copy link

Update on KTO loss I am done with the loss but I have problem with assertions.I am working on it.

@hongpeng-guo
Copy link
Collaborator

hongpeng-guo commented Nov 25, 2024

I was following this thread and working on a chunked, fused linear KL-divergence implementation for distillation use cases. Since distillation losses differ from preference losses, introducing a LigerFusedLinearDistillationBase parent class could be helpful.

In general, the distillation pipeline involves three key inputs: teacher_logits, student_logits, and ground_truth_label. The first two inputs are used to calculate the soft loss (KL divergence), while the latter two are used to compute the hard loss (cross-entropy). The final distillation loss is typically a weighted sum of these two components.

To leverage chunked, linear-fused optimizations, we could design the solution to accept inputs as teacher_tensor (BT, hidden_dim_teacher), student_tensor (BT, hidden_dim_student), and true_label (BT,). Using these inputs, we can apply the chunked, linear-fused approach to efficiently compute both the KL-divergence loss and the cross-entropy loss.

cc @ByronHsu, @shivam15s, @pramodith: What are your thoughts on this? Do you think it makes sense to include the cross-entropy loss as part of the DistillationBase class? Thanks for your feedback!

@pramodith
Copy link
Collaborator

@hongpeng-guo yes! I like your approach it's cleaner to create a new Base class for distillation losses, we're kind of doing the same for the Alignment losses to by computing the nll (cross-entropy loss of the accepted responses inside the Base class.)

.

@ByronHsu
Copy link
Collaborator

+1 on @hongpeng-guo proposal. @shivam15s can help polish the base class

@shivam15s
Copy link
Collaborator Author

Sounds good @hongpeng-guo, a separate base class for distillation is absolutely needed!

@vulkomilev
Copy link

please review and comment my PR on KTO here #410

@vulkomilev
Copy link

there is an update about #410

@ccdv-ai
Copy link

ccdv-ai commented Dec 8, 2024

Is CPO-SimPO planned? This can be implemented in SimPO.

Reference: https://github.com/fe1ixxu/CPO_SIMPO

Quote

CPO and SimPO share similar objectives but have different goals. CPO adds a BC-regularizer to prevent the model from deviating too much from the preferred data distribution.

$L_{CPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \beta \log \pi_{\theta}(y_w | x) - \beta \log \pi_{\theta}(y_l | x) \Big) + \log \pi_\theta(y_w| x)\Big]$

SimPO incorporates length normalization and target reward margin to improve model performance and prevent the generation of long but low-quality sequences:

$L_{SimPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \frac{\beta}{|y_w|} \log \pi_{\theta}(y_w | x) - \frac{\beta}{|y_l|} \log \pi_{\theta}(y_l | x) - \gamma \Big) \Big]$

These two objectives can be jointly used, which we call CPO-SimPO:

$L_{CPO-SimPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \frac{\beta}{|y_w|} \log \pi_{\theta}(y_w | x) - \frac{\beta}{|y_l|} \log \pi_{\theta}(y_l | x) - \gamma \Big)+ \alpha \log \pi_\theta(y_w| x)\Big]$

@pramodith
Copy link
Collaborator

@ccdv-ai I think this can be done via the existing set of hyperparams of setting compute_nll_loss=True and alpha for the BC regularizer and right now all our alignment loss functions do assume length normalization

average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

ByronHsu pushed a commit that referenced this issue Dec 9, 2024
## Summary

Made #417 from the main
repo.

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR
is the s first split from
#408, focusing solely on
introducing the Knowledge Distillation base class. As a result, this PR
does not include any tests at the moment.

#### Code Changes

1. Refactor `beta` into two weights: `weight_hard_loss` and
`weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`.
@Tcc0403 also pointed out that we could use `torch.lerp` if applicable.

2. Pass `teacher_logits` and `student_logits` directly to the divergence
loss function. This avoids redundant computations of converting logits
to log probabilities and then reverting them to raw logits. However note
that we are not reusing the `student_log_probs` value calculated during
`ce_loss` in distillation base.

    1. Remove the unnecessary `get_batch_logps` in `test/utils.py`.

3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to
@hongpeng-guo's great advice.
1. Fix the loss calculation to use per-token values instead of averaging
across the sequence length dimension.

4. Normalize the `distillation_loss` using `(full_target !=
ignore_index).sum()`.

#### TODO  

1. [X] Although a slightly slowdown is reasonable, we need to
investigate why this PR's implementation is **significantly slower**
compared to the naive approach. Thanks to @Tcc0403 's clarification.
    
The issue arises because we are not properly configuring the
`chunk_size` for the `B * T` dimension, which is extremely large (a few
thousand). The previous default of 1 results in an excessive number of
chunks.

In contrast, this problem does not occur with the preference loss, as
chunking is performed on the `B` dimension. This produces fewer than 10
chunks, which is efficient and works as expected.

In conclusion, I set `chunk_size` to `1024` works pretty well in new
benchmark results as shown in
#425

2. [ ]
#417 (comment)

#### Knowledge Distillation

Knowledge Distillation (KD; [Hinton et al.
2015](https://arxiv.org/abs/1503.02531), [Gou et al.
2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to
build a smaller, cheaper model (“student model”) to speed up inference
by transferring skills from a pre-trained expensive model (“teacher
model”) into the student.

In knowledge distillation, a student model is trained to replicate the
outputs of a teacher model using a distillation loss. Neural networks
typically include a softmax layer; for instance, a large language model
produces a probability distribution over tokens. Let `z_t` and `z_s`
represent the logits before the softmax layer for the teacher and
student models, respectively. The distillation loss reduces the
discrepancy between the two softmax outputs at a high temperature `T`.
When ground truth labels `y` are available, this approach can be
combined with a supervised learning objective, such as cross-entropy, to
compare the student’s outputs with the ground truth.

The combined loss function is defined as:

```math
\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),
``` 

Here,  we directly pass in `logits` rather than `logpbs`. @Tcc0403 

#### Shared `DistillationBase`

To support various distillation learning objectives, this PR aims to add
a `LigerFusedLinearDistillationBase` which is basically same as propose
by @hongpeng-guo within this discussion
#371 (comment).
Thank you @hongpeng-guo for thinking through this.

## Testing Done

I'll post JSD tests and benchmarks results in next PR:
#425

- Hardware Type: L40S
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
@vulkomilev
Copy link

there is an update about KTO on #410

@hebiao064 hebiao064 mentioned this issue Dec 13, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants