Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Stuffs #259

Merged
merged 12 commits into from
Jul 20, 2024
5 changes: 1 addition & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ classifiers = [
python = ">=3.8,<4.0.0"
numpy = { version = "*", python = ">=3.8" }
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
bitsandbytes = { version = "^0.43", optional = true }

[tool.poetry.dev-dependencies]
isort = { version = "^5", python = ">=3.8" }
Expand All @@ -55,6 +56,9 @@ ruff = "*"
pytest = "*"
pytest-cov = "*"

[tool.poetry.extras]
bitsandbytes = ["bitsandbytes"]

[[tool.poetry.source]]
name = "torch"
url = "https://download.pytorch.org/whl/cpu"
Expand Down
17 changes: 9 additions & 8 deletions pytorch_optimizer/optimizer/rotograd.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from importlib.util import find_spec
from typing import Any, List, Optional, Sequence

import torch
from torch import nn

try:
from geotorch import orthogonal
HAS_GEOTORCH: bool = find_spec('geotorch') is not None

HAS_GEOTORCH = True
except ImportError:
HAS_GEOTORCH = False
if HAS_GEOTORCH:
from geotorch import orthogonal


def divide(numer, denom, eps: float = 1e-15):
def divide(numer: torch.Tensor, de_nom: torch.Tensor, eps: float = 1e-15) -> torch.Tensor:
r"""Numerically stable division."""
return (
torch.sign(numer) * torch.sign(denom) * torch.exp(torch.log(numer.abs() + eps) - torch.log(denom.abs() + eps))
torch.sign(numer)
* torch.sign(de_nom)
* torch.exp(torch.log(numer.abs() + eps) - torch.log(de_nom.abs() + eps))
)


Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__(
):
super().__init__()
if not HAS_GEOTORCH:
raise ImportError('[-] you need to install geotorch to use RotoGrad. pip install geotorch')
raise ImportError('[-] you need to install `geotorch` to use RotoGrad. `pip install geotorch`')

self._backbone = [backbone]
self.heads = heads
Expand Down
16 changes: 9 additions & 7 deletions pytorch_optimizer/optimizer/shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:

shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

if group['weight_decay'] > 0.0:
if not group['decoupled_weight_decay']:
graft_grad.add_(p, alpha=group['weight_decay'])
shampoo_grad.add_(p, alpha=group['weight_decay'])
else:
graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
for g in (graft_grad, shampoo_grad):
self.apply_weight_decay(
p,
g,
group['lr'],
group['weight_decay'],
group['decoupled_weight_decay'],
fixed_decay=False,
)

state['momentum'].mul_(beta1).add_(shampoo_grad)
graft_momentum = graft.update_momentum(grad, beta1)
Expand Down
68 changes: 34 additions & 34 deletions pytorch_optimizer/optimizer/shampoo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Graft:
def __init__(self, *args):
pass

def add_statistics(self, grad: torch.Tensor, unused_beta2: float):
def add_statistics(self, grad: torch.Tensor, unused_beta2: float) -> None:
r"""Add the statistics."""
pass

Expand All @@ -47,7 +47,7 @@ class SGDGraft(Graft):

def __init__(self, var: torch.Tensor):
super().__init__(var)
self.momentum: torch.Tensor = torch.zeros_like(var, device=var.device)
self.momentum: torch.Tensor = torch.zeros_like(var)

def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor:
r"""Update momentum."""
Expand Down Expand Up @@ -78,13 +78,13 @@ def __init__(self, var: torch.Tensor, diagonal_eps: float):
self.diagonal_eps = diagonal_eps
self.statistics: torch.Tensor = torch.zeros_like(var)

def add_statistics(self, grad: torch.Tensor, _):
def add_statistics(self, grad: torch.Tensor, _) -> None:
r"""Add the statistics."""
self.statistics.add_(grad.pow(2))

def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
r"""Get preconditioned gradient."""
return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)
return grad.div(self.statistics.sqrt().add_(self.diagonal_eps))


class RMSPropGraft(SGDGraft):
Expand All @@ -99,13 +99,13 @@ def __init__(self, var: torch.Tensor, diagonal_eps: float):
self.diagonal_eps = diagonal_eps
self.statistics: torch.Tensor = torch.zeros_like(var)

