diff --git a/docs/changelogs/v3.1.0.md b/docs/changelogs/v3.1.0.md index 61d03de06..19e423d47 100644 --- a/docs/changelogs/v3.1.0.md +++ b/docs/changelogs/v3.1.0.md @@ -9,12 +9,15 @@ * you can use by `optimizer = load_optimizer('q_galore_adamw8bit')` * Support more bnb optimizers. (#258) * `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`. +* Improve `power_iteration()` speed up to 40%. (#259) +* Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260) ### Refactor -* Refactor `AdamMini`. (#258) +* Refactor `AdamMini` optimizer. (#258) * Deprecate optional dependency, `bitsandbytes`. (#258) * Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258) +* Refactor `shampoo_utils.py`. (#259) ### Bug diff --git a/poetry.lock b/poetry.lock index 68d693053..d4ae9cb20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,24 @@ # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +[[package]] +name = "bitsandbytes" +version = "0.43.1" +description = "k-bit optimizers and matrix multiplication routines." +optional = true +python-versions = "*" +files = [ + {file = "bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a81c826d576d6d691c7b4a7491c8fdc0f37f769795d6ca2e54afa605d2c260a3"}, + {file = "bitsandbytes-0.43.1-py3-none-win_amd64.whl", hash = "sha256:52c1c7189a6ca006555a9663e544e75f40520a97a26e075411f9f9aca0771fcd"}, +] + +[package.dependencies] +numpy = "*" +torch = "*" + +[package.extras] +benchmark = ["matplotlib", "pandas"] +test = ["scipy"] + [[package]] name = "black" version = "24.4.2" @@ -488,13 +507,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pytest" -version = "8.2.2" +version = "8.3.1" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, - {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, + {file = "pytest-8.3.1-py3-none-any.whl", hash = "sha256:e9600ccf4f563976e2c99fa02c7624ab938296551f280835ee6516df8bc4ae8c"}, + {file = "pytest-8.3.1.tar.gz", hash = "sha256:7e8e5c5abd6e93cb1cc151f23e57adc31fcf8cfd2a3ff2da63e23f732de35db6"}, ] [package.dependencies] @@ -502,7 +521,7 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=1.5,<2.0" +pluggy = ">=1.5,<2" tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] @@ -528,40 +547,40 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.5.1" +version = "0.5.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.1-py3-none-linux_armv6l.whl", hash = "sha256:6ecf968fcf94d942d42b700af18ede94b07521bd188aaf2cd7bc898dd8cb63b6"}, - {file = "ruff-0.5.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:204fb0a472f00f2e6280a7c8c7c066e11e20e23a37557d63045bf27a616ba61c"}, - {file = "ruff-0.5.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d235968460e8758d1e1297e1de59a38d94102f60cafb4d5382033c324404ee9d"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38beace10b8d5f9b6bdc91619310af6d63dd2019f3fb2d17a2da26360d7962fa"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e478d2f09cf06add143cf8c4540ef77b6599191e0c50ed976582f06e588c994"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0368d765eec8247b8550251c49ebb20554cc4e812f383ff9f5bf0d5d94190b0"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:3a9a9a1b582e37669b0138b7c1d9d60b9edac880b80eb2baba6d0e566bdeca4d"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bdd9f723e16003623423affabcc0a807a66552ee6a29f90eddad87a40c750b78"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:be9fd62c1e99539da05fcdc1e90d20f74aec1b7a1613463ed77870057cd6bd96"}, - {file = "ruff-0.5.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e216fc75a80ea1fbd96af94a6233d90190d5b65cc3d5dfacf2bd48c3e067d3e1"}, - {file = "ruff-0.5.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c4c2112e9883a40967827d5c24803525145e7dab315497fae149764979ac7929"}, - {file = "ruff-0.5.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dfaf11c8a116394da3b65cd4b36de30d8552fa45b8119b9ef5ca6638ab964fa3"}, - {file = "ruff-0.5.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d7ceb9b2fe700ee09a0c6b192c5ef03c56eb82a0514218d8ff700f6ade004108"}, - {file = "ruff-0.5.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:bac6288e82f6296f82ed5285f597713acb2a6ae26618ffc6b429c597b392535c"}, - {file = "ruff-0.5.1-py3-none-win32.whl", hash = "sha256:5c441d9c24ec09e1cb190a04535c5379b36b73c4bc20aa180c54812c27d1cca4"}, - {file = "ruff-0.5.1-py3-none-win_amd64.whl", hash = "sha256:b1789bf2cd3d1b5a7d38397cac1398ddf3ad7f73f4de01b1e913e2abc7dfc51d"}, - {file = "ruff-0.5.1-py3-none-win_arm64.whl", hash = "sha256:2875b7596a740cbbd492f32d24be73e545a4ce0a3daf51e4f4e609962bfd3cd2"}, - {file = "ruff-0.5.1.tar.gz", hash = "sha256:3164488aebd89b1745b47fd00604fb4358d774465f20d1fcd907f9c0fc1b0655"}, + {file = "ruff-0.5.4-py3-none-linux_armv6l.whl", hash = "sha256:82acef724fc639699b4d3177ed5cc14c2a5aacd92edd578a9e846d5b5ec18ddf"}, + {file = "ruff-0.5.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:da62e87637c8838b325e65beee485f71eb36202ce8e3cdbc24b9fcb8b99a37be"}, + {file = "ruff-0.5.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e98ad088edfe2f3b85a925ee96da652028f093d6b9b56b76fc242d8abb8e2059"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c55efbecc3152d614cfe6c2247a3054cfe358cefbf794f8c79c8575456efe19"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9b85eaa1f653abd0a70603b8b7008d9e00c9fa1bbd0bf40dad3f0c0bdd06793"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0cf497a47751be8c883059c4613ba2f50dd06ec672692de2811f039432875278"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:09c14ed6a72af9ccc8d2e313d7acf7037f0faff43cde4b507e66f14e812e37f7"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:628f6b8f97b8bad2490240aa84f3e68f390e13fabc9af5c0d3b96b485921cd60"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3520a00c0563d7a7a7c324ad7e2cde2355733dafa9592c671fb2e9e3cd8194c1"}, + {file = "ruff-0.5.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93789f14ca2244fb91ed481456f6d0bb8af1f75a330e133b67d08f06ad85b516"}, + {file = "ruff-0.5.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:029454e2824eafa25b9df46882f7f7844d36fd8ce51c1b7f6d97e2615a57bbcc"}, + {file = "ruff-0.5.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9492320eed573a13a0bc09a2957f17aa733fff9ce5bf00e66e6d4a88ec33813f"}, + {file = "ruff-0.5.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a6e1f62a92c645e2919b65c02e79d1f61e78a58eddaebca6c23659e7c7cb4ac7"}, + {file = "ruff-0.5.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:768fa9208df2bec4b2ce61dbc7c2ddd6b1be9fb48f1f8d3b78b3332c7d71c1ff"}, + {file = "ruff-0.5.4-py3-none-win32.whl", hash = "sha256:e1e7393e9c56128e870b233c82ceb42164966f25b30f68acbb24ed69ce9c3a4e"}, + {file = "ruff-0.5.4-py3-none-win_amd64.whl", hash = "sha256:58b54459221fd3f661a7329f177f091eb35cf7a603f01d9eb3eb11cc348d38c4"}, + {file = "ruff-0.5.4-py3-none-win_arm64.whl", hash = "sha256:bd53da65f1085fb5b307c38fd3c0829e76acf7b2a912d8d79cadcdb4875c1eb7"}, + {file = "ruff-0.5.4.tar.gz", hash = "sha256:2795726d5f71c4f4e70653273d1c23a8182f07dd8e48c12de5d867bfb7557eed"}, ] [[package]] name = "sympy" -version = "1.13.0" +version = "1.13.1" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.13.0-py3-none-any.whl", hash = "sha256:6b0b32a4673fb91bd3cac3b55406c8e01d53ae22780be467301cc452f6680c92"}, - {file = "sympy-1.13.0.tar.gz", hash = "sha256:3b6af8f4d008b9a1a6a4268b335b984b23835f26d1d60b0526ebc71d48a25f57"}, + {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, + {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, ] [package.dependencies] @@ -642,7 +661,10 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[extras] +bitsandbytes = ["bitsandbytes"] + [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0.0" -content-hash = "8bc2a8c8202fa34296ac68c6f058bf8e33360abeddb06b2fca03ba7b64a3d02f" +content-hash = "d51586f8352db14a18dd407b19285c9649564b029e6e6aae52a0d566515e5c81" diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index 464cc3e9f..df96bb908 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -7,7 +7,7 @@ import torch from torch import nn from torch.distributed import all_reduce -from torch.nn import functional as f +from torch.nn.functional import cosine_similarity from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.utils import clip_grad_norm_ @@ -62,7 +62,7 @@ def to_real(x: torch.Tensor) -> torch.Tensor: return x.real if torch.is_complex(x) else x -def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8): +def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8) -> None: r"""Normalize gradient with stddev. :param x: torch.Tensor. gradient. @@ -119,7 +119,7 @@ def cosine_similarity_by_view( """ x = view_func(x) y = view_func(y) - return f.cosine_similarity(x, y, dim=1, eps=eps).abs_() + return cosine_similarity(x, y, dim=1, eps=eps).abs_() def clip_grad_norm( @@ -315,6 +315,7 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor: return x +@torch.no_grad() def reg_noise( network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4 ) -> Union[torch.Tensor, float]: @@ -332,11 +333,14 @@ def reg_noise( reg_coef: float = 0.5 / (eta * num_data) noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature) - loss = 0 - for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True): - reg = torch.sub(param1, param2).pow_(2) * reg_coef - noise1 = param1 * torch.randn_like(param1) * noise_coef - noise2 = param2 * torch.randn_like(param2) * noise_coef - loss += torch.sum(reg - noise1 - noise2) + loss = torch.tensor(0.0, device=next(network1.parameters()).device) + + for param1, param2 in zip(network1.parameters(), network2.parameters()): + reg = (param1 - param2).pow_(2).mul_(reg_coef).sum() + + noise = param1 * torch.randn_like(param1) + noise.add_(param2 * torch.randn_like(param2)) + + loss.add_(reg - noise.mul_(noise_coef).sum()) return loss diff --git a/requirements-dev.txt b/requirements-dev.txt index e431b1b07..e4f7097fc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,9 +22,9 @@ pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0" pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -pytest==8.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0" -ruff==0.5.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -sympy==1.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +pytest==8.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" +ruff==0.5.4 ; python_version >= "3.8" and python_full_version < "4.0.0" +sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0" tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/requirements.txt b/requirements.txt index deac15a7f..ac61ed57e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ mkl==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and pl mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -sympy==1.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0" tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"