Skip to content

Commit

Permalink
[Feature] Support Adafactor Optimizer (#1361)
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku authored Sep 21, 2023
1 parent 53474ef commit d617bca
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 2 deletions.
32 changes: 32 additions & 0 deletions docs/en/common_usage/better_optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,35 @@ runner = Runner(
)
runner.train()
```

## transformers

[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier。

```{note}
If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.8.5`.
```

- Installation

```bash
pip install transformers
```

- Usage

Take the `Adafactor` as an example.

```python
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
# To view the input parameters for Adafactor, you can refer to
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/optimization.py#L492
optim_wrapper=dict(optimizer=dict(type='Adafactor', lr=1e-5,
weight_decay=1e-2, scale_parameter=False, relative_step=False)),
train_cfg=dict(by_epoch=True, max_epochs=3),
)
runner.train()
```
15 changes: 15 additions & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ def register_bitsandbytes_optimizers() -> List[str]:
BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers()


def register_transformers_optimizers():
transformer_optimizers = []
try:
from transformers import Adafactor
except ImportError:
pass
else:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
transformer_optimizers.append('Adafactor')
return transformer_optimizers


TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers()


def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
"""Build function of OptimWrapper.
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ neptune
parameterized
pydantic==1.10.9
pytest
transformers
19 changes: 17 additions & 2 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from mmengine.optim.optimizer.builder import (BITSANDBYTES_OPTIMIZERS,
DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
TORCH_OPTIMIZERS,
TRANSFORMERS_OPTIMIZERS)
from mmengine.registry import DefaultScope, Registry, build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
Expand Down Expand Up @@ -53,6 +54,14 @@ def has_bitsandbytes() -> bool:
return False


def has_transformers() -> bool:
try:
import transformers # noqa: F401
return True
except ImportError:
return False


class ExampleModel(nn.Module):

def __init__(self):
Expand Down Expand Up @@ -244,7 +253,7 @@ def test_dadaptation_optimizers(self):
def test_lion_optimizers(self):
assert 'Lion' in LION_OPTIMIZERS

@unittest.skipIf(not has_bitsandbytes(), 'dadaptation is not installed')
@unittest.skipIf(not has_bitsandbytes(), 'bitsandbytes is not installed')
def test_bitsandbytes_optimizers(self):
bitsandbytes_optimizers = [
'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit',
Expand All @@ -254,6 +263,12 @@ def test_bitsandbytes_optimizers(self):
assert set(bitsandbytes_optimizers).issubset(
set(BITSANDBYTES_OPTIMIZERS))

@unittest.skipIf(not has_transformers(), 'transformers is not installed')
def test_transformers_optimizers(self):
transformers_optimizers = ['Adafactor']
assert set(transformers_optimizers).issubset(
set(TRANSFORMERS_OPTIMIZERS))

def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optim_wrapper_cfg = dict(
Expand Down

0 comments on commit d617bca

Please sign in to comment.