diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 5f33c6fb7e..5e8ab263a7 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -323,3 +323,35 @@ python image_to_video_generation.py \ --gaudi_config Habana/stable-diffusion \ --bf16 ``` + +### Image-to-video ControlNet + +Here is how to generate video conditioned by depth: +python image_to_video_generation.py \ + --model_name_or_path "stabilityai/stable-video-diffusion-img2vid" \ + --controlnet_model_name_or_path "CiaraRowles/temporal-controlnet-depth-svd-v1" \ + --control_image_path "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_0.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_1.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_2.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_3.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_4.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_5.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_6.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_7.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_8.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_9.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_10.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_11.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_12.png?raw=true" \ + "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/depth/frame_13.png?raw=true" \ + --image_path "https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/validation_demo/chair.png?raw=true" \ + --video_save_dir SVD_controlnet \ + --save_frames_as_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 \ + --num_frames 14 \ + --motion_bucket_id=14 \ + --width=512 \ + --height=512 diff --git a/examples/stable-diffusion/image_to_video_generation.py b/examples/stable-diffusion/image_to_video_generation.py index 7beb73a1ac..a2fdee4a51 100755 --- a/examples/stable-diffusion/image_to_video_generation.py +++ b/examples/stable-diffusion/image_to_video_generation.py @@ -49,6 +49,12 @@ def main(): type=str, help="Path to pre-trained model", ) + parser.add_argument( + "--controlnet_model_name_or_path", + default="CiaraRowles/temporal-controlnet-depth-svd-v1", + type=str, + help="Path to pre-trained controlnet model", + ) # Pipeline arguments parser.add_argument( @@ -58,6 +64,13 @@ def main(): nargs="*", help="Path to input image(s) to guide video generation", ) + parser.add_argument( + "--control_image_path", + type=str, + default=None, + nargs="*", + help="Path to controlnet input image(s) to guide video generation", + ) parser.add_argument( "--num_videos_per_prompt", type=int, default=1, help="The number of videos to generate per prompt image." ) @@ -164,7 +177,12 @@ def main(): ), ) parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.") - + parser.add_argument( + "--num_frames", + type=int, + default=25, + help="The number of video frames to generate." + ) args = parser.parse_args() from optimum.habana.diffusers import GaudiStableVideoDiffusionPipeline @@ -177,6 +195,30 @@ def main(): ) logger.setLevel(logging.INFO) + # Load input image(s) + input = [] + logger.info("Input image(s):") + if isinstance(args.image_path, str): + args.image_path = [args.image_path] + for image_path in args.image_path: + print(image_path) + image = load_image(image_path) + image = image.resize((args.height, args.width)) + input.append(image) + logger.info(image_path) + + # Load control input image + control_input = [] + if args.control_image_path is not None: + logger.info("Input control image(s):") + if isinstance(args.control_image_path, str): + args.control_image_path = [args.control_image_path] + for control_image in args.control_image_path: + image = load_image(control_image) + image = image.resize((args.height, args.width)) + control_input.append(image) + logger.info(control_image) + # Initialize the scheduler and the generation pipeline scheduler = GaudiEulerDiscreteScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler") kwargs = { @@ -188,41 +230,68 @@ def main(): if args.bf16: kwargs["torch_dtype"] = torch.bfloat16 - pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained( - args.model_name_or_path, - **kwargs, - ) + if args.control_image_path is not None: + from optimum.habana.diffusers import GaudiStableVideoDiffusionPipelineControlNet + from optimum.habana.diffusers.models import ControlNetSDVModel + from optimum.habana.diffusers.models import UNetSpatioTemporalConditionControlNetModel + controlnet = controlnet = ControlNetSDVModel.from_pretrained( + args.controlnet_model_name_or_path, + subfolder="controlnet") + unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained( + args.model_name_or_path, + subfolder="unet") + pipeline = GaudiStableVideoDiffusionPipelineControlNet.from_pretrained( + args.model_name_or_path, + controlnet=controlnet, + unet=unet, + **kwargs) - # Set seed before running the model - set_seed(args.seed) + # Set seed before running the model + set_seed(args.seed) - # Load input image(s) - input = [] - logger.info("Input image(s):") - if isinstance(args.image_path, str): - args.image_path = [args.image_path] - for image_path in args.image_path: - image = load_image(image_path) - image = image.resize((args.height, args.width)) - input.append(image) - logger.info(image_path) + # Generate images + outputs = pipeline( + image=input, + controlnet_condition=control_input, + num_videos_per_prompt=args.num_videos_per_prompt, + batch_size=args.batch_size, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + min_guidance_scale=args.min_guidance_scale, + max_guidance_scale=args.max_guidance_scale, + fps=args.fps, + motion_bucket_id=args.motion_bucket_id, + noise_aug_strength=args.noise_aug_strength, + decode_chunk_size=args.decode_chunk_size, + output_type=args.output_type, + num_frames=args.num_frames, + ) + else: + pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained( + args.model_name_or_path, + **kwargs, + ) - # Generate images - outputs = pipeline( - image=input, - num_videos_per_prompt=args.num_videos_per_prompt, - batch_size=args.batch_size, - height=args.height, - width=args.width, - num_inference_steps=args.num_inference_steps, - min_guidance_scale=args.min_guidance_scale, - max_guidance_scale=args.max_guidance_scale, - fps=args.fps, - motion_bucket_id=args.motion_bucket_id, - noise_aug_strength=args.noise_aug_strength, - decode_chunk_size=args.decode_chunk_size, - output_type=args.output_type, - ) + # Set seed before running the model + set_seed(args.seed) + + # Generate images + outputs = pipeline( + image=input, + num_videos_per_prompt=args.num_videos_per_prompt, + batch_size=args.batch_size, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + min_guidance_scale=args.min_guidance_scale, + max_guidance_scale=args.max_guidance_scale, + fps=args.fps, + motion_bucket_id=args.motion_bucket_id, + noise_aug_strength=args.noise_aug_strength, + decode_chunk_size=args.decode_chunk_size, + output_type=args.output_type, + ) # Save the pipeline in the specified directory if not None if args.pipeline_save_dir is not None: diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index 26d5d2d359..c2be285aa1 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -5,4 +5,5 @@ from .pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import GaudiStableDiffusionUpscalePipeline from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline from .pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import GaudiStableVideoDiffusionPipeline +from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import GaudiStableVideoDiffusionPipelineControlNet from .schedulers import GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler diff --git a/optimum/habana/diffusers/models/__init__.py b/optimum/habana/diffusers/models/__init__.py index 245359bcf2..0a67a7c93e 100644 --- a/optimum/habana/diffusers/models/__init__.py +++ b/optimum/habana/diffusers/models/__init__.py @@ -1 +1,3 @@ from .unet_2d_condition import gaudi_unet_2d_condition_model_forward +from .controlnet_sdv import ControlNetSDVModel +from .unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel diff --git a/optimum/habana/diffusers/models/controlnet_sdv.py b/optimum/habana/diffusers/models/controlnet_sdv.py new file mode 100644 index 0000000000..c3c142f176 --- /dev/null +++ b/optimum/habana/diffusers/models/controlnet_sdv.py @@ -0,0 +1,758 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.controlnet import FromOriginalControlNetMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_3d_blocks import ( + get_down_block, get_up_block,UNetMidBlockSpatioTemporal, +) +from diffusers.models import UNetSpatioTemporalConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + Copied from https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/models/controlnet_sdv.py + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbeddingSVD(nn.Module): + """ + Copied from https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/models/controlnet_sdv.py + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames + #combine batch and frames dimensions + batch_size, frames, channels, height, width = conditioning.size() + conditioning = conditioning.view(batch_size * frames, channels, height, width) + + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + #split them apart again + #actually not needed + #new_channels, new_height, new_width = embedding.shape[1], embedding.shape[2], embedding.shape[3] + #embedding = embedding.view(batch_size, frames, new_channels, new_height, new_width) + + + return embedding + + +class ControlNetSDVModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): + r""" + Copied from https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/main/models/controlnet_sdv.py + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), + num_frames: int = 25, + conditioning_channels: int = 3, + conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256), + ): + super().__init__() + self.sample_size = sample_size + + print("layers per block is", layers_per_block) + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + padding=1, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # down + output_channel = block_out_channels[0] + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block[i]): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + + # mid + mid_block_channel = block_out_channels[-1] + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + controlnet_cond: torch.FloatTensor = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + guess_mode: bool = False, + conditioning_scale: float = 1.0, + + + ) -> Union[ControlNetOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.FloatTensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain + tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + + #controlnet cond + if controlnet_cond != None: + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond + + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + @classmethod + def from_unet( + cls, + unet: UNetSpatioTemporalConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + print(unet.config) + controlnet = cls( + in_channels=unet.config.in_channels, + down_block_types=unet.config.down_block_types, + block_out_channels=unet.config.block_out_channels, + addition_time_embed_dim=unet.config.addition_time_embed_dim, + transformer_layers_per_block=unet.config.transformer_layers_per_block, + cross_attention_dim=unet.config.cross_attention_dim, + num_attention_heads=unet.config.num_attention_heads, + num_frames=unet.config.num_frames, + sample_size=unet.config.sample_size, # Added based on the dict + layers_per_block=unet.config.layers_per_block, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + conditioning_channels = conditioning_channels, + conditioning_embedding_out_channels = conditioning_embedding_out_channels, + ) + #controlnet rgb channel order ignored, set to not makea difference by default + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + # if controlnet.class_embedding: + # controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/optimum/habana/diffusers/models/unet_spatio_temporal_condition_controlnet.py b/optimum/habana/diffusers/models/unet_spatio_temporal_condition_controlnet.py new file mode 100644 index 0000000000..e108d26db7 --- /dev/null +++ b/optimum/habana/diffusers/models/unet_spatio_temporal_condition_controlnet.py @@ -0,0 +1,243 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import register_to_config +from diffusers.utils import logging +from diffusers.models.embeddings import Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput, UNetSpatioTemporalConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UNetSpatioTemporalConditionControlNetModel(UNetSpatioTemporalConditionModel): + r""" + Copied from https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/765cd95c3659c54593ae36a9616121f00b3d7c29/models/unet_spatio_temporal_condition_controlnet.py#L356 + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), + num_frames: int = 25, + ): + super().__init__( + sample_size=sample_size, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + addition_time_embed_dim=addition_time_embed_dim, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + num_frames=num_frames + ) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + added_time_ids: torch.Tensor=None, + ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.FloatTensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain + tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + sample = sample + mid_block_additional_residual + + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/optimum/habana/diffusers/pipelines/controlnet/pipeline_stable_video_diffusion_controlnet.py b/optimum/habana/diffusers/pipelines/controlnet/pipeline_stable_video_diffusion_controlnet.py new file mode 100644 index 0000000000..8ec9941272 --- /dev/null +++ b/optimum/habana/diffusers/pipelines/controlnet/pipeline_stable_video_diffusion_controlnet.py @@ -0,0 +1,720 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import PIL.Image +import numpy as np +from math import ceil +import torch +from typing import Callable, Dict, List, Optional, Union + +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( + _append_dims, + _resize_with_antialiasing, + tensor2vid, +) +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers import EulerDiscreteScheduler + +from ..stable_video_diffusion.pipeline_stable_video_diffusion import ( + GaudiStableVideoDiffusionPipeline, + GaudiStableVideoDiffusionPipelineOutput +) +from ..pipeline_utils import GaudiDiffusionPipeline +from ...models import ControlNetSDVModel +from ...models import UNetSpatioTemporalConditionControlNetModel +from ....transformers.gaudi_configuration import GaudiConfig +from ....utils import speed_metrics + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class GaudiStableVideoDiffusionPipelineControlNet(GaudiStableVideoDiffusionPipeline): + r""" + Adapted from: https://github.com/CiaraStrawberry/svd-temporal-controlnet/blob/765cd95c3659c54593ae36a9616121f00b3d7c29/pipeline/pipeline_stable_video_diffusion_controlnet.py#L99 + - Generation is performed by batches + - Added support for HPU graphs + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionControlNetModel`]): + A `UNetSpatioTemporalConditionControlNetModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionControlNetModel, + controlnet: ControlNetSDVModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + GaudiDiffusionPipeline.__init__( + self, + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + controlnet=controlnet, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.to(self._device) + + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + ''' This is different with statble_video_diffusion + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + ''' + + #image = image.unsqueeze(0) + image = _resize_with_antialiasing(image, (224, 224)) + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + @classmethod + def _split_inputs_into_batches( + cls, + batch_size, + latents, + image_latents, + image_embeddings, + controlnet_condition, + added_time_ids, + num_images, + do_classifier_free_guidance, + ): + if do_classifier_free_guidance: + negative_image_embeddings, image_embeddings = image_embeddings.chunk(2) + negative_added_time_ids, added_time_ids = added_time_ids.chunk(2) + else: + negative_image_embeddings = None + negative_added_time_ids = None + + # If the last batch has less samples than batch_size, compute number of dummy samples to pad + last_samples = latents.shape[0] % batch_size + num_dummy_samples = batch_size - last_samples if last_samples > 0 else 0 + + # Generate num_batches batches of size batch_size + latents_batches = cls._split_input_into_batches(latents, batch_size, num_dummy_samples) + image_latents_batches = cls._split_image_latents_into_batches( + image_latents, batch_size, num_dummy_samples, num_images, do_classifier_free_guidance + ) + image_embeddings_batches = cls._split_input_into_batches( + image_embeddings, batch_size, num_dummy_samples, negative_image_embeddings + ) + controlnet_condition_batches = cls._split_input_into_batches( + controlnet_condition, batch_size, num_dummy_samples + ) + added_time_ids_batches = cls._split_input_into_batches( + added_time_ids, batch_size, num_dummy_samples, negative_added_time_ids + ) + + return ( + latents_batches, + image_latents_batches, + image_embeddings_batches, + controlnet_condition_batches, + added_time_ids_batches, + num_dummy_samples, + ) + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + controlnet_condition:[torch.FloatTensor] = None, + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + controlnet_cond_scale=1.0, + batch_size=1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + num_images = 1 + elif isinstance(image, list): + num_images = len(image) + else: + num_images = image.shape[0] + num_batches = ceil((num_videos_per_prompt * num_images) / batch_size) + logger.info( + f"{num_images} image(s) received, {num_videos_per_prompt} video(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + + device = self._execution_device + self.controlnet.to(device, dtype=torch.bfloat16) + self.unet.to(device, dtype=torch.bfloat16) + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width) + # torch.randn is broken on HPU so running it on CPU + rand_device = "cpu" if device.type == "hpu" else device + noise = randn_tensor(image.shape, generator=generator, device=rand_device, dtype=image.dtype).to(device) + image = image + noise_aug_strength * noise + + needs_upcasting = ( + self.vae.dtype == torch.float16 or self.vae.dtype == torch.bfloat16 + ) and self.vae.config.force_upcast + + if needs_upcasting: + cast_dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, + do_classifier_free_guidance=False) # Override to return only conditional latents + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + device + ) + added_time_ids = added_time_ids.to(device) + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self.scheduler.reset_timestep_dependent_params() + + # 7. Prepare latent variables + + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + #prepare controlnet condition + controlnet_condition = self.image_processor.preprocess(controlnet_condition, height=height, width=width) + controlnet_condition = controlnet_condition.unsqueeze(0) + controlnet_condition = controlnet_condition.to(device, latents.dtype) + + # 8. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + # 9. Split into batches (HPU-specific step) + ( + latents_batches, + image_latents_batches, + image_embeddings_batches, + controlnet_condition_batches, + added_time_ids_batches, + num_dummy_samples, + ) = self._split_inputs_into_batches( + batch_size, + latents, + image_latents, + image_embeddings, + controlnet_condition, + added_time_ids, + num_images, + do_classifier_free_guidance, + # self.do_classifier_free_guidance(), + ) + outputs = { + "frames": [], + } + t0 = time.time() + t1 = t0 + + # 10. Denoising loop + throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + self._num_timesteps = len(timesteps) + for j in self.progress_bar(range(num_batches)): + # The throughput is calculated from the 3rd iteration + # because compilation occurs in the first two iterations + if j == throughput_warmup_steps: + t1 = time.time() + + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + image_latents_batch = image_latents_batches[0] + image_latents_batches = torch.roll(image_latents_batches, shifts=-1, dims=0) + image_embeddings_batch = image_embeddings_batches[0] + image_embeddings_batches = torch.roll(image_embeddings_batches, shifts=-1, dims=0) + added_time_ids_batch = added_time_ids_batches[0] + added_time_ids_batches = torch.roll(added_time_ids_batches, shifts=-1, dims=0) + controlnet_condition_batch = controlnet_condition_batches[0] + controlnet_condition_batches = torch.roll(controlnet_condition_batches, shifts=-1, dims=0) + + for i in self.progress_bar(range(num_inference_steps)): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if do_classifier_free_guidance : + controlnet_condition_input = torch.cat([controlnet_condition_batch] * 2) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat([latent_model_input, image_latents_batch], dim=2) + down_block_res_samples, mid_block_res_sample = self.controlnet_hpu( + latent_model_input, + timestep, + encoder_hidden_states=image_embeddings_batch, + added_time_ids=added_time_ids_batch, + controlnet_cond=controlnet_condition_input, + return_dict=False, + guess_mode=False, + conditioning_scale=controlnet_cond_scale, + ) + # predict the noise residual + noise_pred = self.unet_hpu( + latent_model_input, + timestep, + encoder_hidden_states=image_embeddings_batch, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + added_time_ids=added_time_ids_batch, + ) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch).prev_sample + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + + latents_batch = callback_outputs.pop("latents", latents_batch) + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=cast_dtype) + frames = self.decode_latents(latents_batch, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents_batch + + outputs["frames"].append(frames) + + speed_metrics_prefix = "generation" + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=num_batches * batch_size + if t1 == t0 + else (num_batches - throughput_warmup_steps) * batch_size, + num_steps=num_batches, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["frames"][-1] = outputs["frames"][-1][:-num_dummy_samples] + + # Process generated images + for i, frames in enumerate(outputs["frames"][:]): + if i == 0: + outputs["frames"].clear() + + if output_type == "pil": + outputs["frames"] += frames + else: + outputs["frames"] += [*frames] + + self.maybe_free_model_hooks() + + if not return_dict: + return outputs["frames"] + + return GaudiStableVideoDiffusionPipelineOutput( + frames=outputs["frames"], + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) + + @torch.no_grad() + def controlnet_hpu( + self, + control_model_input, + timestep, + encoder_hidden_states, + added_time_ids, + controlnet_cond, + return_dict, + guess_mode, + conditioning_scale, + ): + if self.use_hpu_graphs: + return self.controlnet_capture_replay( + control_model_input, + timestep, + encoder_hidden_states, + added_time_ids, + controlnet_cond, + return_dict, + guess_mode, + conditioning_scale, + ) + else: + return self.controlnet( + control_model_input, + timestep, + encoder_hidden_states=encoder_hidden_states, + added_time_ids=added_time_ids, + controlnet_cond=controlnet_cond, + return_dict=return_dict, + guess_mode=guess_mode, + conditioning_scale=conditioning_scale, + ) + + @torch.no_grad() + def controlnet_capture_replay( + self, + control_model_input, + timestep, + encoder_hidden_states, + added_time_ids, + controlnet_cond, + return_dict, + guess_mode, + conditioning_scale, + ): + inputs = [ + control_model_input, + timestep, + encoder_hidden_states, + added_time_ids, + controlnet_cond, + return_dict, + guess_mode, + conditioning_scale, + ] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if cached is None: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + outputs = self.controlnet( + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + None, + inputs[5], + inputs[6], + inputs[7], + ) + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs + + @torch.no_grad() + def unet_hpu( + self, + latent_model_input, + timestep, + encoder_hidden_states, + down_block_additional_residuals, + mid_block_additional_residual, + return_dict, + added_time_ids, + ): + if self.use_hpu_graphs: + return self.unet_capture_replay( + latent_model_input, + timestep, + encoder_hidden_states, + down_block_additional_residuals, + mid_block_additional_residual, + return_dict, + added_time_ids, + ) + else: + return self.unet( + latent_model_input, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + return_dict=return_dict, + added_time_ids=added_time_ids, + )[0] + + @torch.no_grad() + def unet_capture_replay( + self, + latent_model_input, + timestep, + encoder_hidden_states, + down_block_additional_residuals, + mid_block_additional_residual, + return_dict, + added_time_ids, + ): + inputs = [ + latent_model_input, + timestep, + encoder_hidden_states, + down_block_additional_residuals, + mid_block_additional_residual, + return_dict, + added_time_ids, + ] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if cached is None: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + outputs = self.unet( + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + inputs[5], + inputs[6], + )[0] + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs