Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Oct 1, 2024
1 parent 15c0ebb commit ad97aaa
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 35 deletions.
4 changes: 2 additions & 2 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def build_train_valid_test_data_loaders(neox_args):
else:
pipe_load = True

# Data loader only on rank 0 of each model/sequence parallel group.
# Data loader only on rank 0 of each model and context parallel group.
if (
mpu.get_model_parallel_rank() == 0
and pipe_load
Expand Down Expand Up @@ -675,7 +675,7 @@ def build_train_valid_test_data_loaders(neox_args):
# broadcast globally instead of just the model parallel group.
torch.distributed.broadcast(flags, src=0)
else:
# The same data should be used for the model parallel and sequence parallel groups
# The same data should be used for the model parallel and context parallel groups
torch.distributed.broadcast(
flags,
mpu.get_model_parallel_src_rank(),
Expand Down
16 changes: 7 additions & 9 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,19 @@ def _initialize_distributed(neox_args):
# Setup 3D topology.
pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
sp = neox_args.context_parallel_size if neox_args.context_parallel_size >= 1 else 1
# assert (
# neox_args.world_size % (pp * mp * sp) == 0
# ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, sp={sp}"
cp = neox_args.context_parallel_size if neox_args.context_parallel_size >= 1 else 1
assert (
neox_args.world_size % (pp * mp * cp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, cp={cp}"
assert (
neox_args.world_size % (pp * mp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
# dp = neox_args.world_size // (pp * mp * sp)
# The data parallel ranks will be used for context parallel
# to piggy back the gradient all reduce
dp = neox_args.world_size // (pp * mp)
assert dp % sp == 0 # The data parallel ranks will be used for sequence parallel
assert dp % cp == 0
from deepspeed.runtime.pipe.topology import ProcessTopology

# With 4D parallelism, we have 4 dimensions: pipe, data, model, sequence
# So we need to define it manually...
# topo = ProcessTopology(axes=["pipe", "data", "model", "seq"], dims=[pp, dp, mp, sp])
topo = ProcessTopology(axes=["pipe", "data", "model"], dims=[pp, dp, mp])

# Offset base seeds for the interior pipeline stages.
Expand Down
3 changes: 0 additions & 3 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def cross_entropy(output, labels, _fp16=False):
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
loss = loss.bfloat16()
else:
# torch.distributed.barrier()
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
return loss

Expand Down
11 changes: 10 additions & 1 deletion megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,16 @@ def initialize_model_parallel(
Initialize model data parallel groups.
Arguments:
model_parallel_size: number of GPUs used to parallelize model.
model_parallel_size: number of GPUs used for model parallelism.
pipe_parallel_size: number of GPUs used for pipeline parallelism.
context_parallel_size: number of GPUs used for context parallelism.
topology: topology if it exists.
fp32_allreduce: whether or not to do all reduce in fp32.
Adjacent ranks are ordered by model parallel, then context parallel,
then data parallel. Context parallelism duplicates weights among GPUs in
a context parallel group, so we piggy back on the data parallel group
for the gradient all-reduce.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
Expand Down
16 changes: 8 additions & 8 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def configure_distributed_args(self):
if self.rank == 0:
print(
self.__class__.__name__
+ ".configure_distributed_args() using world size: {}, pipe-parallel size: {}, sequence-parallel size: {}, and model-parallel size: {} ".format(
+ ".configure_distributed_args() using world size: {}, pipe-parallel size: {}, context-parallel size: {}, and model-parallel size: {} ".format(
self.world_size,
self.pipe_parallel_size,
self.context_parallel_size,
Expand Down Expand Up @@ -912,16 +912,16 @@ def calculate_derived(self):
pp_size = pp_size if pp_size >= 1 else 1
mp_size = self.model_parallel_size
mp_size = mp_size if mp_size >= 1 else 1
sp_size = self.context_parallel_size
sp_size = sp_size if sp_size >= 1 else 1
cp_size = self.context_parallel_size
cp_size = cp_size if cp_size >= 1 else 1
self.update_value("model_parallel_size", mp_size)
self.update_value("context_parallel_size", sp_size)
self.update_value("context_parallel_size", cp_size)

# pp_size, mp_size, and sp_size are only used here to compute dp world size and nowhere else.
dp_world_size = (global_num_gpus / pp_size) / (mp_size * sp_size)
# pp_size, mp_size, and cp_size are only used here to compute dp world size and nowhere else.
dp_world_size = (global_num_gpus / pp_size) / (mp_size * cp_size)
if not (dp_world_size % 1 == 0):
error_message = (
f"{ERROR}"
"ERROR"
+ self.__class__.__name__
+ ".calculate_derived() "
+ f"(global_num_gpus / pp_size) / mp_size [({global_num_gpus} / {pp_size}) / {mp_size}] must be a whole number"
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def calculate_derived(self):
)
assert all([item == "ring" for item in self.attention_config]) or (
not self.is_context_parallel
), "Sequence parallel requires ring attention!"
), "Context parallel requires ring attention!"
assert (
len(self.attention_config) == self.num_layers
), "Length of attention config list must equal num_layers"
Expand Down
12 changes: 0 additions & 12 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,18 +1272,6 @@ def train_step(
reference_model=None,
):
"""Single training step."""
# global train_step_counter
# if train_step_counter > 10:
# torch.distributed.barrier()
# if torch.distributed.get_rank() == 0:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
# if torch.distributed.get_rank() == 1:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()
# # assert False
# train_step_counter += 1

# Pipeline parallelism schedules forward/backward/step
if neox_args.is_pipe_parallel:
reduced_loss = train_step_pipe(
Expand Down

0 comments on commit ad97aaa

Please sign in to comment.