From f5aca196c87e0a5fa70b0afcd740fbe4761f9e44 Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 25 Oct 2024 00:35:40 +0900 Subject: [PATCH] update: get_optimizer_parameters --- pytorch_optimizer/optimizer/utils.py | 32 +++++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 5d20ae98..05bd99b3 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -1,7 +1,7 @@ import math import warnings from importlib.util import find_spec -from typing import Callable, Dict, List, Optional, Tuple, Union, Set +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -200,37 +200,43 @@ def get_optimizer_parameters( ) -> PARAMETERS: r"""Get optimizer parameters while filtering specified modules. + Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only + need to input `LayerNorm` to exclude weight decay from the layer norm layer(s). + :param model_or_parameter: Union[nn.Module, List]. model or parameters. :param weight_decay: float. weight_decay. :param wd_ban_list: List[str]. ban list not to set weight decay. :returns: PARAMETERS. new parameter list. """ - banned_parameter_names: Set[str] = set() - for module_name, module in model_or_parameter.named_modules(): - for param_name, _ in module.named_parameters(recurse=False): - full_param_name = f'{module_name}.{param_name}' if module_name else param_name - if ( - any(banned in param_name for banned in wd_ban_list) - or any(banned in module_name for banned in wd_ban_list) - or any(banned in module._get_name() for banned in wd_ban_list) - ): - banned_parameter_names.add(full_param_name) + banned_parameter_patterns: Set[str] = set() if isinstance(model_or_parameter, nn.Module): + for module_name, module in model_or_parameter.named_modules(): + for param_name, _ in module.named_parameters(recurse=False): + full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name + if any( + banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name()) + ): + banned_parameter_patterns.add(full_param_name) + model_or_parameter = list(model_or_parameter.named_parameters()) + else: + banned_parameter_patterns.update(wd_ban_list) return [ { 'params': [ p for n, p in model_or_parameter - if p.requires_grad and not any(nd in n for nd in banned_parameter_names) + if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns) ], 'weight_decay': weight_decay, }, { 'params': [ - p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in banned_parameter_names) + p + for n, p in model_or_parameter + if p.requires_grad and any(nd in n for nd in banned_parameter_patterns) ], 'weight_decay': 0.0, },