-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
458 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,374 @@ | ||
from pathlib import Path | ||
|
||
import safetensors.torch | ||
import torch | ||
import tqdm | ||
|
||
from ..log import log | ||
from ..utils import StringConvertibleEnum | ||
from ..utils import output_dir as comfy_out_dir | ||
|
||
PRUNE_DATA = { | ||
"known_junk_prefix": [ | ||
"embedding_manager.embedder.", | ||
"lora_te_text_model", | ||
"control_model.", | ||
], | ||
"nai_keys": { | ||
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", | ||
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", | ||
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", | ||
}, | ||
} | ||
|
||
# position_ids in clip is int64. model_ema.num_updates is int32 | ||
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} | ||
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} | ||
dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16} | ||
|
||
|
||
class Precision(StringConvertibleEnum): | ||
FULL = "full" | ||
FP32 = "fp32" | ||
FP16 = "fp16" | ||
BF16 = "bf16" | ||
FP8 = "fp8" | ||
|
||
|
||
class Operation(StringConvertibleEnum): | ||
COPY = "copy" | ||
CONVERT = "convert" | ||
DELETE = "delete" | ||
|
||
|
||
class MTB_ModelPruner: | ||
@classmethod | ||
def INPUT_TYPES(cls): | ||
return { | ||
"optional": { | ||
"unet": ("MODEL",), | ||
"clip": ("CLIP",), | ||
"vae": ("VAE",), | ||
}, | ||
"required": { | ||
"save_separately": ("BOOLEAN", {"default": False}), | ||
"save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}), | ||
"fix_clip": ("BOOLEAN", {"default": True}), | ||
"remove_junk": ("BOOLEAN", {"default": True}), | ||
"ema_mode": ( | ||
("disabled", "remove_ema", "ema_only"), | ||
{"default": "remove_ema"}, | ||
), | ||
"precision_unet": ( | ||
Precision.list_members(), | ||
{"default": Precision.FULL.value}, | ||
), | ||
"operation_unet": ( | ||
Operation.list_members(), | ||
{"default": Operation.CONVERT.value}, | ||
), | ||
"precision_clip": ( | ||
Precision.list_members(), | ||
{"default": Precision.FULL.value}, | ||
), | ||
"operation_clip": ( | ||
Operation.list_members(), | ||
{"default": Operation.CONVERT.value}, | ||
), | ||
"precision_vae": ( | ||
Precision.list_members(), | ||
{"default": Precision.FULL.value}, | ||
), | ||
"operation_vae": ( | ||
Operation.list_members(), | ||
{"default": Operation.CONVERT.value}, | ||
), | ||
}, | ||
} | ||
|
||
OUTPUT_NODE = True | ||
RETURN_TYPES = () | ||
CATEGORY = "mtb/prune" | ||
FUNCTION = "prune" | ||
|
||
def convert_precision(self, tensor: torch.Tensor, precision: Precision): | ||
precision = Precision.from_str(precision) | ||
log.debug(f"Converting to {precision}") | ||
match precision: | ||
case Precision.FP8: | ||
if tensor.dtype in dtypes_to_fp8: | ||
return tensor.to(torch.float8_e4m3fn) | ||
log.error(f"Cannot convert {tensor.dtype} to fp8") | ||
return tensor | ||
case Precision.FP16: | ||
if tensor.dtype in dtypes_to_fp16: | ||
return tensor.half() | ||
log.error(f"Cannot convert {tensor.dtype} to f16") | ||
return tensor | ||
case Precision.BF16: | ||
if tensor.dtype in dtypes_to_bf16: | ||
return tensor.bfloat16() | ||
log.error(f"Cannot convert {tensor.dtype} to bf16") | ||
return tensor | ||
case Precision.FULL | Precision.FP32: | ||
return tensor | ||
|
||
def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None): | ||
if clip: | ||
return (any(k.startswith("conditioner.embedders") for k in clip),) | ||
return False | ||
|
||
def has_ema(self, unet: dict[str, torch.Tensor]): | ||
return any(k.startswith("model_ema") for k in unet) | ||
|
||
def fix_clip(self, clip: dict[str, torch.Tensor] | None): | ||
if self.is_sdxl_model(clip): | ||
log.warn("[fix clip] SDXL not supported") | ||
return | ||
|
||
if clip is None: | ||
return | ||
|
||
position_id_key = ( | ||
"cond_stage_model.transformer.text_model.embeddings.position_ids" | ||
) | ||
if position_id_key in clip: | ||
correct = torch.Tensor([list(range(77))]).to(torch.int64) | ||
now = clip[position_id_key].to(torch.int64) | ||
|
||
broken = correct.ne(now) | ||
broken = [i for i in range(77) if broken[0][i]] | ||
|
||
if len(broken) != 0: | ||
clip[position_id_key] = correct | ||
log.info(f"[Converter] Fixed broken clip\n{broken}") | ||
else: | ||
log.info( | ||
"[Converter] Clip in this model is fine, skip fixing..." | ||
) | ||
|
||
else: | ||
log.info("[Converter] Missing position id in model, try fixing...") | ||
clip[position_id_key] = torch.Tensor([list(range(77))]).to( | ||
torch.int64 | ||
) | ||
return clip | ||
|
||
def get_dicts(self, unet, clip, vae): | ||
clip_sd = clip.get_sd() | ||
state_dict = unet.model.state_dict_for_saving( | ||
clip_sd, vae.get_sd(), None | ||
) | ||
|
||
unet = { | ||
k: v | ||
for k, v in state_dict.items() | ||
if k.startswith("model.diffusion_model") | ||
} | ||
clip = { | ||
k: v | ||
for k, v in state_dict.items() | ||
if k.startswith("cond_stage_model") | ||
or k.startswith("conditioner.embedders") | ||
} | ||
vae = { | ||
k: v | ||
for k, v in state_dict.items() | ||
if k.startswith("first_stage_model") | ||
} | ||
|
||
other = { | ||
k: v | ||
for k, v in state_dict.items() | ||
if k not in unet and k not in vae and k not in clip | ||
} | ||
|
||
return (unet, clip, vae, other) | ||
|
||
def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]): | ||
need_delete: list[str] = [] | ||
for layer in tensors: | ||
for key in layer: | ||
for jk in PRUNE_DATA["known_junk_prefix"]: | ||
if key.startswith(jk): | ||
need_delete.append(".".join([layer, key])) | ||
|
||
for k in need_delete: | ||
log.info(f"Removing junk data: {k}") | ||
del tensors[k] | ||
|
||
return tensors | ||
|
||
def prune( | ||
self, | ||
*, | ||
save_separately: bool, | ||
save_folder: str, | ||
fix_clip: bool, | ||
remove_junk: bool, | ||
ema_mode: str, | ||
precision_unet: Precision, | ||
precision_clip: Precision, | ||
precision_vae: Precision, | ||
operation_unet: str, | ||
operation_clip: str, | ||
operation_vae: str, | ||
unet: dict[str, torch.Tensor] | None = None, | ||
clip: dict[str, torch.Tensor] | None = None, | ||
vae: dict[str, torch.Tensor] | None = None, | ||
): | ||
operation = { | ||
"unet": Operation.from_str(operation_unet), | ||
"clip": Operation.from_str(operation_clip), | ||
"vae": Operation.from_str(operation_vae), | ||
} | ||
precision = { | ||
"unet": Precision.from_str(precision_unet), | ||
"clip": Precision.from_str(precision_clip), | ||
"vae": Precision.from_str(precision_vae), | ||
} | ||
|
||
unet, clip, vae, _other = self.get_dicts(unet, clip, vae) | ||
|
||
out_dir = Path(save_folder) | ||
folder = out_dir.parent | ||
if not out_dir.is_absolute(): | ||
folder = (comfy_out_dir / save_folder).parent | ||
|
||
if not folder.exists(): | ||
if folder.parent.exists(): | ||
folder.mkdir() | ||
else: | ||
raise FileNotFoundError( | ||
f"Folder {folder.parent} does not exist" | ||
) | ||
|
||
name = out_dir.name | ||
save_name = f"{name}-{precision_unet}" | ||
if ema_mode != "disabled": | ||
save_name += f"-{ema_mode}" | ||
if fix_clip: | ||
save_name += "-clip-fix" | ||
|
||
if ( | ||
any(o == Operation.CONVERT for o in operation.values()) | ||
and any(p == Precision.FP8 for p in precision.values()) | ||
and torch.__version__ < "2.1.0" | ||
): | ||
raise NotImplementedError( | ||
"PyTorch 2.1.0 or newer is required for fp8 conversion" | ||
) | ||
|
||
if not self.is_sdxl_model(clip): | ||
for part in [unet, vae, clip]: | ||
if part: | ||
nai_keys = PRUNE_DATA["nai_keys"] | ||
for k in list(part.keys()): | ||
for r in nai_keys: | ||
if isinstance(k, str) and k.startswith(r): | ||
new_key = k.replace(r, nai_keys[r]) | ||
part[new_key] = part[k] | ||
del part[k] | ||
log.info( | ||
f"[Converter] Fixed novelai error key {k}" | ||
) | ||
break | ||
|
||
if fix_clip: | ||
clip = self.fix_clip(clip) | ||
|
||
ok: dict[str, dict[str, torch.Tensor]] = { | ||
"unet": {}, | ||
"clip": {}, | ||
"vae": {}, | ||
} | ||
|
||
def _hf(part: str, wk: str, t: torch.Tensor): | ||
if not isinstance(t, torch.Tensor): | ||
log.debug("Not a torch tensor, skipping key") | ||
return | ||
|
||
log.debug(f"Operation {operation[part]}") | ||
if operation[part] == Operation.CONVERT: | ||
ok[part][wk] = self.convert_precision( | ||
t, precision[part] | ||
) # conv_func(t) | ||
elif operation[part] == Operation.COPY: | ||
ok[part][wk] = t | ||
elif operation[part] == Operation.DELETE: | ||
return | ||
|
||
log.info("[Converter] Converting model...") | ||
|
||
for part_name, part in zip( | ||
["unet", "vae", "clip", "other"], | ||
[unet, vae, clip], | ||
strict=False, | ||
): | ||
if part: | ||
match ema_mode: | ||
case "remove_ema": | ||
for k, v in tqdm.tqdm(part.items()): | ||
if "model_ema." not in k: | ||
_hf(part_name, k, v) | ||
case "ema_only": | ||
if not self.has_ema(part): | ||
log.warn("No EMA to extract") | ||
return | ||
for k in tqdm.tqdm(part): | ||
ema_k = "___" | ||
try: | ||
ema_k = "model_ema." + k[6:].replace(".", "") | ||
except Exception: | ||
pass | ||
if ema_k in part: | ||
_hf(part_name, k, part[ema_k]) | ||
elif not k.startswith("model_ema.") or k in [ | ||
"model_ema.num_updates", | ||
"model_ema.decay", | ||
]: | ||
_hf(part_name, k, part[k]) | ||
case "disabled" | _: | ||
for k, v in tqdm.tqdm(part.items()): | ||
_hf(part_name, k, v) | ||
|
||
if save_separately: | ||
if remove_junk: | ||
ok = self.do_remove_junk(ok) | ||
|
||
flat_ok = { | ||
k: v | ||
for _, subdict in ok.items() | ||
for k, v in subdict.items() | ||
} | ||
save_path = ( | ||
folder / f"{part_name}-{save_name}.safetensors" | ||
).as_posix() | ||
safetensors.torch.save_file(flat_ok, save_path) | ||
ok: dict[str, dict[str, torch.Tensor]] = { | ||
"unet": {}, | ||
"clip": {}, | ||
"vae": {}, | ||
} | ||
|
||
if save_separately: | ||
return () | ||
|
||
if remove_junk: | ||
ok = self.do_remove_junk(ok) | ||
|
||
flat_ok = { | ||
k: v for _, subdict in ok.items() for k, v in subdict.items() | ||
} | ||
|
||
try: | ||
safetensors.torch.save_file( | ||
flat_ok, (folder / f"{save_name}.safetensors").as_posix() | ||
) | ||
except Exception as e: | ||
log.error(e) | ||
|
||
return () | ||
|
||
|
||
__nodes__ = [MTB_ModelPruner] |
Oops, something went wrong.