diff --git a/hidiffusion/raunet.py b/hidiffusion/raunet.py index 5cf04da..219cb98 100644 --- a/hidiffusion/raunet.py +++ b/hidiffusion/raunet.py @@ -369,7 +369,7 @@ def configure_blocks( enabled = True model_configs = { - "SD15": { + "SD 1.5/2.1": { "blocks": ("3", "8"), "ca_blocks": ("1", "11"), "modes": { diff --git a/scripts/forge_hidiffusion.py b/scripts/forge_hidiffusion.py index 0b0784b..45e97e6 100644 --- a/scripts/forge_hidiffusion.py +++ b/scripts/forge_hidiffusion.py @@ -36,10 +36,9 @@ def show(self, is_img2img): def ui(self, *args, **kwargs): with InputAccordion(False, label=self.title()) as enabled: model_type = gr.Radio( - choices=["SD15", "SDXL"], + choices=["SD 1.5/2.1", "SDXL"], value=lambda: "SDXL", label="Model Type", - info="Note: Use SD15 setting for SD 2.1 as well.", ) with gr.Tab("RAUNet"): @@ -68,10 +67,17 @@ def ui(self, *args, **kwargs): with InputAccordion(False, label="Advanced Options") as use_raunet_advanced: with gr.Group(): + gr.HTML( + """ + Recommended block settings:
+ + """ + ) raunet_input_blocks = gr.Text(label="Input Blocks", value="3") raunet_output_blocks = gr.Text(label="Output Blocks", value="8") gr.Markdown( - "For SD1.5: Input 3 corresponds to Output 8, Input 6 to Output 5, Input 9 to Output 2" + "For SD1.5/2.1: Input 3 corresponds to Output 8, Input 6 to Output 5, Input 9 to Output 2" ) gr.Markdown("For SDXL: Input 3 corresponds to Output 5, Input 6 to Output 2") @@ -127,7 +133,11 @@ def ui(self, *args, **kwargs): gr.Markdown("Advanced MSW-MSA settings. For fine-tuning performance and quality improvements.") with gr.Group(): gr.HTML( - "Recommended block settings:
" + """ + Recommended block settings:
+ + """ ) mswmsa_input_blocks = gr.Text(label="Input Blocks", value="1,2") mswmsa_middle_blocks = gr.Text(label="Middle Blocks", value="") @@ -158,7 +168,7 @@ def ui(self, *args, **kwargs): # Add JavaScript to handle visibility and model-specific settings def update_raunet_settings(model_type): - if model_type == "SD15": + if model_type == "SD 1.5/2.1": return "3", "8", "4", "8", 0.0, 0.45, 0.0, 0.3 else: # SDXL return ( @@ -188,7 +198,7 @@ def update_raunet_settings(model_type): ) def update_mswmsa_settings(model_type): - if model_type == "SD15": + if model_type == "SD 1.5/2.1": return "1,2", "", "9,10,11" else: # SDXL return "4,5", "", "4,5"