diff --git a/megatron/training.py b/megatron/training.py index ee5a339e4..277f127c3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -406,7 +406,9 @@ def get_batch(neox_args, data_iterator): datatype=datatype, ) elif neox_args.train_impl == "kto": - assert neox_args.train_micro_batch_size_per_gpu > 1, "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." + assert ( + neox_args.train_micro_batch_size_per_gpu > 1 + ), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -461,7 +463,11 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" - assert neox_args.train_impl not in ["kto", "dpo", "rm"], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" + assert neox_args.train_impl not in [ + "kto", + "dpo", + "rm", + ], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" # Items and their type. keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]