diff --git a/examples/stable_diffusion/README.md b/examples/stable_diffusion/README.md new file mode 100644 index 000000000..71eff9f96 --- /dev/null +++ b/examples/stable_diffusion/README.md @@ -0,0 +1,28 @@ +# Stable Diffusion XL + +This document elaborates how to build the [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model to runnable engines on single or multiple GPUs and perform a image generation task using these engines. + +The design of distributed parallel inference comes from the CVPR 2024 paper [Distrifusion](https://github.com/mit-han-lab/distrifuser). In order to reduce the difficulty of implementation, all communications in the example are synchronous. + +## Usage + +### 1. Build TensorRT Engine(s) + +```bash +# 1 gpu +python build_sdxl_unet.py --size 1024 + +# 2 gpus +mpirun -n 2 python build_sdxl_unet.py --size 1024 +``` + +### 2. Generate images using the engine(s) + + +```bash +# 1 gpu +python run_sdxl.py --size 1024 --prompt "flowers, rabbit" + +# 2 gpus +mpirun -n 2 python run_sdxl.py --size 1024 --prompt "flowers, rabbit" +``` diff --git a/examples/stable_diffusion/build_sdxl_unet.py b/examples/stable_diffusion/build_sdxl_unet.py new file mode 100755 index 000000000..65e462e8e --- /dev/null +++ b/examples/stable_diffusion/build_sdxl_unet.py @@ -0,0 +1,144 @@ +import argparse +import os + +import tensorrt as trt +import torch +from diffusers import DiffusionPipeline + +import tensorrt_llm +from tensorrt_llm.builder import Builder +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.unet.pp.unet_pp import DistriUNetPP +from tensorrt_llm.models.unet.unet_2d_condition import UNet2DConditionModel +from tensorrt_llm.models.unet.weights import load_from_hf_unet +from tensorrt_llm.network import net_guard + +parser = argparse.ArgumentParser(description='build the UNet TensorRT engine.') +parser.add_argument('--size', type=int, default=1024, help='image size') +parser.add_argument('--output_dir', + type=str, + default=None, + help='output directory') + +args = parser.parse_args() + +size = args.size +sample_size = size // 8 + +world_size = tensorrt_llm.mpi_world_size() +rank = tensorrt_llm.mpi_rank() +output_dir = f'sdxl_s{size}_w{world_size}' if args.output_dir is None else args.output_dir +if rank == 0 and not os.path.exists(output_dir): + os.makedirs(output_dir) + +device_per_batch = world_size // 2 if world_size > 1 else 1 +batch_group = 2 if world_size > 1 else 1 + +# Use tp_size to indicate the size of patch parallelism +# Use pp_size to indicate the size of batch parallelism +mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=device_per_batch, + pp_size=batch_group) + +torch.cuda.set_device(tensorrt_llm.mpi_rank()) + +tensorrt_llm.logger.set_level('verbose') +builder = Builder() +builder_config = builder.create_builder_config( + name='UNet2DConditionModel', + precision='float16', + timing_cache='model.cache', + profiling_verbosity='detailed', + tensor_parallel=world_size, + precision_constraints= + None, # do not use obey or the precision error will be too large +) + +pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) +model = UNet2DConditionModel( + sample_size=sample_size, + in_channels=4, + out_channels=4, + center_input_sample=False, + flip_sin_to_cos=True, + freq_shift=0, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + block_out_channels=(320, 640, 1280), + layers_per_block=2, + downsample_padding=1, + mid_block_scale_factor=1.0, + act_fn="silu", + norm_num_groups=32, + norm_eps=1e-5, + cross_attention_dim=2048, + attention_head_dim=[5, 10, 20], + addition_embed_type="text_time", + addition_time_embed_dim=256, + projection_class_embeddings_input_dim=2816, + transformer_layers_per_block=[1, 2, 10], + use_linear_projection=True, + dtype=trt.float16, +) + +load_from_hf_unet(pipeline.unet, model) +model = DistriUNetPP(model, mapping) + +# Module -> Network +network = builder.create_network() +network.plugin_config.to_legacy_setting() +if mapping.world_size > 1: + network.plugin_config.set_nccl_plugin('float16') + +with net_guard(network): + # Prepare + network.set_named_parameters(model.named_parameters()) + + # Forward + sample = tensorrt_llm.Tensor( + name='sample', + dtype=trt.float16, + shape=[2, 4, sample_size, sample_size], + ) + timesteps = tensorrt_llm.Tensor( + name='timesteps', + dtype=trt.float16, + shape=[ + 1, + ], + ) + encoder_hidden_states = tensorrt_llm.Tensor( + name='encoder_hidden_states', + dtype=trt.float16, + shape=[2, 77, 2048], + ) + text_embeds = tensorrt_llm.Tensor( + name='text_embeds', + dtype=trt.float16, + shape=[2, 1280], + ) + time_ids = tensorrt_llm.Tensor( + name='time_ids', + dtype=trt.float16, + shape=[2, 6], + ) + + output = model(sample, timesteps, encoder_hidden_states, text_embeds, + time_ids) + + # Mark outputs + output_dtype = trt.float16 + output.mark_output('pred', output_dtype) + +# Network -> Engine +engine = builder.build_engine(network, builder_config) +assert engine is not None, 'Failed to build engine.' + +engine_name = f'sdxl_unet_s{size}_w{world_size}_r{rank}.engine' +engine_path = os.path.join(output_dir, engine_name) +with open(engine_path, 'wb') as f: + f.write(engine) +builder.save_config(builder_config, os.path.join(output_dir, 'config.json')) diff --git a/examples/stable_diffusion/pipeline_stable_diffusion_xl.py b/examples/stable_diffusion/pipeline_stable_diffusion_xl.py new file mode 100755 index 000000000..2c8230868 --- /dev/null +++ b/examples/stable_diffusion/pipeline_stable_diffusion_xl.py @@ -0,0 +1,1364 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import tensorrt as trt +import torch +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin) +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import (AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import \ + StableDiffusionXLPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import (USE_PEFT_BACKEND, deprecate, + is_invisible_watermark_available, + is_torch_xla_available, logging, + replace_example_docstring, scale_lora_layers, + unscale_lora_layers) +from diffusers.utils.torch_utils import randn_tensor +from transformers import (CLIPImageProcessor, CLIPTextModel, + CLIPTextModelWithProjection, CLIPTokenizer, + CLIPVisionModelWithProjection) + +import tensorrt_llm +from tensorrt_llm.runtime import Session, TensorInfo + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import \ + StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), + keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + \ + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPipeline( + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available( + ) + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.execution_device = torch.device('cpu') + self.engine = {} + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + silence_dtype_warnings: bool = False, + ): + super().to(torch_device) + if isinstance(torch_device, str): + torch_device = torch.device(torch_device) + self.execution_device = torch_device + return self + + def prepare(self, path, size): + self.unet.cpu() + torch.cuda.empty_cache() + + def trt_dtype_to_torch(dtype): + if dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + elif dtype == trt.int32: + return torch.int32 + else: + raise TypeError("%s is not supported" % dtype) + + config_path = os.path.join(path, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + config['builder_config']['precision'] + world_size = config['builder_config']['tensor_parallel'] + + runtime_world_size = tensorrt_llm.mpi_world_size() + assert world_size == runtime_world_size, f'Engine world size ({world_size}) != Runtime world size ({runtime_world_size})' + runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0 + torch.cuda.set_device(runtime_rank) + + serialize_file = f'sdxl_unet_s{size}_w{world_size}_r{runtime_rank}.engine' + serialize_path = os.path.join(path, serialize_file) + self.stream = torch.cuda.current_stream().cuda_stream + print(f'Loading engine from {serialize_path}') + with open(serialize_path, 'rb') as f: + engine_buffer = f.read() + print(f'Creating session from engine') + self.session = Session.from_serialized_engine(engine_buffer) + + output_info = self.session.infer_shapes([ + TensorInfo('sample', trt.DataType.HALF, + [2, 4, size // 8, size // 8]), + TensorInfo('timesteps', trt.DataType.HALF, [ + 1, + ]), + TensorInfo('encoder_hidden_states', trt.DataType.HALF, + [2, 77, 2048]), + TensorInfo('text_embeds', trt.DataType.HALF, [2, 1280]), + TensorInfo('time_ids', trt.DataType.HALF, [2, 6]), + ]) + self.outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = self.execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance( + self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, + lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, + lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2 + ] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ([ + self.text_encoder, self.text_encoder_2 + ] if self.text_encoder is not None else [self.text_encoder_2]) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, + text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, + padding="longest", + return_tensors="pt").input_ids + + if untruncated_ids.shape[ + -1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1:-1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}") + + prompt_embeds = text_encoder(text_input_ids.to(device), + output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like( + pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * \ + [negative_prompt] if isinstance( + negative_prompt, str) else negative_prompt + negative_prompt_2 = (batch_size * [negative_prompt_2] if isinstance( + negative_prompt_2, str) else negative_prompt_2) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}.") + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`.") + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt( + negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[ + -2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, + dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, + device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, + device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, + seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat( + 1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, + -1) + + if self.text_encoder is not None: + if isinstance( + self, + StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance( + self, + StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, + return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, + dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) + or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two.") + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two.") + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) + and not isinstance(prompt, list)): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and (not isinstance(prompt_2, str) + and not isinstance(prompt_2, list)): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}.") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None): + shape = (batch_size, num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, + generator=generator, + device=device, + dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, + original_size, + crops_coords_top_left, + target_size, + dtype, + text_encoder_projection_dim=None): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + + text_encoder_projection_dim) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, + w, + embedding_dim=512, + dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.execution_device + + # 3. Encode input prompt + lora_scale = (self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None else None) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], + dim=0) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], + dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + + # 8. Denoising loop + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if (self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 and self.denoising_end < 1): + discrete_timestep_cutoff = int( + round(self.scheduler.config.num_train_timesteps - + (self.denoising_end * + self.scheduler.config.num_train_timesteps))) + num_inference_steps = len( + list( + filter(lambda ts: ts >= discrete_timestep_cutoff, + timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - + 1).repeat( + batch_size * + num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim).to( + device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents] * + 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids + } + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + t = t.to(latent_model_input.dtype) + feed_dict = { + 'sample': latent_model_input, + 'timesteps': t.unsqueeze(0), + 'encoder_hidden_states': prompt_embeds, + 'text_embeds': add_text_embeds, + 'time_ids': add_time_ids, + } + ok = self.session.run(feed_dict, self.outputs, self.stream) + assert ok, "Runtime execution failed" + noise_pred = self.outputs['pred'] + torch.cuda.synchronize() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, + t, + latents, + **extra_step_kwargs, + return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds) + add_time_ids = callback_outputs.pop("add_time_ids", + add_time_ids) + negative_add_time_ids = callback_outputs.pop( + "negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and + (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, + return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, + output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, ) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/examples/stable_diffusion/run_sdxl.py b/examples/stable_diffusion/run_sdxl.py new file mode 100644 index 000000000..51ee1f222 --- /dev/null +++ b/examples/stable_diffusion/run_sdxl.py @@ -0,0 +1,58 @@ +import argparse +import time + +import numpy as np +import torch +from pipeline_stable_diffusion_xl import StableDiffusionXLPipeline + +import tensorrt_llm + +world_size = tensorrt_llm.mpi_world_size() +rank = tensorrt_llm.mpi_rank() + +parser = argparse.ArgumentParser( + description='run SDXL with the UNet TensorRT engine.') +parser.add_argument('--size', type=int, default=1024) +parser.add_argument('--seed', type=int, default=233) +parser.add_argument('--num_inference_steps', type=int, default=50) +parser.add_argument( + '--prompt', + type=str, + default= + "masterpiece, gouache painting, 1girl, distant view, lone boat, willow trees" +) +parser.add_argument('--model_dir', + type=str, + default=None, + help='model directory') + +args = parser.parse_args() +size = args.size +seed = args.seed +prompt = args.prompt +num_inference_steps = args.num_inference_steps +model_dir = f'sdxl_s{size}_w{world_size}' if args.model_dir is None else args.model_dir + +pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + use_safetensors=True, +) +pipeline.set_progress_bar_config(disable=rank != 0) +pipeline.prepare(f'sdxl_s{size}_w{world_size}', size) +pipeline.to('cuda') + +li = [] +for i in range(10): + st = time.time() + image = pipeline(num_inference_steps=num_inference_steps, + prompt=prompt, + generator=torch.Generator(device="cuda").manual_seed(seed), + height=size, + width=size).images[0] + ed = time.time() + li.append(ed - st) + +if rank == 0: + print(f'Avg latency: {np.sum(li[-7:]) / 7.0}s') + image.save(f"output.png") diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 65390bcc7..84654596d 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -134,6 +134,7 @@ def create_builder_config(self, profiling_verbosity: str = "layer_names_only", use_strip_plan: bool = False, weight_streaming: bool = False, + precision_constraints: Optional[str] = "obey", **kwargs) -> BuilderConfig: ''' @brief Create a builder config with given precisions and timing cache @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS @@ -178,16 +179,18 @@ def create_builder_config(self, fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache() if precision == 'float16' or precision == trt.DataType.HALF: config.set_flag(trt.BuilderFlag.FP16) - config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + if precision_constraints == 'obey': + config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) elif precision == 'bfloat16' or precision == trt.DataType.BF16: config.set_flag(trt.BuilderFlag.BF16) - config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + if precision_constraints == 'obey': + config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) if int8: config.set_flag(trt.BuilderFlag.INT8) - if fp8: config.set_flag(trt.BuilderFlag.FP8) - config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + if precision_constraints == 'obey': + config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806, True) diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index b89ffff99..28c708c22 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -493,6 +493,12 @@ def split(self, split_size_or_sections, dim=0): ''' return split(self, split_size_or_sections, dim) + def select(self, dim, index): + ''' + See functional.select. + ''' + return select(self, dim, index) + def unbind(self, dim=0): ''' See functional.unbind. @@ -1145,7 +1151,8 @@ def slice(input: Tensor, starts: Union[Tensor, Sequence[int]], sizes: Union[Tensor, Sequence[int]], strides: Union[Tensor, Sequence[int]] = None, - mode: trt.SampleMode = None) -> Tensor: + mode: trt.SampleMode = None, + fill_value: Union[float, Tensor] = None) -> Tensor: ''' Add an operation to extract a slice from a tensor. @@ -1220,6 +1227,9 @@ def slice(input: Tensor, if isinstance(strides, Tensor) or strides is None: trt_strides = [1 for _ in range(input_ndim)] + if fill_value is not None and isinstance(fill_value, float): + fill_value = constant(fp32_array(fill_value)) + layer = default_trtnet().add_slice(input.trt_tensor, start=trt_starts, shape=trt_sizes, @@ -1236,6 +1246,9 @@ def slice(input: Tensor, if isinstance(strides, Tensor): layer.set_input(3, strides.trt_tensor) + if mode is trt.SampleMode.FILL and isinstance(fill_value, Tensor): + layer.set_input(4, fill_value.trt_tensor) + return _create_tensor(layer.get_output(0), layer) @@ -2971,7 +2984,7 @@ def log_softmax(input: Tensor, dim: int) -> Tensor: def reduce(input: Tensor, op: trt.ReduceOperation, - dim: int, + dim: Union[int, Tuple[int]], keepdim: bool = False) -> Tensor: ''' Add an reduction operation to do along a dimension. @@ -3010,7 +3023,9 @@ def reduce(input: Tensor, min = partial(reduce, op=trt.ReduceOperation.MIN) -def mean(input: Tensor, dim: int, keepdim: bool = False) -> Tensor: +def mean(input: Tensor, + dim: Union[int, Tuple[int]], + keepdim: bool = False) -> Tensor: ''' Add an operation to compute the mean along a dimension. @@ -3248,15 +3263,19 @@ def group_norm(input: Tensor, ] + [input.size(i) for i in range(2, ndim)]) x = input.view(new_shape) - reduce_dim = tuple(range(2, ndim + 1)) - ux = x.mean(dim=reduce_dim, keepdim=True) - numerator = x - ux - varx = numerator * numerator - varx = varx.mean(dim=reduce_dim, keepdim=True) - - denom = varx + eps - denom = denom.sqrt() - y = numerator / denom + # instance norm + w_shape = [1, num_groups] + [1 for i in range(ndim - 1)] + instance_weight = constant(np.ones(w_shape, dtype=trt_dtype_to_np(x.dtype))) + instance_bias = constant(np.zeros(w_shape, dtype=trt_dtype_to_np(x.dtype))) + axes_mask = 0 + for i in range(2, x.ndim()): + axes_mask |= 1 << i + layer = default_trtnet().add_normalization(x.trt_tensor, + instance_weight.trt_tensor, + instance_bias.trt_tensor, + axes_mask) + layer.epsilon = eps + y = _create_tensor(layer.get_output(0), layer) y = y.view(old_shape) new_shape = concat([num_channels] + [1 for _ in range(2, ndim)]) diff --git a/tensorrt_llm/models/unet/attention.py b/tensorrt_llm/models/unet/attention.py index 6c803f914..c56605106 100644 --- a/tensorrt_llm/models/unet/attention.py +++ b/tensorrt_llm/models/unet/attention.py @@ -105,25 +105,35 @@ def _transpose_for_scores(tensor, heads): def _attention(query, key, value, scale): - attention_scores = matmul(query, key.transpose(-1, -2)) - attention_scores = attention_scores * scale + # Multiply scale first to avoid overflow + # Do not use use_fp32_acc or it will be very slow + attention_scores = matmul(query * math.sqrt(scale), + key.transpose(-1, -2) * math.sqrt(scale), + use_fp32_acc=False) attention_probs = softmax(attention_scores, dim=-1) - hidden_states = matmul(attention_probs, value) + hidden_states = matmul(attention_probs, value, use_fp32_acc=False) hidden_states = hidden_states.permute([0, 2, 1, 3]) return hidden_states class SelfAttention(Module): - def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64): + def __init__(self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dtype=None): super().__init__() self.inner_dim = dim_head * heads self.scale = dim_head**-0.5 self.heads = heads self._slice_size = None - self.to_qkv = Linear(query_dim, 3 * self.inner_dim, bias=False) - self.to_out = Linear(self.inner_dim, query_dim) + self.to_qkv = Linear(query_dim, + 3 * self.inner_dim, + bias=False, + dtype=dtype) + self.to_out = Linear(self.inner_dim, query_dim, dtype=dtype) def forward(self, hidden_states, mask=None): assert not hidden_states.is_dynamic() @@ -148,7 +158,8 @@ def __init__(self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, - dim_head: int = 64): + dim_head: int = 64, + dtype=None): super().__init__() self.inner_dim = dim_head * heads context_dim = context_dim if context_dim is not None else query_dim @@ -156,9 +167,12 @@ def __init__(self, self.heads = heads self._slice_size = None - self.to_q = Linear(query_dim, self.inner_dim, bias=False) - self.to_kv = Linear(context_dim, 2 * self.inner_dim, bias=False) - self.to_out = Linear(self.inner_dim, query_dim) + self.to_q = Linear(query_dim, self.inner_dim, bias=False, dtype=dtype) + self.to_kv = Linear(context_dim, + 2 * self.inner_dim, + bias=False, + dtype=dtype) + self.to_out = Linear(self.inner_dim, query_dim, dtype=dtype) def forward(self, hidden_states, context=None, mask=None): assert not hidden_states.is_dynamic() @@ -182,12 +196,16 @@ def forward(self, hidden_states, context=None, mask=None): class FeedForward(Module): - def __init__(self, dim: int, dim_out: Optional[int] = None, mult: int = 4): + def __init__(self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dtype=None): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - self.proj_in = Linear(dim, inner_dim * 2) - self.proj_out = Linear(inner_dim, dim_out) + self.proj_in = Linear(dim, inner_dim * 2, dtype=dtype) + self.proj_out = Linear(inner_dim, dim_out, dtype=dtype) def forward(self, hidden_states): x = self.proj_in(hidden_states) @@ -203,20 +221,23 @@ def __init__( n_heads: int, d_head: int, context_dim: Optional[int] = None, + dtype=None, ): super().__init__() self.attn1 = SelfAttention(query_dim=dim, heads=n_heads, - dim_head=d_head) # is a self-attention - self.ff = FeedForward(dim) + dim_head=d_head, + dtype=dtype) # is a self-attention + self.ff = FeedForward(dim, dtype=dtype) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, heads=n_heads, - dim_head=d_head) # is self-attn if context is none - self.norm1 = LayerNorm(dim) - self.norm2 = LayerNorm(dim) - self.norm3 = LayerNorm(dim) + dim_head=d_head, + dtype=dtype) # is self-attn if context is none + self.norm1 = LayerNorm(dim, dtype=dtype) + self.norm2 = LayerNorm(dim, dtype=dtype) + self.norm3 = LayerNorm(dim, dtype=dtype) def forward(self, hidden_states, context=None): hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states @@ -236,8 +257,11 @@ def __init__( num_layers: int = 1, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + dtype=None, ): super().__init__() + self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim @@ -245,39 +269,64 @@ def __init__( self.norm = GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, - affine=True) - - self.proj_in = Conv2d(in_channels, - inner_dim, - kernel_size=(1, 1), - stride=(1, 1), - padding=(0, 0)) + affine=True, + dtype=dtype) + + if use_linear_projection: + self.proj_in = Linear(in_channels, inner_dim, dtype=dtype) + else: + self.proj_in = Conv2d(in_channels, + inner_dim, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dtype=dtype) self.transformer_blocks = ModuleList([ BasicTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, - context_dim=cross_attention_dim) - for d in range(num_layers) + context_dim=cross_attention_dim, + dtype=dtype) for d in range(num_layers) ]) - self.proj_out = Conv2d(inner_dim, - in_channels, - kernel_size=(1, 1), - stride=(1, 1), - padding=(0, 0)) + + if use_linear_projection: + self.proj_out = Linear(inner_dim, in_channels, dtype=dtype) + else: + self.proj_out = Conv2d(inner_dim, + in_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dtype=dtype) def forward(self, hidden_states, context=None): assert not hidden_states.is_dynamic() batch, _, height, weight = hidden_states.size() residual = hidden_states hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.size()[1] - hidden_states = hidden_states.permute([0, 2, 3, 1]).view( - [batch, height * weight, inner_dim]) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.size()[1] + hidden_states = hidden_states.permute([0, 2, 3, 1]).view( + [batch, height * weight, inner_dim]) + else: + inner_dim = hidden_states.size()[1] + hidden_states = hidden_states.permute([0, 2, 3, 1]).view( + [batch, height * weight, inner_dim]) + hidden_states = self.proj_in(hidden_states) + for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.view([batch, height, weight, - inner_dim]).permute([0, 3, 1, 2]) - hidden_states = self.proj_out(hidden_states) + + if not self.use_linear_projection: + hidden_states = hidden_states.view( + [batch, height, weight, inner_dim]).permute([0, 3, 1, 2]) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.view( + [batch, height, weight, inner_dim]).permute([0, 3, 1, 2]) + return hidden_states + residual diff --git a/tensorrt_llm/models/unet/embeddings.py b/tensorrt_llm/models/unet/embeddings.py index dca425519..0a0a77b99 100644 --- a/tensorrt_llm/models/unet/embeddings.py +++ b/tensorrt_llm/models/unet/embeddings.py @@ -25,7 +25,8 @@ def get_timestep_embedding(timesteps, flip_sin_to_cos=False, downscale_freq_shift=1.0, scale=1.0, - max_period=10000): + max_period=10000, + dtype=None): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. @@ -69,14 +70,14 @@ def get_timestep_embedding(timesteps, class TimestepEmbedding(Module): - def __init__(self, channel, time_embed_dim, act_fn="silu"): + def __init__(self, channel, time_embed_dim, act_fn="silu", dtype=None): super().__init__() - self.linear_1 = Linear(channel, time_embed_dim) + self.linear_1 = Linear(channel, time_embed_dim, dtype=dtype) self.act = None if act_fn == "silu": self.act = silu - self.linear_2 = Linear(time_embed_dim, time_embed_dim) + self.linear_2 = Linear(time_embed_dim, time_embed_dim, dtype=dtype) def forward(self, sample): sample = self.linear_1(sample) @@ -90,11 +91,16 @@ def forward(self, sample): class Timesteps(Module): - def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): + def __init__(self, + num_channels, + flip_sin_to_cos, + downscale_freq_shift, + dtype=None): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift + self.dtype = dtype def forward(self, timesteps): t_emb = get_timestep_embedding( @@ -102,5 +108,5 @@ def forward(self, timesteps): self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, - ) + dtype=self.dtype) return t_emb diff --git a/tensorrt_llm/models/unet/pp/__init__.py b/tensorrt_llm/models/unet/pp/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/tensorrt_llm/models/unet/pp/attention.py b/tensorrt_llm/models/unet/pp/attention.py new file mode 100755 index 000000000..f64d666b0 --- /dev/null +++ b/tensorrt_llm/models/unet/pp/attention.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ....functional import allgather, split +from ....mapping import Mapping +from ....module import Module +from ..attention import CrossAttention, SelfAttention, _attention + + +class DistriSelfAttentionPP(Module): + + def __init__(self, module: SelfAttention, mapping: Mapping = Mapping()): + super().__init__() + self.mapping = mapping + self.module = module + + def forward(self, hidden_states): + mapping = self.mapping + attn = self.module + + batch_size, sequence_length, _ = hidden_states.shape + + qkv = attn.to_qkv(hidden_states) + + query, kv = split(qkv, [attn.inner_dim, attn.inner_dim * 2], dim=2) + + if mapping.tp_size == 1: + full_kv = kv + else: + full_kv = allgather(kv, group=mapping.tp_group, gather_dim=1) + + key, value = split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view([batch_size, -1, attn.heads, + head_dim]).transpose(1, 2) + key = key.view([batch_size, -1, attn.heads, head_dim]).transpose(1, 2) + value = value.view([batch_size, -1, attn.heads, + head_dim]).transpose(1, 2) + + hidden_states = _attention(query, key, value, attn.scale) + + hidden_states = hidden_states.view( + [batch_size, -1, attn.heads * head_dim]) + + # linear proj + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +class DistriCrossAttentionPP(Module): + + def __init__(self, module: CrossAttention, mapping: Mapping = Mapping()): + super().__init__() + self.mapping = mapping + self.module = module + self.kv_cache = None + + def forward(self, hidden_states, context): + attn = self.module + recompute_kv = self.kv_cache is None + + if context is None: + context = hidden_states + + batch_size, sequence_length, _ = context.shape + + query = attn.to_q(hidden_states) + + if recompute_kv or self.kv_cache is None: + kv = attn.to_kv(context) + self.kv_cache = kv + else: + kv = self.kv_cache + key, value = split(kv, kv.shape[-1] // 2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view([batch_size, -1, attn.heads, + head_dim]).transpose(1, 2) + key = key.view([batch_size, -1, attn.heads, head_dim]).transpose(1, 2) + value = value.view([batch_size, -1, attn.heads, + head_dim]).transpose(1, 2) + + hidden_states = _attention(query, key, value, scale=attn.scale) + + hidden_states = hidden_states.view( + [batch_size, -1, attn.heads * head_dim]) + + # linear proj + hidden_states = attn.to_out(hidden_states) + return hidden_states diff --git a/tensorrt_llm/models/unet/pp/conv2d.py b/tensorrt_llm/models/unet/pp/conv2d.py new file mode 100755 index 000000000..921f1b794 --- /dev/null +++ b/tensorrt_llm/models/unet/pp/conv2d.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorrt as trt + +from ....functional import allgather, concat, conv2d, slice, stack, unsqueeze +from ....layers import Conv2d +from ....mapping import Mapping +from ....module import Module + + +def pad(input, pad): + assert input.ndim() == 4 + n, c, h, w = input.shape + padded_input = slice(input, + starts=[0, 0, -pad[2], -pad[0]], + sizes=[n, c, pad[2] + h + pad[3], pad[0] + w + pad[1]], + mode=trt.SampleMode.FILL, + fill_value=0.0) + return padded_input + + +class DistriConv2dPP(Module): + + def __init__(self, + conv: Conv2d, + mapping: Mapping = Mapping(), + is_first_layer: bool = False): + super().__init__() + self.mapping = mapping + self.conv = conv + self.is_first_layer = is_first_layer + + def sliced_forward(self, x): + mapping = self.mapping + b, c, h, w = x.shape + assert h % mapping.tp_size == 0 + + stride = self.conv.stride[0] + padding = self.conv.padding[0] + + output_h = x.shape[2] // stride // mapping.tp_size + idx = mapping.tp_rank + h_begin = output_h * idx * stride - padding + h_end = output_h * (idx + 1) * stride + padding + final_padding = [padding, padding, 0, 0] + if h_begin < 0: + h_begin = 0 + final_padding[2] = padding + if h_end > h: + h_end = h + final_padding[3] = padding + sliced_input = slice(x, [0, 0, h_begin, 0], [b, c, h_end - h_begin, w]) + padded_input = pad(sliced_input, final_padding) + return conv2d(padded_input, + self.conv.weight.value, + None if self.conv.bias is None else self.conv.bias.value, + stride=self.conv.stride, + padding=(0, 0)) + + def forward(self, x, *args, **kwargs): + mapping = self.mapping + if self.is_first_layer: + full_x = x + output = self.sliced_forward(full_x) + else: + boundary_size = self.conv.padding[0] + + def create_padded_x(x, boundaries): + if mapping.tp_rank == 0: + b = boundaries.select(0, mapping.tp_rank + 1).select(0, 0) + concat_x = concat([x, b], dim=2) + padded_x = pad(concat_x, [0, 0, boundary_size, 0]) + elif mapping.tp_rank == mapping.tp_size - 1: + b = boundaries.select(0, mapping.tp_rank - 1).select(0, 1) + concat_x = concat([b, x], dim=2) + padded_x = pad(concat_x, [0, 0, 0, boundary_size]) + else: + b0 = boundaries.select(0, mapping.tp_rank - 1).select(0, 1) + b1 = boundaries.select(0, mapping.tp_rank + 1).select(0, 0) + padded_x = concat( + [ + b0, + x, + b1, + ], + dim=2, + ) + return padded_x + + n, c, h, w = x.shape + b0 = slice(x, [0, 0, 0, 0], [n, c, boundary_size, w]) + b1 = slice(x, [0, 0, h - boundary_size, 0], + [n, c, boundary_size, w]) + boundary = stack([b0, b1], dim=0) + + boundaries = allgather(unsqueeze(boundary, 0), + group=mapping.tp_group) + padded_x = create_padded_x(x, boundaries) + output = conv2d( + padded_x, + self.conv.weight.value, + self.conv.bias.value, + stride=self.conv.stride, + padding=(0, self.conv.padding[1]), + ) + + return output diff --git a/tensorrt_llm/models/unet/pp/groupnorm.py b/tensorrt_llm/models/unet/pp/groupnorm.py new file mode 100755 index 000000000..6cbf4d2d0 --- /dev/null +++ b/tensorrt_llm/models/unet/pp/groupnorm.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ....functional import allreduce, pow, select, stack +from ....layers import GroupNorm +from ....mapping import Mapping +from ....module import Module + + +class DistriGroupNorm(Module): + + def __init__(self, + module: GroupNorm, + mapping: Mapping = Mapping(), + is_first_layer: bool = False): + super().__init__() + self.mapping = mapping + self.module = module + + def forward(self, x, *args, **kwargs): + mapping = self.mapping + module = self.module + n, c, h, w = x.shape + num_groups = module.num_groups + group_size = c // num_groups + + x = x.view([n, num_groups, group_size, h, w]) + x_mean = x.mean(dim=4, keepdim=True).mean(dim=(3, 2), keepdim=True) + x2_mean = pow(x, 2.0).mean(dim=4, keepdim=True).mean(dim=(3, 2), + keepdim=True) + mean = stack([x_mean, x2_mean], dim=0) + mean = allreduce(mean, mapping.tp_group) + mean = mean / (mapping.tp_size * 1.0) + x_mean = select(mean, 0, 0) + x2_mean = select(mean, 0, 1) + var = x2_mean - pow(x_mean, 2.0) + num_elements = group_size * h * w + var = var * (num_elements / (num_elements - 1)) + std = (var + module.eps).sqrt() + output = (x - x_mean) / std + output = output.view([n, c, h, w]) + if module.affine: + output = output * module.weight.value.view([1, -1, 1, 1]) + output = output + module.bias.value.view([1, -1, 1, 1]) + + return output diff --git a/tensorrt_llm/models/unet/pp/unet_pp.py b/tensorrt_llm/models/unet/pp/unet_pp.py new file mode 100755 index 000000000..9b1a7d64a --- /dev/null +++ b/tensorrt_llm/models/unet/pp/unet_pp.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ....functional import allgather, concat, slice, stack +from ....layers import Conv2d, GroupNorm +from ....mapping import Mapping +from ....module import Module +from ..attention import CrossAttention, SelfAttention +from ..unet_2d_condition import UNet2DConditionModel +from .attention import DistriCrossAttentionPP, DistriSelfAttentionPP +from .conv2d import DistriConv2dPP +from .groupnorm import DistriGroupNorm + + +class DistriUNetPP(Module): + + def __init__(self, + model: UNet2DConditionModel, + mapping: Mapping = Mapping()): + super().__init__() + self.mapping = mapping + self.model = model + if mapping.tp_size > 1: + for name, module in model.named_modules(): + if isinstance(module, DistriConv2dPP) or isinstance(module, DistriSelfAttentionPP) \ + or isinstance(module, DistriCrossAttentionPP) or isinstance(module, DistriGroupNorm): + continue + for subname, submodule in module.named_children(): + if isinstance(submodule, Conv2d): + kernel_size = submodule.kernel_size + if kernel_size == (1, 1) or kernel_size == 1: + continue + wrapped_submodule = DistriConv2dPP( + submodule, + mapping, + is_first_layer=subname == "conv_in") + setattr(module, subname, wrapped_submodule) + elif isinstance(submodule, SelfAttention): + wrapped_submodule = DistriSelfAttentionPP( + submodule, mapping) + setattr(module, subname, wrapped_submodule) + elif isinstance(submodule, CrossAttention): + wrapped_submodule = DistriCrossAttentionPP( + submodule, mapping) + setattr(module, subname, wrapped_submodule) + elif isinstance(submodule, GroupNorm): + wrapped_submodule = DistriGroupNorm(submodule, mapping) + setattr(module, subname, wrapped_submodule) + + def forward(self, + sample, + timesteps, + encoder_hidden_states, + text_embeds=None, + time_ids=None): + mapping = self.mapping + b, c, h, w = sample.shape + + if mapping.world_size == 1: + output = self.model( + sample, + timesteps, + encoder_hidden_states, + text_embeds=text_embeds, + time_ids=time_ids, + ) + elif mapping.pp_size > 1: + assert b == 2 and mapping.pp_size == 2 + batch_idx = mapping.pp_rank + # sample[batch_idx : batch_idx + 1] + sample = slice(sample, [batch_idx, 0, 0, 0], [1, c, h, w]) + e_shape = encoder_hidden_states.shape + encoder_hidden_states = slice( + encoder_hidden_states, [batch_idx, 0, 0], + [1, e_shape[1], e_shape[2] + ]) # encoder_hidden_states[batch_idx : batch_idx + 1] + if text_embeds: + t_shape = text_embeds.shape + # text_embeds[batch_idx : batch_idx + 1] + text_embeds = slice(text_embeds, [batch_idx, 0], + [1, t_shape[1]]) + if time_ids: + t_shape = time_ids.shape + # time_ids[batch_idx : batch_idx + 1] + time_ids = slice(time_ids, [batch_idx, 0], [1, t_shape[1]]) + output = self.model( + sample, + timesteps, + encoder_hidden_states, + text_embeds=text_embeds, + time_ids=time_ids, + ) + output = allgather( + output, + [i for i in range(mapping.world_size)], + ) + patch_list = [] + for i in range(mapping.tp_size): + patch_list.append(output.select(dim=0, index=i)) + b1 = concat(patch_list, dim=1) + patch_list = [] + for i in range(mapping.tp_size, mapping.world_size): + patch_list.append(output.select(dim=0, index=i)) + b2 = concat(patch_list, dim=1) + output = stack([b1, b2], dim=0) + else: + output = self.model( + sample, + timesteps, + encoder_hidden_states, + text_embeds=text_embeds, + time_ids=time_ids, + ) + output = allgather(output, mapping.tp_group, 2) + + return output + + @property + def add_embedding(self): + return self.model.add_embedding diff --git a/tensorrt_llm/models/unet/resnet.py b/tensorrt_llm/models/unet/resnet.py index 463a5835f..38a936100 100644 --- a/tensorrt_llm/models/unet/resnet.py +++ b/tensorrt_llm/models/unet/resnet.py @@ -14,7 +14,7 @@ # limitations under the License. from functools import partial -from ...functional import avg_pool2d, interpolate, silu, view +from ...functional import avg_pool2d, interpolate, silu from ...layers import (AvgPool2d, Conv2d, ConvTranspose2d, GroupNorm, Linear, Mish) from ...module import Module @@ -26,7 +26,8 @@ def __init__(self, channels: int, use_conv=False, use_conv_transpose=False, - out_channels=None) -> None: + out_channels=None, + dtype=None) -> None: super().__init__() self.channels = channels @@ -35,11 +36,17 @@ def __init__(self, self.use_conv_transpose = use_conv_transpose self.use_conv = use_conv if self.use_conv_transpose: - self.conv = ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + self.conv = ConvTranspose2d(channels, + self.out_channels, + 4, + 2, + 1, + dtype=dtype) elif use_conv: self.conv = Conv2d(self.channels, self.out_channels, (3, 3), - padding=(1, 1)) + padding=(1, 1), + dtype=dtype) else: self.conv = None @@ -72,7 +79,8 @@ def __init__(self, channels, use_conv=False, out_channels=None, - padding=1) -> None: + padding=1, + dtype=None) -> None: super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -84,7 +92,8 @@ def __init__(self, self.conv = Conv2d(self.channels, self.out_channels, (3, 3), stride=stride, - padding=(padding, padding)) + padding=(padding, padding), + dtype=dtype) else: assert self.channels == self.out_channels self.conv = AvgPool2d(kernel_size=stride, stride=stride) @@ -121,6 +130,7 @@ def __init__( use_in_shortcut=None, up=False, down=False, + dtype=None, ): super().__init__() self.pre_norm = pre_norm @@ -140,27 +150,33 @@ def __init__( self.norm1 = GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, - affine=True) + affine=True, + dtype=dtype) self.conv1 = Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), - padding=(1, 1)) + padding=(1, 1), + dtype=dtype) if temb_channels is not None: - self.time_emb_proj = Linear(temb_channels, out_channels) + self.time_emb_proj = Linear(temb_channels, + out_channels, + dtype=dtype) else: self.time_emb_proj = None self.norm2 = GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, - affine=True) + affine=True, + dtype=dtype) self.conv2 = Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), - padding=(1, 1)) + padding=(1, 1), + dtype=dtype) if non_linearity == "swish": self.nonlinearity = lambda x: silu(x) @@ -177,7 +193,9 @@ def __init__( scale_factor=2.0, mode="nearest") else: - self.upsample = Upsample2D(in_channels, use_conv=False) + self.upsample = Upsample2D(in_channels, + use_conv=False, + dtype=dtype) elif self.down: if kernel == "sde_vp": @@ -186,7 +204,8 @@ def __init__( self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, - name="op") + name="op", + dtype=dtype) self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut @@ -195,7 +214,8 @@ def __init__( out_channels, kernel_size=(1, 1), stride=(1, 1), - padding=(0, 0)) + padding=(0, 0), + dtype=dtype) else: self.conv_shortcut = None @@ -210,12 +230,17 @@ def forward(self, input_tensor, temb): elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) + hidden_states = self.conv1(hidden_states) - if temb is not None: + if self.time_emb_proj is not None: temb = self.time_emb_proj(self.nonlinearity(temb)) new_shape = list(temb.size()) new_shape.extend([1, 1]) #[:,:,None,None] ->view - hidden_states = hidden_states + view(temb, new_shape) + temb = temb.view(new_shape) + + assert self.time_embedding_norm == "default" + if temb is not None: + hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) diff --git a/tensorrt_llm/models/unet/unet_2d_blocks.py b/tensorrt_llm/models/unet/unet_2d_blocks.py index b45dfa1c8..024792212 100644 --- a/tensorrt_llm/models/unet/unet_2d_blocks.py +++ b/tensorrt_llm/models/unet/unet_2d_blocks.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Union + from ...functional import concat from ...module import Module, ModuleList from .attention import AttentionBlock, Transformer2DModel @@ -28,8 +30,11 @@ def get_down_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + transformer_layers_per_block=1, cross_attention_dim=None, downsample_padding=None, + use_linear_projection=False, + dtype=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith( "UNetRes") else down_block_type @@ -43,6 +48,7 @@ def get_down_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, + dtype=dtype, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -50,6 +56,7 @@ def get_down_block( "cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -59,6 +66,8 @@ def get_down_block( downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + use_linear_projection=use_linear_projection, + dtype=dtype, ) raise ValueError(f"{down_block_type} does not exist.") @@ -75,7 +84,10 @@ def get_up_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + transformer_layers_per_block=1, cross_attention_dim=None, + use_linear_projection=False, + dtype=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith( "UNetRes") else up_block_type @@ -89,6 +101,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + dtype=dtype, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: @@ -96,6 +109,7 @@ def get_up_block( "cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, @@ -105,6 +119,8 @@ def get_up_block( resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + use_linear_projection=use_linear_projection, + dtype=dtype, ) raise ValueError(f"{up_block_type} does not exist.") @@ -127,6 +143,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + dtype=None, ): super().__init__() resnets = [] @@ -148,6 +165,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + dtype=dtype, )) self.resnets = ModuleList(resnets) @@ -156,7 +174,8 @@ def __init__( self.upsamplers = ModuleList([ Upsample2D(out_channels, use_conv=True, - out_channels=out_channels) + out_channels=out_channels, + dtype=dtype) ]) else: self.upsamplers = None @@ -198,6 +217,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + dtype=None, ): super().__init__() resnets = [] @@ -216,6 +236,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + dtype=dtype, )) self.resnets = ModuleList(resnets) @@ -225,7 +246,8 @@ def __init__( Downsample2D(out_channels, use_conv=True, out_channels=out_channels, - padding=downsample_padding) + padding=downsample_padding, + dtype=dtype) ]) else: self.downsamplers = None @@ -260,6 +282,7 @@ def __init__( attn_num_head_channels=1, attention_type="default", output_scale_factor=1.0, + dtype=None, **kwargs, ): super().__init__() @@ -281,6 +304,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + dtype=dtype, ) ] attentions = [] @@ -293,6 +317,7 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, num_groups=resnet_groups, + dtype=dtype, )) resnets.append( ResnetBlock2D( @@ -306,6 +331,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + dtype=dtype, )) self.attentions = ModuleList(attentions) @@ -333,6 +359,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -343,6 +370,8 @@ def __init__( attention_type="default", output_scale_factor=1.0, add_upsample=True, + use_linear_projection: bool = False, + dtype=None, ): super().__init__() resnets = [] @@ -351,6 +380,11 @@ def __init__( self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block + ] * num_layers + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels @@ -368,15 +402,19 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + dtype=dtype, )) attentions.append( Transformer2DModel(in_channels=out_channels, + num_layers=transformer_layers_per_block[i], num_attention_heads=attn_num_head_channels, attention_head_dim=out_channels // attn_num_head_channels, norm_num_groups=resnet_groups, - cross_attention_dim=cross_attention_dim)) + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + dtype=dtype)) self.attentions = ModuleList(attentions) self.resnets = ModuleList(resnets) @@ -384,7 +422,8 @@ def __init__( self.upsamplers = ModuleList([ Upsample2D(out_channels, use_conv=True, - out_channels=out_channels) + out_channels=out_channels, + dtype=dtype) ]) else: self.upsamplers = None @@ -423,6 +462,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -434,6 +474,8 @@ def __init__( output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + use_linear_projection: bool = False, + dtype=None, ): super().__init__() resnets = [] @@ -442,6 +484,11 @@ def __init__( self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block + ] * num_layers + for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -456,14 +503,18 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + dtype=dtype, )) attentions.append( Transformer2DModel(in_channels=out_channels, + num_layers=transformer_layers_per_block[i], num_attention_heads=attn_num_head_channels, attention_head_dim=out_channels // attn_num_head_channels, norm_num_groups=resnet_groups, - cross_attention_dim=cross_attention_dim)) + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + dtype=dtype)) self.attentions = ModuleList(attentions) self.resnets = ModuleList(resnets) @@ -472,7 +523,8 @@ def __init__( Downsample2D(out_channels, use_conv=True, out_channels=out_channels, - padding=downsample_padding) + padding=downsample_padding, + dtype=dtype) ]) else: self.downsamplers = None @@ -505,6 +557,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -514,6 +567,8 @@ def __init__( attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, + use_linear_projection: bool = False, + dtype=None, **kwargs, ): super().__init__() @@ -523,44 +578,50 @@ def __init__( resnet_groups = resnet_groups if resnet_groups is not None else min( in_channels // 4, 32) + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block + ] * num_layers + # there is always at least one resnet resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) + ResnetBlock2D(in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + dtype=dtype) ] attentions = [] - for _ in range(num_layers): + for i in range(num_layers): attentions.append( Transformer2DModel(in_channels=in_channels, + num_layers=transformer_layers_per_block[i], num_attention_heads=attn_num_head_channels, attention_head_dim=in_channels // attn_num_head_channels, norm_num_groups=resnet_groups, - cross_attention_dim=cross_attention_dim)) + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + dtype=dtype)) resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - )) + ResnetBlock2D(in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + dtype=dtype)) self.attentions = ModuleList(attentions) self.resnets = ModuleList(resnets) diff --git a/tensorrt_llm/models/unet/unet_2d_condition.py b/tensorrt_llm/models/unet/unet_2d_condition.py index e76b462e9..96531df8d 100644 --- a/tensorrt_llm/models/unet/unet_2d_condition.py +++ b/tensorrt_llm/models/unet/unet_2d_condition.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...functional import silu +from typing import Optional + +from ...functional import cast, concat, silu from ...layers import Conv2d, GroupNorm from ...module import Module, ModuleList from .embeddings import TimestepEmbedding, Timesteps @@ -42,28 +44,57 @@ def __init__( norm_num_groups=32, norm_eps=1e-5, cross_attention_dim=1280, + transformer_layers_per_block=1, attention_head_dim=8, + use_linear_projection=False, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + dtype=None, ): super().__init__() self.sample_size = sample_size + self.addition_embed_type = addition_embed_type time_embed_dim = block_out_channels[0] * 4 # input self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + dtype=dtype) # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, - freq_shift) + self.time_proj = Timesteps(block_out_channels[0], + flip_sin_to_cos, + freq_shift, + dtype=dtype) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, - time_embed_dim) + time_embed_dim, + dtype=dtype) + + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, + flip_sin_to_cos, + freq_shift, + dtype=dtype) + self.add_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, + time_embed_dim, + dtype=dtype) + down_blocks = [] up_blocks = [] + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim, ) * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block + ] * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -74,6 +105,7 @@ def __init__( down_block = get_down_block( down_block_type, num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -81,9 +113,10 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, - ) + use_linear_projection=use_linear_projection, + dtype=dtype) down_blocks.append(down_block) self.down_blocks = ModuleList(down_blocks) # mid @@ -93,13 +126,19 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], resnet_time_scale_shift="default", cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + dtype=dtype, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_transformer_layers_per_block = list( + reversed(transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel @@ -113,6 +152,8 @@ def __init__( up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, + transformer_layers_per_block= + reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, @@ -121,7 +162,9 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + use_linear_projection=use_linear_projection, + dtype=dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel @@ -129,16 +172,35 @@ def __init__( # out self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, - eps=norm_eps) + eps=norm_eps, + dtype=dtype) self.conv_act = silu self.conv_out = Conv2d(block_out_channels[0], out_channels, (3, 3), - padding=(1, 1)) - - def forward(self, sample, timesteps, encoder_hidden_states): + padding=(1, 1), + dtype=dtype) + + def forward(self, + sample, + timesteps, + encoder_hidden_states, + text_embeds=None, + time_ids=None): + # time t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) + aug_emb = None + if self.addition_embed_type == "text_time": + assert text_embeds is not None and time_ids is not None + time_embeds = self.add_time_proj(time_ids.view([-1])) + time_embeds = time_embeds.view([text_embeds.shape[0], -1]) + add_embeds = concat([text_embeds, time_embeds], dim=1) + add_embeds = cast(add_embeds, emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + sample = self.conv_in(sample) down_block_res_samples = (sample, ) diff --git a/tensorrt_llm/models/unet/weights.py b/tensorrt_llm/models/unet/weights.py index 76153d8f0..9a8e4c7c6 100644 --- a/tensorrt_llm/models/unet/weights.py +++ b/tensorrt_llm/models/unet/weights.py @@ -31,8 +31,9 @@ def update_crossattn_downblock_2d_weight(src, dst): for index, value in enumerate(src.attentions): update_transformer_2d_model_weight(dst.attentions[index], value) - for index, value in enumerate(src.downsamplers): - dst.downsamplers[index].conv.update_parameters(value.conv) + if src.downsamplers: + for index, value in enumerate(src.downsamplers): + dst.downsamplers[index].conv.update_parameters(value.conv) def update_transformer_2d_model_weight(gm, m): @@ -160,17 +161,25 @@ def update_unet_2d_condition_model_weights(src, dst): dst.conv_in.update_parameters(src.conv_in) dst.time_embedding.update_parameters(src.time_embedding) + if src.config.addition_embed_type: + dst.add_embedding.update_parameters(src.add_embedding) - for index, value in enumerate(src.down_blocks[:-1]): - update_crossattn_downblock_2d_weight(value, dst.down_blocks[index]) - - update_downblock_2d_weight(src.down_blocks[-1], dst.down_blocks[-1]) + for index, type in enumerate(src.config.down_block_types): + if type == 'CrossAttnDownBlock2D': + update_crossattn_downblock_2d_weight(src.down_blocks[index], + dst.down_blocks[index]) + elif type == 'DownBlock2D': + update_downblock_2d_weight(src.down_blocks[index], + dst.down_blocks[index]) update_unet_mid_block_2d_weight(src.mid_block, dst.mid_block) - update_upblock_2d_weight(src.up_blocks[0], dst.up_blocks[0]) - for index, value in enumerate(src.up_blocks[1:]): - update_crossattn_upblock_2d_weight(value, dst.up_blocks[index + 1]) + for index, type in enumerate(src.config.up_block_types): + if type == 'CrossAttnUpBlock2D': + update_crossattn_upblock_2d_weight(src.up_blocks[index], + dst.up_blocks[index]) + elif type == 'UpBlock2D': + update_upblock_2d_weight(src.up_blocks[index], dst.up_blocks[index]) dst.conv_norm_out.update_parameters(src.conv_norm_out)