Skip to content

Commit

Permalink
☄️ Update Comet integration to include LogCompletionsCallback and Tra…
Browse files Browse the repository at this point in the history
…iner.evaluation_loop() (huggingface#2501)

* Implemented integration with Comet in `LogCompletionsCallback`. Implemented related integration test.

* Implemented integration with Comet in `CPOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `CPOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `DPOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `BCOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `KTOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `ORPOTrainer.evaluation_loop()` during logging of `game_log` table.
  • Loading branch information
yaricom authored Dec 28, 2024
1 parent aed5da5 commit 763738f
Show file tree
Hide file tree
Showing 16 changed files with 182 additions and 100 deletions.
4 changes: 2 additions & 2 deletions tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def test_bco_trainer_generate_during_eval_no_wandb(self):

with self.assertRaisesRegex(
ValueError,
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install with `pip install wandb` to resolve.",
expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve.",
):
BCOTrainer(
model=self.model,
Expand Down
45 changes: 42 additions & 3 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_peft_available

from tests.testing_utils import require_mergekit
from tests.testing_utils import require_comet, require_mergekit
from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback
from trl.mergekit_utils import MergeConfig

Expand Down Expand Up @@ -216,7 +216,6 @@ def test_lora(self):
self.assertListEqual(winrate_history, self.expected_winrates)


@require_wandb
class LogCompletionsCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
Expand All @@ -234,7 +233,8 @@ def tokenize_function(examples):

self.generation_config = GenerationConfig(max_length=32)

def test_basic(self):
@require_wandb
def test_basic_wandb(self):
import wandb

with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -271,6 +271,45 @@ def test_basic(self):
# Check that the prompt is in the log
self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0])

@require_comet
def test_basic_comet(self):
import comet_ml

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="comet_ml",
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
processing_class=self.tokenizer,
)
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2)
trainer.add_callback(completions_callback)
trainer.train()

# close experiment to make sure all pending data are flushed
experiment = comet_ml.get_running_experiment()
assert experiment is not None
experiment.end()

# get experiment assets and check that all required tables was logged
steps = len(self.dataset["train"]) + len(self.dataset["test"])
tables_logged = int(steps / 2) + 1 # +1 to include zero step

api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id)
tables = api_experiment.get_asset_list("dataframe")
assert tables is not None
assert len(tables) == tables_logged
assert all(table["fileName"] == "completions.csv" for table in tables)


# On Windows, temporary directory cleanup fails when using the MergeModelCallback.
# This is not an issue with the functionality of the code itself, but it can cause the test to fail
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self):

with self.assertRaisesRegex(
ValueError,
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve.",
expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve.",
):
DPOTrainer(
model=self.model,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def test_kto_trainer_generate_during_eval_no_wandb(self):

with self.assertRaisesRegex(
ValueError,
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install with `pip install wandb` to resolve.",
expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve.",
):
KTOTrainer(
model=self.model,
Expand Down
9 changes: 8 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import unittest

from transformers import is_bitsandbytes_available, is_sklearn_available, is_wandb_available
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available

from trl import BaseBinaryJudge, BasePairwiseJudge, is_diffusers_available, is_llm_blender_available
from trl.import_utils import is_mergekit_available
Expand Down Expand Up @@ -65,6 +65,13 @@ def require_sklearn(test_case):
return unittest.skipUnless(is_sklearn_available(), "test requires sklearn")(test_case)


def require_comet(test_case):
"""
Decorator marking a test that requires Comet. Skips the test if Comet is not available.
"""
return unittest.skipUnless(is_comet_available(), "test requires comet_ml")(test_case)


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BCOConfig(TrainingArguments):
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B during
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
evaluation.
is_encoder_decoder (`Optional[bool]`, *optional*, defaults to `None`):
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
Expand Down
42 changes: 23 additions & 19 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.amp as amp
import torch.nn as nn
Expand All @@ -44,6 +45,7 @@
ProcessorMixin,
Trainer,
TrainingArguments,
is_comet_available,
is_sklearn_available,
is_wandb_available,
)
Expand All @@ -60,6 +62,7 @@
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -456,10 +459,10 @@ def make_inputs_require_grad(module, input, output):

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

if args.generate_during_eval and not is_wandb_available():
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install with `pip install wandb` to resolve."
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve."
)

if model is not None:
Expand Down Expand Up @@ -1398,28 +1401,29 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

target_indicies = [i for i in range(len(random_batch["delta"])) if random_batch["delta"][i] is False]
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
target_batch = {
"prompt_input_ids": itemgetter(*target_indicies)(random_batch["prompt_input_ids"]),
"prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]),
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
}
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)

self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy", "Ref Model"],
rows=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
target_batch["prompt"], policy_output_decoded, ref_output_decoded
)
],
)
}
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
],
)
self.state.log_history.pop()
if "wandb" in self.args.report_to:
wandb.log({"game_log": wandb.Table(data=table)})

if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)

# Base evaluation
initial_output = super().evaluation_loop(
Expand Down
26 changes: 19 additions & 7 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object, is_deepspeed_available
from accelerate.utils import gather_object, is_comet_ml_available, is_deepspeed_available, is_wandb_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
Expand All @@ -34,7 +34,6 @@
TrainerState,
TrainingArguments,
)
from transformers.integrations import WandbCallback
from transformers.trainer_utils import has_length

from ..data_utils import maybe_apply_chat_template
Expand All @@ -48,6 +47,12 @@
if is_deepspeed_available():
import deepspeed

if is_comet_ml_available():
pass

if is_wandb_available():
import wandb


def _generate_completions(
prompts: list[str],
Expand Down Expand Up @@ -406,9 +411,9 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
)


class LogCompletionsCallback(WandbCallback):
class LogCompletionsCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases.
A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases and/or Comet.
Usage:
```python
Expand Down Expand Up @@ -436,7 +441,6 @@ def __init__(
num_prompts: Optional[int] = None,
freq: Optional[int] = None,
):
super().__init__()
self.trainer = trainer
self.generation_config = generation_config
self.freq = freq
Expand Down Expand Up @@ -483,8 +487,16 @@ def on_step_end(self, args, state, control, **kwargs):
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions))
self.table.extend(data)
table = self._wandb.Table(columns=["step", "prompt", "completion"], data=self.table)
self._wandb.log({"completions": table})
table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table)

