Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional asserts and update post training readme #1300

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator):
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert (
neox_args.train_micro_batch_size_per_gpu > 1
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand Down Expand Up @@ -459,6 +462,13 @@ def get_batch(neox_args, data_iterator):

def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""

assert neox_args.train_impl not in [
"kto",
"dpo",
"rm",
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"

# Items and their type.
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
datatype = torch.int64
Expand Down
2 changes: 0 additions & 2 deletions post-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,13 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis

## SFT data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
```

## KTO data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
Expand Down
Loading