Skip to content

Commit

Permalink
Merge pull request #16 from MisterChief95/fix_names
Browse files Browse the repository at this point in the history
Fix names
  • Loading branch information
MisterChief95 authored Nov 22, 2024
2 parents b83d727 + 42568ca commit 5a4ab8d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion hidiffusion/raunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def configure_blocks(
enabled = True

model_configs = {
"SD15": {
"SD 1.5/2.1": {
"blocks": ("3", "8"),
"ca_blocks": ("1", "11"),
"modes": {
Expand Down
22 changes: 16 additions & 6 deletions scripts/forge_hidiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ def show(self, is_img2img):
def ui(self, *args, **kwargs):
with InputAccordion(False, label=self.title()) as enabled:
model_type = gr.Radio(
choices=["SD15", "SDXL"],
choices=["SD 1.5/2.1", "SDXL"],
value=lambda: "SDXL",
label="Model Type",
info="Note: Use SD15 setting for SD 2.1 as well.",
)

with gr.Tab("RAUNet"):
Expand Down Expand Up @@ -68,10 +67,17 @@ def ui(self, *args, **kwargs):

with InputAccordion(False, label="Advanced Options") as use_raunet_advanced:
with gr.Group():
gr.HTML(
"""
Recommended block settings:<br>
<ul><li>SD 1.5/2.1: Input 3 corresponds to Output 8, Input 6 to Output 5, Input 9 to Output 2</li>
<li>SDXL: Input 3 corresponds to Output 5, Input 6 to Output 2</li></ul>
"""
)
raunet_input_blocks = gr.Text(label="Input Blocks", value="3")
raunet_output_blocks = gr.Text(label="Output Blocks", value="8")
gr.Markdown(
"For SD1.5: Input 3 corresponds to Output 8, Input 6 to Output 5, Input 9 to Output 2"
"For SD1.5/2.1: Input 3 corresponds to Output 8, Input 6 to Output 5, Input 9 to Output 2"
)
gr.Markdown("For SDXL: Input 3 corresponds to Output 5, Input 6 to Output 2")

Expand Down Expand Up @@ -127,7 +133,11 @@ def ui(self, *args, **kwargs):
gr.Markdown("Advanced MSW-MSA settings. For fine-tuning performance and quality improvements.")
with gr.Group():
gr.HTML(
"Recommended block settings:<br><ul><li>SD15: input 1,2, output 9,10,11</li><li>SDXL: input 4,5, output 4,5</li></ul>"
"""
Recommended block settings:<br>
<ul><li>SD 1.5/2.1: input 1,2 for output 9,10,11</li>
<li>SDXL: input 4,5 for output 4,5</li></ul>
"""
)
mswmsa_input_blocks = gr.Text(label="Input Blocks", value="1,2")
mswmsa_middle_blocks = gr.Text(label="Middle Blocks", value="")
Expand Down Expand Up @@ -158,7 +168,7 @@ def ui(self, *args, **kwargs):

# Add JavaScript to handle visibility and model-specific settings
def update_raunet_settings(model_type):
if model_type == "SD15":
if model_type == "SD 1.5/2.1":
return "3", "8", "4", "8", 0.0, 0.45, 0.0, 0.3
else: # SDXL
return (
Expand Down Expand Up @@ -188,7 +198,7 @@ def update_raunet_settings(model_type):
)

def update_mswmsa_settings(model_type):
if model_type == "SD15":
if model_type == "SD 1.5/2.1":
return "1,2", "", "9,10,11"
else: # SDXL
return "4,5", "", "4,5"
Expand Down

0 comments on commit 5a4ab8d

Please sign in to comment.