Skip to content

Commit

Permalink
LayerNorm Refactor (#1269)
Browse files Browse the repository at this point in the history
* Add TE skeleton

* Update NeoXArgs docs automatically

* added option for te version of norms

* import TERMSNorm

* add te norm options to norm arg

* add TE objects in weight decay function

* reformat

* add TERMSNorm and TELayerNorm

* Update NeoXArgs docs automatically

* - add Fused RMS Norm from apex

* - make it consistent with how layernorm looks

* Merged transformer engine and apex fused layernorm branches

* Added assertion if TE is used

* Removed unnecessary transformer-engine import

* Changed importerror text for TE

* Added requirements/requirements-transformerengine.txt

* Add TE skeleton

* Update NeoXArgs docs automatically

* added option for te version of norms

* import TERMSNorm

* add te norm options to norm arg

* add TE objects in weight decay function

* reformat

* add TERMSNorm and TELayerNorm

* Update NeoXArgs docs automatically

* - add Fused RMS Norm from apex

* - make it consistent with how layernorm looks

* Merged transformer engine and apex fused layernorm branches

* Added assertion if TE is used

* Removed unnecessary transformer-engine import

* Changed importerror text for TE

* Added requirements/requirements-transformerengine.txt

* update comments

* precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: lintangsutawika <lintang@stella-ord-0.stella-ord.tenant-eleutherai.svc.tenant.chi.local>
Co-authored-by: lintangsutawika <[email protected]>
Co-authored-by: dmahan93 <[email protected]>
Co-authored-by: aurelion-source <[email protected]>
Co-authored-by: aurelion-source <[email protected]>
  • Loading branch information
8 people authored Sep 9, 2024
1 parent 77e8158 commit 836aefa
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
hooks:
- id: codespell
args: [
'--ignore-words-list=reord,dout', # Word used in error messages that need rewording
'--ignore-words-list=reord,dout,te', # Word used in error messages that need rewording. te --> transformerengine
--check-filenames,
--check-hidden,
]
Expand Down
6 changes: 3 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 53d0ae8
Default = 217b4c5

current git hash of repository

Expand Down Expand Up @@ -335,11 +335,11 @@ Model Arguments
- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm']
- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm', 'te_rmsnorm', 'te_layernorm']
Default = layernorm
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm".
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm".
Expand Down
114 changes: 113 additions & 1 deletion megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
except:
HAVE_PERSIST_LAYER_NORM = False

from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
from apex.normalization.fused_layer_norm import (
FusedLayerNormAffineFunction,
FusedRMSNormAffineFunction,
)


global fused_layer_norm_cuda
Expand Down Expand Up @@ -148,3 +151,112 @@ def forward(self, input):
)

return output


class MixedFusedRMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False,
apply_rmsnorm_1p=False,
mem_efficient_rms=True,
):
super(MixedFusedRMSNorm, self).__init__()

self.apply_rmsnorm_1p = apply_rmsnorm_1p
self.mem_efficient_rms = mem_efficient_rms
self.norm_fn = FusedRMSNormAffineFunction

global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")

# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [
1024,
1536,
2048,
2304,
3072,
3840,
4096,
5120,
6144,
8192,
10240,
12288,
12800,
15360,
16384,
18432,
20480,
24576,
25600,
30720,
32768,
40960,
49152,
65536,
]
if (
normalized_shape not in persist_ln_hidden_sizes
or not HAVE_PERSIST_LAYER_NORM
):
no_persist_layer_norm = True

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.scale = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel

# set sequence parallelism flag on weight and bias parameters
setattr(self.scale, "sequence_parallel", self.sequence_parallel)

def reset_parameters(self):

if self.apply_rmsnorm_1p:
init.zeros_(self.scale)
else:
init.ones_(self.scale)

def forward(self, input):

weight = self.scale + 1 if self.apply_rmsnorm_1p else self.scale
# CPU path is here for unittest sake.
if not input.is_cuda:
print(
"WARNING! The input of FusedLayerNorm should be on the GPU."
"This warning should only be triggered in the FusedRMSNorm unit tests."
)
# Latest pytorch actually supports F.rms_norm but I don't want to break builds so...
return F.layer_norm(input, self.normalized_shape, weight, None, self.eps)

# Apex does not have versions yet (https://github.com/NVIDIA/apex/pull/1648), so we need to inspect
# the function manually on whether the extra arg introduced in https://github.com/NVIDIA/apex/pull/1715 exists yet
if "memory_efficient" in inspect.getfullargspec(self.norm_fn.forward).args:
return self.norm_fn.apply(
input,
weight,
self.normalized_shape,
self.eps,
self.mem_efficient_rms,
)
else:
return self.norm_fn.apply(input, weight, self.normalized_shape, self.eps)

# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(
inp=output, requires_grad=input.requires_grad, keep_graph=True
)

return output
17 changes: 16 additions & 1 deletion megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

def get_norm(neox_args):
if neox_args.norm == "rmsnorm":
norm = RMSNorm
eps = neox_args.rms_norm_epsilon
if neox_args.rmsnorm_fusion:
from .fused_layer_norm import MixedFusedRMSNorm

norm = MixedFusedRMSNorm
else:
norm = RMSNorm
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
if neox_args.layernorm_fusion:
Expand All @@ -31,6 +36,16 @@ def get_norm(neox_args):
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
elif neox_args.norm == "te_rmsnorm":
from .transformer_engine import TERMSNorm

norm = TERMSNorm
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "te_layernorm":
from .transformer_engine import TELayerNorm

norm = TELayerNorm
eps = neox_args.layernorm_epsilon
else:
raise ValueError(f"norm {neox_args.norm} not recognized")
return norm, eps
Expand Down
137 changes: 137 additions & 0 deletions megatron/model/transformer_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch

try:
import transformer_engine as te
except ImportError:
raise ImportError(
"Unable to import transformer-engine. Please refer to "
"https://github.com/NVIDIA/TransformerEngine for installation instructions."
)


class TERMSNorm(torch.nn.Module):
def __init__(self, dim, eps=1e-8, **kwargs):
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`RMSNorm` based on input
:param dim: model size
:param eps: epsilon value, default 1e-8
"""
super(TERMSNorm, self).__init__()

self.d = dim
self.eps = eps
self.norm = te.pytorch.RMSNorm(
hidden_size=self.d,
eps=self.eps,
**kwargs,
)

def forward(self, x):
return self.norm(x)


class TELayerNorm(torch.nn.Module):
def __init__(self, dim, eps=1.0e-5, **kwargs):
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` based on input
:param dim: model size
:param eps: epsilon value, default 1.0e-5
"""
super(TELayerNorm, self).__init__()

self.d = dim
self.eps = eps
self.norm = te.pytorch.LayerNorm(
hidden_size=self.d,
eps=self.eps,
**kwargs,
)

def forward(self, x):
return self.norm(x)


class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""

def __init__(self):
# TODO
return

def forward(self, x):
# TODO
return


class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""

def __init__(self):
# TODO
return

def forward(self, x):
# TODO
return


class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""

def __init__(self):
# TODO
return

def forward(self, x):
# TODO
return


class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""

def __init__(self):
# TODO
return

def forward(self, x):
# TODO
return


class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
"""

def __init__(self):
# TODO
return

def forward(self, x):
# TODO
return


class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""

def __init__(self):
# TODO
return
15 changes: 12 additions & 3 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Utilities for models."""

import torch
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.fused_softmax import SoftmaxFusionTypes
from megatron import mpu
from types import GeneratorType
Expand All @@ -35,9 +36,17 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"name": "no_weight_decay_params",
}
for module_ in module.modules():
# apply weight decay to any "...Norm" modules.
if "norm" in type(module_).__name__.lower() or neox_args.weight_decay == 0.0:
# also include all parameters here if no weight decay is being done
if any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, TELayerNorm),
isinstance(module_, TERMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
Expand Down
11 changes: 9 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,23 @@ class NeoXArgsModel(NeoXArgsTemplate):
Maximum number of position embeddings to use. This is the size of position embedding.
"""

norm: Literal["layernorm", "rmsnorm", "scalenorm"] = "layernorm"
norm: Literal[
"layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm"
] = "layernorm"
"""
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm".
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm".
"""

layernorm_fusion: bool = False
"""
Use fused layer norm kernel (if `norm` is `layernorm`).
"""

rmsnorm_fusion: bool = False
"""
Use fused RMS norm kernel (if `norm` is `rmsnorm`).
"""

use_qk_layernorm: bool = False
"""
Use QK Normalization
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-transformerengine.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
1 change: 1 addition & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Tests use pytests with coverage and forked plugins. Install with:

```bash
pip install -r requirements/requirements.txt
pip install -r requirements/requirements-dev.txt
```

Expand Down

0 comments on commit 836aefa

Please sign in to comment.