diff --git a/README.md b/README.md index 2e1b77e8..24d013eb 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -27,8 +27,8 @@ So, please double-check the license before using it at your work. $ pip3 install pytorch-optimizer ``` -From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch` optimizers respectively! -please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer) +From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch`, `torchao` optimizers respectively! +please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer), [torchao installation](https://github.com/pytorch/ao?tab=readme-ov-file#installation) before installing it. From `v3.0.0`, drop `Python 3.7` support. However, you can still use this package with `Python 3.7` by installing with `--ignore-requires-python` option. diff --git a/docs/changelogs/v3.2.0.md b/docs/changelogs/v3.2.0.md index 830bf53d..d6d508f7 100644 --- a/docs/changelogs/v3.2.0.md +++ b/docs/changelogs/v3.2.0.md @@ -6,6 +6,8 @@ * [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321) * Support `AdEMAMix` variants. (#276) * `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit` +* Support 8/4bit, fp8 optimizers. (#208, #281) + * `torchao_adamw8bit`, `torchao_adamw4bit`, `torchao_adamwfp8`. ### Bug diff --git a/docs/index.md b/docs/index.md index 2e1b77e8..24d013eb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **77 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -27,8 +27,8 @@ So, please double-check the license before using it at your work. $ pip3 install pytorch-optimizer ``` -From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch` optimizers respectively! -please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer) +From `v2.12.0`, `v3.1.0`, you can use `bitsandbytes`, `q-galore-torch`, `torchao` optimizers respectively! +please check [the bnb requirements](https://github.com/TimDettmers/bitsandbytes?tab=readme-ov-file#tldr), [q-galore-torch installation](https://github.com/VITA-Group/Q-GaLore?tab=readme-ov-file#install-q-galore-optimizer), [torchao installation](https://github.com/pytorch/ao?tab=readme-ov-file#installation) before installing it. From `v3.0.0`, drop `Python 3.7` support. However, you can still use this package with `Python 3.7` by installing with `--ignore-requires-python` option. diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 34a0d0c4..42b5e9ec 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -134,6 +134,7 @@ HAS_BNB: bool = find_spec('bitsandbytes') is not None HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None +HAS_TORCHAO: bool = find_spec('torchao') is not None OPTIMIZER_LIST: List[OPTIMIZER] = [ AdamW, @@ -323,19 +324,40 @@ def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') +def load_ao_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover + r"""load TorchAO optimizer instance.""" + from torchao.prototype import low_bit_optim + + if 'adamw8bit' in optimizer: + return low_bit_optim.AdamW8bit + if 'adamw4bit' in optimizer: + return low_bit_optim.AdamW4bit + if 'adamwfp8' in optimizer: + return low_bit_optim.AdamWFp8 + + raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') + + def load_optimizer(optimizer: str) -> OPTIMIZER: optimizer: str = optimizer.lower() if optimizer.startswith('bnb'): if HAS_BNB and torch.cuda.is_available(): return load_bnb_optimizer(optimizer) # pragma: no cover - raise ImportError(f'[-] bitsandbytes and CUDA required for the optimizer {optimizer}') + raise ImportError(f'bitsandbytes and CUDA required for the optimizer {optimizer}') if optimizer.startswith('q_galore'): if HAS_Q_GALORE and torch.cuda.is_available(): return load_q_galore_optimizer(optimizer) # pragma: no cover - raise ImportError(f'[-] bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}') + raise ImportError(f'bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}') + if optimizer.startswith('torchao'): + if HAS_TORCHAO and torch.cuda.is_available(): + return load_ao_optimizer(optimizer) # pragma: no cover + raise ImportError( + f'torchao required for the optimizer {optimizer}. ' + 'usage: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#usage' + ) if optimizer not in OPTIMIZERS: - raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') + raise NotImplementedError(f'not implemented optimizer : {optimizer}') return OPTIMIZERS[optimizer] diff --git a/tests/test_create_optimizer.py b/tests/test_create_optimizer.py index 9daa047c..c5771ed6 100644 --- a/tests/test_create_optimizer.py +++ b/tests/test_create_optimizer.py @@ -33,3 +33,8 @@ def test_bnb_optimizer(): def test_q_galore_optimizer(): with pytest.raises(ImportError): load_optimizer('q_galore_adamw8bit') + + +def test_torchao_optimizer(): + with pytest.raises(ImportError): + load_optimizer('torchao_adamw4bit')