Skip to content

Commit

Permalink
Fix sequence parallel with tied weight embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Aug 18, 2024
1 parent 9d883de commit 28a5a62
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
5 changes: 4 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output):
)

logits = parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
lm_output,
embedding.word_embeddings_weight,
self.parallel_output,
seq_parallel=self.neox_args.sequence_parallel,
)
return logits

Expand Down
29 changes: 23 additions & 6 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False, # TODO: add a parallel-input check? need to AG in fwd pass to reshard, here?
gather_output=False, # TODO: add a parallel-input check? need to AG in fwd pass to reshard, here?
init_method=init_method,
skip_bias_add=True,
MOE=MOE,
Expand All @@ -127,7 +127,7 @@ def __init__(
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
parallel_output=parallel_output, # seqpar should do parallel_output?
parallel_output=parallel_output, # seqpar should do parallel_output?
skip_bias_add=True,
MOE=MOE,
MoE_mp_size=MoE_mp_size,
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(
gather_output=not parallel_output,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1
)

# else:
Expand Down Expand Up @@ -1028,7 +1028,11 @@ def __init__(
# GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers.
# the reduction we use is a simple allreduce for pure Tensor Parallel,
# but needs to be a reduce-scatter when using sequence parallel (LN sharding.)
self.reduce = mpu.mappings.reduce_from_model_parallel_region if not neox_args.sequence_parallel else mpu.mappings.reduce_scatter_to_sequence_parallel_region
self.reduce = (
mpu.mappings.reduce_from_model_parallel_region
if not neox_args.sequence_parallel
else mpu.mappings.reduce_scatter_to_sequence_parallel_region
)

# Self attention.
self.attention = ParallelSelfAttention(
Expand Down Expand Up @@ -1343,10 +1347,23 @@ def forward(self, args):
return self.norm(args)


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
def parallel_lm_logits(
input_,
word_embeddings_weight,
parallel_output,
seq_parallel=False,
seq_dim=1,
bias=None,
):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
if seq_parallel:
input_parallel = mpu.gather_from_sequence_parallel_region(
input_, seq_dim=seq_dim
)
else:
# Set up backprop all-reduce.
input_parallel = mpu.copy_to_model_parallel_region(input_)

# Matrix multiply.
if bias is None:
Expand Down

0 comments on commit 28a5a62

Please sign in to comment.