-
Notifications
You must be signed in to change notification settings - Fork 22
/
hubconf.py
23 lines (19 loc) · 952 Bytes
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from functools import partial as _partial
from functools import update_wrapper as _update_wrapper
from pytorch_optimizer import get_supported_lr_schedulers as _get_supported_lr_schedulers
from pytorch_optimizer import get_supported_optimizers as _get_supported_optimizers
from pytorch_optimizer import load_lr_scheduler as _load_lr_scheduler
from pytorch_optimizer import load_optimizer as _load_optimizer
dependencies = ['torch']
for _optimizer in _get_supported_optimizers():
name: str = _optimizer.__name__
_func = _partial(_load_optimizer, optimizer=name)
_update_wrapper(_func, _optimizer)
for n in (name, name.lower(), name.upper()):
globals()[n] = _func
for _scheduler in _get_supported_lr_schedulers():
name: str = _scheduler.__name__
_func = _partial(_load_lr_scheduler, lr_scheduler=name)
_update_wrapper(_func, _scheduler)
for n in (name, name.lower(), name.upper()):
globals()[n] = _func