Skip to content

Commit

Permalink
update: get_optimizer_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Oct 24, 2024
1 parent 3e48e92 commit f5aca19
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import warnings
from importlib.util import find_spec
from typing import Callable, Dict, List, Optional, Tuple, Union, Set
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -200,37 +200,43 @@ def get_optimizer_parameters(
) -> PARAMETERS:
r"""Get optimizer parameters while filtering specified modules.
Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only
need to input `LayerNorm` to exclude weight decay from the layer norm layer(s).
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
:param weight_decay: float. weight_decay.
:param wd_ban_list: List[str]. ban list not to set weight decay.
:returns: PARAMETERS. new parameter list.
"""
banned_parameter_names: Set[str] = set()
for module_name, module in model_or_parameter.named_modules():
for param_name, _ in module.named_parameters(recurse=False):
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
if (
any(banned in param_name for banned in wd_ban_list)
or any(banned in module_name for banned in wd_ban_list)
or any(banned in module._get_name() for banned in wd_ban_list)
):
banned_parameter_names.add(full_param_name)
banned_parameter_patterns: Set[str] = set()

if isinstance(model_or_parameter, nn.Module):
for module_name, module in model_or_parameter.named_modules():
for param_name, _ in module.named_parameters(recurse=False):
full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name
if any(
banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name())
):
banned_parameter_patterns.add(full_param_name)

model_or_parameter = list(model_or_parameter.named_parameters())
else:
banned_parameter_patterns.update(wd_ban_list)

return [
{
'params': [
p
for n, p in model_or_parameter
if p.requires_grad and not any(nd in n for nd in banned_parameter_names)
if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns)
],
'weight_decay': weight_decay,
},
{
'params': [
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in banned_parameter_names)
p
for n, p in model_or_parameter
if p.requires_grad and any(nd in n for nd in banned_parameter_patterns)
],
'weight_decay': 0.0,
},
Expand Down

0 comments on commit f5aca19

Please sign in to comment.