Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move types/classes to separate python file #18

Merged
merged 2 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hidiffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def apply_mswmsaa_attention_simple(model_type: str, model: UnetPatcher) -> UnetP

time_range: tuple[float] = (0.2, 1.0)

if model_type == "SD15":
if model_type == "SD 1.5/2.1":
blocks: tuple[str] = ("1,2", "", "11,10,9")
elif model_type == "SDXL":
blocks: tuple[str] = ("4,5", "", "5,4")
Expand Down
192 changes: 18 additions & 174 deletions hidiffusion/raunet.py
Original file line number Diff line number Diff line change
@@ -1,183 +1,27 @@
from typing import Any

import torch
import torch.nn.functional as F

from backend.modules.k_model import KModel
from backend.modules.k_prediction import Prediction
import backend.nn.unet as unet
from backend.patcher.unet import UnetPatcher
import backend.nn.unet as unet

from .logger import logger
from .utils import (
check_time,
convert_time,
get_sigma,
parse_blocks,
scale_samples,
UPSCALE_METHODS,
)
from .logger import logger


class HDConfigClass:
"""A configuration class for high-definition image processing.
This class manages settings and parameters for the HD upscale of the generation.
Attributes:
enabled (bool): Flag to enable/disable HD processing. Defaults to False.
start_sigma (float, optional): Starting sigma value for processing range.
end_sigma (float, optional): Ending sigma value for processing range.
use_blocks (list, optional): List of valid processing blocks.
two_stage_upscale (bool): Whether to use two-stage upscaling. Defaults to True.
upscale_mode (str): Upscaling algorithm to use. Defaults to "bislerp".
Methods:
check(topts): Validates processing options against configuration settings.
Args:
topts (dict): Dictionary containing processing options to validate.
Returns:
bool: True if options are valid according to configuration, False otherwise.
"""

enabled: bool = False
start_sigma: float | None = None
end_sigma: float | None = None
use_blocks: float | None = None
two_stage_upscale: bool = True
upscale_mode: str = UPSCALE_METHODS[0]

def check(self, topts: dict[str, torch.Tensor]) -> bool:
if not self.enabled or not isinstance(topts, dict) or topts.get("block") not in self.use_blocks:
return False
return check_time(topts, self.start_sigma, self.end_sigma)


HDCONFIG = HDConfigClass()

CONTROLNET_SCALE_ARGS: dict[str, Any] = {"mode": "bilinear", "align_corners": False}
ORIG_APPLY_CONTROL = unet.apply_control
ORIG_FORWARD_TIMESTEP_EMBED = unet.TimestepEmbedSequential.forward
ORIG_UPSAMPLE = unet.Upsample
ORIG_DOWNSAMPLE = unet.Downsample


class HDUpsample(ORIG_UPSAMPLE):
"""
A modified upsampling layer that extends ORIG_UPSAMPLE for high-definition image processing.
This class implements custom upsampling behavior based on configuration settings,
with options for two-stage upscaling and different upscaling modes.
Parameters:
Inherits all parameters from ORIG_UPSAMPLE parent class.
Returns:
torch.Tensor: The upsampled tensor.
Methods:
forward(x, output_shape=None, transformer_options=None):
Performs the upsampling operation on the input tensor.
Args:
x (torch.Tensor): Input tensor to be upsampled
output_shape (tuple, optional): Desired output shape. Defaults to None.
transformer_options (dict, optional): Configuration options for transformation. Defaults to None.
Returns:
torch.Tensor: Upsampled tensor after processing through interpolation and convolution
"""

def forward(self, x, output_shape=None, transformer_options=None):
if self.dims == 3 or not self.use_conv or not HDCONFIG.check(transformer_options):
return super().forward(x, output_shape=output_shape)
shape = output_shape[2:4] if output_shape is not None else (x.shape[2] * 4, x.shape[3] * 4)
if HDCONFIG.two_stage_upscale:
x = F.interpolate(x, size=(shape[0] // 2, shape[1] // 2), mode="nearest")
x = scale_samples(
x,
shape[1],
shape[0],
mode=HDCONFIG.upscale_mode,
)
return self.conv(x)


class HDDownsample(ORIG_DOWNSAMPLE):
"""HDDownsample is a modified downsampling layer that extends ORIG_DOWNSAMPLE.
This class implements specialized downsampling for images using dilated convolutions
when specific conditions are met. Otherwise, it falls back to original downsampling behavior.
Attributes:
COPY_OP_KEYS (tuple): Keys of attributes to copy from original operation to temporary operation.
Includes parameters_manual_cast, weight_function, bias_function, weight, and bias.
Args:
*args (list): Variable length argument list passed to parent class.
**kwargs (dict): Arbitrary keyword arguments passed to parent class.
Methods:
forward(x, transformer_options=None): Performs the downsampling operation.
Uses dilated convolution when dims==2, use_conv is True and HDCONFIG conditions are met.
Otherwise falls back to original downsampling.
Args:
x: Input tensor to downsample
transformer_options: Optional configuration for transformation
Returns:
Downsampled tensor using either dilated convolution or original method
"""

COPY_OP_KEYS = (
"parameters_manual_cast",
"weight_function",
"bias_function",
"weight",
"bias",
)

def __init__(self, *args: list, **kwargs: dict):
super().__init__(*args, **kwargs)

def forward(self, x, transformer_options=None):
if self.dims == 3 or not self.use_conv or not HDCONFIG.check(transformer_options):
return super().forward(x)
tempop = unet.conv_nd(
self.dims,
self.channels,
self.out_channels,
3, # kernel size
stride=(4, 4),
padding=(2, 2),
dilation=(2, 2),
)
for k in self.COPY_OP_KEYS:
if hasattr(self.op, k):
setattr(tempop, k, getattr(self.op, k))
return tempop(x)


# Create proxy classes that inherit from original UNet classes
class ProxyUpsample(HDUpsample):
"""Proxy class that can switch between HD and original upsampling implementations."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.orig_instance = ORIG_UPSAMPLE(*args, **kwargs)
# Transfer weights and parameters
self.orig_instance.conv = self.conv

def forward(self, *args, **kwargs):
if HDCONFIG.enabled:
return super().forward(*args, **kwargs)
return self.orig_instance.forward(*args, **kwargs)


class ProxyDownsample(HDDownsample):
"""Proxy class that can switch between HD and original downsampling implementations."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.orig_instance = ORIG_DOWNSAMPLE(*args, **kwargs)
# Transfer weights and parameters
self.orig_instance.op = self.op

def forward(self, *args, **kwargs):
if HDCONFIG.enabled:
return super().forward(*args, **kwargs)
return self.orig_instance.forward(*args, **kwargs)

from .types import (
CONTROLNET_SCALE_ARGS,
HD_CONFIG,
HDDownsample,
HDUpsample,
ORIG_APPLY_CONTROL,
ORIG_FORWARD_TIMESTEP_EMBED,
)

# Replace original classes with proxy classes
unet.Upsample = ProxyUpsample
unet.Downsample = ProxyDownsample

logger.info("Proxied UNet Upsample and Downsample classes")

Expand Down Expand Up @@ -223,15 +67,15 @@ def apply_unet_patches():
"""
Apply patches to modify UNet behavior for HiDiffusion functionality.
This function applies patches to the UNet model by:
1. Enabling the HDCONFIG flag
1. Enabling the HD_CONFIG flag
2. Overriding TimestepEmbedSequential's forward method with HiDiffusion implementation
3. Overriding UNet's apply_control method with HiDiffusion implementation
The patches allow the UNet to work with the HiDiffusion architecture and processing.
Note:
This is a side-effect function that modifies global state.
"""

HDCONFIG.enabled = True
HD_CONFIG.enabled = True
unet.TimestepEmbedSequential.forward = hd_forward_timestep_embed
unet.apply_control = hd_apply_control
logger.info("Applied UNet patches")
Expand All @@ -242,15 +86,15 @@ def remove_unet_patches():
Removes patches applied to the UNet model by disabling HiDiffusion configuration and restoring original
forward methods.
This function removes patches to the UNet model by:
1. Disabling the HDCONFIG flag
1. Disabling the HD_CONFIG flag
2. Restores original forward method for TimestepEmbedSequential
3. Restores original apply_control method
This function should be called to restore UNet to its original state after using HiDiffusion.
Returns:
None
"""

HDCONFIG.enabled = False
HD_CONFIG.enabled = False
unet.TimestepEmbedSequential.forward = ORIG_FORWARD_TIMESTEP_EMBED
unet.apply_control = ORIG_APPLY_CONTROL
logger.info("Removed UNet patches")
Expand Down Expand Up @@ -282,7 +126,7 @@ def apply_rau_net(
kmodel: KModel = unet_patcher.model
predictor: Prediction = kmodel.predictor

HDCONFIG.start_sigma, HDCONFIG.end_sigma = convert_time(
HD_CONFIG.start_sigma, HD_CONFIG.end_sigma = convert_time(
predictor,
time_mode,
start_time,
Expand Down Expand Up @@ -328,9 +172,9 @@ def output_block_patch(h, hsp, extra_options):
unet_patcher.set_model_input_block_patch(input_block_patch)
unet_patcher.set_model_output_block_patch(output_block_patch)

HDCONFIG.use_blocks = use_blocks
HDCONFIG.two_stage_upscale = not skip_two_stage_upscale
HDCONFIG.upscale_mode = upscale_mode
HD_CONFIG.use_blocks = use_blocks
HD_CONFIG.two_stage_upscale = not skip_two_stage_upscale
HD_CONFIG.upscale_mode = upscale_mode

return unet_patcher

Expand Down
Loading