-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug in chunking? #439
Comments
@cinjon there seems to be a confusion here. The |
I understand that it's a different purpose. My intention was just to show that chunking is different than what this is. This is more like just iterating through the batch, somewhat equivalent to a gradient accum afaict. I would think that a chunking solution for preference would do something like use the chunks to get logps and then later sum them up in the final loss. That way, we don't have to materialize a full logsoftmax but instead can materialize a chunk of it only. |
Got you. Yes, the "chunk" ing here is applied on the batch dimension for the preference loss and not on the sequence length dim. We are currently improving the API so these details will be added to the docs :) I am not sure how easy/difficult it is to do a sequence-level chunking on top of this (to further reduce memory). Let me know if you would like to give this a try :) Thanks again! |
I think something like this would work, albeit it would have to be adjusted to work in the Liger approach: https://gist.github.com/cinjon/e30343a854521a8c95d71e257f51fa9c. |
The above chunking approach saves memory, but does not save enough to run my model. I think I'll have to implement it in place a la the fused_linear_cross_entropy's use of |
Do you have any guidance on the above? It's a healthy chunk of work. |
@cinjon I went through your code and it seems doable since the hidden states for each token can be processed independently (similar to FLCE). Will discuss with @shivam15s and @ByronHsu to align this with our roadmap. |
Cool, thanks for checking it out. Yeah, the main problem still persisting in the code I showed is that the gradient for the full [b, t, v] still has to be materialized in the backwards pass. |
Hi @cinjon , just wanted to get an estimate of the memory consumption for your use-case. Can you share some details about the GPU that you are using so that we can plan accordingly. |
We're using H100s and training models > Gemma9b. That one is fine with normal preference optimization, but once I go much bigger than that is when I run into problems. |
🐛 Describe the bug
I'm a bit confused by the chunked_loss implementation in src/liger_kernel/chunked_loss/fused_linear_preference.py. Namely, it seems more like a batched_loss than a chunk loss.
My expectation is that it will chunk on the tokens, a la https://pytorch.org/torchtune/0.3/generated/torchtune.modules.loss.CEWithChunkedOutputLoss.html. But it chunks on the batch instead by first separating the chosen from the rejected, then choosing
chunks
to be based on the batch dimension.Is this intended?
Reproduce
No response
Versions
Environment Report:
Operating System: Linux-6.1.100+-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.5.0+cu124
CUDA version: 12.4
Triton version: 3.1.0
Transformers version: 4.42.3
The text was updated successfully, but these errors were encountered: