From f17f8170ac0635d5e1067c52c23bb889a427fde7 Mon Sep 17 00:00:00 2001 From: altoiddealer Date: Tue, 29 Oct 2024 10:57:45 -0400 Subject: [PATCH] Minimal preload before dump current model data Performs the minimum steps needed to verify model/modules before the loaded model data gets trashed, subsequently preventing unnecessary computation and headaches due to a missing required text encoder, etc. --- backend/loader.py | 39 +++++++++++++++++++++++++++++++++++++-- modules/sd_models.py | 26 ++++++++++++++------------ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index 7a6f86ede..b5e142aad 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -30,6 +30,24 @@ dir_path = os.path.dirname(__file__) +def check_huggingface_component(component_name:str, cls_name:str, state_dict): + check_sd = False + comp_str = 'model' + + if cls_name == 'AutoencoderKL': + check_sd = True + comp_str = 'VAE' + elif component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: + check_sd = True + comp_str = 'CLIP' + elif cls_name == 'T5EncoderModel': + check_sd = True + comp_str = 'T5' + + if check_sd and (not isinstance(state_dict, dict) or len(state_dict) <= 16): + raise AssertionError(f'You do not have {comp_str} state dict!') + + def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_path, state_dict): config_path = os.path.join(repo_path, component_name) @@ -269,16 +287,33 @@ def split_state_dict(sd, additional_state_dicts: list = None): @torch.inference_mode() -def forge_loader(sd, additional_state_dicts=None): +def forge_preloader(sd, additional_state_dicts=None): + """performs minimum steps to validate model params before reloading models""" try: state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts) except: raise ValueError('Failed to recognize model type!') repo_name = estimated_config.huggingface_repo - local_path = os.path.join(dir_path, 'huggingface', repo_name) + config: dict = DiffusionPipeline.load_config(local_path) + + for component_name, v in config.items(): + if isinstance(v, list) and len(v) == 2: + _, cls_name = v + component_sd = state_dicts.get(component_name, None) + # Raise AssertionError for invalid params + check_huggingface_component(component_name, cls_name, component_sd) + + return state_dicts, estimated_config, config + + +@torch.inference_mode() +def forge_loader(state_dicts:dict, estimated_config, config:dict): + repo_name = estimated_config.huggingface_repo + local_path = os.path.join(dir_path, 'huggingface', repo_name) + huggingface_components = {} for component_name, v in config.items(): if isinstance(v, list) and len(v) == 2: diff --git a/modules/sd_models.py b/modules/sd_models.py index 3359603e6..3a1e6030a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -18,7 +18,7 @@ from modules.shared import opts, cmd_opts from modules.timer import Timer import numpy as np -from backend.loader import forge_loader +from backend.loader import forge_preloader, forge_loader from backend import memory_management from backend.args import dynamic_args from backend.utils import load_torch_file @@ -475,7 +475,19 @@ def forge_model_reload(): if model_data.forge_hash == current_hash: return model_data.sd_model, False + + # verify model/components before reload + checkpoint_info = model_data.forge_loading_parameters['checkpoint_info'] + if checkpoint_info is None: + raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].') + + state_dict = checkpoint_info.filename + additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', []) + + state_dicts, estimated_config, config = forge_preloader(state_dict, additional_state_dicts=additional_state_dicts) + + # reload model print('Loading Model: ' + str(model_data.forge_loading_parameters)) timer = Timer() @@ -488,20 +500,10 @@ def forge_model_reload(): timer.record("unload existing model") - checkpoint_info = model_data.forge_loading_parameters['checkpoint_info'] - - if checkpoint_info is None: - raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].') - - state_dict = checkpoint_info.filename - additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', []) - - timer.record("cache state dict") - dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None) dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir dynamic_args['emphasis_name'] = opts.emphasis - sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts) + sd_model = forge_loader(state_dicts, estimated_config, config) timer.record("forge model load") sd_model.extra_generation_params = {}