-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Chunked DPO Loss Kernel #378
Conversation
run_benchmarks, | ||
) | ||
|
||
from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I use HF DPO impl here in benchmarking for function reusability purpose? Or write another naive impl in pure torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HF DPO should be fine
return grad_input, grad_weight, None, grad_bias, None, None, None | ||
|
||
|
||
class HF_DPO_Loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I move this HF impl to file test_dpo_loss.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, since HF impl is only for testing purpose
can we modify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a FYI, I think we should wait until @shivam15s pushes a generic/inheritable class that handles all the chunking and other repetitive logic common to different loss functions, before pushing new loss functions.
Great work @austin362667 ! The additional summing of NLL loss is going to be useful for IRPO loss as well :). I'll be creating a simple base class which adds the boilerplate code (backward/torch compile logic) that you can inherit from, as @pramodith mentioned |
76017b7
to
b88708d
Compare
Issue addressed. Thanks @Tcc0403 @lancerts @pramodith @shivam15s and @ByronHsu for review! |
I think we should make chunked_loss functions nn.Module (like flce and fljsd) for users? same for orpo? cc @shivam15s @ByronHsu |
@Tcc0403 that is the plan! |
Signed-off-by: Austin Liu <[email protected]> Fix benchmark script
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
a995bc4
to
854c1b3
Compare
Summary
Add support for a fused, torch-compiled, and chunked DPO (Direct Preference Optimization) loss kernel, as requested in #371.
This implementation is largely based on the excellent work done on ORPO (#362) by @shivam15s.
DPO Loss Formulation
In a reference setting:
Corresponds to:
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence