Skip to content

Commit

Permalink
add asserts and fix post training readme
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Oct 7, 2024
1 parent 774eb58 commit 1d33d06
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 4 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def get_batch(neox_args, data_iterator):
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert neox_args.train_micro_batch_size_per_gpu > 1, "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand Down Expand Up @@ -459,6 +460,9 @@ def get_batch(neox_args, data_iterator):

def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""

assert neox_args.train_impl not in ["kto", "dpo", "rm"], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"

# Items and their type.
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
datatype = torch.int64
Expand Down
2 changes: 0 additions & 2 deletions post-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,13 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis

## SFT data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages
```

## KTO data
```bash
python post-training/llama_dpo_data.py
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward
Expand Down

0 comments on commit 1d33d06

Please sign in to comment.