From 2254e7664d505acc14298f027b190ee24daf223b Mon Sep 17 00:00:00 2001 From: Brendan Lackey Date: Thu, 21 Nov 2024 14:52:25 -0700 Subject: [PATCH] Add some doc comments to classes --- hidiffusion/raunet.py | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/hidiffusion/raunet.py b/hidiffusion/raunet.py index 641f1c1..66a7074 100644 --- a/hidiffusion/raunet.py +++ b/hidiffusion/raunet.py @@ -54,6 +54,25 @@ def check(self, topts: dict[str, torch.Tensor]) -> bool: class HDUpsample(ORIG_UPSAMPLE): + """ + A modified upsampling layer that extends ORIG_UPSAMPLE for high-definition image processing. + This class implements custom upsampling behavior based on configuration settings, + with options for two-stage upscaling and different upscaling modes. + Parameters: + Inherits all parameters from ORIG_UPSAMPLE parent class. + Returns: + torch.Tensor: The upsampled tensor. + Methods: + forward(x, output_shape=None, transformer_options=None): + Performs the upsampling operation on the input tensor. + Args: + x (torch.Tensor): Input tensor to be upsampled + output_shape (tuple, optional): Desired output shape. Defaults to None. + transformer_options (dict, optional): Configuration options for transformation. Defaults to None. + Returns: + torch.Tensor: Upsampled tensor after processing through interpolation and convolution + """ + def forward(self, x, output_shape=None, transformer_options=None): if self.dims == 3 or not self.use_conv or not HDCONFIG.check(transformer_options): return super().forward(x, output_shape=output_shape) @@ -70,6 +89,26 @@ def forward(self, x, output_shape=None, transformer_options=None): class HDDownsample(ORIG_DOWNSAMPLE): + """HDDownsample is a modified downsampling layer that extends ORIG_DOWNSAMPLE. + This class implements specialized downsampling for images using dilated convolutions + when specific conditions are met. Otherwise, it falls back to original downsampling behavior. + Attributes: + COPY_OP_KEYS (tuple): Keys of attributes to copy from original operation to temporary operation. + Includes parameters_manual_cast, weight_function, bias_function, weight, and bias. + Args: + *args (list): Variable length argument list passed to parent class. + **kwargs (dict): Arbitrary keyword arguments passed to parent class. + Methods: + forward(x, transformer_options=None): Performs the downsampling operation. + Uses dilated convolution when dims==2, use_conv is True and HDCONFIG conditions are met. + Otherwise falls back to original downsampling. + Args: + x: Input tensor to downsample + transformer_options: Optional configuration for transformation + Returns: + Downsampled tensor using either dilated convolution or original method + """ + COPY_OP_KEYS = ( "parameters_manual_cast", "weight_function", @@ -134,7 +173,7 @@ def forward(self, *args, **kwargs): unet.Upsample = ProxyUpsample unet.Downsample = ProxyDownsample -logger.info("\x1b[32m[HiDiffusion]\x1b[0m Proxied UNet Upsample and Downsample classes") +logger.info("Proxied UNet Upsample and Downsample classes") # TODO: Implement Forge FreeU compatibility