Skip to content

Commit

Permalink
enable param group configuration in llm-foundry (#760)
Browse files Browse the repository at this point in the history
* enable param group configuration in llm-foundry

* add doc string

* add debug logs

* add test, fix bug

* spell check; mark test gpu

* updt to use RegEx search

* Apply suggestions from code review

Co-authored-by: Daniel King <[email protected]>

* updt with dakinggg pr comments

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
vchiley and dakinggg authored Nov 29, 2023
1 parent 4f399bf commit 5f21855
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 18 deletions.
22 changes: 12 additions & 10 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
device, or b) step() is executed on a non-CUDA parameter.
"""

def __init__(self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True, # XXX this flag is mostly for testing...
):

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down
118 changes: 112 additions & 6 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import functools
import logging
import os
import re
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from composer import algorithms
Expand Down Expand Up @@ -155,18 +158,121 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm:
raise ValueError(f'Not sure how to build algorithm: {name}')


def _extract_param_groups(
model: torch.nn.Module,
optimizer_config: Dict[str, Any],
) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
"""Extracts parameter groups defined in the optimizer config.
The optimizer_config defines the optimizer args. It can additionally have key
`disable_grad` which is a string or list of strings. If a string matches a
parameter name, then that parameter will have `requires_grad=False`. This is
useful for freezing parameters. It can additionally have a key
`param_groups` which is a list of dicts. In this dict, key `param_str_match`
defines a string; if a parameter name contains this string, then it will be
in this parameter group. This is useful for grouping parameters together.
The dict can also contain any other key that is a valid optimizer arg.
Note: to handle name overlap conflicts, params are assigned to parameter
groups and added to `param_groups` in the order that `param_str_match` appear
in `param_groups`.
Usage
To disable gradient for all parameters that contain the string "norm" or "bias":
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"disable_grad": ["norm", "bias"]
}
```
To create and modify the optimizer parameters for all parameters that contain
the string "norm" and "bias" separately:
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"param_groups": [
{
"param_str_match": "norm",
"lr": 1e-4,
"weight_decay": 0.0,
},
{
"param_str_match": "bias",
"lr": 5e-4,
"weight_decay": 0.0,
},
],
}
```
Args:
model (torch.nn.Module): model to extract parameters from
optimizer_config (Dict[str, Any]): optimizer config
Returns:
Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
torch.Tensor's or dict's. Specifies what Tensors should be optimized
and their param groupings.
"""
if 'disable_grad' in optimizer_config.keys():
str_matches = optimizer_config.pop('disable_grad')
if isinstance(str_matches, str):
str_matches = [str_matches]
for str_match in str_matches:
for n, p in model.named_parameters():
if re.search(str_match, n):
p.requires_grad = False
log.debug(f'Setting `{n}.requires_grad = False`.')

param_groups_config = optimizer_config.pop('param_groups', None)
if param_groups_config is not None:
params = []
param_dict = OrderedDict((n, p) for n, p in model.named_parameters())

log.debug(f'Default optimizer settings: {optimizer_config}.')
for param_group_config in param_groups_config:
str_match = param_group_config.pop('param_str_match')
filter_fn = functools.partial(re.search, str_match)
param_names = [n for n in param_dict.keys() if filter_fn(n)]
group_params = {'params': [param_dict.pop(n) for n in param_names]}
group_params.update(param_group_config)

log.debug(
f'Creating optimizer param_group with parameters: {param_names} ' +\
f'(extracted using {str_match=}). The param_group optimizer ' +\
f'setting overrides are: {param_group_config}.')

params.append(group_params)

params.insert(0, {'params': param_dict.values()})
return params

return model.parameters()


def build_optimizer(model: torch.nn.Module, name: str,
optimizer_config: Dict[str, Any]) -> Optimizer:

params = _extract_param_groups(model, optimizer_config)

if name == 'decoupled_adamw':
return DecoupledAdamW(model.parameters(), **optimizer_config)
return DecoupledAdamW(params, **optimizer_config)
elif name == 'decoupled_lionw':
return DecoupledLionW(model.parameters(), **optimizer_config)
return DecoupledLionW(params, **optimizer_config)
elif name == 'clip_lion':
return DecoupledClipLion(model.parameters(), **optimizer_config)
return DecoupledClipLion(params, **optimizer_config)
elif name == 'adalr_lion':
return DecoupledAdaLRLion(model.parameters(), **optimizer_config)
return DecoupledAdaLRLion(params, **optimizer_config)
elif name == 'decoupled_lionw_8b':
return DecoupledLionW_8bit(model.parameters(), **optimizer_config)
return DecoupledLionW_8bit(params, **optimizer_config)
else:
raise ValueError(f'Not sure how to build optimizer: {name}')

Expand Down
89 changes: 87 additions & 2 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import re
import unittest.mock as mock
from typing import Union
from copy import deepcopy
from typing import Any, Dict, Union

import pytest
import torch
import torch.nn as nn
from composer.callbacks import Generate
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import build_callback, build_tokenizer
from llmfoundry.utils.builders import (build_callback, build_optimizer,
build_tokenizer)


@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [
Expand Down Expand Up @@ -110,3 +115,83 @@ def test_build_hf_checkpointer_callback():
assert isinstance(kwargs['mlflow_logging_config'], dict)
assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict)
assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict


class _DummyModule(nn.Module):

def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32):
super().__init__()
self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype)
self.norm0 = nn.LayerNorm(3, device=device, dtype=dtype)
self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore
return self.linear1(self.norm0(self.linear0(x)))


@pytest.mark.parametrize('name, optimizer_config', [
('decoupled_adamw', {}),
('decoupled_lionw', {}),
('clip_lion', {}),
('adalr_lion', {}),
pytest.param('decoupled_lionw_8b', {}, marks=pytest.mark.gpu),
])
@pytest.mark.parametrize('opt_additional_config', [
{
'disable_grad': 'norm'
},
{
'disable_grad': ['norm', 'bias']
},
{
'param_groups': [{
'param_str_match': 'norm',
'lr': 1e-9,
'weight_decay': 0.0,
},]
},
{
'param_groups': [{
'param_str_match': 'no.*.bias',
'lr': 1e-9,
'weight_decay': 0.0,
},]
},
{
'param_groups': [{
'param_str_match': 'norm',
'lr': 1e-4,
'weight_decay': 0.0,
},],
'disable_grad': ['bias'],
},
])
def test_build_optimizer(name: str, optimizer_config: Dict[str, Any],
opt_additional_config: Dict[str, Any]):
model = _DummyModule()
optimizer_config.update(deepcopy(opt_additional_config))
optimizer = build_optimizer(model, name, optimizer_config)

if 'disable_grad' in opt_additional_config.keys():
disable_grad = opt_additional_config['disable_grad']
if isinstance(disable_grad, str):
disable_grad = [disable_grad]
for n, p in model.named_parameters():
for k in disable_grad:
if re.search(k, n):
assert not p.requires_grad

if 'param_groups' in opt_additional_config.keys():
for param_group_config, param_group in zip(
opt_additional_config['param_groups'],
optimizer.param_groups[1:]):
param_group_config = deepcopy(param_group_config)
param_str_match = param_group_config.pop('param_str_match')

for k, v in param_group_config.items():
assert param_group[k] == v

param_ids = [id(p) for p in param_group['params']]
for n, p in model.named_parameters():
if re.search(param_str_match, n):
assert id(p) in param_ids

0 comments on commit 5f21855

Please sign in to comment.