Skip to content

Commit

Permalink
Merge pull request #282 from Vectorrent/fix-weight-decay-banning
Browse files Browse the repository at this point in the history
[Fix] Implement better `wd_ban_list` handling
  • Loading branch information
kozistr authored Oct 24, 2024
2 parents 20ed84f + 546531c commit 769e5fb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
29 changes: 25 additions & 4 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,23 +198,44 @@ def get_optimizer_parameters(
weight_decay: float,
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
) -> PARAMETERS:
r"""Get optimizer parameters while filtering specified modules.
r"""
Get optimizer parameters while filtering specified modules.
:param model_or_parameter: Union[nn.Module, List]. model or parameters.
:param weight_decay: float. weight_decay.
:param wd_ban_list: List[str]. ban list not to set weight decay.
:returns: PARAMETERS. new parameter list.
"""


fully_qualified_names = []
for module_name, module in model_or_parameter.named_modules():
for param_name, _param in module.named_parameters(recurse=False):
# Full parameter name includes module and parameter names
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
# Check if any ban list substring is in the parameter name or module name
if (
any(banned in param_name for banned in wd_ban_list)
or any(banned in module_name for banned in wd_ban_list)
or any(banned in module._get_name() for banned in wd_ban_list)
):
fully_qualified_names.append(full_param_name)

if isinstance(model_or_parameter, nn.Module):
model_or_parameter = list(model_or_parameter.named_parameters())

return [
{
'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in wd_ban_list)],
'params': [
p
for n, p in model_or_parameter
if p.requires_grad and not any(nd in n for nd in fully_qualified_names)
],
'weight_decay': weight_decay,
},
{
'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in wd_ban_list)],
'params': [
p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names)
],
'weight_decay': 0.0,
},
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_neuron_mean_norm():

def test_get_optimizer_parameters():
model: nn.Module = Example()
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
wd_ban_list: List[str] = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm']

before_parameters = list(model.named_parameters())
after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list)
Expand Down

0 comments on commit 769e5fb

Please sign in to comment.