Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Unify-transformer_lm_megatron-and-transformer_lm (#634)
Browse files Browse the repository at this point in the history
* move to transformer_lm_mp.py

* remove transformer_lm_mp

* add arch transformer_lm_gpt and transformer_lm_gpt2_tiny

* update arch to model_parallel_transformer_lm

* rename to transformer_lm.py

* remove transformer_decoder_mp.py

* remove transformer_decoder_layer_mp.py

* fix lint errors

---------

Co-authored-by: Nikolay Bashlykov <[email protected]>
  • Loading branch information
bashnick and Nikolay Bashlykov authored Feb 4, 2023
1 parent 5c819d7 commit 51871bd
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 427 deletions.
68 changes: 67 additions & 1 deletion metaseq/models/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import logging
from typing import Any, Dict, List, Optional

import torch
Expand All @@ -20,10 +20,21 @@
LayerNorm,
PositionalEmbedding,
TransformerDecoderLayer,
ModelParallelTransformerDecoderLayer,
Linear,
)
from metaseq.modules.checkpoint_activations import checkpoint_wrapper

try:
from megatron import mpu
from megatron.mpu import (
gather_from_tensor_model_parallel_region,
)

has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -497,3 +508,58 @@ def buffered_future_mask(self, tensor, input_tokens=None):
return self._future_mask
else:
return self._future_mask[:cur_seq_len, :cur_seq_len]


class ModelParallelTransformerDecoder(TransformerDecoder):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerDecoderLayer`.
"""

def build_base_decoder_layer(self, args, **kwargs):
return ModelParallelTransformerDecoderLayer(args)

def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if not self.share_input_output_embed:
raise NotImplementedError(
"Model parallel training currently requires --share-decoder-input-output-embed"
)

is_sequence_parallel = getattr(self.args, "sequence_parallel", False)
if is_sequence_parallel:
input_parallel = features
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(features)

# project back to size of vocabulary
x = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel,
self.output_projection.weight,
None,
False, # gradient_accumulation_fusion
False, # async_grad_allreduce
is_sequence_parallel, # sequence_parallel
)
# Gather output if model is in inference mode (i.e. evallm or generation) cause both are not yet compatible with
# parallel vocab embeddings
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy" or getattr(
self, "inference", False
):
x = gather_from_tensor_model_parallel_region(x).contiguous()

return x

# This hook used as proxy for tracking state if model is in eval or generation mode.
def make_generation_fast_(self, **unused):
self.inference = True

def forward_embedding(
self,
*args,
):
x, embed, positions = super().forward_embedding(*args)
is_sequence_parallel = getattr(self.args, "sequence_parallel", False)
if is_sequence_parallel:
x = mpu.scatter_to_sequence_parallel_region(x)
return x, embed, positions
79 changes: 0 additions & 79 deletions metaseq/models/transformer_decoder_mp.py

This file was deleted.

143 changes: 129 additions & 14 deletions metaseq/models/transformer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

import logging
try:
from megatron.mpu import VocabParallelEmbedding

has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False

from dataclasses import dataclass, field
from typing import Optional
Expand All @@ -23,12 +30,14 @@
from metaseq.models.transformer_decoder import (
DEFAULT_MIN_PARAMS_TO_WRAP,
TransformerDecoder,
ModelParallelTransformerDecoder,
)
from metaseq.modules.embedding import Embedding
from metaseq.modules.activation_functions import get_available_activation_fns

DEFAULT_MAX_TARGET_POSITIONS = 1024
import logging

DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -225,30 +234,134 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None):
)


@register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install megatron using the setup instructions!"
)

# make sure all arguments are present in older models
base_lm_architecture(args)

task.source_dictionary.pad_to_multiple_(8)
task.target_dictionary.pad_to_multiple_(8)

# task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
# task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)

if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)

embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
assert getattr(
args, "use_sharded_state", False
), "Use sharded state must be True for tensor parallel, otherwise model saving and loaded might be broken"

if getattr(args, "sequence_parallel", False):
assert (
getattr(args, "model_parallel_size", 1) > 1
), "--sequence-parallel only works when --model-parallel-size is greater than 1"
assert (
getattr(args, "dropout", 0.0) == 0.0
), "havent yet tested if rng states are correct for dropout with seq_parallel"
assert (
getattr(args, "activation_fn", "gelu") == "gelu"
or getattr(args, "activation_fn", "gelu") == "relu"
), "For now only supports gelu and relu"
assert not getattr(
args, "checkpoint_activations", False
), "Cannot set --checkpoint-activations with sequence parallel."
assert not getattr(
args, "distribute_checkpointed_activations", False
), "Cannot set --distribute-checkpointed-activations with sequence parallel."

decoder = ModelParallelTransformerDecoder(
args,
task.target_dictionary,
embed_tokens,
)
return cls(decoder)

@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)

@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):
std = embed_dim**-0.5
if getattr(args, "truncate_init", False):
nn.init.trunc_normal_(tensor, mean=0, std=std, a=-3 * std, b=3 * std)
else:
nn.init.normal_(tensor, mean=0, std=std)
nn.init.constant_(tensor[1], 0)

def _vocab_init_megatron(tensor, **kwargs):
std = getattr(args, "megatron_init_sigma", 0.006)
if getattr(args, "truncate_init", False):
nn.init.trunc_normal_(tensor, mean=0, std=std, a=-3 * std, b=3 * std)
else:
nn.init.normal_(tensor, mean=0, std=std)
nn.init.constant_(tensor[1], 0)

if getattr(args, "memory_efficient_fp16", False):
dtype = torch.bfloat16 if getattr(args, "bf16", False) else torch.half
else:
dtype = torch.float32

embed_tokens = VocabParallelEmbedding(
len(dictionary),
embed_dim,
dictionary.pad(),
init_method=_vocab_init_megatron
if getattr(args, "full_megatron_init", False)
else _vocab_init,
use_cpu_initialization=not getattr(
args, "tensor_parallel_init_model_on_gpu", False
),
dtype=dtype,
)
return embed_tokens


def base_lm_architecture(args):
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)

args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.decoder_learned_sinusoidal = getattr(args, "decoder_learned_sinusoidal", False)
args.activation_fn = getattr(args, "activation_fn", "relu")

args.add_bos_token = getattr(args, "add_bos_token", False)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.decoder_learned_sinusoidal = getattr(args, "decoder_learned_sinusoidal", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.add_bos_token = getattr(args, "add_bos_token", False)


@register_model_architecture("transformer_lm", "transformer_lm_gpt")
@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron")
def transformer_lm_megatron(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)


@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_gpt")
def transformer_lm_gpt(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072)
Expand All @@ -260,7 +373,9 @@ def transformer_lm_gpt(args):
base_lm_architecture(args)


@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny")
@register_model_architecture(
"model_parallel_transformer_lm", "transformer_lm_gpt2_tiny"
)
def transformer_lm_gpt2_tiny(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64)
Expand Down
Loading

0 comments on commit 51871bd

Please sign in to comment.