Skip to content

Commit

Permalink
Add additional asserts and update post training readme (#1300)
Browse files Browse the repository at this point in the history
* add asserts and fix post training readme

* precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
2 people authored and jahatef committed Oct 31, 2024
1 parent a418670 commit 540d856
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 10 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator):
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert (
neox_args.train_micro_batch_size_per_gpu > 1
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand Down Expand Up @@ -459,6 +462,13 @@ def get_batch(neox_args, data_iterator):

def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""

assert neox_args.train_impl not in [
"kto",
"dpo",
"rm",
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"

# Items and their type.
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
datatype = torch.int64
Expand Down
2 changes: 0 additions & 2 deletions post-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,13 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis

## SFT data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
```

## KTO data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
Expand Down

0 comments on commit 540d856

Please sign in to comment.