-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor
LigerFusedLinearPreferenceBase
(#381)
## Summary This PR refactors the `LigerFusedLinearPreferenceBase` class to contain an abstractmethod corresponding to the calculation of the loss that needs to be implemented by all sub-classes. It also adds a new function to the class called `_compute_loss` which is mostly the same as the `_compute_orpo_loss` function introduced in #362 but makes it generic to calculate the NLL/Cross Entropy Loss plus accepts a custom loss function that implements a new alignment loss function. Most RLHF/RLAIF/Alignment algorithms state their final loss as `NLL + Beta * (Alignment_Loss) `so adding the NLL logic inside the base class reduces repeated code. The _compute_loss function accepts ## Testing Done On A100-80G-SXM - Hardware Type: <BLANK> - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Co-authored-by: pramodith <[email protected]>
- Loading branch information
Showing
2 changed files
with
126 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters