diff --git a/metaseq/models/transformer_decoder.py b/metaseq/models/transformer_decoder.py index 74d1da5a6..7a2b0c69b 100644 --- a/metaseq/models/transformer_decoder.py +++ b/metaseq/models/transformer_decoder.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import math +import logging from typing import Any, Dict, List, Optional import torch @@ -20,10 +20,21 @@ LayerNorm, PositionalEmbedding, TransformerDecoderLayer, + ModelParallelTransformerDecoderLayer, Linear, ) from metaseq.modules.checkpoint_activations import checkpoint_wrapper +try: + from megatron import mpu + from megatron.mpu import ( + gather_from_tensor_model_parallel_region, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + logger = logging.getLogger(__name__) @@ -497,3 +508,58 @@ def buffered_future_mask(self, tensor, input_tokens=None): return self._future_mask else: return self._future_mask[:cur_seq_len, :cur_seq_len] + + +class ModelParallelTransformerDecoder(TransformerDecoder): + """ + Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerDecoderLayer`. + """ + + def build_base_decoder_layer(self, args, **kwargs): + return ModelParallelTransformerDecoderLayer(args) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if not self.share_input_output_embed: + raise NotImplementedError( + "Model parallel training currently requires --share-decoder-input-output-embed" + ) + + is_sequence_parallel = getattr(self.args, "sequence_parallel", False) + if is_sequence_parallel: + input_parallel = features + else: + input_parallel = mpu.copy_to_tensor_model_parallel_region(features) + + # project back to size of vocabulary + x = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( + input_parallel, + self.output_projection.weight, + None, + False, # gradient_accumulation_fusion + False, # async_grad_allreduce + is_sequence_parallel, # sequence_parallel + ) + # Gather output if model is in inference mode (i.e. evallm or generation) cause both are not yet compatible with + # parallel vocab embeddings + if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy" or getattr( + self, "inference", False + ): + x = gather_from_tensor_model_parallel_region(x).contiguous() + + return x + + # This hook used as proxy for tracking state if model is in eval or generation mode. + def make_generation_fast_(self, **unused): + self.inference = True + + def forward_embedding( + self, + *args, + ): + x, embed, positions = super().forward_embedding(*args) + is_sequence_parallel = getattr(self.args, "sequence_parallel", False) + if is_sequence_parallel: + x = mpu.scatter_to_sequence_parallel_region(x) + return x, embed, positions diff --git a/metaseq/models/transformer_decoder_mp.py b/metaseq/models/transformer_decoder_mp.py deleted file mode 100644 index a6310434b..000000000 --- a/metaseq/models/transformer_decoder_mp.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import logging - -from metaseq.modules import ( - ModelParallelTransformerDecoderLayer, -) -from metaseq.models.transformer_decoder import TransformerDecoder - -try: - from megatron import mpu - from megatron.mpu import ( - gather_from_tensor_model_parallel_region, - ) - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -logger = logging.getLogger(__name__) - - -class ModelParallelTransformerDecoder(TransformerDecoder): - """ - Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer - is a :class:`ModelParallelTransformerDecoderLayer`. - """ - - def build_base_decoder_layer(self, args, **kwargs): - return ModelParallelTransformerDecoderLayer(args) - - def output_layer(self, features, **kwargs): - """Project features to the vocabulary size.""" - if not self.share_input_output_embed: - raise NotImplementedError( - "Model parallel training currently requires --share-decoder-input-output-embed" - ) - - is_sequence_parallel = getattr(self.args, "sequence_parallel", False) - if is_sequence_parallel: - input_parallel = features - else: - input_parallel = mpu.copy_to_tensor_model_parallel_region(features) - - # project back to size of vocabulary - x = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( - input_parallel, - self.output_projection.weight, - None, - False, # gradient_accumulation_fusion - False, # async_grad_allreduce - is_sequence_parallel, # sequence_parallel - ) - # Gather output if model is in inference mode (i.e. evallm or generation) cause both are not yet compatible with - # parallel vocab embeddings - if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy" or getattr( - self, "inference", False - ): - x = gather_from_tensor_model_parallel_region(x).contiguous() - - return x - - # This hook used as proxy for tracking state if model is in eval or generation mode. - def make_generation_fast_(self, **unused): - self.inference = True - - def forward_embedding( - self, - *args, - ): - x, embed, positions = super().forward_embedding(*args) - is_sequence_parallel = getattr(self.args, "sequence_parallel", False) - if is_sequence_parallel: - x = mpu.scatter_to_sequence_parallel_region(x) - return x, embed, positions diff --git a/metaseq/models/transformer_lm.py b/metaseq/models/transformer_lm.py index d6dc12929..963ec192b 100644 --- a/metaseq/models/transformer_lm.py +++ b/metaseq/models/transformer_lm.py @@ -3,8 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn -import logging +try: + from megatron.mpu import VocabParallelEmbedding + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False from dataclasses import dataclass, field from typing import Optional @@ -23,12 +30,14 @@ from metaseq.models.transformer_decoder import ( DEFAULT_MIN_PARAMS_TO_WRAP, TransformerDecoder, + ModelParallelTransformerDecoder, ) from metaseq.modules.embedding import Embedding from metaseq.modules.activation_functions import get_available_activation_fns -DEFAULT_MAX_TARGET_POSITIONS = 1024 +import logging +DEFAULT_MAX_TARGET_POSITIONS = 1024 logger = logging.getLogger(__name__) @@ -225,30 +234,134 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): ) +@register_model("model_parallel_transformer_lm") +class ModelParallelTransformerLanguageModel(TransformerLanguageModel): + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install megatron using the setup instructions!" + ) + + # make sure all arguments are present in older models + base_lm_architecture(args) + + task.source_dictionary.pad_to_multiple_(8) + task.target_dictionary.pad_to_multiple_(8) + + # task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + # task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_embed_dim + ) + assert getattr( + args, "use_sharded_state", False + ), "Use sharded state must be True for tensor parallel, otherwise model saving and loaded might be broken" + + if getattr(args, "sequence_parallel", False): + assert ( + getattr(args, "model_parallel_size", 1) > 1 + ), "--sequence-parallel only works when --model-parallel-size is greater than 1" + assert ( + getattr(args, "dropout", 0.0) == 0.0 + ), "havent yet tested if rng states are correct for dropout with seq_parallel" + assert ( + getattr(args, "activation_fn", "gelu") == "gelu" + or getattr(args, "activation_fn", "gelu") == "relu" + ), "For now only supports gelu and relu" + assert not getattr( + args, "checkpoint_activations", False + ), "Cannot set --checkpoint-activations with sequence parallel." + assert not getattr( + args, "distribute_checkpointed_activations", False + ), "Cannot set --distribute-checkpointed-activations with sequence parallel." + + decoder = ModelParallelTransformerDecoder( + args, + task.target_dictionary, + embed_tokens, + ) + return cls(decoder) + + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + def _vocab_init(tensor, **kwargs): + std = embed_dim**-0.5 + if getattr(args, "truncate_init", False): + nn.init.trunc_normal_(tensor, mean=0, std=std, a=-3 * std, b=3 * std) + else: + nn.init.normal_(tensor, mean=0, std=std) + nn.init.constant_(tensor[1], 0) + + def _vocab_init_megatron(tensor, **kwargs): + std = getattr(args, "megatron_init_sigma", 0.006) + if getattr(args, "truncate_init", False): + nn.init.trunc_normal_(tensor, mean=0, std=std, a=-3 * std, b=3 * std) + else: + nn.init.normal_(tensor, mean=0, std=std) + nn.init.constant_(tensor[1], 0) + + if getattr(args, "memory_efficient_fp16", False): + dtype = torch.bfloat16 if getattr(args, "bf16", False) else torch.half + else: + dtype = torch.float32 + + embed_tokens = VocabParallelEmbedding( + len(dictionary), + embed_dim, + dictionary.pad(), + init_method=_vocab_init_megatron + if getattr(args, "full_megatron_init", False) + else _vocab_init, + use_cpu_initialization=not getattr( + args, "tensor_parallel_init_model_on_gpu", False + ), + dtype=dtype, + ) + return embed_tokens + + def base_lm_architecture(args): + args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.0) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.decoder_learned_sinusoidal = getattr(args, "decoder_learned_sinusoidal", False) - args.activation_fn = getattr(args, "activation_fn", "relu") - - args.add_bos_token = getattr(args, "add_bos_token", False) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_learned_sinusoidal = getattr(args, "decoder_learned_sinusoidal", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) - args.checkpoint_activations = getattr(args, "checkpoint_activations", False) - args.offload_activations = getattr(args, "offload_activations", False) - if args.offload_activations: - args.checkpoint_activations = True + args.add_bos_token = getattr(args, "add_bos_token", False) -@register_model_architecture("transformer_lm", "transformer_lm_gpt") +@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron") +def transformer_lm_megatron(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_gpt") def transformer_lm_gpt(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072) @@ -260,7 +373,9 @@ def transformer_lm_gpt(args): base_lm_architecture(args) -@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") +@register_model_architecture( + "model_parallel_transformer_lm", "transformer_lm_gpt2_tiny" +) def transformer_lm_gpt2_tiny(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64) diff --git a/metaseq/models/transformer_lm_mp.py b/metaseq/models/transformer_lm_mp.py deleted file mode 100644 index 04306465a..000000000 --- a/metaseq/models/transformer_lm_mp.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from metaseq.models.transformer_decoder_mp import ( - ModelParallelTransformerDecoder, -) -from metaseq.models import register_model, register_model_architecture -from metaseq.models.transformer_lm import TransformerLanguageModel - - -try: - from megatron.mpu import VocabParallelEmbedding - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -DEFAULT_MAX_TARGET_POSITIONS = 1024 - - -@register_model("model_parallel_transformer_lm") -class ModelParallelTransformerLanguageModel(TransformerLanguageModel): - @classmethod - def build_model(cls, args, task): - """Build a new model instance.""" - if not has_megatron_submodule: - raise ImportError( - "\n\nPlease install megatron using the setup instructions!" - ) - - # make sure all arguments are present in older models - base_lm_architecture(args) - - task.source_dictionary.pad_to_multiple_(8) - task.target_dictionary.pad_to_multiple_(8) - - # task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) - # task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) - - if getattr(args, "max_target_positions", None) is None: - args.max_target_positions = getattr( - args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS - ) - - embed_tokens = cls.build_embedding( - args, task.source_dictionary, args.decoder_embed_dim - ) - assert getattr( - args, "use_sharded_state", False - ), "Use sharded state must be True for tensor parallel, otherwise model saving and loaded might be broken" - - if getattr(args, "sequence_parallel", False): - assert ( - getattr(args, "model_parallel_size", 1) > 1 - ), "--sequence-parallel only works when --model-parallel-size is greater than 1" - assert ( - getattr(args, "dropout", 0.0) == 0.0 - ), "havent yet tested if rng states are correct for dropout with seq_parallel" - assert ( - getattr(args, "activation_fn", "gelu") == "gelu" - or getattr(args, "activation_fn", "gelu") == "relu" - ), "For now only supports gelu and relu" - assert not getattr( - args, "checkpoint_activations", False - ), "Cannot set --checkpoint-activations with sequence parallel." - assert not getattr( - args, "distribute_checkpointed_activations", False - ), "Cannot set --distribute-checkpointed-activations with sequence parallel." - - decoder = ModelParallelTransformerDecoder( - args, - task.target_dictionary, - embed_tokens, - ) - return cls(decoder) - - @staticmethod - def add_args(parser): - TransformerLanguageModel.add_args(parser) - - @classmethod - def build_embedding(cls, args, dictionary, embed_dim, path=None): - def _vocab_init(tensor, **kwargs): - std = embed_dim**-0.5 - if getattr(args, "truncate_init", False): - nn.init.trunc_normal_(tensor, mean=0, std=std, a=-3 * std, b=3 * std) - else: - nn.init.normal_(tensor, mean=0, std=std) - nn.init.constant_(tensor[1], 0) - - def _vocab_init_megatron(tensor, **kwargs): - std = getattr(args, "megatron_init_sigma", 0.006) - if getattr(args, "truncate_init", False): - nn.init.trunc_normal_(tensor, mean=0, std=std, a=-3 * std, b=3 * std) - else: - nn.init.normal_(tensor, mean=0, std=std) - nn.init.constant_(tensor[1], 0) - - if getattr(args, "memory_efficient_fp16", False): - dtype = torch.bfloat16 if getattr(args, "bf16", False) else torch.half - else: - dtype = torch.float32 - - embed_tokens = VocabParallelEmbedding( - len(dictionary), - embed_dim, - dictionary.pad(), - init_method=_vocab_init_megatron - if getattr(args, "full_megatron_init", False) - else _vocab_init, - use_cpu_initialization=not getattr( - args, "tensor_parallel_init_model_on_gpu", False - ), - dtype=dtype, - ) - return embed_tokens - - -def base_lm_architecture(args): - args.activation_fn = getattr(args, "activation_fn", "relu") - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.0) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) - args.decoder_layers = getattr(args, "decoder_layers", 6) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) - args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", False - ) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.decoder_learned_sinusoidal = getattr(args, "decoder_learned_sinusoidal", False) - args.no_scale_embedding = getattr(args, "no_scale_embedding", False) - args.add_bos_token = getattr(args, "add_bos_token", False) - - -@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron") -def transformer_lm_megatron(args): - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4) - args.decoder_layers = getattr(args, "decoder_layers", 72) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_fn = getattr(args, "activation_fn", "gelu") - base_lm_architecture(args) diff --git a/metaseq/modules/__init__.py b/metaseq/modules/__init__.py index d5d853eee..cfdf6dc45 100644 --- a/metaseq/modules/__init__.py +++ b/metaseq/modules/__init__.py @@ -14,8 +14,10 @@ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .linear import Linear from .feedforward_network import FeedForwardNetwork -from .transformer_decoder_layer import TransformerDecoderLayer -from .transformer_decoder_layer_mp import ModelParallelTransformerDecoderLayer +from .transformer_decoder_layer import ( + TransformerDecoderLayer, + ModelParallelTransformerDecoderLayer, +) from .sequence_parallel_transformer_layer import SequeuceParallelTransformerBlock __all__ = [ diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py index 9a96d9dbc..07f677a28 100644 --- a/metaseq/modules/transformer_decoder_layer.py +++ b/metaseq/modules/transformer_decoder_layer.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict, Optional - +import math import torch import torch.nn as nn from torch import Tensor @@ -13,6 +13,7 @@ from metaseq.modules import ( ActivationFn, MultiheadAttention, + ModelParallelMultiheadAttention, Dropout, FeedForwardNetwork, LayerNorm, @@ -23,6 +24,20 @@ load_megatron_fused_kernel, ) +try: + from megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +def _weight_init(weight): + return nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + class TransformerDecoderLayer(nn.Module): """Pre-norm Decoder layer block. @@ -247,3 +262,156 @@ def forward( def make_generation_fast_(self, **kwargs): pass + + +class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer): + """Decoder layer block. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + # TODO[susanz]: unify method signatures with non-model-parallel version. + def build_fc1( + self, + input_dim, + output_dim, + initialize_params_on_gpu, + full_megatron_init, + megatron_init_sigma, + dtype, + disable_bias=False, + truncate_init=False, + ): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install megatron using the setup instructions!" + ) + + def _init_method_bias(bias): + fan_in = input_dim + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(bias, -bound, bound) + + if full_megatron_init: + # Setting bias init method to None, initializes biases with zero. + init_method_weights = utils.init_method_normal( + megatron_init_sigma, truncate_init=truncate_init + ) + init_method_bias = None + else: + init_method_weights = _weight_init + init_method_bias = _init_method_bias + + return ColumnParallelLinear( + input_dim, + output_dim, + gather_output=False, + init_method=init_method_weights, + skip_bias_add=self.skip_bias_add, + init_method_bias=init_method_bias, + use_cpu_initialization=not initialize_params_on_gpu, + dtype=dtype, + bias=not disable_bias, + ) + + # TODO[susanz]: unify method signatures with non-model-parallel version. + # Note: only fc2 includes the full_megatron_init_scalar arg, given the scaled_init_method_normal call. + def build_fc2( + self, + input_dim, + output_dim, + initialize_params_on_gpu, + full_megatron_init, + full_megatron_init_scalar, + megatron_init_sigma, + num_layers, + dtype, + disable_bias=False, + truncate_init=False, + ): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install megatron using the setup instructions!" + ) + + skip_bias_add = self.skip_bias_add + if full_megatron_init: + init_method_weights = utils.scaled_init_method_normal( + megatron_init_sigma * full_megatron_init_scalar, + num_layers, + truncate_init=truncate_init, + ) + else: + init_method_weights = _weight_init + + fc2 = RowParallelLinear( + input_dim, + output_dim, + input_is_parallel=True, + init_method=init_method_weights, + skip_bias_add=skip_bias_add, + use_cpu_initialization=not initialize_params_on_gpu, + bias=not disable_bias, + dtype=dtype, + ) + if not full_megatron_init: + # Copy nn.linear initialization to get same initialization as of non-model-parallel. + # fan_in, _ = nn.init._calculate_fan_in_and_fan_out(fc2.weight) + fan_in = input_dim + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(fc2.bias, -bound, bound) + return fc2 + + def build_self_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + use_cpu_initialization=not getattr( + args, "tensor_parallel_init_model_on_gpu", False + ), + full_megatron_init=getattr(args, "full_megatron_init", False), + full_megatron_init_scalar=getattr(args, "full_megatron_init_scalar", 1.0), + megatron_init_sigma=getattr(args, "megatron_init_sigma", 0.006), + num_layers=args.decoder_layers, + dtype=utils.get_model_init_dtype(args), + bias=not getattr(args, "disable_bias", False), + attn_variant=getattr(args, "attn_variant", "default"), + xf_attn_op=getattr(args, "xf_attn_op", None), + truncate_init=getattr(args, "truncate_init", None), + ) + + def forward_attention( + self, + query, + key, + value, + residual, + key_padding_mask=None, + incremental_state=None, + attn_mask=None, + ): + # This is calling into ModelParallelMultiheadAttention.forward + attn_output, attn_bias = self.self_attn( + query=query, + key=key, + value=value, + key_padding_mask=key_padding_mask, + incremental_state=incremental_state, + attn_mask=attn_mask, + ) + # Note [naman]: got rid off fused bias, dropout and residual cause + # now we dont use dropout. And we dont use jit scripting also cause + # it seems to use additional gpu memory for activations for dropout + # even when its disabled. + if attn_bias is not None: + attn_output = attn_output + attn_bias.view(1, 1, -1) + + x = torch.nn.functional.dropout( + attn_output, + p=self.args.dropout, + training=self.training, + ) + x = x + residual + return x diff --git a/metaseq/modules/transformer_decoder_layer_mp.py b/metaseq/modules/transformer_decoder_layer_mp.py deleted file mode 100644 index 845625f80..000000000 --- a/metaseq/modules/transformer_decoder_layer_mp.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import torch -from torch import nn - -from metaseq import utils -from metaseq.modules import ModelParallelMultiheadAttention -from metaseq.modules import TransformerDecoderLayer - -try: - from megatron.mpu import ( - ColumnParallelLinear, - RowParallelLinear, - ) - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -def _weight_init(weight): - return nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - - -class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer): - """Decoder layer block. - - See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. - """ - - # TODO[susanz]: unify method signatures with non-model-parallel version. - def build_fc1( - self, - input_dim, - output_dim, - initialize_params_on_gpu, - full_megatron_init, - megatron_init_sigma, - dtype, - disable_bias=False, - truncate_init=False, - ): - if not has_megatron_submodule: - raise ImportError( - "\n\nPlease install megatron using the setup instructions!" - ) - - def _init_method_bias(bias): - fan_in = input_dim - bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(bias, -bound, bound) - - if full_megatron_init: - # Setting bias init method to None, initializes biases with zero. - init_method_weights = utils.init_method_normal( - megatron_init_sigma, truncate_init=truncate_init - ) - init_method_bias = None - else: - init_method_weights = _weight_init - init_method_bias = _init_method_bias - - return ColumnParallelLinear( - input_dim, - output_dim, - gather_output=False, - init_method=init_method_weights, - skip_bias_add=self.skip_bias_add, - init_method_bias=init_method_bias, - use_cpu_initialization=not initialize_params_on_gpu, - dtype=dtype, - bias=not disable_bias, - ) - - # TODO[susanz]: unify method signatures with non-model-parallel version. - # Note: only fc2 includes the full_megatron_init_scalar arg, given the scaled_init_method_normal call. - def build_fc2( - self, - input_dim, - output_dim, - initialize_params_on_gpu, - full_megatron_init, - full_megatron_init_scalar, - megatron_init_sigma, - num_layers, - dtype, - disable_bias=False, - truncate_init=False, - ): - if not has_megatron_submodule: - raise ImportError( - "\n\nPlease install megatron using the setup instructions!" - ) - - skip_bias_add = self.skip_bias_add - if full_megatron_init: - init_method_weights = utils.scaled_init_method_normal( - megatron_init_sigma * full_megatron_init_scalar, - num_layers, - truncate_init=truncate_init, - ) - else: - init_method_weights = _weight_init - - fc2 = RowParallelLinear( - input_dim, - output_dim, - input_is_parallel=True, - init_method=init_method_weights, - skip_bias_add=skip_bias_add, - use_cpu_initialization=not initialize_params_on_gpu, - bias=not disable_bias, - dtype=dtype, - ) - if not full_megatron_init: - # Copy nn.linear initialization to get same initialization as of non-model-parallel. - # fan_in, _ = nn.init._calculate_fan_in_and_fan_out(fc2.weight) - fan_in = input_dim - bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(fc2.bias, -bound, bound) - return fc2 - - def build_self_attention(self, embed_dim, args, **unused_kwargs): - return ModelParallelMultiheadAttention( - embed_dim=embed_dim, - num_heads=args.decoder_attention_heads, - dropout=args.attention_dropout, - self_attention=True, - use_cpu_initialization=not getattr( - args, "tensor_parallel_init_model_on_gpu", False - ), - full_megatron_init=getattr(args, "full_megatron_init", False), - full_megatron_init_scalar=getattr(args, "full_megatron_init_scalar", 1.0), - megatron_init_sigma=getattr(args, "megatron_init_sigma", 0.006), - num_layers=args.decoder_layers, - dtype=utils.get_model_init_dtype(args), - bias=not getattr(args, "disable_bias", False), - attn_variant=getattr(args, "attn_variant", "default"), - xf_attn_op=getattr(args, "xf_attn_op", None), - truncate_init=getattr(args, "truncate_init", None), - ) - - def forward_attention( - self, - query, - key, - value, - residual, - key_padding_mask=None, - incremental_state=None, - attn_mask=None, - ): - # This is calling into ModelParallelMultiheadAttention.forward - attn_output, attn_bias = self.self_attn( - query=query, - key=key, - value=value, - key_padding_mask=key_padding_mask, - incremental_state=incremental_state, - attn_mask=attn_mask, - ) - # Note [naman]: got rid off fused bias, dropout and residual cause - # now we dont use dropout. And we dont use jit scripting also cause - # it seems to use additional gpu memory for activations for dropout - # even when its disabled. - if attn_bias is not None: - attn_output = attn_output + attn_bias.view(1, 1, -1) - - x = torch.nn.functional.dropout( - attn_output, - p=self.args.dropout, - training=self.training, - ) - x = x + residual - return x