diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 34db5f524f..aa79f0df2e 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -16,6 +16,8 @@ title: Run Inference - local: tutorials/stable_diffusion title: Stable Diffusion + - local: tutorials/stable_diffusion_ldm3d + title: LDM3D title: Tutorials - sections: - local: usage_guides/overview diff --git a/docs/source/tutorials/stable_diffusion_ldm3d.mdx b/docs/source/tutorials/stable_diffusion_ldm3d.mdx new file mode 100644 index 0000000000..d7c975ceb6 --- /dev/null +++ b/docs/source/tutorials/stable_diffusion_ldm3d.mdx @@ -0,0 +1,67 @@ + + +# Text-to-(RGB, depth) + +LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./stable_diffusion) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps. + +The abstract from the paper is: + +*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).* + + +## How to generate RGB and depth images? + +To generate RGB and depth images with Stable Diffusion LDM3D on Gaudi, you need to instantiate two instances: +- A pipeline with [`GaudiStableDiffusionLDM3DPipeline`]. This pipeline supports *text-to-(rgb, depth) generation*. +- A scheduler with [`GaudiDDIMScheduler`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiDDIMScheduler). This scheduler has been optimized for Gaudi. + +When initializing the pipeline, you have to specify `use_habana=True` to deploy it on HPUs. +Furthermore, to get the fastest possible generations you should enable **HPU graphs** with `use_hpu_graphs=True`. +Finally, you will need to specify a [Gaudi configuration](https://huggingface.co/docs/optimum/habana/package_reference/gaudi_config) which can be downloaded from the Hugging Face Hub. + +```python +from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionLDM3DPipeline +from optimum.habana.utils import set_seed + +model_name = "Intel/ldm3d-4c" + +scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + +set_seed(42) + +pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained( + model_name, + scheduler=scheduler, + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", +) +outputs = pipeline( + prompt=["High quality photo of an astronaut riding a horse in space"], + num_images_per_prompt=1, + batch_size=1, + output_type="pil", + num_inference_steps=40, + guidance_scale=5.0, + negative_prompt=None +) + + +rgb_image, depth_image = outputs.rgb, outputs.depth +rgb_image[0].save("astronaut_ldm3d_rgb.png") +depth_image[0].save("astronaut_ldm3d_depth.png") +``` diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 45c4dd7dc0..e1c92e8fcb 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -86,3 +86,32 @@ python text_to_image_generation.py \ > There are two different checkpoints for Stable Diffusion 2: > - use [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) for generating 768x768 images > - use [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) for generating 512x512 images + + +### Latent Diffusion Model for 3D (LDM3D) + +[LDM3D](https://arxiv.org/abs/2305.10853) generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. + +[Original checkpoint](https://huggingface.co/Intel/ldm3d) and [latest checkpoint](https://huggingface.co/Intel/ldm3d-4c) are open source. +A [demo](https://huggingface.co/spaces/Intel/ldm3d) is also available. Here is how to run this model: + +```python +python text_to_image_generation.py \ + --ldm3d_model_name_or_path "Intel/ldm3d-4c" \ + --prompts "An image of a squirrel in Picasso style" \ + --num_images_per_prompt 10 \ + --batch_size 2 \ + --height 768 \ + --width 768 \ + --image_save_dir /tmp/stable_diffusion_images \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion-2 \ + --ldm3d +``` + +> There are three different checkpoints for LDM3D: +> - use [original checkpoint](https://huggingface.co/Intel/ldm3d) to generate outputs from the paper +> - use [the latest checkpoint](https://huggingface.co/Intel/ldm3d-4c) for generating improved results +> - use [the pano checkpoint](https://huggingface.co/Intel/ldm3d-pano) to generate panoramic view + diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 747834e2c5..042bfa34a7 100644 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -20,7 +20,7 @@ import torch -from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline +from optimum.habana.diffusers import GaudiDDIMScheduler from optimum.habana.utils import set_seed @@ -121,9 +121,20 @@ def main(): ), ) parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.") + parser.add_argument( + "--ldm3d", action="store_true", help="Use LDM3D to generate an image and a depth map from a given text prompt." + ) args = parser.parse_args() + if args.ldm3d: + from optimum.habana.diffusers import GaudiStableDiffusionLDM3DPipeline as GaudiStableDiffusionPipeline + + if args.model_name_or_path == "runwayml/stable-diffusion-v1-5": + args.model_name_or_path = "Intel/ldm3d-4c" + else: + from optimum.habana.diffusers import GaudiStableDiffusionPipeline + # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -174,8 +185,14 @@ def main(): image_save_dir = Path(args.image_save_dir) image_save_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving images in {image_save_dir.resolve()}...") - for i, image in enumerate(outputs.images): - image.save(image_save_dir / f"image_{i+1}.png") + if args.ldm3d: + for i, rgb in enumerate(outputs.rgb): + rgb.save(image_save_dir / f"rgb_{i+1}.png") + for i, depth in enumerate(outputs.depth): + depth.save(image_save_dir / f"depth_{i+1}.png") + else: + for i, image in enumerate(outputs.images): + image.save(image_save_dir / f"image_{i+1}.png") else: logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.") diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index 0df716baa0..add6a7fa8a 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -1,3 +1,4 @@ from .pipelines.pipeline_utils import GaudiDiffusionPipeline from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline +from .pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import GaudiStableDiffusionLDM3DPipeline from .schedulers import GaudiDDIMScheduler diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py new file mode 100644 index 0000000000..5af3ed5bc7 --- /dev/null +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -0,0 +1,825 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import time +import warnings +from dataclasses import dataclass +from math import ceil +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from diffusers.image_processor import VaeImageProcessorLDM3D +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import BaseOutput +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from optimum.utils import logging + +from ....transformers.gaudi_configuration import GaudiConfig +from ....utils import speed_metrics +from ..pipeline_utils import GaudiDiffusionPipeline + + +logger = logging.get_logger(__name__) + + +@dataclass +class GaudiStableDiffusionLDM3DPipelineOutput(BaseOutput): + rgb: Union[List[PIL.Image.Image], np.ndarray] + depth: Union[List[PIL.Image.Image], np.ndarray] + throughput: float + nsfw_content_detected: Optional[List[bool]] + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class GaudiStableDiffusionLDM3DPipeline( + GaudiDiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + """ + Extends the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline) class: + - Generation is performed by batches + - Two `mark_step()` were added to add support for lazy mode + - Added support for HPU graphs + - Adjusted original Stable Diffusion to match with the LDM3D implementation (input and output being different) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + use_habana (bool, defaults to `False`): + Whether to use Gaudi (`True`) or CPU (`False`). + use_hpu_graphs (bool, defaults to `False`): + Whether to use HPU graphs or not. + gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`): + Gaudi configuration to use. Can be a string to download it from the Hub. + Or a previously initialized config can be passed. + bf16_full_eval (bool, defaults to `False`): + Whether to use full bfloat16 evaluation instead of 32-bit. + This will be faster and save memory compared to fp32/mixed precision but can harm generated images. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + super().__init__( + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.to(self._device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + num_prompts = 1 + elif prompt is not None and isinstance(prompt, list): + num_prompts = len(prompt) + else: + num_prompts = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * num_prompts + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif num_prompts != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has {len(negative_prompt)} elements, but `prompt`:" + f" {prompt} has {num_prompts}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(num_prompts * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + rgb_feature_extractor_input = feature_extractor_input[0] + safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != num_images: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective number" + f" of images of {num_images}. Make sure the number of images matches the length of the generators." + ) + + if latents is None: + # torch.randn is broken on HPU so running it on CPU + rand_device = "cpu" if device.type == "hpu" else device + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(num_images) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @classmethod + def _split_inputs_into_batches(cls, batch_size, latents, text_embeddings, uncond_embeddings): + # Use torch.split to generate num_batches batches of size batch_size + latents_batches = list(torch.split(latents, batch_size)) + text_embeddings_batches = list(torch.split(text_embeddings, batch_size)) + if uncond_embeddings is not None: + uncond_embeddings_batches = list(torch.split(uncond_embeddings, batch_size)) + + # If the last batch has less samples than batch_size, pad it with dummy samples + num_dummy_samples = 0 + if latents_batches[-1].shape[0] < batch_size: + num_dummy_samples = batch_size - latents_batches[-1].shape[0] + # Pad latents_batches + sequence_to_stack = (latents_batches[-1],) + tuple( + torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + latents_batches[-1] = torch.vstack(sequence_to_stack) + # Pad text_embeddings_batches + sequence_to_stack = (text_embeddings_batches[-1],) + tuple( + torch.zeros_like(text_embeddings_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + text_embeddings_batches[-1] = torch.vstack(sequence_to_stack) + # Pad uncond_embeddings_batches if necessary + if uncond_embeddings is not None: + sequence_to_stack = (uncond_embeddings_batches[-1],) + tuple( + torch.zeros_like(uncond_embeddings_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + uncond_embeddings_batches[-1] = torch.vstack(sequence_to_stack) + + # Stack batches in the same tensor + latents_batches = torch.stack(latents_batches) + if uncond_embeddings is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (uncond_embeddings_batch, text_embeddings_batch) in enumerate( + zip(uncond_embeddings_batches, text_embeddings_batches[:]) + ): + text_embeddings_batches[i] = torch.cat([uncond_embeddings_batch, text_embeddings_batch]) + text_embeddings_batches = torch.stack(text_embeddings_batches) + + return latents_batches, text_embeddings_batches, num_dummy_samples + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + batch_size: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated images. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated images. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + The number of images in a batch. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated randomly. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.GaudiStableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.GaudiStableDiffusionPipelineOutput`] or `tuple`: + [`~diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.GaudiStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + num_prompts = 1 + elif prompt is not None and isinstance(prompt, list): + num_prompts = len(prompt) + else: + num_prompts = prompt_embeds.shape[0] + num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) + logger.info( + f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device="cpu") + timesteps = self.scheduler.timesteps.to(device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + num_prompts * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Split into batches (HPU-specific step) + latents_batches, text_embeddings_batches, num_dummy_samples = self._split_inputs_into_batches( + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + ) + + outputs = { + "images": [], + "has_nsfw_concept": [], + } + t0 = time.time() + t1 = t0 + + # 8. Denoising loop + for j in self.progress_bar(range(num_batches)): + # The throughput is calculated from the 3rd iteration + # because compilation occurs in the first two iterations + if j == 2: + t1 = time.time() + + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + text_embeddings_batch = text_embeddings_batches[0] + text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) + + for i in range(num_inference_steps): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) + + capture = True if self.use_hpu_graphs and i < 2 else False + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch + ) + # latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = self.unet_hpu( + latent_model_input, + timestep, + text_embeddings_batch, + cross_attention_kwargs, + capture, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_batch = self.scheduler.step( + noise_pred, latents_batch, **extra_step_kwargs, return_dict=False + )[0] + + if not self.use_hpu_graphs: + self.htcore.mark_step() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, timestep, latents_batch) + + if not output_type == "latent": + # 8. Post-processing + image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents_batch + outputs["images"].append(image) + + self.scheduler.reset_timestep_dependent_params() + + if not self.use_hpu_graphs: + self.htcore.mark_step() + + speed_metrics_prefix = "generation" + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=num_batches * batch_size if t1 == t0 else (num_batches - 2) * batch_size, + num_steps=num_batches, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] + + # Process generated images + for i, image in enumerate(outputs["images"][:]): + if i == 0: + outputs["images"].clear() + + if output_type == "latent": + has_nsfw_concept = None + else: + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + rgb, depth = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + if output_type == "pil": + outputs["images"] += image + else: + outputs["images"] += [*image] + + if has_nsfw_concept is not None: + outputs["has_nsfw_concept"] += has_nsfw_concept + else: + outputs["has_nsfw_concept"] = None + + if not return_dict: + return ((rgb, depth), has_nsfw_concept) + + return GaudiStableDiffusionLDM3DPipelineOutput( + rgb=rgb, + depth=depth, + nsfw_content_detected=has_nsfw_concept, + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) + + @torch.no_grad() + def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, capture): + if self.use_hpu_graphs: + return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture) + else: + return self.unet( + latent_model_input, + timestep, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + @torch.no_grad() + def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture): + inputs = [latent_model_input, timestep, encoder_hidden_states, False] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if capture: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + outputs = self.unet(inputs[0], inputs[1], inputs[2], inputs[3])[0] + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 2eab7623da..dc93c7def7 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -27,7 +27,12 @@ from transformers.testing_utils import slow from optimum.habana import GaudiConfig -from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiDiffusionPipeline, GaudiStableDiffusionPipeline +from optimum.habana.diffusers import ( + GaudiDDIMScheduler, + GaudiDiffusionPipeline, + GaudiStableDiffusionLDM3DPipeline, + GaudiStableDiffusionPipeline, +) from optimum.habana.utils import set_seed @@ -609,3 +614,35 @@ def test_no_generation_regression(self): self.assertEqual(image.shape, (512, 512, 3)) self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3) + + @slow + def test_no_generation_regression_ldm3d(self): + model_name = "Intel/ldm3d-4c" + # fp32 + with hmp.disable_casts(): + scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipeline = GaudiStableDiffusionLDM3DPipeline.from_pretrained( + model_name, + scheduler=scheduler, + safety_checker=None, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=GaudiConfig(use_habana_mixed_precision=False), + ) + set_seed(27) + outputs = pipeline( + prompt="An image of a squirrel in Picasso style", + output_type="np", + ) + + expected_slice_rgb = np.array([0.7083766, 1.0, 1.0, 0.70610344, 0.9867363, 1.0, 0.7214538, 1.0, 1.0]) + expected_slice_depth = np.array( + [0.919621, 0.92072034, 0.9184986, 0.91994286, 0.9242079, 0.93387043, 0.92345214, 0.93558526, 0.9223714] + ) + rgb = outputs.rgb[0] + depth = outputs.depth[0] + + self.assertEqual(rgb.shape, (512, 512, 3)) + self.assertEqual(depth.shape, (512, 512, 1)) + self.assertLess(np.abs(expected_slice_rgb - rgb[-3:, -3:, -1].flatten()).max(), 5e-3) + self.assertLess(np.abs(expected_slice_depth - depth[-3:, -3:, -1].flatten()).max(), 5e-3)