Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hako-mikan committed Jan 22, 2025
1 parent 5a0f745 commit 3cf597e
Showing 1 changed file with 116 additions and 68 deletions.
184 changes: 116 additions & 68 deletions scripts/mergers/pluslora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from safetensors.torch import load_file, save_file
from scripts.kohyas import extract_lora_from_models as ext
from scripts.A1111 import networks as nets
from scripts.mergers.model_util import (filenamecutter, savemodel)
from scripts.mergers.model_util import filenamecutter, savemodel
from scripts.mergers.mergers import extract_super, unload_forge
from tqdm import tqdm
from modules.ui import versions_html
Expand Down Expand Up @@ -486,7 +486,7 @@ def merge_lora_models(models, ratios, sets, locon, calc_precision, device):
# merge
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if 'alpha' in key or "dora" in key:
continue

lora_module_name = key[:key.rfind(".lora_")]
Expand Down Expand Up @@ -705,15 +705,12 @@ def lycomerge(filename, ratios, calc_precision, device):
def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precision,metasets,device):
if model == []: return "ERROR: No model Selected"
if lnames == "":return "ERROR: No LoRA Selected"

add = ""

print("Plus LoRA start")
import lora
lnames = lnames.split(",")
print("Plus LoRA start")
add = ""

temp = []
for n in lnames:
for n in lnames.split(","):
if ":" in n:
temp.append(n.split(":"))
else:
Expand All @@ -724,11 +721,11 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
loraratios=loraratios.splitlines()
ldict ={}

for i,l in enumerate(loraratios):
for l in loraratios:
if ":" not in l or not any(l.count(",") == x - 1 for x in BLOCKNUMS) : continue
ldict[l.split(":")[0].strip()]=l.split(":")[1]

names, filenames, loratypes, lweis = [], [], [], []
names, filenames, lweis = [], [], []

for n in lnames:
if len(n) ==2:
Expand All @@ -754,6 +751,7 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
dname = dname + "+"+n

checkpoint_info = sd_models.get_closet_checkpoint_match(model)

if forge:
revert_target = sd_models.get_closet_checkpoint_match(shared.opts.sd_model_checkpoint)
print(f"Loading {model}")
Expand All @@ -762,6 +760,7 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi

isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_0.keys()
isv2 = "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight" in theta_0.keys()
isflux = any("double_block" in k for k in theta_0.keys())

try:
import networks
Expand Down Expand Up @@ -791,65 +790,14 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
checkpoint_info = sd_models.get_closet_checkpoint_match(model)
if orig_checkpoint != checkpoint_info:
sd_models.reload_model_weights(info=checkpoint_info)

theta_0 = newpluslora(theta_0,filenames,lweis,names, calc_precision, isxl,isv2, keychanger)

if orig_checkpoint:
sd_models.reload_model_weights(info=orig_checkpoint)
else:
for name,filename, lwei in zip(names,filenames, lweis):
print(f"loading: {name}")
lora_sd, metadata, isv2 = load_state_dict(filename, torch.float, device)

print(f"merging..." ,lwei)
for key in lora_sd.keys():
ratio = 1
fullkey = convert_diffusers_name_to_compvis(key,isv2)

msd_key, _ = fullkey.split(".", 1)
if isxl:
if "lora_unet" in msd_key:
msd_key = msd_key.replace("lora_unet", "diffusion_model")
elif "lora_te1_text_model" in msd_key:
msd_key = msd_key.replace("lora_te1_text_model", "0_transformer_text_model")

for i,block in enumerate(LBLCOKS26):
if block in fullkey or block in msd_key:
ratio = lwei[i]

if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'

# print(f"apply {key} to {module}")

down_weight = lora_sd[key].to(device="cpu")
up_weight = lora_sd[up_key].to(device="cpu")

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = theta_0[keychanger[msd_key]].to(device="cpu")

if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale

elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale

theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight)
theta_0 = oldpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychanger, device)

#usemodelgen(theta_0,model)
settings.append(save_precision)
settings.append("safetensors")
Expand All @@ -867,6 +815,63 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
gc.collect()
return result + add

def oldpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychanger, device):
for name,filename, lwei in zip(names,filenames, lweis):
print(f"loading: {name}")
lora_sd, metadata, isv2 = load_state_dict(filename, torch.float, device)

