-
-
Notifications
You must be signed in to change notification settings - Fork 925
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ipo-dpo trainer * fix missing abstract method * chatml template, grad checkpointing kwargs support * fix steps calc for RL and add dataloader kwargs * wip to fix dpo and start ppo * more fixes * refactor to generalize map fn * fix dataset loop and handle argilla pref dataset * set training args * load reference model on seperate gpu if more than one device * no auto upload to hub for dpo, don't add lora adapters to ref model for dpo * fixes for rl training * support for ipo from yaml * set dpo training args from the config, add tests * chore: lint * set sequence_len for model in test * add RLHF docs
- Loading branch information
Showing
11 changed files
with
388 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# RLHF (Beta) | ||
|
||
### Overview | ||
|
||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human | ||
feedback. Various methods include, but not limited to: | ||
|
||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl) | ||
- Direct Preference Optimization (DPO) | ||
- Identity Preference Optimization (IPO) | ||
|
||
|
||
### RLHF using Axolotl | ||
|
||
[!IMPORTANT] | ||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. | ||
|
||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML | ||
|
||
#### DPO | ||
```yaml | ||
rl: true | ||
datasets: | ||
- path: Intel/orca_dpo_pairs | ||
split: train | ||
type: intel_apply_chatml | ||
- path: argilla/ultrafeedback-binarized-preferences | ||
split: train | ||
type: argilla_apply_chatml | ||
``` | ||
#### IPO | ||
```yaml | ||
rl: ipo | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,3 +37,5 @@ tensorboard | |
s3fs | ||
gcsfs | ||
# adlfs | ||
|
||
trl @ git+https://github.com/huggingface/trl.git@main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
module for TRL PPO training | ||
""" | ||
import torch | ||
from tqdm import tqdm | ||
from trl import PPOTrainer | ||
|
||
|
||
class TRLPPOTrainer(PPOTrainer): | ||
""" | ||
wrapper for ppo trainer to handle customizations | ||
""" | ||
|
||
def train( | ||
self, | ||
reward_pipe, | ||
resume_from_checkpoint=None, # pylint: disable=unused-argument | ||
): | ||
generation_kwargs = { | ||
"min_length": -1, | ||
"top_k": 0.0, | ||
"top_p": 1.0, | ||
"do_sample": True, | ||
"pad_token_id": self.tokenizer.eos_token_id, | ||
"max_new_tokens": 32, | ||
} | ||
sent_kwargs = { | ||
"return_all_scores": True, | ||
"function_to_apply": "none", | ||
"batch_size": 16, | ||
} | ||
|
||
for epoch, batch in tqdm( # pylint: disable=unused-variable | ||
enumerate(self.dataloader) | ||
): | ||
query_tensors = batch["input_ids"] | ||
|
||
# generate model response | ||
response_tensors, ref_response_tensors = self.generate( | ||
query_tensors, | ||
return_prompt=False, | ||
generate_ref_response=True, | ||
**generation_kwargs | ||
) | ||
batch["response"] = self.tokenizer.batch_decode(response_tensors) | ||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) | ||
|
||
# Compute sentiment score | ||
texts = [q + r for q, r in zip(batch["query"], batch["response"])] | ||
pipe_outputs = reward_pipe(texts, **sent_kwargs) | ||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] | ||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] | ||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) | ||
ref_rewards = [ | ||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs | ||
] | ||
batch["ref_rewards"] = ref_rewards | ||
|
||
# Run PPO step | ||
stats = self.step(query_tensors, response_tensors, rewards) | ||
self.log_stats( | ||
stats, | ||
batch, | ||
rewards, | ||
columns_to_log=["query", "response", "ref_response", "ref_rewards"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.