Skip to content

Commit

Permalink
Merge pull request #283 from kozistr/refactor/get-model-parameters
Browse files Browse the repository at this point in the history
[Fix] when `model_or_parameter` is not `nn.Module` instance.
  • Loading branch information
kozistr authored Oct 24, 2024
2 parents 769e5fb + 23adc86 commit ed1d3e1
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 67 deletions.
5 changes: 5 additions & 0 deletions docs/changelogs/v3.2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
* `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit`
* Support 8/4bit, fp8 optimizers. (#208, #281)
* `torchao_adamw8bit`, `torchao_adamw4bit`, `torchao_adamwfp8`.
* Support a module-name-level (e.g. `LayerNorm`) weight decay exclusion for `get_optimizer_parameters`. (#282, #283)

### Bug

* Fix `should_grokfast` condition when initialization. (#279, #280)

### Contributions

thanks to @Vectorrent
113 changes: 75 additions & 38 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 21 additions & 19 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import warnings
from importlib.util import find_spec
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -198,43 +198,45 @@ def get_optimizer_parameters(
weight_decay: float,
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
) -> PARAMETERS:
r"""
Get optimizer parameters while filtering specified modules.
r"""Get optimizer parameters while filtering specified modules.
Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only
need to input `LayerNorm` to exclude weight decay from the layer norm layer(s).
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
:param weight_decay: float. weight_decay.
:param wd_ban_list: List[str]. ban list not to set weight decay.
:returns: PARAMETERS. new parameter list.
"""


fully_qualified_names = []
for module_name, module in model_or_parameter.named_modules():
for param_name, _param in module.named_parameters(recurse=False):
# Full parameter name includes module and parameter names
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
# Check if any ban list substring is in the parameter name or module name
if (
any(banned in param_name for banned in wd_ban_list)
or any(banned in module_name for banned in wd_ban_list)
or any(banned in module._get_name() for banned in wd_ban_list)
):
fully_qualified_names.append(full_param_name)
banned_parameter_patterns: Set[str] = set()

if isinstance(model_or_parameter, nn.Module):
for module_name, module in model_or_parameter.named_modules():
for param_name, _ in module.named_parameters(recurse=False):
full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name
if any(
banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name())
):
banned_parameter_patterns.add(full_param_name)

model_or_parameter = list(model_or_parameter.named_parameters())
else:
banned_parameter_patterns.update(wd_ban_list)

return [
{
'params': [
p
for n, p in model_or_parameter
if p.requires_grad and not any(nd in n for nd in fully_qualified_names)
if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns)
],
'weight_decay': weight_decay,
},
{
'params': [
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)
p
for n, p in model_or_parameter
if p.requires_grad and any(nd in n for nd in banned_parameter_patterns)
],
'weight_decay': 0.0,
},
Expand Down
12 changes: 7 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platfo
coverage[toml]==7.6.1 ; python_version >= "3.8"
exceptiongroup==1.2.2 ; python_version < "3.11" and python_version >= "3.8"
filelock==3.16.1 ; python_version >= "3.8"
fsspec==2024.9.0 ; python_version >= "3.8"
fsspec==2024.10.0 ; python_version >= "3.8"
iniconfig==2.0.0 ; python_version >= "3.8"
isort==5.13.2 ; python_version >= "3.8"
jinja2==3.1.4 ; python_version >= "3.8"
markupsafe==2.1.5 ; python_version >= "3.8"
mpmath==1.3.0 ; python_version >= "3.8"
mpmath==1.3.0 ; python_version >= "3.9" or python_version == "3.8"
mypy-extensions==1.0.0 ; python_version >= "3.8"
networkx==3.1 ; python_version >= "3.8"
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
Expand All @@ -22,8 +22,10 @@ platformdirs==4.3.6 ; python_version >= "3.8"
pluggy==1.5.0 ; python_version >= "3.8"
pytest-cov==5.0.0 ; python_version >= "3.8"
pytest==8.3.3 ; python_version >= "3.8"
ruff==0.6.9 ; python_version >= "3.8"
sympy==1.13.3 ; python_version >= "3.8"
ruff==0.7.0 ; python_version >= "3.8"
setuptools==75.2.0 ; python_version >= "3.12"
sympy==1.12.1 ; python_version == "3.8"
sympy==1.13.1 ; python_version >= "3.9"
tomli==2.0.2 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
torch==2.4.1+cpu ; python_version >= "3.8"
torch==2.5.0+cpu ; python_version >= "3.8"
typing-extensions==4.12.2 ; python_version >= "3.8"
10 changes: 6 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
--extra-index-url https://download.pytorch.org/whl/cpu

filelock==3.16.1 ; python_version >= "3.8"
fsspec==2024.9.0 ; python_version >= "3.8"
fsspec==2024.10.0 ; python_version >= "3.8"
jinja2==3.1.4 ; python_version >= "3.8"
markupsafe==2.1.5 ; python_version >= "3.8"
mpmath==1.3.0 ; python_version >= "3.8"
mpmath==1.3.0 ; python_version >= "3.9" or python_version == "3.8"
networkx==3.1 ; python_version >= "3.8"
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
numpy==2.0.2 ; python_version >= "3.9"
sympy==1.13.3 ; python_version >= "3.8"
torch==2.4.1+cpu ; python_version >= "3.8"
setuptools==75.2.0 ; python_version >= "3.12"
sympy==1.12.1 ; python_version == "3.8"
sympy==1.13.1 ; python_version >= "3.9"
torch==2.5.0+cpu ; python_version >= "3.8"
typing-extensions==4.12.2 ; python_version >= "3.8"
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def test_get_optimizer_parameters():
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm']

before_parameters = list(model.named_parameters())

_ = get_optimizer_parameters(before_parameters, weight_decay=1e-3, wd_ban_list=wd_ban_list)
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)

for before, after in zip(before_parameters, after_parameters):
layer_name: str = before[0]
if layer_name.find('bias') != -1 or layer_name in wd_ban_list:
if layer_name.find('bias') != -1 or layer_name.find('LayerNorm') != -1:
assert after['weight_decay'] == 0.0


Expand Down

0 comments on commit ed1d3e1

Please sign in to comment.