Skip to content

Commit

Permalink
pre-commit and clean up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Aug 19, 2024
1 parent 8f26029 commit d9db749
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def __init__(
self.init_method = init_method
self.num_tokentypes = num_tokentypes

self.sequence_parallel = neox_args.sequence_parallel # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks
self.sequence_parallel = (
neox_args.sequence_parallel
) # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks

self.use_mup = neox_args.use_mup
self.mup_embedding_mult = neox_args.mup_embedding_mult
Expand Down Expand Up @@ -163,10 +165,9 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
embeddings.mul_(self.mup_embedding_mult)

if self.sequence_parallel:
# TODO: megatron-lm does dropout using the scattered embs. This'd save a tiny bit of time, perhaps?
# TODO: megatron-lm does dropout using the scattered embs. This would save a tiny bit of time, perhaps?
# Not a priority since we don't often use dropout
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# pass

return embeddings

Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class NeoXArgsParallelism(NeoXArgsTemplate):
according to pipeline parallel size.
"""

sequence_parallel: bool = False # TODO: default to True?
sequence_parallel: bool = False
"""
flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198)
(Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1.
Expand Down

0 comments on commit d9db749

Please sign in to comment.