Skip to content

Commit

Permalink
Fix RLOO checkpointing (#2114)
Browse files Browse the repository at this point in the history
* Fix RLOO checkpointing for transformers>=4.45.0

* Add missing import

* Fix pre-commit issues

* Added test for RLOO checkpointing

* Ensure that tokenizer matches SFT and Reward model

* Pre-commit formatting

* processing class

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
4 people authored Oct 7, 2024
1 parent ac038ef commit 82ad390
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
47 changes: 47 additions & 0 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
# limitations under the License.
import platform
import subprocess
import tempfile
import unittest

import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

from trl import RLOOConfig, RLOOTrainer


def test():
Expand All @@ -26,6 +32,8 @@ def test():
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--sft_model_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos
Expand Down Expand Up @@ -71,3 +79,42 @@ def test_rloo_reward():
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)


class RLOOTrainerTester(unittest.TestCase):
def setUp(self):
self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.reward_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"

self.policy_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id)
self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)

self.tokenizer = AutoTokenizer.from_pretrained(self.sft_model_id, padding_side="left")
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

def test_rloo_checkpoint(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RLOOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
total_episodes=1,
report_to="none",
)

dummy_text = {"content": "Hello World!", "role": "user"}
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
dummy_dataset = Dataset.from_dict({"input_ids": dummy_data})

trainer = RLOOTrainer(
config=training_args,
policy=self.policy_model,
reward_model=self.reward_model,
ref_policy=self.policy_ref_model,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

trainer._save_checkpoint(trainer.model, trial=None)
14 changes: 9 additions & 5 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, PrinterCallback
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback

from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
Expand Down Expand Up @@ -158,17 +158,21 @@ def __init__(
#########
### trainer specifics
#########
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)

self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
Expand Down

0 comments on commit 82ad390

Please sign in to comment.