From 43d65ae68c97e077117b17b7c9d1936583f965eb Mon Sep 17 00:00:00 2001 From: Mel Massadian Date: Sun, 9 Jun 2024 19:13:46 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20add=20ModelPruner=20(wip)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nodes/prune.py | 374 +++++++++++++++++++++++++++++++++++++++++++++++++ utils.py | 85 ++++++++++- 2 files changed, 458 insertions(+), 1 deletion(-) create mode 100644 nodes/prune.py diff --git a/nodes/prune.py b/nodes/prune.py new file mode 100644 index 0000000..8ce8f6c --- /dev/null +++ b/nodes/prune.py @@ -0,0 +1,374 @@ +from pathlib import Path + +import safetensors.torch +import torch +import tqdm + +from ..log import log +from ..utils import StringConvertibleEnum +from ..utils import output_dir as comfy_out_dir + +PRUNE_DATA = { + "known_junk_prefix": [ + "embedding_manager.embedder.", + "lora_te_text_model", + "control_model.", + ], + "nai_keys": { + "cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", + "cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", + "cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", + }, +} + +# position_ids in clip is int64. model_ema.num_updates is int32 +dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} +dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} +dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16} + + +class Precision(StringConvertibleEnum): + FULL = "full" + FP32 = "fp32" + FP16 = "fp16" + BF16 = "bf16" + FP8 = "fp8" + + +class Operation(StringConvertibleEnum): + COPY = "copy" + CONVERT = "convert" + DELETE = "delete" + + +class MTB_ModelPruner: + @classmethod + def INPUT_TYPES(cls): + return { + "optional": { + "unet": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + }, + "required": { + "save_separately": ("BOOLEAN", {"default": False}), + "save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}), + "fix_clip": ("BOOLEAN", {"default": True}), + "remove_junk": ("BOOLEAN", {"default": True}), + "ema_mode": ( + ("disabled", "remove_ema", "ema_only"), + {"default": "remove_ema"}, + ), + "precision_unet": ( + Precision.list_members(), + {"default": Precision.FULL.value}, + ), + "operation_unet": ( + Operation.list_members(), + {"default": Operation.CONVERT.value}, + ), + "precision_clip": ( + Precision.list_members(), + {"default": Precision.FULL.value}, + ), + "operation_clip": ( + Operation.list_members(), + {"default": Operation.CONVERT.value}, + ), + "precision_vae": ( + Precision.list_members(), + {"default": Precision.FULL.value}, + ), + "operation_vae": ( + Operation.list_members(), + {"default": Operation.CONVERT.value}, + ), + }, + } + + OUTPUT_NODE = True + RETURN_TYPES = () + CATEGORY = "mtb/prune" + FUNCTION = "prune" + + def convert_precision(self, tensor: torch.Tensor, precision: Precision): + precision = Precision.from_str(precision) + log.debug(f"Converting to {precision}") + match precision: + case Precision.FP8: + if tensor.dtype in dtypes_to_fp8: + return tensor.to(torch.float8_e4m3fn) + log.error(f"Cannot convert {tensor.dtype} to fp8") + return tensor + case Precision.FP16: + if tensor.dtype in dtypes_to_fp16: + return tensor.half() + log.error(f"Cannot convert {tensor.dtype} to f16") + return tensor + case Precision.BF16: + if tensor.dtype in dtypes_to_bf16: + return tensor.bfloat16() + log.error(f"Cannot convert {tensor.dtype} to bf16") + return tensor + case Precision.FULL | Precision.FP32: + return tensor + + def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None): + if clip: + return (any(k.startswith("conditioner.embedders") for k in clip),) + return False + + def has_ema(self, unet: dict[str, torch.Tensor]): + return any(k.startswith("model_ema") for k in unet) + + def fix_clip(self, clip: dict[str, torch.Tensor] | None): + if self.is_sdxl_model(clip): + log.warn("[fix clip] SDXL not supported") + return + + if clip is None: + return + + position_id_key = ( + "cond_stage_model.transformer.text_model.embeddings.position_ids" + ) + if position_id_key in clip: + correct = torch.Tensor([list(range(77))]).to(torch.int64) + now = clip[position_id_key].to(torch.int64) + + broken = correct.ne(now) + broken = [i for i in range(77) if broken[0][i]] + + if len(broken) != 0: + clip[position_id_key] = correct + log.info(f"[Converter] Fixed broken clip\n{broken}") + else: + log.info( + "[Converter] Clip in this model is fine, skip fixing..." + ) + + else: + log.info("[Converter] Missing position id in model, try fixing...") + clip[position_id_key] = torch.Tensor([list(range(77))]).to( + torch.int64 + ) + return clip + + def get_dicts(self, unet, clip, vae): + clip_sd = clip.get_sd() + state_dict = unet.model.state_dict_for_saving( + clip_sd, vae.get_sd(), None + ) + + unet = { + k: v + for k, v in state_dict.items() + if k.startswith("model.diffusion_model") + } + clip = { + k: v + for k, v in state_dict.items() + if k.startswith("cond_stage_model") + or k.startswith("conditioner.embedders") + } + vae = { + k: v + for k, v in state_dict.items() + if k.startswith("first_stage_model") + } + + other = { + k: v + for k, v in state_dict.items() + if k not in unet and k not in vae and k not in clip + } + + return (unet, clip, vae, other) + + def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]): + need_delete: list[str] = [] + for layer in tensors: + for key in layer: + for jk in PRUNE_DATA["known_junk_prefix"]: + if key.startswith(jk): + need_delete.append(".".join([layer, key])) + + for k in need_delete: + log.info(f"Removing junk data: {k}") + del tensors[k] + + return tensors + + def prune( + self, + *, + save_separately: bool, + save_folder: str, + fix_clip: bool, + remove_junk: bool, + ema_mode: str, + precision_unet: Precision, + precision_clip: Precision, + precision_vae: Precision, + operation_unet: str, + operation_clip: str, + operation_vae: str, + unet: dict[str, torch.Tensor] | None = None, + clip: dict[str, torch.Tensor] | None = None, + vae: dict[str, torch.Tensor] | None = None, + ): + operation = { + "unet": Operation.from_str(operation_unet), + "clip": Operation.from_str(operation_clip), + "vae": Operation.from_str(operation_vae), + } + precision = { + "unet": Precision.from_str(precision_unet), + "clip": Precision.from_str(precision_clip), + "vae": Precision.from_str(precision_vae), + } + + unet, clip, vae, _other = self.get_dicts(unet, clip, vae) + + out_dir = Path(save_folder) + folder = out_dir.parent + if not out_dir.is_absolute(): + folder = (comfy_out_dir / save_folder).parent + + if not folder.exists(): + if folder.parent.exists(): + folder.mkdir() + else: + raise FileNotFoundError( + f"Folder {folder.parent} does not exist" + ) + + name = out_dir.name + save_name = f"{name}-{precision_unet}" + if ema_mode != "disabled": + save_name += f"-{ema_mode}" + if fix_clip: + save_name += "-clip-fix" + + if ( + any(o == Operation.CONVERT for o in operation.values()) + and any(p == Precision.FP8 for p in precision.values()) + and torch.__version__ < "2.1.0" + ): + raise NotImplementedError( + "PyTorch 2.1.0 or newer is required for fp8 conversion" + ) + + if not self.is_sdxl_model(clip): + for part in [unet, vae, clip]: + if part: + nai_keys = PRUNE_DATA["nai_keys"] + for k in list(part.keys()): + for r in nai_keys: + if isinstance(k, str) and k.startswith(r): + new_key = k.replace(r, nai_keys[r]) + part[new_key] = part[k] + del part[k] + log.info( + f"[Converter] Fixed novelai error key {k}" + ) + break + + if fix_clip: + clip = self.fix_clip(clip) + + ok: dict[str, dict[str, torch.Tensor]] = { + "unet": {}, + "clip": {}, + "vae": {}, + } + + def _hf(part: str, wk: str, t: torch.Tensor): + if not isinstance(t, torch.Tensor): + log.debug("Not a torch tensor, skipping key") + return + + log.debug(f"Operation {operation[part]}") + if operation[part] == Operation.CONVERT: + ok[part][wk] = self.convert_precision( + t, precision[part] + ) # conv_func(t) + elif operation[part] == Operation.COPY: + ok[part][wk] = t + elif operation[part] == Operation.DELETE: + return + + log.info("[Converter] Converting model...") + + for part_name, part in zip( + ["unet", "vae", "clip", "other"], + [unet, vae, clip], + strict=False, + ): + if part: + match ema_mode: + case "remove_ema": + for k, v in tqdm.tqdm(part.items()): + if "model_ema." not in k: + _hf(part_name, k, v) + case "ema_only": + if not self.has_ema(part): + log.warn("No EMA to extract") + return + for k in tqdm.tqdm(part): + ema_k = "___" + try: + ema_k = "model_ema." + k[6:].replace(".", "") + except Exception: + pass + if ema_k in part: + _hf(part_name, k, part[ema_k]) + elif not k.startswith("model_ema.") or k in [ + "model_ema.num_updates", + "model_ema.decay", + ]: + _hf(part_name, k, part[k]) + case "disabled" | _: + for k, v in tqdm.tqdm(part.items()): + _hf(part_name, k, v) + + if save_separately: + if remove_junk: + ok = self.do_remove_junk(ok) + + flat_ok = { + k: v + for _, subdict in ok.items() + for k, v in subdict.items() + } + save_path = ( + folder / f"{part_name}-{save_name}.safetensors" + ).as_posix() + safetensors.torch.save_file(flat_ok, save_path) + ok: dict[str, dict[str, torch.Tensor]] = { + "unet": {}, + "clip": {}, + "vae": {}, + } + + if save_separately: + return () + + if remove_junk: + ok = self.do_remove_junk(ok) + + flat_ok = { + k: v for _, subdict in ok.items() for k, v in subdict.items() + } + + try: + safetensors.torch.save_file( + flat_ok, (folder / f"{save_name}.safetensors").as_posix() + ) + except Exception as e: + log.error(e) + + return () + + +__nodes__ = [MTB_ModelPruner] diff --git a/utils.py b/utils.py index 4cc1c8c..d35ff56 100644 --- a/utils.py +++ b/utils.py @@ -9,8 +9,9 @@ import subprocess import sys import uuid +from enum import Enum from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, TypeVar, Union import folder_paths import numpy as np @@ -209,6 +210,88 @@ def get_server_info(): # region MISC Utilities + +# TODO: use mtb.core directly instead of copying parts here +T = TypeVar("T", bound="StringConvertibleEnum") + + +class StringConvertibleEnum(Enum): + """Base class for enums with utility methods for string conversion and member listing.""" + + @classmethod + def from_str(cls: type[T], label: str | T) -> T: + """ + Convert a string to the corresponding enum value (case sensitive). + + Args: + label (Union[str, T]): The string or enum value to convert. + + Returns + ------- + T: The corresponding enum value. + + Raises + ------ + ValueError: If the label does not correspond to any enum member. + """ + if isinstance(label, cls): + return label + if isinstance(label, str): + # from key + if label in cls.__members__: + return cls[label] + + for member in cls: + if member.value == label: + return member + + raise ValueError( + f"Unknown label: '{label}'. Valid members: {list(cls.__members__.keys())}, " + f"valid values: {cls.list_members()}" + ) + + @classmethod + def to_str(cls: type[T], enum_value: T) -> str: + """ + Convert an enum value to its string representation. + + Args: + enum_value (T): The enum value to convert. + + Returns + ------- + str: The string representation of the enum value. + + Raises + ------ + ValueError: If the enum value is invalid. + """ + if isinstance(enum_value, cls): + return enum_value.value + raise ValueError(f"Invalid Enum: {enum_value}") + + @classmethod + def list_members(cls: type[T]) -> list[str]: + """ + Return a list of string representations of all enum members. + + Returns + ------- + List[str]: List of all enum member values. + """ + return [enum.value for enum in cls] + + def __str__(self) -> str: + """ + Returns the string representation of the enum value. + + Returns + ------- + str: The string representation of the enum value. + """ + return self.value + + def backup_file( fp: Path, target: Optional[Path] = None,