Skip to content

Commit

Permalink
move process_data arg into TrainingArgs
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Clifford <[email protected]>
  • Loading branch information
MichaelClifford committed Oct 5, 2024
1 parent aefde0e commit 9a7986a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ run_training(
)

```

Check failure on line 284 in README.md

View workflow job for this annotation

GitHub Actions / markdown-lint

Fenced code blocks should be surrounded by blank lines

README.md:284 MD031/blanks-around-fences Fenced code blocks should be surrounded by blank lines [Context: "```"] https://github.com/DavidAnson/markdownlint/blob/v0.35.0/doc/md031.md
## Example training with separate data pre-processing

Check failure on line 285 in README.md

View workflow job for this annotation

GitHub Actions / markdown-lint

Headings should be surrounded by blank lines

README.md:285 MD022/blanks-around-headings Headings should be surrounded by blank lines [Expected: 1; Actual: 0; Above] [Context: "## Example training with separate data pre-processing"] https://github.com/DavidAnson/markdownlint/blob/v0.35.0/doc/md022.md

If the machines above have shared storage, users can preprocess the training dataset a single time so that it can then be distributed to each machine with the following update:
If the machines in the example above have shared storage, users can pre-process the training dataset a single time so that it can then be distributed to each machine by making the following updates.

Check failure on line 287 in README.md

View workflow job for this annotation

GitHub Actions / markdown-lint

Trailing spaces

README.md:287:199 MD009/no-trailing-spaces Trailing spaces [Expected: 0 or 2; Actual: 1] https://github.com/DavidAnson/markdownlint/blob/v0.35.0/doc/md009.md

```python
from instructlab.training import (
Expand All @@ -295,6 +296,25 @@ from instructlab.training import (
data_process as dp
)

training_args = TrainingArgs(
# define data-specific arguments
model_path = "ibm-granite/granite-7b-base",
data_path = "path/to/dataset.jsonl",
ckpt_output_dir = "data/saved_checkpoints",
data_output_dir = "data/outputs",

# define model-trianing parameters
max_seq_len = 4096,
max_batch_len = 60000,
num_epochs = 10,
effective_batch_size = 3840,
save_samples = 250000,
learning_rate = 2e-6,
warmup_steps = 800,
is_padding_free = True, # set this to true when using Granite-based models
random_seed = 42,
process_data = True,
)
...

data_process_args = DataProcessArgs(
Expand All @@ -309,6 +329,5 @@ dp.main(data_process_args)
run_training(
torch_args=torchrun_args,
train_args=training_args,
process_data = False
)
```
8 changes: 2 additions & 6 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@


# defer import of main_ds
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""Wrapper around the main training job that calls torchrun."""
# Local
from .main_ds import run_training

return run_training(
torch_args=torch_args, train_args=train_args, process_data=process_data
)
return run_training(torch_args=torch_args, train_args=train_args)
3 changes: 3 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,6 @@ class TrainingArgs(BaseModel):
# https://github.com/instructlab/training/issues/28
# quantize_dtype: QuantizeDataType = QuantizeDataType.NONE
lora: LoraOptions | None = None

# This field defines whether or not data processing will occur inside of `run_training()`
process_data: Optional[bool] = True
6 changes: 2 additions & 4 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,7 @@ def main(args):


# public API
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
Expand All @@ -647,7 +645,7 @@ def run_training(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if process_data:
if train_args.process_data:
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
Expand Down

0 comments on commit 9a7986a

Please sign in to comment.