From 0630baaf8d7419f7054ba4268ef00e271046ffcc Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 8 Jan 2025 17:19:22 +0000 Subject: [PATCH] this should work? --- src/axolotl/datasets.py | 2 ++ .../{stepwise.py => stepwise_supervised.py} | 28 +++++++++++++++---- src/axolotl/utils/data/sft.py | 9 +----- 3 files changed, 26 insertions(+), 13 deletions(-) rename src/axolotl/prompt_strategies/{stepwise.py => stepwise_supervised.py} (79%) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index b5638a614d..fba0bab031 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -52,6 +52,8 @@ def process(self, dataset): if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True map_kwargs["batch_size"] = 100 + import pdb + pdb.set_trace() return dataset.map( self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, diff --git a/src/axolotl/prompt_strategies/stepwise.py b/src/axolotl/prompt_strategies/stepwise_supervised.py similarity index 79% rename from src/axolotl/prompt_strategies/stepwise.py rename to src/axolotl/prompt_strategies/stepwise_supervised.py index 4c6e46bbec..4ea74a8d95 100644 --- a/src/axolotl/prompt_strategies/stepwise.py +++ b/src/axolotl/prompt_strategies/stepwise_supervised.py @@ -5,15 +5,16 @@ from itertools import chain -from typing import Dict, List, Optional, Union +from typing import Dict, Generator, List, Optional, Union -from transformers import BatchEncoding +from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompters import Prompter +from axolotl.utils.dict import DictDefault -class StepwiseSupervisedPromptTokenizingStrategy(PromptTokenizingStrategy): +class StepwiseSupervisedPromptTokenizingStrategy: """ Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning. These datasets should include the following columns: @@ -24,7 +25,6 @@ class StepwiseSupervisedPromptTokenizingStrategy(PromptTokenizingStrategy): def __init__( self, - prompter: Prompter, tokenizer, train_on_inputs: bool = False, sequence_len: int = 2048, @@ -33,7 +33,9 @@ def __init__( train_on_last_step_only: bool = False, is_eval: bool = False, ): - super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) + self.tokenizer = tokenizer + self.train_on_inputs = train_on_inputs + self.sequence_len = sequence_len self.step_separator = step_separator self.max_completion_length = max_completion_length self.train_on_last_step_only = train_on_last_step_only @@ -42,6 +44,8 @@ def __init__( def tokenize_prompt( self, prompt: Dict[str, Union[str, List[str]]] ) -> BatchEncoding: + # Inspired by TRL's PRMTRainer + # https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206 prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[ "input_ids" ] @@ -98,3 +102,17 @@ def tokenize_prompt( "attention_mask": [1] * len(input_ids), } ) + + @property + def supports_batched(self): + return False + + +def load( + tokenizer: PreTrainedTokenizer, cfg: DictDefault +) -> StepwiseSupervisedPromptTokenizingStrategy: + return StepwiseSupervisedPromptTokenizingStrategy( + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 17d41e556f..8a5eeff5a1 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -6,9 +6,9 @@ from typing import List, Tuple, Union from datasets import ( + concatenate_datasets, Dataset, DatasetDict, - concatenate_datasets, load_dataset, load_from_disk, ) @@ -456,13 +456,6 @@ def get_dataset_wrapper( dataset, **ds_kwargs, ) - elif ds_strategy := config_dataset.type == "stepwise_supervised": - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, - dataset, - **ds_kwargs, - ) elif ds_strategy := load( config_dataset.type, tokenizer, cfg, config_dataset, processor=processor ):