if "wandb" in args.report_to:
wandb.log({"completions": table})

if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=table,
)

# Save the last logged step, so we don't log the same completions multiple times
self._last_logged_step = state.global_step
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class CPOConfig(TrainingArguments):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from the model to W&B during evaluation.
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
is_encoder_decoder (`Optional[bool]`, *optional*, defaults to `None`):
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
you need to specify if the model returned by the callable is an encoder-decoder model.
Expand Down
33 changes: 19 additions & 14 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Callable, Literal, Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.amp as amp
import torch.nn as nn
Expand All @@ -40,6 +41,7 @@
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_comet_available,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
Expand All @@ -55,6 +57,7 @@
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -200,10 +203,10 @@ def make_inputs_require_grad(module, input, output):

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

if args.generate_during_eval and not is_wandb_available():
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve."
)

if model is not None:
Expand Down Expand Up @@ -936,18 +939,20 @@ def evaluation_loop(

policy_output_decoded = self.generate_from_model(self.model, random_batch)

self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy"],
rows=[
[prompt, pol[len(prompt) :]]
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
],
)
}
table = pd.DataFrame(
columns=["Prompt", "Policy"],
data=[
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
],
)
self.state.log_history.pop()
if "wandb" in self.args.report_to:
wandb.log({"game_log": wandb.Table(data=table)})

if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)

# Base evaluation
initial_output = super().evaluation_loop(
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class DPOConfig(TrainingArguments):
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B during
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
evaluation.
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
Expand Down
Loading

0 comments on commit 763738f

Please sign in to comment.