Skip to content

Commit

Permalink
Merge branch 'mike/groups_change_4' into 'main'
Browse files Browse the repository at this point in the history
Allow Encoder to Have Different TP Size

See merge request ADLR/megatron-lm!1593
  • Loading branch information
ericharper committed Aug 5, 2024
2 parents 3fd0c44 + 8af3dae commit 2fd6e2b
Show file tree
Hide file tree
Showing 17 changed files with 477 additions and 137 deletions.
12 changes: 9 additions & 3 deletions examples/multimodal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def model_provider(

vision_config = deepcopy(base_config)
vision_config = get_vision_model_config(vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling)
if args.pipeline_model_parallel_size > 1:
assert args.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage."
vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size

if use_te:
vision_transformer_layer_spec = get_layer_spec_te(is_vit=True)
Expand All @@ -82,6 +79,15 @@ def model_provider(

vision_projection_config = deepcopy(base_config)
vision_projection_config = get_vision_projection_config(vision_projection_config, language_config.hidden_size)

if args.encoder_pipeline_model_parallel_size > 0:
assert args.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage."
vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size
vision_projection_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size
if args.encoder_tensor_model_parallel_size > 0:
vision_transformer_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
vision_projection_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size

vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules

model = LLaVAModel(
Expand Down
26 changes: 21 additions & 5 deletions megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,29 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:

# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
torch.distributed.broadcast(
num_tokens,
src=parallel_state.get_pipeline_model_parallel_last_rank(),
group=parallel_state.get_pipeline_model_parallel_group(),
)
last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
pp_group = parallel_state.get_pipeline_model_parallel_group()

if not isinstance(last_rank, list):
assert not isinstance(last_rank, list)
last_rank = [last_rank]
assert not isinstance(pp_group, list)
pp_group = [pp_group]

# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list = []
for lr, group in zip(last_rank, pp_group):
torch.distributed.broadcast(
num_tokens,
src=lr,
group=group,
)
num_tokens_list.append(torch.clone(num_tokens))
assert all(x.item() == num_tokens_list[0] for x in num_tokens_list)

# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
Expand Down
43 changes: 43 additions & 0 deletions megatron/core/models/multimodal/llava_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

try:
import apex

from megatron.core.fusions.fused_layer_norm import FusedLayerNorm

HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings

from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm

warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm


def decoder_model_with_transformer_engine_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
Expand Down Expand Up @@ -54,3 +69,31 @@ def decoder_model_with_transformer_engine_default_spec(
mlp_bda=get_bias_dropout_add,
),
)


def decoder_model_with_local_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder local spec."""
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
43 changes: 42 additions & 1 deletion megatron/core/models/vision/vit_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,28 @@
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

try:
import apex

from megatron.core.fusions.fused_layer_norm import FusedLayerNorm

HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings

from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm

warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm


# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:
Expand All @@ -40,8 +56,33 @@ def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:
)


def get_vit_layer_with_local_spec() -> ModuleSpec:
mlp = _get_mlp_module_spec(use_te=False)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)


# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(use_te: bool = True,) -> ModuleSpec:
def _get_mlp_module_spec(
use_te: bool = True,
) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
Expand Down
Loading

0 comments on commit 2fd6e2b

Please sign in to comment.