Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Oct 8, 2024
1 parent 1d33d06 commit e1b2b7c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def get_batch(neox_args, data_iterator):
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert neox_args.train_micro_batch_size_per_gpu > 1, "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
assert (
neox_args.train_micro_batch_size_per_gpu > 1
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand Down Expand Up @@ -461,7 +463,11 @@ def get_batch(neox_args, data_iterator):
def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""

assert neox_args.train_impl not in ["kto", "dpo", "rm"], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"
assert neox_args.train_impl not in [
"kto",
"dpo",
"rm",
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"

# Items and their type.
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
Expand Down

0 comments on commit e1b2b7c

Please sign in to comment.