Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal preload before dump current model data #2215

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions backend/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@
dir_path = os.path.dirname(__file__)


def check_huggingface_component(component_name:str, cls_name:str, state_dict):
Copy link
Contributor

@psydok psydok Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In near future, I may be engaged in restoring sd_checkpoints_limit (but not yet fact). Then I will have to work with these modules. So I decided to take look at PRs.

Spaces? Typehint? Private method?

Suggested change
def check_huggingface_component(component_name:str, cls_name:str, state_dict):
def _check_huggingface_component(component_name: str, cls_name: str, state_dict: dict[str, Any] | list[?]):

check_sd = False
comp_str = 'model'

if cls_name == 'AutoencoderKL':
check_sd = True
comp_str = 'VAE'
elif component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
check_sd = True
comp_str = 'CLIP'
elif cls_name == 'T5EncoderModel':
check_sd = True
comp_str = 'T5'

if check_sd and (not isinstance(state_dict, dict) or len(state_dict) <= 16):
Copy link
Contributor

@psydok psydok Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain this magic number 16? Why 16?

Copy link
Contributor Author

@altoiddealer altoiddealer Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@psydok to answer both questions, the state_dict variable and that specific condition, were both lifted from load_huggingface_component which is called from forge_loader.

These are Illyasviel code - I simply relocated this conditional check. With this PR it is now checking the condition before trashing the currently loaded model data, so if an exception is raised the current model data does not get unloaded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the context clues of this condition, I otherwise have no idea what the expected structure of state_dict is supposed to be. There is even less type hinting where I lifted the code from load_huggingface_component()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I got it! Thank you very much!

raise AssertionError(f'You do not have {comp_str} state dict!')


def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
config_path = os.path.join(repo_path, component_name)

Expand Down Expand Up @@ -269,16 +287,33 @@ def split_state_dict(sd, additional_state_dicts: list = None):


@torch.inference_mode()
def forge_loader(sd, additional_state_dicts=None):
def forge_preloader(sd, additional_state_dicts=None):
"""performs minimum steps to validate model params before reloading models"""
try:
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)
except:
raise ValueError('Failed to recognize model type!')

repo_name = estimated_config.huggingface_repo

local_path = os.path.join(dir_path, 'huggingface', repo_name)

config: dict = DiffusionPipeline.load_config(local_path)

for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
_, cls_name = v
component_sd = state_dicts.get(component_name, None)
# Raise AssertionError for invalid params
check_huggingface_component(component_name, cls_name, component_sd)

return state_dicts, estimated_config, config


@torch.inference_mode()
def forge_loader(state_dicts:dict, estimated_config, config:dict):
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)

huggingface_components = {}
for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
Expand Down
26 changes: 14 additions & 12 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modules.shared import opts, cmd_opts
from modules.timer import Timer
import numpy as np
from backend.loader import forge_loader
from backend.loader import forge_preloader, forge_loader
from backend import memory_management
from backend.args import dynamic_args
from backend.utils import load_torch_file
Expand Down Expand Up @@ -475,7 +475,19 @@ def forge_model_reload():

if model_data.forge_hash == current_hash:
return model_data.sd_model, False

# verify model/components before reload
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']

if checkpoint_info is None:
raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].')

state_dict = checkpoint_info.filename
additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])

state_dicts, estimated_config, config = forge_preloader(state_dict, additional_state_dicts=additional_state_dicts)

# reload model
print('Loading Model: ' + str(model_data.forge_loading_parameters))

timer = Timer()
Expand All @@ -488,20 +500,10 @@ def forge_model_reload():

timer.record("unload existing model")

checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']

if checkpoint_info is None:
raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].')

state_dict = checkpoint_info.filename
additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])

timer.record("cache state dict")

dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
dynamic_args['emphasis_name'] = opts.emphasis
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
sd_model = forge_loader(state_dicts, estimated_config, config)
timer.record("forge model load")

sd_model.extra_generation_params = {}
Expand Down