Skip to content

Commit

Permalink
refactor get_params_for_weight_decay_optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Sep 9, 2024
1 parent 3d7d706 commit 8b1cfa7
Showing 1 changed file with 33 additions and 46 deletions.
79 changes: 33 additions & 46 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@
import torch.distributed as dist


def get_params_for_weight_decay_optimization(module, neox_args):
"""Divide params into with-weight-decay and without-weight-decay groups.
import importlib
from typing import List, Dict, Any, Union


def get_params_for_weight_decay_optimization(
module: Any, neox_args: Any
) -> List[Dict[str, Any]]:
"""
Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
weight_decay_params = {"params": [], "name": "weight_decay_params"}
Expand All @@ -35,58 +42,38 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"weight_decay": 0.0,
"name": "no_weight_decay_params",
}
for module_ in module.modules():
if neox_args.norm == "te_layernorm" or neox_args.norm == "te_rmsnorm":

def is_no_weight_decay_module(module_: Any) -> bool:
if neox_args.norm in ["te_layernorm", "te_rmsnorm"]:
try:
from megatron.model.transformer_engine import TELayerNorm, TERMSNorm
te = importlib.import_module("megatron.model.transformer_engine")
return isinstance(module_, (te.TELayerNorm, te.TERMSNorm))
except ImportError:
raise ImportError(
"""Unable to import transformer-engine. Please refer to
https://github.com/NVIDIA/TransformerEngine for installation instructions."""
)
if any(
[
isinstance(module_, TELayerNorm),
isinstance(module_, TERMSNorm),
]
):
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
"Unable to import transformer-engine. Please refer to "
"https://github.com/NVIDIA/TransformerEngine for installation instructions."
)
elif any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
return (
isinstance(module_, (LayerNorm, RMSNorm, ScaleNorm))
or neox_args.weight_decay == 0.0
)

for module_ in module.modules():
if is_no_weight_decay_module(module_):
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
[p for p in module_._parameters.values() if p is not None]
)
else:
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None
and n != "bias"
and not getattr(p, "_no_weight_decay", False)
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None
and (n == "bias" or getattr(p, "_no_weight_decay", False))
]
)
for name, param in module_._parameters.items():
if param is None:
continue
if name == "bias" or getattr(param, "_no_weight_decay", False):
no_weight_decay_params["params"].append(param)
else:
weight_decay_params["params"].append(param)

if neox_args.weight_decay == 0.0:
# only return a single param group
# with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once.
# to avoid this, only use a single param group when weight decay is off.
# Only return a single param group to minimize calls to compressed_allreduce with onebitadam
return [no_weight_decay_params]
return weight_decay_params, no_weight_decay_params

Expand Down

0 comments on commit 8b1cfa7

Please sign in to comment.