-
Notifications
You must be signed in to change notification settings - Fork 913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal preload before dump current model data #2215
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,24 @@ | |
dir_path = os.path.dirname(__file__) | ||
|
||
|
||
def check_huggingface_component(component_name:str, cls_name:str, state_dict): | ||
check_sd = False | ||
comp_str = 'model' | ||
|
||
if cls_name == 'AutoencoderKL': | ||
check_sd = True | ||
comp_str = 'VAE' | ||
elif component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: | ||
check_sd = True | ||
comp_str = 'CLIP' | ||
elif cls_name == 'T5EncoderModel': | ||
check_sd = True | ||
comp_str = 'T5' | ||
|
||
if check_sd and (not isinstance(state_dict, dict) or len(state_dict) <= 16): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please explain this magic number 16? Why 16? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @psydok to answer both questions, the These are Illyasviel code - I simply relocated this conditional check. With this PR it is now checking the condition before trashing the currently loaded model data, so if an exception is raised the current model data does not get unloaded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aside from the context clues of this condition, I otherwise have no idea what the expected structure of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I got it! Thank you very much! |
||
raise AssertionError(f'You do not have {comp_str} state dict!') | ||
|
||
|
||
def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_path, state_dict): | ||
config_path = os.path.join(repo_path, component_name) | ||
|
||
|
@@ -269,16 +287,33 @@ def split_state_dict(sd, additional_state_dicts: list = None): | |
|
||
|
||
@torch.inference_mode() | ||
def forge_loader(sd, additional_state_dicts=None): | ||
def forge_preloader(sd, additional_state_dicts=None): | ||
"""performs minimum steps to validate model params before reloading models""" | ||
try: | ||
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts) | ||
except: | ||
raise ValueError('Failed to recognize model type!') | ||
|
||
repo_name = estimated_config.huggingface_repo | ||
|
||
local_path = os.path.join(dir_path, 'huggingface', repo_name) | ||
|
||
config: dict = DiffusionPipeline.load_config(local_path) | ||
|
||
for component_name, v in config.items(): | ||
if isinstance(v, list) and len(v) == 2: | ||
_, cls_name = v | ||
component_sd = state_dicts.get(component_name, None) | ||
# Raise AssertionError for invalid params | ||
check_huggingface_component(component_name, cls_name, component_sd) | ||
|
||
return state_dicts, estimated_config, config | ||
|
||
|
||
@torch.inference_mode() | ||
def forge_loader(state_dicts:dict, estimated_config, config:dict): | ||
repo_name = estimated_config.huggingface_repo | ||
local_path = os.path.join(dir_path, 'huggingface', repo_name) | ||
|
||
huggingface_components = {} | ||
for component_name, v in config.items(): | ||
if isinstance(v, list) and len(v) == 2: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In near future, I may be engaged in restoring sd_checkpoints_limit (but not yet fact). Then I will have to work with these modules. So I decided to take look at PRs.
Spaces? Typehint? Private method?