def add_statistics(self, grad: torch.Tensor, beta2: float):
def add_statistics(self, grad: torch.Tensor, beta2: float) -> None:
r"""Add the statistics."""
self.statistics.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
r"""Get preconditioned gradient."""
return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)
return grad.div(self.statistics.sqrt().add_(self.diagonal_eps))


class BlockPartitioner:
Expand All @@ -121,51 +121,51 @@ class BlockPartitioner:
"""

def __init__(self, var: torch.Tensor, rank: int, block_size: int, pre_conditioner_type: int):
self.shape: List[int] = var.shape
self.shape: torch.Size = var.shape

self.splits: List[Tuple[int, np.ndarray]] = []
self.split_sizes: List[Tuple[int, np.ndarray]] = []
self.splits: List[Tuple[int, torch.Tensor]] = []
self.split_sizes: List[Tuple[int, torch.Tensor]] = []

split_sizes: List[np.ndarray] = []
split_sizes: List[torch.Tensor] = []

# We split var into smaller blocks. Here we store the metadata to make that split.
for i, d in enumerate(self.shape):
if block_size <= 0 or block_size >= d:
split_sizes.append(np.array([d], dtype=np.int32))
split_sizes.append(torch.tensor([d], dtype=torch.int32))
continue

# d - 1, otherwise split appends a 0-size array.
num_split: int = (d - 1) // block_size
indices = (np.arange(num_split, dtype=np.int32) + 1) * block_size
indices = (torch.arange(num_split, dtype=torch.int32) + 1) * block_size

sizes: np.ndarray = np.ones(num_split + 1, dtype=np.int32) * block_size
sizes: torch.Tensor = torch.full((num_split + 1,), block_size, dtype=torch.int32)
sizes[-1] = d - indices[-1]

self.splits.append((i, indices))
self.split_sizes.append((i, sizes))
split_sizes.append(sizes)

self.num_splits: int = len(split_sizes)
self.pre_conditioner_shapes: List[List[int]] = self.build_pre_conditioner_shapes(
self.pre_conditioner_shapes: List[List[torch.Tensor]] = self.build_pre_conditioner_shapes(
split_sizes, pre_conditioner_type, rank
)

@staticmethod
def build_pre_conditioner_shapes(
split_sizes: List[np.ndarray], pre_conditioner_type: int, rank: int
) -> List[List[int]]:
split_sizes: List[torch.Tensor], pre_conditioner_type: int, rank: int
) -> List[List[torch.Tensor]]:
r"""Build pre-conditioner shapes."""
pre_conditioner_shapes: List[List[int]] = []
pre_conditioner_shapes: List[List[torch.Tensor]] = []
for t in itertools.product(*split_sizes):
t_shape: List[Optional[List[int]]] = [[d, d] for d in t]
t_shape: List[Optional[List[torch.Tensor]]] = [[d, d] for d in t]
if pre_conditioner_type == PreConditionerType.INPUT:
t_shape = t_shape[:-1] + [None]
if pre_conditioner_type == PreConditionerType.OUTPUT:
t_shape[-1] = None
elif pre_conditioner_type == PreConditionerType.OUTPUT:
t_shape = [None] * (rank - 1) + t_shape[-1:]
pre_conditioner_shapes.extend(t_shape)
return pre_conditioner_shapes

def shapes_for_pre_conditioners(self) -> List[List[int]]:
def shapes_for_pre_conditioners(self) -> List[List[torch.Tensor]]:
r"""Get shapes of pre-conditioner."""
return self.pre_conditioner_shapes

Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(

self.w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2)

self.original_shape: List[int] = var.shape
self.original_shape: torch.Size = var.shape
self.transformed_shape: List[int] = (
merge_small_dims(self.original_shape, block_size) if shape_interpretation else var.shape
)
Expand All @@ -267,7 +267,7 @@ def __init__(
pre_conditioner_type=self.pre_conditioner_type,
)

shapes: List[Optional[List[int]]] = self.partitioner.shapes_for_pre_conditioners()
shapes: List[Optional[List[torch.Tensor]]] = self.partitioner.shapes_for_pre_conditioners()
self.statistics = [self.matrix_eps * torch.eye(shape[0], device=var.device) for shape in shapes if shape]
self.pre_conditioners = [torch.eye(shape[0], device=var.device) for shape in shapes if shape]
self.is_same_shapes = None not in shapes and len(np.unique(shapes)) == 1
Expand All @@ -291,7 +291,7 @@ def skip_precondition(self, x: torch.Tensor) -> bool:
dim > self.no_preconditioning_for_layers_with_dim_gt for dim in x.shape
)

def add_statistics(self, grad: torch.Tensor):
def add_statistics(self, grad: torch.Tensor) -> None:
r"""Compute statistics from gradients and add to the correct state entries.

