Skip to content

Commit

Permalink
cache impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Jan 27, 2024
1 parent de71102 commit a01e016
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
12 changes: 11 additions & 1 deletion sd_meh/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def merge_models(
work_device: Optional[str] = None,
prune: bool = False,
threads: int = 1,
cache: Optional[Dict] = None,
) -> Dict:
thetas = load_thetas(models, prune, device, precision)

Expand Down Expand Up @@ -169,6 +170,7 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
cache=cache,
)

return un_prune_model(merged, thetas, models, device, prune, precision)
Expand Down Expand Up @@ -221,6 +223,7 @@ def simple_merge(
device: str = "cpu",
work_device: Optional[str] = None,
threads: int = 1,
cache: Optional[Dict] = None,
) -> Dict:
futures = []
with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress:
Expand All @@ -238,6 +241,7 @@ def simple_merge(
weights_clip,
device,
work_device,
cache,
)
futures.append(future)

Expand Down Expand Up @@ -367,6 +371,7 @@ def merge_key(
weights_clip: bool = False,
device: str = "cpu",
work_device: Optional[str] = None,
cache: Optional[Dict] = None,
) -> Optional[Tuple[str, Dict]]:
if work_device is None:
work_device = device
Expand Down Expand Up @@ -410,7 +415,7 @@ def merge_key(
except AttributeError as e:
raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e

merge_args = get_merge_method_args(current_bases, thetas, key, work_device)
merge_args = get_merge_method_args(current_bases, thetas, key, work_device, cache)

# dealing wiht pix2pix and inpainting models
if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()):
Expand Down Expand Up @@ -460,11 +465,16 @@ def get_merge_method_args(
thetas: Dict,
key: str,
work_device: str,
cache: Optional[Dict],
) -> Dict:
if cache is not None and key not in cache:
cache[key] = {}

merge_method_args = {
"a": thetas["model_a"][key].to(work_device),
"b": thetas["model_b"][key].to(work_device),
**current_bases,
"cache": cache[key]
}

if "model_c" in thetas:
Expand Down
55 changes: 35 additions & 20 deletions sd_meh/merge_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,28 +240,35 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
a_neurons -= a_centroid
b_neurons -= b_centroid

svd_driver = "gesvd" if a.is_cuda else None
u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver)

alpha_is_float = alpha != round(alpha)
if alpha_is_float:
# cancel reflection. without this, eigenvalues often have a complex component
# and then we can't obtain a valid dtype for the merge
u[:, -1] /= torch.det(u) * torch.det(v_t)

transform = rotation = u @ v_t
if not torch.isfinite(u).all():
raise ValueError(
textwrap.dedent(
f"""determinant error: {torch.det(rotation)}.
This can happen when merging on the CPU with the "rotate" method.
Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks.
See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"""

if kwargs["cache"] is not None and "rotation" in kwargs["cache"]:
rotation = transform = kwargs["cache"]["rotation"].to(a.device)
else:
svd_driver = "gesvd" if a.is_cuda else None
u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver)

if alpha_is_float:
# cancel reflection. without this, eigenvalues often have a complex component
# and then we can't obtain a valid dtype for the merge
u[:, -1] /= torch.det(u) * torch.det(v_t)

rotation = transform = u @ v_t
if not torch.isfinite(u).all():
raise ValueError(
textwrap.dedent(
f"""determinant error: {torch.det(rotation)}.
This can happen when merging on the CPU with the "rotate" method.
Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks.
See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"""
)
)
)

if kwargs["cache"] is not None:
kwargs["cache"]["rotation"] = rotation.cpu()

if alpha_is_float:
transform = fractional_matrix_power(transform, alpha)
transform = fractional_matrix_power(transform, alpha, kwargs["cache"])
elif alpha == 0:
transform = torch.eye(
len(transform),
Expand All @@ -280,8 +287,16 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
return a_neurons.reshape_as(a).to(a.dtype)


def fractional_matrix_power(matrix: Tensor, power: float):
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
def fractional_matrix_power(matrix: Tensor, power: float, cache: dict):
if cache is not None and "eigenvalues" in cache:
eigenvalues = cache["eigenvalues"].to(matrix.device)
eigenvectors = cache["eigenvectors"].to(matrix.device)
else:
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
if cache is not None:
cache["eigenvalues"] = eigenvalues.cpu()
cache["eigenvectors"] = eigenvectors.cpu()

eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors)
return result.real.to(dtype=matrix.dtype)

0 comments on commit a01e016

Please sign in to comment.