From e1b2b7caa2d578cfdd2f2dd9e36975a73acfc599 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Tue, 8 Oct 2024 12:25:29 -0700 Subject: [PATCH] precommit --- megatron/training.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index ee5a339e4..277f127c3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -406,7 +406,9 @@ def get_batch(neox_args, data_iterator): datatype=datatype, ) elif neox_args.train_impl == "kto": - assert neox_args.train_micro_batch_size_per_gpu > 1, "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." + assert ( + neox_args.train_micro_batch_size_per_gpu > 1 + ), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -461,7 +463,11 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" - assert neox_args.train_impl not in ["kto", "dpo", "rm"], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" + assert neox_args.train_impl not in [ + "kto", + "dpo", + "rm", + ], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" # Items and their type. keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]