From 1391436c38438d86a93f3190ef3f117723cb66d8 Mon Sep 17 00:00:00 2001 From: ljleb Date: Wed, 6 Sep 2023 23:55:15 -0400 Subject: [PATCH 1/6] perp-add --- sd_meh/merge_methods.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index c10c459..9865d52 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -209,3 +209,12 @@ def filter_top_k(a: Tensor, k: float): k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) top_k_filter = (torch.abs(a) >= k_value).float() return a * top_k_filter + + +def add_perpendicular( + a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs +) -> Tensor: + a_c = a.float() - c.float() + b_c = b.float() - c.float() + b_perp = b_c - a_c * torch.sum(a_c * b_c) / torch.norm(a_c) ** 2 + return a.float() + alpha * b_perp From b1d510df21597ca3f292c05ff2c9dbb749584616 Mon Sep 17 00:00:00 2001 From: ljleb Date: Wed, 6 Sep 2023 23:56:57 -0400 Subject: [PATCH 2/6] all --- sd_meh/merge_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 9865d52..3ccd4e3 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -17,6 +17,7 @@ "similarity_add_difference", "distribution_crossover", "ties_add_difference", + "add_perpendicular", ] From dbd2ca971dd47d2f8af194868f21a076bddf7f48 Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 7 Sep 2023 01:15:37 -0400 Subject: [PATCH 3/6] no diff space --- sd_meh/merge_methods.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 3ccd4e3..a984881 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -213,9 +213,13 @@ def filter_top_k(a: Tensor, k: float): def add_perpendicular( - a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs + a: Tensor, b: Tensor, alpha: float, c: Tensor = None, **kwargs ) -> Tensor: - a_c = a.float() - c.float() - b_c = b.float() - c.float() - b_perp = b_c - a_c * torch.sum(a_c * b_c) / torch.norm(a_c) ** 2 - return a.float() + alpha * b_perp + if c is None: + c = a + + a = a.float() + b = b.float() + c = c.float() + b_perp = b - c * torch.sum(c * b) / torch.norm(c) ** 2 + return a + alpha * b_perp From a7137857dce5a4b4cf5bbed80bcea9f18dbb56ed Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 11 Feb 2024 15:59:00 -0500 Subject: [PATCH 4/6] try alternative orthogonalization strategy --- sd_meh/merge_methods.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index a984881..f135848 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -215,11 +215,7 @@ def filter_top_k(a: Tensor, k: float): def add_perpendicular( a: Tensor, b: Tensor, alpha: float, c: Tensor = None, **kwargs ) -> Tensor: - if c is None: - c = a - - a = a.float() - b = b.float() - c = c.float() - b_perp = b - c * torch.sum(c * b) / torch.norm(c) ** 2 - return a + alpha * b_perp + a_diff = a.float() - c.float() + b_diff = b.float() - c.float() + b_ortho = b_diff * torch.sum(a_diff * b_diff) / torch.norm(b_diff) ** 2 + return (a + alpha * b_ortho).to(a.dtype) From 77d75d4f74ee1f8c2c194edd311a36ea6ca2d3ed Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 11 Feb 2024 16:50:19 -0500 Subject: [PATCH 5/6] perp in diff space --- sd_meh/merge_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index f135848..6f440e2 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -217,5 +217,5 @@ def add_perpendicular( ) -> Tensor: a_diff = a.float() - c.float() b_diff = b.float() - c.float() - b_ortho = b_diff * torch.sum(a_diff * b_diff) / torch.norm(b_diff) ** 2 - return (a + alpha * b_ortho).to(a.dtype) + b_perp = b_diff - a_diff * torch.sum(a_diff * b_diff) / torch.norm(a_diff) ** 2 + return (a + alpha * b_perp).to(a.dtype) From ec7ecb886cd3466f0a85b4b03611894d7fe84292 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 12 Feb 2024 11:32:54 -0500 Subject: [PATCH 6/6] nan --- sd_meh/merge_methods.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 6f440e2..1546833 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -217,5 +217,9 @@ def add_perpendicular( ) -> Tensor: a_diff = a.float() - c.float() b_diff = b.float() - c.float() - b_perp = b_diff - a_diff * torch.sum(a_diff * b_diff) / torch.norm(a_diff) ** 2 - return (a + alpha * b_perp).to(a.dtype) + a_ortho = a_diff * (a_diff / torch.linalg.norm(a_diff) * (b_diff / torch.linalg.norm(a_diff))).sum() + b_perp = b_diff - a_ortho + res = a + alpha * b_perp + if torch.isnan(res).any(): + return a + return res.to(a.dtype)