Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in chunking? #439

Open
cinjon opened this issue Dec 8, 2024 · 10 comments
Open

Bug in chunking? #439

cinjon opened this issue Dec 8, 2024 · 10 comments

Comments

@cinjon
Copy link

cinjon commented Dec 8, 2024

🐛 Describe the bug

I'm a bit confused by the chunked_loss implementation in src/liger_kernel/chunked_loss/fused_linear_preference.py. Namely, it seems more like a batched_loss than a chunk loss.

My expectation is that it will chunk on the tokens, a la https://pytorch.org/torchtune/0.3/generated/torchtune.modules.loss.CEWithChunkedOutputLoss.html. But it chunks on the batch instead by first separating the chosen from the rejected, then choosing chunks to be based on the batch dimension.

Is this intended?

Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len)
....

        len_chosen = target.shape[0] // 2
        chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
        _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
        _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
        _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
        _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

Reproduce

No response

Versions

Environment Report:

Operating System: Linux-6.1.100+-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.5.0+cu124
CUDA version: 12.4
Triton version: 3.1.0
Transformers version: 4.42.3

@kvignesh1420
Copy link
Collaborator

@cinjon there seems to be a confusion here.
The PyTorch reference you posted is for fused linear CE (FLCE), which is implemented here: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py

The src/liger_kernel/chunked_loss/fused_linear_preference.py is for preference tuning loss and is a bit different from FLCE in implementation.

@cinjon
Copy link
Author

cinjon commented Dec 9, 2024

I understand that it's a different purpose. My intention was just to show that chunking is different than what this is. This is more like just iterating through the batch, somewhat equivalent to a gradient accum afaict.

I would think that a chunking solution for preference would do something like use the chunks to get logps and then later sum them up in the final loss. That way, we don't have to materialize a full logsoftmax but instead can materialize a chunk of it only.

@kvignesh1420
Copy link
Collaborator

Got you. Yes, the "chunk" ing here is applied on the batch dimension for the preference loss and not on the sequence length dim. We are currently improving the API so these details will be added to the docs :)

I am not sure how easy/difficult it is to do a sequence-level chunking on top of this (to further reduce memory). Let me know if you would like to give this a try :)

Thanks again!

@cinjon
Copy link
Author

cinjon commented Dec 9, 2024

I think something like this would work, albeit it would have to be adjusted to work in the Liger approach: https://gist.github.com/cinjon/e30343a854521a8c95d71e257f51fa9c.

@cinjon
Copy link
Author

cinjon commented Dec 9, 2024

The above chunking approach saves memory, but does not save enough to run my model. I think I'll have to implement it in place a la the fused_linear_cross_entropy's use of liger_cross_entropy_kernel.

@cinjon
Copy link
Author

cinjon commented Dec 9, 2024

Do you have any guidance on the above? It's a healthy chunk of work.

@kvignesh1420
Copy link
Collaborator

kvignesh1420 commented Dec 9, 2024

@cinjon I went through your code and it seems doable since the hidden states for each token can be processed independently (similar to FLCE). Will discuss with @shivam15s and @ByronHsu to align this with our roadmap.
cc: @hebiao064

@cinjon
Copy link
Author

cinjon commented Dec 9, 2024

Cool, thanks for checking it out. Yeah, the main problem still persisting in the code I showed is that the gradient for the full [b, t, v] still has to be materialized in the backwards pass.

@kvignesh1420
Copy link
Collaborator

Hi @cinjon , just wanted to get an estimate of the memory consumption for your use-case. Can you share some details about the GPU that you are using so that we can plan accordingly.

@cinjon
Copy link
Author

cinjon commented Dec 10, 2024

We're using H100s and training models > Gemma9b. That one is fine with normal preference optimization, but once I go much bigger than that is when I run into problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants