Skip to content

Commit

Permalink
Merge pull request #260 from kozistr/refactor/code
Browse files Browse the repository at this point in the history
[Update] Improve the performance
  • Loading branch information
kozistr authored Jul 21, 2024
2 parents 22f994b + 4666678 commit 474510f
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 41 deletions.
5 changes: 4 additions & 1 deletion docs/changelogs/v3.1.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
* you can use by `optimizer = load_optimizer('q_galore_adamw8bit')`
* Support more bnb optimizers. (#258)
* `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`.
* Improve `power_iteration()` speed up to 40%. (#259)
* Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260)

### Refactor

* Refactor `AdamMini`. (#258)
* Refactor `AdamMini` optimizer. (#258)
* Deprecate optional dependency, `bitsandbytes`. (#258)
* Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258)
* Refactor `shampoo_utils.py`. (#259)

### Bug

Expand Down
76 changes: 49 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 13 additions & 9 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn
from torch.distributed import all_reduce
from torch.nn import functional as f
from torch.nn.functional import cosine_similarity
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.utils import clip_grad_norm_

Expand Down Expand Up @@ -62,7 +62,7 @@ def to_real(x: torch.Tensor) -> torch.Tensor:
return x.real if torch.is_complex(x) else x


def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8):
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8) -> None:
r"""Normalize gradient with stddev.
:param x: torch.Tensor. gradient.
Expand Down Expand Up @@ -119,7 +119,7 @@ def cosine_similarity_by_view(
"""
x = view_func(x)
y = view_func(y)
return f.cosine_similarity(x, y, dim=1, eps=eps).abs_()
return cosine_similarity(x, y, dim=1, eps=eps).abs_()


def clip_grad_norm(
Expand Down Expand Up @@ -315,6 +315,7 @@ def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
return x


@torch.no_grad()
def reg_noise(
network1: nn.Module, network2: nn.Module, num_data: int, lr: float, eta: float = 8e-3, temperature: float = 1e-4
) -> Union[torch.Tensor, float]:
Expand All @@ -332,11 +333,14 @@ def reg_noise(
reg_coef: float = 0.5 / (eta * num_data)
noise_coef: float = math.sqrt(2.0 / lr / num_data * temperature)

loss = 0
for param1, param2 in zip(network1.parameters(), network2.parameters(), strict=True):
reg = torch.sub(param1, param2).pow_(2) * reg_coef
noise1 = param1 * torch.randn_like(param1) * noise_coef
noise2 = param2 * torch.randn_like(param2) * noise_coef
loss += torch.sum(reg - noise1 - noise2)
loss = torch.tensor(0.0, device=next(network1.parameters()).device)

for param1, param2 in zip(network1.parameters(), network2.parameters()):
reg = (param1 - param2).pow_(2).mul_(reg_coef).sum()

noise = param1 * torch.randn_like(param1)
noise.add_(param2 * torch.randn_like(param2))

loss.add_(reg - noise.mul_(noise_coef).sum())

return loss
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest==8.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.5.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest==8.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.5.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows"
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"
torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mkl==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and pl
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows"
torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"

0 comments on commit 474510f

Please sign in to comment.