From 82ad390caf821db092f78186ba6b19f48bc804fd Mon Sep 17 00:00:00 2001 From: bartoszzuk <57541034+bartoszzuk@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:11:17 +0200 Subject: [PATCH] Fix RLOO checkpointing (#2114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix RLOO checkpointing for transformers>=4.45.0 * Add missing import * Fix pre-commit issues * Added test for RLOO checkpointing * Ensure that tokenizer matches SFT and Reward model * Pre-commit formatting * processing class --------- Co-authored-by: Kashif Rasul Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- tests/test_rloo_trainer.py | 47 +++++++++++++++++++++++++++++++++++++ trl/trainer/rloo_trainer.py | 14 +++++++---- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 03a7ea7709..bb5bb8f2c9 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -13,8 +13,14 @@ # limitations under the License. import platform import subprocess +import tempfile +import unittest import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import RLOOConfig, RLOOTrainer def test(): @@ -26,6 +32,8 @@ def test(): --gradient_accumulation_steps 1 \ --total_episodes 10 \ --model_name_or_path EleutherAI/pythia-14m \ + --sft_model_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ --missing_eos_penalty 1.0 \ --save_strategy no \ --stop_token eos @@ -71,3 +79,42 @@ def test_rloo_reward(): baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1) vec_advantages = rlhf_reward - baseline torch.testing.assert_close(vec_advantages.flatten(), advantages) + + +class RLOOTrainerTester(unittest.TestCase): + def setUp(self): + self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + self.reward_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + self.policy_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id) + self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id) + + self.tokenizer = AutoTokenizer.from_pretrained(self.sft_model_id, padding_side="left") + self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + def test_rloo_checkpoint(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RLOOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + total_episodes=1, + report_to="none", + ) + + dummy_text = {"content": "Hello World!", "role": "user"} + dummy_data = self.tokenizer.apply_chat_template(dummy_text) + dummy_dataset = Dataset.from_dict({"input_ids": dummy_data}) + + trainer = RLOOTrainer( + config=training_args, + policy=self.policy_model, + reward_model=self.reward_model, + ref_policy=self.policy_ref_model, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + trainer._save_checkpoint(trainer.model, trial=None) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index b082a1d827..18066976ca 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -43,7 +43,7 @@ ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK -from transformers.trainer_callback import CallbackHandler, PrinterCallback +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( @@ -158,10 +158,6 @@ def __init__( ######### ### trainer specifics ######### - self.state = OnlineTrainerState( - is_local_process_zero=self.is_local_process_zero(), - is_world_process_zero=self.is_world_process_zero(), - ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( @@ -169,6 +165,14 @@ def __init__( ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 self.hp_search_backend = None self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None