diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7de35027a..249255306 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: hooks: - id: codespell args: [ - '--ignore-words-list=reord,dout', # Word used in error messages that need rewording + '--ignore-words-list=reord,dout,te', # Word used in error messages that need rewording. te --> transformerengine --check-filenames, --check-hidden, ] diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 413138597..d24b2b60a 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 53d0ae8 + Default = 217b4c5 current git hash of repository @@ -335,11 +335,11 @@ Model Arguments -- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm'] +- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm', 'te_rmsnorm', 'te_layernorm'] Default = layernorm - Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm". + Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm". diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index d33ded506..3fd251147 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -21,7 +21,10 @@ except: HAVE_PERSIST_LAYER_NORM = False -from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction +from apex.normalization.fused_layer_norm import ( + FusedLayerNormAffineFunction, + FusedRMSNormAffineFunction, +) global fused_layer_norm_cuda @@ -148,3 +151,112 @@ def forward(self, input): ) return output + + +class MixedFusedRMSNorm(torch.nn.Module): + def __init__( + self, + normalized_shape, + eps=1e-5, + no_persist_layer_norm=True, + sequence_parallel=False, + apply_rmsnorm_1p=False, + mem_efficient_rms=True, + ): + super(MixedFusedRMSNorm, self).__init__() + + self.apply_rmsnorm_1p = apply_rmsnorm_1p + self.mem_efficient_rms = mem_efficient_rms + self.norm_fn = FusedRMSNormAffineFunction + + global fused_layer_norm_cuda + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + # List of hiddens sizes supported in the persistent layer norm kernel + # If the hidden size is not supported, fall back to the non-persistent + # kernel. + persist_ln_hidden_sizes = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, + ] + if ( + normalized_shape not in persist_ln_hidden_sizes + or not HAVE_PERSIST_LAYER_NORM + ): + no_persist_layer_norm = True + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.scale = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + self.no_persist_layer_norm = no_persist_layer_norm + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.scale, "sequence_parallel", self.sequence_parallel) + + def reset_parameters(self): + + if self.apply_rmsnorm_1p: + init.zeros_(self.scale) + else: + init.ones_(self.scale) + + def forward(self, input): + + weight = self.scale + 1 if self.apply_rmsnorm_1p else self.scale + # CPU path is here for unittest sake. + if not input.is_cuda: + print( + "WARNING! The input of FusedLayerNorm should be on the GPU." + "This warning should only be triggered in the FusedRMSNorm unit tests." + ) + # Latest pytorch actually supports F.rms_norm but I don't want to break builds so... + return F.layer_norm(input, self.normalized_shape, weight, None, self.eps) + + # Apex does not have versions yet (https://github.com/NVIDIA/apex/pull/1648), so we need to inspect + # the function manually on whether the extra arg introduced in https://github.com/NVIDIA/apex/pull/1715 exists yet + if "memory_efficient" in inspect.getfullargspec(self.norm_fn.forward).args: + return self.norm_fn.apply( + input, + weight, + self.normalized_shape, + self.eps, + self.mem_efficient_rms, + ) + else: + return self.norm_fn.apply(input, weight, self.normalized_shape, self.eps) + + # Apex's fast layer norm function outputs a 'view' tensor (i.e., has + # a populated '_base' field). This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + output = make_viewless_tensor( + inp=output, requires_grad=input.requires_grad, keep_graph=True + ) + + return output diff --git a/megatron/model/norms.py b/megatron/model/norms.py index 19e1aeae6..ba175d3eb 100644 --- a/megatron/model/norms.py +++ b/megatron/model/norms.py @@ -18,8 +18,13 @@ def get_norm(neox_args): if neox_args.norm == "rmsnorm": - norm = RMSNorm eps = neox_args.rms_norm_epsilon + if neox_args.rmsnorm_fusion: + from .fused_layer_norm import MixedFusedRMSNorm + + norm = MixedFusedRMSNorm + else: + norm = RMSNorm elif neox_args.norm == "layernorm": eps = neox_args.layernorm_epsilon if neox_args.layernorm_fusion: @@ -31,6 +36,16 @@ def get_norm(neox_args): elif neox_args.norm == "scalenorm": eps = neox_args.scalenorm_epsilon norm = ScaleNorm + elif neox_args.norm == "te_rmsnorm": + from .transformer_engine import TERMSNorm + + norm = TERMSNorm + eps = neox_args.rms_norm_epsilon + elif neox_args.norm == "te_layernorm": + from .transformer_engine import TELayerNorm + + norm = TELayerNorm + eps = neox_args.layernorm_epsilon else: raise ValueError(f"norm {neox_args.norm} not recognized") return norm, eps diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py new file mode 100644 index 000000000..338513a97 --- /dev/null +++ b/megatron/model/transformer_engine.py @@ -0,0 +1,137 @@ +import torch + +try: + import transformer_engine as te +except ImportError: + raise ImportError( + "Unable to import transformer-engine. Please refer to " + "https://github.com/NVIDIA/TransformerEngine for installation instructions." + ) + + +class TERMSNorm(torch.nn.Module): + def __init__(self, dim, eps=1e-8, **kwargs): + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `RMSNorm` based on input + :param dim: model size + :param eps: epsilon value, default 1e-8 + """ + super(TERMSNorm, self).__init__() + + self.d = dim + self.eps = eps + self.norm = te.pytorch.RMSNorm( + hidden_size=self.d, + eps=self.eps, + **kwargs, + ) + + def forward(self, x): + return self.norm(x) + + +class TELayerNorm(torch.nn.Module): + def __init__(self, dim, eps=1.0e-5, **kwargs): + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `LayerNorm` based on input + :param dim: model size + :param eps: epsilon value, default 1.0e-5 + """ + super(TELayerNorm, self).__init__() + + self.d = dim + self.eps = eps + self.norm = te.pytorch.LayerNorm( + hidden_size=self.d, + eps=self.eps, + **kwargs, + ) + + def forward(self, x): + return self.norm(x) + + +class TELinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + """ + + def __init__(self): + # TODO + return + + def forward(self, x): + # TODO + return + + +class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): + """ + Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines + layernorm and linear layers + """ + + def __init__(self): + # TODO + return + + def forward(self, x): + # TODO + return + + +class TEColumnParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `ColumnParallelLinear` layer. + """ + + def __init__(self): + # TODO + return + + def forward(self, x): + # TODO + return + + +class TERowParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `RowParallelLinear` layer. + """ + + def __init__(self): + # TODO + return + + def forward(self, x): + # TODO + return + + +class TEDotProductAttention(te.pytorch.DotProductAttention): + """ + Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + has "flash attention" enabled. + """ + + def __init__(self): + # TODO + return + + def forward(self, x): + # TODO + return + + +class TEDelayedScaling(te.common.recipe.DelayedScaling): + """ + Wrapper for the Transformer-Engine's `DelayedScaling` layer. + """ + + def __init__(self): + # TODO + return diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 97b409c1d..77e7f521d 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -18,6 +18,7 @@ """Utilities for models.""" import torch +from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm from megatron.model.fused_softmax import SoftmaxFusionTypes from megatron import mpu from types import GeneratorType @@ -35,9 +36,17 @@ def get_params_for_weight_decay_optimization(module, neox_args): "name": "no_weight_decay_params", } for module_ in module.modules(): - # apply weight decay to any "...Norm" modules. - if "norm" in type(module_).__name__.lower() or neox_args.weight_decay == 0.0: - # also include all parameters here if no weight decay is being done + if any( + [ + isinstance(module_, LayerNorm), + isinstance(module_, RMSNorm), + isinstance(module_, TELayerNorm), + isinstance(module_, TERMSNorm), + isinstance(module_, ScaleNorm), + ] + ) or ( + neox_args.weight_decay == 0.0 + ): # also include all parameters here if no weight decay is being done no_weight_decay_params["params"].extend( [p for p in list(module_._parameters.values()) if p is not None] ) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 814622a5b..b5e7a619d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -162,9 +162,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Maximum number of position embeddings to use. This is the size of position embedding. """ - norm: Literal["layernorm", "rmsnorm", "scalenorm"] = "layernorm" + norm: Literal[ + "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm" + ] = "layernorm" """ - Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm". + Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm". """ layernorm_fusion: bool = False @@ -172,6 +174,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Use fused layer norm kernel (if `norm` is `layernorm`). """ + rmsnorm_fusion: bool = False + """ + Use fused RMS norm kernel (if `norm` is `rmsnorm`). + """ + use_qk_layernorm: bool = False """ Use QK Normalization diff --git a/requirements/requirements-transformerengine.txt b/requirements/requirements-transformerengine.txt new file mode 100644 index 000000000..2050d7566 --- /dev/null +++ b/requirements/requirements-transformerengine.txt @@ -0,0 +1 @@ +pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable diff --git a/tests/README.md b/tests/README.md index f5ba5e560..32618d757 100644 --- a/tests/README.md +++ b/tests/README.md @@ -3,6 +3,7 @@ Tests use pytests with coverage and forked plugins. Install with: ```bash +pip install -r requirements/requirements.txt pip install -r requirements/requirements-dev.txt ```