print(f"merging..." ,lwei)
for key in lora_sd.keys():
ratio = 1
fullkey = convert_diffusers_name_to_compvis(key,isv2)

msd_key, _ = fullkey.split(".", 1)
if isxl:
if "lora_unet" in msd_key:
msd_key = msd_key.replace("lora_unet", "diffusion_model")
elif "lora_te1_text_model" in msd_key:
msd_key = msd_key.replace("lora_te1_text_model", "0_transformer_text_model")

for i,block in enumerate(LBLCOKS26):
if block in fullkey or block in msd_key:
ratio = lwei[i]

if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'

# print(f"apply {key} to {module}")

down_weight = lora_sd[key].to(device="cpu")
up_weight = lora_sd[up_key].to(device="cpu")

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = theta_0[keychanger[msd_key]].to(device="cpu")

if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale

elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale

theta_0[keychanger[msd_key]] = torch.nn.Parameter(weight)
return theta_0

def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychanger):
nets.load_networks(names, [1]* len(names),[1]* len(names), None, isxl, isv2)

Expand Down Expand Up @@ -916,16 +921,21 @@ def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychan

def plusweights(weight, module, bias = None):
with torch.no_grad():
updown = module.calc_updown(weight.to(dtype=torch.float))
if weight.dtype == torch.float8_e4m3fn or weight.dtype == torch.float8_e5m2: # Float8 の場合
orig_dtype = weight.dtype
weight = weight.to(torch.float32) # Float32 に変換
else:
orig_dtype = None
updown = module.calc_updown(weight.to(dtype=torch.float32))
if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
if type(updown) == tuple:
updown, ex_bias = updown
if ex_bias is not None and bias is not None:
bias += ex_bias
weight += updown
return weight, bias
weight += updown.to(weight.dtype)
return weight if orig_dtype is None else weight.to(orig_dtype), bias

def plusweightsqvk(inweight, outweight, network_layer_name, module ,net,bias = None):
with torch.no_grad():
Expand Down Expand Up @@ -972,6 +982,7 @@ def lbw(lora,lwei,isv2):
errormodules.append(key)

ltype = type(lora.modules[key]).__name__

set = False
if ltype in LORAANDSOON.keys():
setattr(lora.modules[key],LORAANDSOON[ltype],torch.nn.Parameter(getattr(lora.modules[key],LORAANDSOON[ltype]) * ratio))
Expand Down Expand Up @@ -1498,7 +1509,7 @@ def __init__(
}


def convert_diffusers_name_to_compvis(key, is_sd2):
def convert_diffusers_name_to_compvis(key, is_sd2, isflux = False):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
Expand Down Expand Up @@ -1561,6 +1572,43 @@ def match(match_list, regex_text):
else:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

#for flux
if match(m, r"lora_unet_double_blocks_(\d+)_(img|txt)_(attn|mlp|mod)_(proj|qkv|lin|\d+)(.*)"):
block_type = m[1] # img or txt
module_type = m[2] # attn, mlp, mod
specific_module = m[3] # proj, qkv, lin, or numeric index

# Create suffix for specific module types
if module_type == "attn":
if specific_module == "proj":
suffix = "proj.weight"
elif specific_module == "qkv":
suffix = "qkv.weight"
else:
suffix = f"norm.{specific_module}_norm.scale"
elif module_type == "mlp":
suffix = f"{specific_module}.weight"
elif module_type == "mod":
suffix = f"lin.weight"
else:
suffix = specific_module

return f"model.diffusion_model.double_blocks.{m[0]}.{block_type}_{module_type}.{suffix}"

if match(m, r"lora_unet_single_blocks_(\d+)_(linear\d+|modulation_lin)(.*)"):
block_index = m[0] # single block index
module_name = m[1] # linear1, linear2, or modulation_lin

# Create suffix for module types
if "linear" in module_name:
suffix = f"{module_name}.weight"
elif module_name == "modulation_lin":
suffix = "modulation.lin.weight"
else:
suffix = module_name

return f"model.diffusion_model.single_blocks.{block_index}.{suffix}"

return key

def read_model_state_dict(checkpoint_info, device):
Expand Down

0 comments on commit 3cf597e

Please sign in to comment.