Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create preprocess CLI #785

Merged
merged 10 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-)
Expand Down Expand Up @@ -824,14 +823,41 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml
```

#### Multi-GPU
#### Preprocess dataset

You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.

- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- Use `--debug` to see preprocessed examples.

You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
python -m axolotl.cli.preprocess your_config.yml
```

##### Config
#### Multi-GPU

Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
is the recommended multi-GPU option currently because FSDP may experience
[loss instability](https://github.com/huggingface/transformers/issues/26498).

##### DeepSpeed

Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated

We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.

```yaml
deepspeed: deepspeed/zero1.json
```

```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```

##### FSDP

- llama FSDP
```yaml
Expand All @@ -856,24 +882,6 @@ wandb_run_id:
wandb_log_model:
```

### Training with Deepspeed

Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated

We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.

```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```

or

```yaml
deepspeed: deepspeed/zero1.json
```

### Inference

Pass the appropriate flag to the train command:
Expand Down
2 changes: 0 additions & 2 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)


Expand Down
8 changes: 7 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def load_datasets(
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)

train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
)

if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
Expand All @@ -238,6 +240,10 @@ def load_datasets(
text_only=cli_args.debug_text_only,
)

LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)

return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
Expand Down
53 changes: 53 additions & 0 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path

import fire
import transformers
from colorama import Fore

from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH

LOG = logging.getLogger("axolotl.cli.preprocess")


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH

_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET
)


if __name__ == "__main__":
fire.Fire(do_cli)
13 changes: 0 additions & 13 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import fire
import transformers
from colorama import Fore

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -16,7 +15,6 @@
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
Expand All @@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH

dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)


Expand Down
13 changes: 12 additions & 1 deletion src/axolotl/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,22 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)


@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""

debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)


def load_model_and_tokenizer(
*,
cfg: DictDefault,
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
raise NotImplementedError

def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
(
instruction,
input, # pylint: disable=redefined-builtin
Expand Down
83 changes: 69 additions & 14 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from enum import Enum
from typing import Generator, Optional, Union

from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template

LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"


class PromptStyle(Enum):
Expand Down Expand Up @@ -55,20 +57,15 @@ def match_prompt_style(self):
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"

def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
if input_text:
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input)
) + self.turn_format.format(instruction=instruction, input=input_text)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
Expand All @@ -77,7 +74,21 @@ def build_prompt(
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res

return res

def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
yield self._build_result(instruction, input, output)

def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)


class UnpromptedPrompter(AlpacaPrompter):
Expand Down Expand Up @@ -191,14 +202,14 @@ def match_prompt_style(self):
)
self.response_split = "ASSISTANT:"

def build_prompt(
def _build_result(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
):
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
Expand All @@ -212,7 +223,30 @@ def build_prompt(
corrected=corrected,
)
res = f"{res}{label}"
yield res

return res

def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
# pylint: disable=duplicate-code
yield self._build_result(
instruction,
input,
output,
reflection,
corrected,
)

def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)


SHAREGPT_ASSERTION_FAILED_ROLE = (
Expand Down Expand Up @@ -247,7 +281,7 @@ def __init__(
if role_key_model:
self.role_key_model = role_key_model

def build_prompt(self, source) -> Generator[str, None, None]:
def _build_result(self, source):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
Expand Down Expand Up @@ -282,11 +316,20 @@ def build_prompt(self, source) -> Generator[str, None, None]:
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])

for part in conv.get_turns():
return conv.get_turns()

def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)

for part in turns:
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
yield part

def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])


class ShareGPTPrompterV2(ShareGPTPrompter):
"""
Expand All @@ -304,3 +347,15 @@ def __init__(
role_key_human=role_key_human,
role_key_model=role_key_model,
)


class UnsupportedPrompter:
"""
A dummy class for custom prompters
"""

def __init__(self) -> None:
pass

def __repr__(self):
return "Pre-tokenized or custom dataset types are unsupported for logging"
Loading