diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 33937f9ad..75777c54c 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -25,8 +25,15 @@ import torch.distributed as dist -def get_params_for_weight_decay_optimization(module, neox_args): - """Divide params into with-weight-decay and without-weight-decay groups. +import importlib +from typing import List, Dict, Any, Union + + +def get_params_for_weight_decay_optimization( + module: Any, neox_args: Any +) -> List[Dict[str, Any]]: + """ + Divide params into with-weight-decay and without-weight-decay groups. Layernorms and biases will have no weight decay but the rest will. """ weight_decay_params = {"params": [], "name": "weight_decay_params"} @@ -35,58 +42,38 @@ def get_params_for_weight_decay_optimization(module, neox_args): "weight_decay": 0.0, "name": "no_weight_decay_params", } - for module_ in module.modules(): - if neox_args.norm == "te_layernorm" or neox_args.norm == "te_rmsnorm": + + def is_no_weight_decay_module(module_: Any) -> bool: + if neox_args.norm in ["te_layernorm", "te_rmsnorm"]: try: - from megatron.model.transformer_engine import TELayerNorm, TERMSNorm + te = importlib.import_module("megatron.model.transformer_engine") + return isinstance(module_, (te.TELayerNorm, te.TERMSNorm)) except ImportError: raise ImportError( - """Unable to import transformer-engine. Please refer to - https://github.com/NVIDIA/TransformerEngine for installation instructions.""" - ) - if any( - [ - isinstance(module_, TELayerNorm), - isinstance(module_, TERMSNorm), - ] - ): - no_weight_decay_params["params"].extend( - [p for p in list(module_._parameters.values()) if p is not None] + "Unable to import transformer-engine. Please refer to " + "https://github.com/NVIDIA/TransformerEngine for installation instructions." ) - elif any( - [ - isinstance(module_, LayerNorm), - isinstance(module_, RMSNorm), - isinstance(module_, ScaleNorm), - ] - ) or ( - neox_args.weight_decay == 0.0 - ): # also include all parameters here if no weight decay is being done + return ( + isinstance(module_, (LayerNorm, RMSNorm, ScaleNorm)) + or neox_args.weight_decay == 0.0 + ) + + for module_ in module.modules(): + if is_no_weight_decay_module(module_): no_weight_decay_params["params"].extend( - [p for p in list(module_._parameters.values()) if p is not None] + [p for p in module_._parameters.values() if p is not None] ) else: - weight_decay_params["params"].extend( - [ - p - for n, p in list(module_._parameters.items()) - if p is not None - and n != "bias" - and not getattr(p, "_no_weight_decay", False) - ] - ) - no_weight_decay_params["params"].extend( - [ - p - for n, p in list(module_._parameters.items()) - if p is not None - and (n == "bias" or getattr(p, "_no_weight_decay", False)) - ] - ) + for name, param in module_._parameters.items(): + if param is None: + continue + if name == "bias" or getattr(param, "_no_weight_decay", False): + no_weight_decay_params["params"].append(param) + else: + weight_decay_params["params"].append(param) + if neox_args.weight_decay == 0.0: - # only return a single param group - # with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once. - # to avoid this, only use a single param group when weight decay is off. + # Only return a single param group to minimize calls to compressed_allreduce with onebitadam return [no_weight_decay_params] return weight_decay_params, no_weight_decay_params