From 540d85658c64cf2f06c23c794a22faa99043a0d3 Mon Sep 17 00:00:00 2001 From: AI-WAIFU <67525070+AI-WAIFU@users.noreply.github.com> Date: Tue, 8 Oct 2024 20:25:59 +0100 Subject: [PATCH] Add additional asserts and update post training readme (#1300) * add asserts and fix post training readme * precommit --------- Co-authored-by: Quentin Anthony --- megatron/training.py | 10 ++++++++++ post-training/README.md | 2 -- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 5976ae6a7..277f127c3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator): datatype=datatype, ) elif neox_args.train_impl == "kto": + assert ( + neox_args.train_micro_batch_size_per_gpu > 1 + ), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -459,6 +462,13 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" + + assert neox_args.train_impl not in [ + "kto", + "dpo", + "rm", + ], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" + # Items and their type. keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 diff --git a/post-training/README.md b/post-training/README.md index 1ba5cde2f..930ad0e31 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -34,7 +34,6 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis ## SFT data ```bash -python post-training/llama_dpo_data.py python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages @@ -42,7 +41,6 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/lla ## KTO data ```bash -python post-training/llama_dpo_data.py python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward