From d617bcafdd0bb232fa93ebfef0fc0f3716cc69b0 Mon Sep 17 00:00:00 2001 From: takuoko Date: Thu, 21 Sep 2023 17:30:24 +0900 Subject: [PATCH] [Feature] Support Adafactor Optimizer (#1361) --- docs/en/common_usage/better_optimizers.md | 32 +++++++++++++++++++ mmengine/optim/optimizer/builder.py | 15 +++++++++ requirements/tests.txt | 1 + .../test_optimizer/test_optimizer.py | 19 +++++++++-- 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/docs/en/common_usage/better_optimizers.md b/docs/en/common_usage/better_optimizers.md index 69c64e6927..9f6bdd6c39 100644 --- a/docs/en/common_usage/better_optimizers.md +++ b/docs/en/common_usage/better_optimizers.md @@ -121,3 +121,35 @@ runner = Runner( ) runner.train() ``` + +## transformers + +[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier。 + +```{note} +If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.8.5`. +``` + +- Installation + +```bash +pip install transformers +``` + +- Usage + +Take the `Adafactor` as an example. + +```python +runner = Runner( + model=ResNet18(), + work_dir='./work_dir', + train_dataloader=train_dataloader_cfg, + # To view the input parameters for Adafactor, you can refer to + # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/optimization.py#L492 + optim_wrapper=dict(optimizer=dict(type='Adafactor', lr=1e-5, + weight_decay=1e-2, scale_parameter=False, relative_step=False)), + train_cfg=dict(by_epoch=True, max_epochs=3), +) +runner.train() +``` diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 644aa25609..6543dacdd2 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -157,6 +157,21 @@ def register_bitsandbytes_optimizers() -> List[str]: BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() +def register_transformers_optimizers(): + transformer_optimizers = [] + try: + from transformers import Adafactor + except ImportError: + pass + else: + OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + transformer_optimizers.append('Adafactor') + return transformer_optimizers + + +TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers() + + def build_optim_wrapper(model: nn.Module, cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper: """Build function of OptimWrapper. diff --git a/requirements/tests.txt b/requirements/tests.txt index 4bbea413fc..599163fd1a 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -11,3 +11,4 @@ neptune parameterized pydantic==1.10.9 pytest +transformers diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 5edfd95ff0..0cf60fcb83 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -17,7 +17,8 @@ from mmengine.optim.optimizer.builder import (BITSANDBYTES_OPTIMIZERS, DADAPTATION_OPTIMIZERS, LION_OPTIMIZERS, - TORCH_OPTIMIZERS) + TORCH_OPTIMIZERS, + TRANSFORMERS_OPTIMIZERS) from mmengine.registry import DefaultScope, Registry, build_from_cfg from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available @@ -53,6 +54,14 @@ def has_bitsandbytes() -> bool: return False +def has_transformers() -> bool: + try: + import transformers # noqa: F401 + return True + except ImportError: + return False + + class ExampleModel(nn.Module): def __init__(self): @@ -244,7 +253,7 @@ def test_dadaptation_optimizers(self): def test_lion_optimizers(self): assert 'Lion' in LION_OPTIMIZERS - @unittest.skipIf(not has_bitsandbytes(), 'dadaptation is not installed') + @unittest.skipIf(not has_bitsandbytes(), 'bitsandbytes is not installed') def test_bitsandbytes_optimizers(self): bitsandbytes_optimizers = [ 'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit', @@ -254,6 +263,12 @@ def test_bitsandbytes_optimizers(self): assert set(bitsandbytes_optimizers).issubset( set(BITSANDBYTES_OPTIMIZERS)) + @unittest.skipIf(not has_transformers(), 'transformers is not installed') + def test_transformers_optimizers(self): + transformers_optimizers = ['Adafactor'] + assert set(transformers_optimizers).issubset( + set(TRANSFORMERS_OPTIMIZERS)) + def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` optim_wrapper_cfg = dict(