:param grad: torch.Tensor. gradient to compute statistics from.
Expand All @@ -302,14 +302,13 @@ def add_statistics(self, grad: torch.Tensor):
reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape)
partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)

for j in range(len(partitioned_grads)):
partitioned_grad: torch.Tensor = partitioned_grads[j]
for j, partitioned_grad in enumerate(partitioned_grads):
for i in range(self.rank):
axes: List[int] = [ax for ax in range(partitioned_grad.ndim) if ax != i]
stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, dims=[axes, axes])
self.statistics[j * self.rank + i].mul_(self.beta2).add_(stat, alpha=self.w2)

def compute_pre_conditioners(self):
def compute_pre_conditioners(self) -> None:
r"""Compute L^{-1/exp} for each stats matrix L.

If `self.use_svd` is enabled and where all shapes of statistics & pre-conditioners are same, perform batch SVD.
Expand All @@ -333,15 +332,15 @@ def compute_pre_conditioners(self):
def precondition_block(
partitioned_grad: torch.Tensor,
should_preconditioned_dims: List[bool],
pre_conditioners_for_grad: List[torch.Tensor],
pre_conditioners_for_grad: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
r"""Perform a preconditioning operation on a single gradient block.

Loop invariant: the dimension to be preconditioned is first
We keep all axes in the same cyclic order they were originally.
"""
rank: int = len(partitioned_grad.shape)
roll: Tuple[int, ...] = (*tuple(range(1, rank)), 0)
roll: Tuple[int, ...] = (*range(1, rank), 0)

i: int = 0
for should_precondition_dim in should_preconditioned_dims:
Expand Down Expand Up @@ -376,7 +375,7 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:

merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads)

return torch.reshape(merged_grad, self.original_shape)
return merged_grad.reshape(self.original_shape)


def build_graft(p: torch.Tensor, graft_type: int, diagonal_eps: float = 1e-10):
Expand Down Expand Up @@ -407,7 +406,8 @@ def power_iteration(mat_g: torch.Tensor, num_iters: int = 100) -> torch.Tensor:

for _ in range(num_iters):
torch.mv(mat_g, v, out=mat_v)
v = mat_v.div(torch.linalg.norm(mat_v))
v.copy_(mat_v)
v.div_(torch.linalg.norm(v))

return (v.t() @ mat_g @ v).clamp_min_(1e-16)

Expand Down Expand Up @@ -490,7 +490,7 @@ def compute_power_schur_newton(

@torch.no_grad()
def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
r"""Compute G^{-1/p} using a SVD.
r"""Compute G^{-1/p} using SVD.

Calculate SVD on the GPU. Sometimes, SVD on the CPU is faster than GPU, but based on the several experiments,
CUDA seems much faster than on CPU.
Expand All @@ -503,14 +503,14 @@ def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
return u @ (s.diag() if len(matrix.shape) == 2 else s.diag_embed()) @ vh


def merge_small_dims(shape_to_merge: List[int], max_dim: int) -> List[int]:
def merge_small_dims(shape_to_merge: Union[List[int], torch.Size], max_dim: int) -> List[int]:
r"""Merge small dimensions.

If there are some small dimensions, we collapse them
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
[1, 2, 768, 1, 2048] --> [2, 768, 2048].

:param shape_to_merge: List[int]. Shape to merge small dimensions.
:param shape_to_merge: Union[List[int], torch.Size]. Shape to merge small dimensions.
:param max_dim: int. Maximal dimension of output shape used in merging.
"""
merged_shape: List[int] = []
Expand Down