Skip to content

Commit

Permalink
Megatron-LM style Sequence Parallel (#1257)
Browse files Browse the repository at this point in the history
* first draft (shape errors occurring)

* training works (but poor convergence)

* debugging progress: current commit works if we do regular TP via impl-ing AR in rowparallel as RS then AG

* Update NeoXArgs docs automatically

* push most recent code (updated mark_norms fn, back to 'real' sequence parallel)

* Update NeoXArgs docs automatically

* Fix LayerNorm all reduce gradient hook

* Sum instead of average for LayerNorm gradient all reduce

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* Fix gather and reduce scatter ops on sequence dimension

* Fix sequence parallel with tied weight embeddings

* Update NeoXArgs docs automatically

* cleanup pass + add MoE arguments.py guard

* pre-commit and clean up comments

* remove vestigial debug code

* remove unused debugging code

* remove dummy test config

* update fp32_allreduce to handle fp16 ; don't cast to fp32 for gathers

* run linter on the rest of the files

* Improve performance of sequence parallel gather, scatter, and reduce

* Add comment

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: Brandon Yang <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
4 people authored Aug 23, 2024
1 parent f8c9e68 commit 8b43196
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 33 deletions.
12 changes: 11 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 455446c
Default = 53d0ae8

current git hash of repository

Expand Down Expand Up @@ -1056,6 +1056,16 @@ Parallelism Arguments
- **sequence_parallel**: bool
Default = False
flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198)
(Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1.
**Set by user, in contrast to neox_args.is_pipe_parallel.**
- **expert_interval**: int
Default = 2
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
# limitations under the License.

from .gpt2_model import GPT2ModelPipe
from .utils import get_params_for_weight_decay_optimization
from .utils import (
get_params_for_weight_decay_optimization,
mark_norms_for_sequence_parallel_grad_sync,
)
from .word_embeddings import SoftEmbedding
5 changes: 4 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output):
)

logits = parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
lm_output,
embedding.word_embeddings_weight,
self.parallel_output,
seq_parallel=self.neox_args.sequence_parallel,
)
return logits

Expand Down
29 changes: 26 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
gather_output=not parallel_output,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0
)

# else:
Expand Down Expand Up @@ -1024,7 +1025,14 @@ def __init__(
self.moe_type = neox_args.moe_type

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
# GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers.
# the reduction we use is a simple allreduce for pure Tensor Parallel,
# but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.)
self.reduce = (
mpu.mappings.reduce_from_model_parallel_region
if not neox_args.sequence_parallel
else mpu.mappings.reduce_scatter_to_sequence_parallel_region
)

# Self attention.
self.attention = ParallelSelfAttention(
Expand Down Expand Up @@ -1339,10 +1347,25 @@ def forward(self, args):
return self.norm(args)


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
def parallel_lm_logits(
input_,
word_embeddings_weight,
parallel_output,
seq_parallel=False,
seq_dim=1,
bias=None,
):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
if seq_parallel:
# if using Sequence Parallelism, our logits are sharded along the sequence dimension.
# gather them here. (backward pass: reduce-scatter)
input_parallel = mpu.gather_from_sequence_parallel_region(
input_, seq_dim=seq_dim
)
else:
# Set up backprop all-reduce.
input_parallel = mpu.copy_to_model_parallel_region(input_)

# Matrix multiply.
if bias is None:
Expand Down
56 changes: 46 additions & 10 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"""Utilities for models."""

import torch
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.fused_softmax import SoftmaxFusionTypes
from megatron import mpu
from types import GeneratorType
import torch.distributed as dist

Expand All @@ -35,15 +35,9 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"name": "no_weight_decay_params",
}
for module_ in module.modules():
if any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
# apply weight decay to any "...Norm" modules.
if "norm" in type(module_).__name__.lower() or neox_args.weight_decay == 0.0:
# also include all parameters here if no weight decay is being done
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
Expand Down Expand Up @@ -359,3 +353,45 @@ def get_fusion_type(neox_args):
elif neox_args.scaled_masked_softmax_fusion:
fusion_type = SoftmaxFusionTypes.general
return fusion_type


def reduce_weight_grads_from_model_parallel_region(input_):
"""A hook that can be applied to any weight tensor via .register_hook().
Allreduces grads for e.g. LN weights across the model parallel group.
Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel.
"""
# Bypass the function if no TP -> no comm needed.
if mpu.get_model_parallel_world_size() == 1:
return input_

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
input_ = input_.bfloat16()

return input_


def mark_norms_for_sequence_parallel_grad_sync(module, neox_args):
"""Iterate through the modules in our model, and for any "...Norm" classnames,
register a hook on each of that module's parameters which will allreduce norms' weights' grads across
the model (sequence) parallel region.
"""

if not neox_args.sequence_parallel:
# if we aren't using sequence parallelism, this is a no-op
return

for module_ in module.modules():
if "norm" in type(module_).__name__.lower():
# this is a norm, we want to allreduce its weight grads across sequence parallel region
for name, param in module_.named_parameters():
if param.requires_grad:
param.register_hook(reduce_weight_grads_from_model_parallel_region)
10 changes: 10 additions & 0 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes

self.sequence_parallel = (
neox_args.sequence_parallel
) # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks

self.use_mup = neox_args.use_mup
self.mup_embedding_mult = neox_args.mup_embedding_mult
self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult
Expand Down Expand Up @@ -159,6 +164,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
with torch.no_grad():
embeddings.mul_(self.mup_embedding_mult)

if self.sequence_parallel:
# TODO: megatron-lm does dropout using the scattered embs. This would save a tiny bit of time, perhaps?
# Not a priority since we don't often use dropout
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)

