Skip to content

Commit

Permalink
hotfix for tp >= 2 and pp > 2 in autoitercount (#1296)
Browse files Browse the repository at this point in the history
* hotfix

* precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
AI-WAIFU and Quentin-Anthony authored Oct 1, 2024
1 parent c1105de commit 774eb58
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,35 @@ def update_iterations(neox_args, data_loaders):
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs
times.
"""
if neox_args.train_iters is not None:
if (not neox_args.do_train) or (neox_args.train_iters is not None):
pass
elif neox_args.train_iters is None and neox_args.train_epochs is None:
print_rank_0(
"ERROR:Failed to specify either train_epochs or train_iters in config file"
)
else:
train_dataloader = data_loaders["train"]
train_epochs = neox_args.train_epochs
gradient_accumulation_steps = neox_args.gradient_accumulation_steps
global_rank = torch.distributed.get_rank()

train_iterations = (
len(train_dataloader) * train_epochs
) // gradient_accumulation_steps
if global_rank == 0:
train_dataloader = data_loaders["train"]
train_epochs = neox_args.train_epochs
gradient_accumulation_steps = neox_args.gradient_accumulation_steps

train_dataloader_len = len(train_dataloader)
train_iterations = (
train_dataloader_len * train_epochs
) // gradient_accumulation_steps

train_iters_tensor = torch.cuda.LongTensor([train_iterations])
else:
train_iters_tensor = torch.cuda.LongTensor([0])

torch.distributed.broadcast(train_iters_tensor, src=0)

neox_args.train_iters = train_iters_tensor[0].item()

neox_args.train_iters = train_iterations
print_rank_0(
f"Training for a total of {train_iterations} iterations, corresponding to {train_epochs} epochs."
f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs."
)


Expand Down

0 comments on commit 774eb58

Please sign in to comment.