Skip to content

Commit

Permalink
TE Import Hotfix (#1272)
Browse files Browse the repository at this point in the history
* fix the te import

* refactor get_params_for_weight_decay_optimization

* remove incorrect type hint and dead imports
  • Loading branch information
Quentin-Anthony authored Sep 9, 2024
1 parent 01e74f4 commit 61a3daa
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
from types import GeneratorType
import torch.distributed as dist

import importlib
from typing import List, Dict, Any

def get_params_for_weight_decay_optimization(module, neox_args):
"""Divide params into with-weight-decay and without-weight-decay groups.

def get_params_for_weight_decay_optimization(module: Any, neox_args: Any):
"""
Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
weight_decay_params = {"params": [], "name": "weight_decay_params"}
Expand All @@ -35,43 +39,38 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"weight_decay": 0.0,
"name": "no_weight_decay_params",
}

def is_no_weight_decay_module(module_: Any) -> bool:
if neox_args.norm in ["te_layernorm", "te_rmsnorm"]:
try:
te = importlib.import_module("megatron.model.transformer_engine")
return isinstance(module_, (te.TELayerNorm, te.TERMSNorm))
except ImportError:
raise ImportError(
"Unable to import transformer-engine. Please refer to "
"https://github.com/NVIDIA/TransformerEngine for installation instructions."
)
return (
isinstance(module_, (LayerNorm, RMSNorm, ScaleNorm))
or neox_args.weight_decay == 0.0
)

for module_ in module.modules():
if any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, TELayerNorm),
isinstance(module_, TERMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
if is_no_weight_decay_module(module_):
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
[p for p in module_._parameters.values() if p is not None]
)
else:
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None
and n != "bias"
and not getattr(p, "_no_weight_decay", False)
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None
and (n == "bias" or getattr(p, "_no_weight_decay", False))
]
)
for name, param in module_._parameters.items():
if param is None:
continue
if name == "bias" or getattr(param, "_no_weight_decay", False):
no_weight_decay_params["params"].append(param)
else:
weight_decay_params["params"].append(param)

if neox_args.weight_decay == 0.0:
# only return a single param group
# with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once.
# to avoid this, only use a single param group when weight decay is off.
# Only return a single param group to minimize calls to compressed_allreduce with onebitadam
return [no_weight_decay_params]
return weight_decay_params, no_weight_decay_params

Expand Down Expand Up @@ -379,7 +378,7 @@ def reduce_weight_grads_from_model_parallel_region(input_):
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())
dist.all_reduce(input_, group=mpu.get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
Expand Down

0 comments on commit 61a3daa

Please sign in to comment.