Skip to content

Commit

Permalink
📦 Support for packing tokenized datasets for SFT (#2011)
Browse files Browse the repository at this point in the history
* feat: add support for packing tokenized datasetS

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: address review comments

Signed-off-by: Mehant Kammakomati <[email protected]>

* feat: add tests for pretokenized dataset packing

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant authored Nov 25, 2024
1 parent 163695e commit 17e8060
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
79 changes: 79 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def formatting_prompts_func(example):
return text


def formatting_func_for_pretokenized(example):
return example["input_ids"]


def formatting_prompts_func_batched(example):
output_text = []
for i, question in enumerate(example["question"]):
Expand Down Expand Up @@ -93,6 +97,17 @@ def setUp(self):
],
}
)
self.dummy_tokenized_dataset = Dataset.from_dict(
{
"input_ids": [
self.tokenizer.encode(
"TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)."
)
]
* 10
}
)

self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
self.standard_prompt_completion_dataset = load_dataset(
"trl-internal-testing/zen", "standard_prompt_completion"
Expand Down Expand Up @@ -158,6 +173,42 @@ def setUp(self):
num_of_sequences=16,
)

self.train_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
seq_length=16,
num_of_sequences=16,
formatting_func=formatting_func_for_pretokenized,
)

self.eval_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
seq_length=16,
num_of_sequences=16,
formatting_func=formatting_func_for_pretokenized,
)

def test_constant_length_dataset_with_pretokenized_data(self):
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
formatting_func=formatting_func_for_pretokenized,
)

assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset)
assert len(constant_len_dataset) > 0

for example in constant_len_dataset:
assert "input_ids" in example
assert "labels" in example

assert len(example["input_ids"]) == constant_len_dataset.seq_length
assert len(example["labels"]) == constant_len_dataset.seq_length

decoded_text = self.tokenizer.decode(example["input_ids"])
assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text)

def test_constant_length_dataset(self):
formatted_dataset = ConstantLengthDataset(
self.tokenizer,
Expand Down Expand Up @@ -236,6 +287,34 @@ def test_sft_trainer(self):

self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))

def test_sft_trainer_with_pretokenzied_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
packing=True,
report_to="none",
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset_from_pretokenized,
eval_dataset=self.eval_dataset_from_pretokenized,
)

trainer.train()

assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# Shoud work as SFTTrainer natively supports conversational lm dataset
Expand Down
6 changes: 5 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,11 @@ def _prepare_dataset(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored."
)

return dataset
def formatting_func(x):
return x["input_ids"]

if not packing:
return dataset

# check if torch dataset / dataloader and do nothing
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check
Expand Down
18 changes: 15 additions & 3 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import datasets
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -627,6 +628,14 @@ def __init__(
"The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
" which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
)
self.pretokenized = False
column_names = (
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
)
if column_names is not None and "input_ids" in column_names:
self.pretokenized = True
# since the dataset is tokenized, the unit of buffer size should be tokens
self.max_buffer_size = seq_length * num_of_sequences

def __len__(self):
return len(self.dataset)
Expand All @@ -651,9 +660,12 @@ def __iter__(self):
break
if self.shuffle:
random.shuffle(buffer)
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
"input_ids"
]
if self.pretokenized:
tokenized_inputs = buffer
else:
tokenized_inputs = self.tokenizer(
buffer, add_special_tokens=self.add_special_tokens, truncation=False
)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
Expand Down

0 comments on commit 17e8060

Please sign in to comment.