diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index e615caf043..93acff6da2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -39,6 +39,10 @@ def formatting_prompts_func(example): return text +def formatting_func_for_pretokenized(example): + return example["input_ids"] + + def formatting_prompts_func_batched(example): output_text = [] for i, question in enumerate(example["question"]): @@ -93,6 +97,17 @@ def setUp(self): ], } ) + self.dummy_tokenized_dataset = Dataset.from_dict( + { + "input_ids": [ + self.tokenizer.encode( + "TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)." + ) + ] + * 10 + } + ) + self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") self.standard_prompt_completion_dataset = load_dataset( "trl-internal-testing/zen", "standard_prompt_completion" @@ -158,6 +173,42 @@ def setUp(self): num_of_sequences=16, ) + self.train_dataset_from_pretokenized = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + seq_length=16, + num_of_sequences=16, + formatting_func=formatting_func_for_pretokenized, + ) + + self.eval_dataset_from_pretokenized = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + seq_length=16, + num_of_sequences=16, + formatting_func=formatting_func_for_pretokenized, + ) + + def test_constant_length_dataset_with_pretokenized_data(self): + constant_len_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + formatting_func=formatting_func_for_pretokenized, + ) + + assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset) + assert len(constant_len_dataset) > 0 + + for example in constant_len_dataset: + assert "input_ids" in example + assert "labels" in example + + assert len(example["input_ids"]) == constant_len_dataset.seq_length + assert len(example["labels"]) == constant_len_dataset.seq_length + + decoded_text = self.tokenizer.decode(example["input_ids"]) + assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text) + def test_constant_length_dataset(self): formatted_dataset = ConstantLengthDataset( self.tokenizer, @@ -236,6 +287,34 @@ def test_sft_trainer(self): self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2")) + def test_sft_trainer_with_pretokenzied_data_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + packing=True, + report_to="none", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset_from_pretokenized, + eval_dataset=self.eval_dataset_from_pretokenized, + ) + + trainer.train() + + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + + assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2") + def test_sft_trainer_uncorrect_data(self): with tempfile.TemporaryDirectory() as tmp_dir: # Shoud work as SFTTrainer natively supports conversational lm dataset diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 2dce57b12b..1e1a6781bc 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -367,7 +367,11 @@ def _prepare_dataset( "You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored." ) - return dataset + def formatting_func(x): + return x["input_ids"] + + if not packing: + return dataset # check if torch dataset / dataloader and do nothing # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 96dc8fba24..f2854a1cec 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -21,6 +21,7 @@ from importlib.metadata import version from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import datasets import numpy as np import pandas as pd import torch @@ -627,6 +628,14 @@ def __init__( "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`" " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing." ) + self.pretokenized = False + column_names = ( + dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None + ) + if column_names is not None and "input_ids" in column_names: + self.pretokenized = True + # since the dataset is tokenized, the unit of buffer size should be tokens + self.max_buffer_size = seq_length * num_of_sequences def __len__(self): return len(self.dataset) @@ -651,9 +660,12 @@ def __iter__(self): break if self.shuffle: random.shuffle(buffer) - tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[ - "input_ids" - ] + if self.pretokenized: + tokenized_inputs = buffer + else: + tokenized_inputs = self.tokenizer( + buffer, add_special_tokens=self.add_special_tokens, truncation=False + )["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: if self.append_concat_token: