Skip to content

Commit

Permalink
fix the te import
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Sep 9, 2024
1 parent 01e74f4 commit 3d7d706
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,27 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"name": "no_weight_decay_params",
}
for module_ in module.modules():
if any(
if neox_args.norm == "te_layernorm" or neox_args.norm == "te_rmsnorm":
try:
from megatron.model.transformer_engine import TELayerNorm, TERMSNorm
except ImportError:
raise ImportError(
"""Unable to import transformer-engine. Please refer to
https://github.com/NVIDIA/TransformerEngine for installation instructions."""
)
if any(
[
isinstance(module_, TELayerNorm),
isinstance(module_, TERMSNorm),
]
):
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
elif any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, TELayerNorm),
isinstance(module_, TERMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
Expand Down

0 comments on commit 3d7d706

Please sign in to comment.