return embeddings


Expand Down
3 changes: 3 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import scatter_to_sequence_parallel_region

from .random import checkpoint
from .random import get_cuda_rng_tracker
Expand Down
39 changes: 36 additions & 3 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import VocabUtility
Expand Down Expand Up @@ -416,6 +418,7 @@ def __init__(
MOE=False,
MoE_mp_size=1,
mup_rescale_parameters=False,
seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout.
):
super(ColumnParallelLinear, self).__init__()

Expand All @@ -427,6 +430,10 @@ def __init__(
world_size = MoE_mp_size if MOE else get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add

self.sequence_parallel = neox_args.sequence_parallel
self.seq_dim = seq_dim

self.init_method = init_method
self.stride = stride
self.mup_rescale_parameters = mup_rescale_parameters
Expand Down Expand Up @@ -551,14 +558,29 @@ def set_parallel_output(self, value: bool):
def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)

if self.sequence_parallel:
input_parallel = input_
else:
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.

if self.sequence_parallel:
# do an AG in the fwd pass, RS in bwd pass.
# gather / scatter portion happens across the sequence dim (self.seq_dim)--
# almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h]
input_parallel = gather_from_sequence_parallel_region(
input_parallel, seq_dim=self.seq_dim
)

bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
assert (
not self.sequence_parallel
), "sequence_parallel=True and gather_output=True are incompatible!"
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
Expand Down Expand Up @@ -623,6 +645,12 @@ def __init__(
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.parallel_output = parallel_output

self.sequence_parallel = neox_args.sequence_parallel
assert not (
self.sequence_parallel and not self.input_is_parallel
), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True."

self.init_method = init_method
self.stride = stride
self.keep_master_weight_for_test = keep_master_weight_for_test
Expand Down Expand Up @@ -748,7 +776,12 @@ def forward(self, input_):
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
if not self.parallel_output:
if self.sequence_parallel and not self.parallel_output:
# do an RS in the fwd pass, AG in bwd pass.
# skip in the gpt-j parallel sublayer case (self.parallel_output=True)
# (user responsible for calling reduce-scatter)
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
elif not self.parallel_output:
output_ = reduce_from_model_parallel_region(output_parallel)
else:
output_ = output_parallel
Expand Down
Loading

0 comments on commit 8b43196

Please sign in to comment.