diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 36963a986d..e77f649f69 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1112,7 +1112,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: class VaeEncoderOnnxConfig(VisionOnnxConfig): - ATOL_FOR_VALIDATION = 1e-2 + ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 @@ -1132,12 +1132,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: return { - "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, + "latent_parameters": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, } class VaeDecoderOnnxConfig(VisionOnnxConfig): - ATOL_FOR_VALIDATION = 1e-3 + ATOL_FOR_VALIDATION = 1e-4 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu # operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index e2125736c4..949b54f468 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -46,11 +46,6 @@ from diffusers import ( DiffusionPipeline, - LatentConsistencyModelImg2ImgPipeline, - LatentConsistencyModelPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, @@ -92,27 +87,13 @@ def _get_submodels_for_export_diffusion( Returns the components of a Stable Diffusion model. """ - is_stable_diffusion = isinstance( - pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline) - ) is_stable_diffusion_xl = isinstance( pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline) ) - is_latent_consistency_model = isinstance( - pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline) - ) - if is_stable_diffusion_xl: projection_dim = pipeline.text_encoder_2.config.projection_dim - elif is_stable_diffusion: - projection_dim = pipeline.text_encoder.config.projection_dim - elif is_latent_consistency_model: - projection_dim = pipeline.text_encoder.config.projection_dim else: - raise ValueError( - f"The export of a DiffusionPipeline model with the class name {pipeline.__class__.__name__} is currently not supported in Optimum. " - "Please open an issue or submit a PR to add the support." - ) + projection_dim = pipeline.text_encoder.config.projection_dim models_for_export = {} @@ -139,7 +120,8 @@ def _get_submodels_for_export_diffusion( vae_encoder = copy.deepcopy(pipeline.vae) if not is_torch_greater_or_equal_than_2_1: vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder) - vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()} + # we return the distribution parameters to be able to recreate it in the decoder + vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters} models_for_export["vae_encoder"] = vae_encoder # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 diff --git a/optimum/onnx/utils.py b/optimum/onnx/utils.py index b52c4f4cda..c014c1b342 100644 --- a/optimum/onnx/utils.py +++ b/optimum/onnx/utils.py @@ -71,6 +71,22 @@ def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tu return src_paths, dst_paths +def _get_model_external_data_paths(model_path: Path) -> List[Path]: + """ + Gets external data paths from the model. + """ + + onnx_model = onnx.load(str(model_path), load_external_data=False) + model_tensors = _get_initializer_tensors(onnx_model) + # filter out tensors that are not external data + model_tensors_ext = [ + ExternalDataInfo(tensor).location + for tensor in model_tensors + if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL + ] + return [model_path.parent / tensor_name for tensor_name in model_tensors_ext] + + def check_model_uses_external_data(model: onnx.ModelProto) -> bool: """ Checks if the model uses external data. diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 1cb5b7c47b..4e25a43690 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -79,7 +79,9 @@ "ORTStableDiffusionInpaintPipeline", "ORTStableDiffusionXLPipeline", "ORTStableDiffusionXLImg2ImgPipeline", + "ORTStableDiffusionXLInpaintPipeline", "ORTLatentConsistencyModelPipeline", + "ORTLatentConsistencyModelImg2ImgPipeline", "ORTPipelineForImage2Image", "ORTPipelineForInpainting", "ORTPipelineForText2Image", @@ -92,6 +94,8 @@ "ORTStableDiffusionInpaintPipeline", "ORTStableDiffusionXLPipeline", "ORTStableDiffusionXLImg2ImgPipeline", + "ORTStableDiffusionXLInpaintPipeline", + "ORTLatentConsistencyModelImg2ImgPipeline", "ORTLatentConsistencyModelPipeline", "ORTPipelineForImage2Image", "ORTPipelineForInpainting", @@ -148,6 +152,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_diffusers_objects import ( ORTDiffusionPipeline, + ORTLatentConsistencyModelImg2ImgPipeline, ORTLatentConsistencyModelPipeline, ORTPipelineForImage2Image, ORTPipelineForInpainting, @@ -156,11 +161,13 @@ ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLPipeline, ) else: from .modeling_diffusion import ( ORTDiffusionPipeline, + ORTLatentConsistencyModelImg2ImgPipeline, ORTLatentConsistencyModelPipeline, ORTPipelineForImage2Image, ORTPipelineForInpainting, @@ -169,6 +176,7 @@ ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLPipeline, ) else: diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 0e54bafed7..845780cafa 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -71,6 +71,25 @@ def dtype(self): return None + def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None): + for arg in args: + if isinstance(arg, torch.device): + device = arg + elif isinstance(arg, torch.dtype): + dtype = arg + + if device is not None and device != self.device: + raise ValueError( + "Cannot change the device of a model part without changing the device of the parent model. " + "Please use the `to` method of the parent model to change the device." + ) + + if dtype is not None and dtype != self.dtype: + raise NotImplementedError( + f"Cannot change the dtype of the model from {self.dtype} to {dtype}. " + f"Please export the model with the desired dtype." + ) + @abstractmethod def forward(self, *args, **kwargs): pass diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 18cd38c5f2..87fcb68c7e 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -13,10 +13,11 @@ # limitations under the License. import importlib +import inspect import logging import os import shutil -import warnings +from abc import abstractmethod from collections import OrderedDict from pathlib import Path from tempfile import TemporaryDirectory @@ -24,23 +25,25 @@ import numpy as np import torch -from diffusers import ( +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.pipelines import ( AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, - ConfigMixin, - DDIMScheduler, + LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, - LMSDiscreteScheduler, - PNDMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import SchedulerMixin from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available +from diffusers.utils.constants import CONFIG_NAME from huggingface_hub import snapshot_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import validate_hf_hub_args @@ -51,14 +54,7 @@ import onnxruntime as ort from ..exporters.onnx import main_export -from ..onnx.utils import _get_external_data_paths -from ..pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin -from ..pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin -from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin -from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin -from ..pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin -from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin -from ..pipelines.diffusers.pipeline_utils import VaeImageProcessor +from ..onnx.utils import _get_model_external_data_paths from ..utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, @@ -66,12 +62,12 @@ DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) -from .base import ORTModelPart from .io_binding import TypeHelper from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( ONNX_WEIGHTS_NAME, get_provider_for_device, + np_to_pt_generators, parse_device, validate_provider_availability, ) @@ -80,380 +76,287 @@ logger = logging.getLogger(__name__) -class ORTPipeline(ORTModel): - auto_model_class = None - model_type = "onnx_pipeline" - +# TODO: support from_pipe() +# TODO: Instead of ORTModel, it makes sense to have a compositional ORTMixin +# TODO: instead of one bloated __init__, we should consider an __init__ per pipeline +class ORTDiffusionPipeline(ORTModel, DiffusionPipeline): config_name = "model_index.json" - sub_component_config_name = "config.json" + auto_model_class = DiffusionPipeline def __init__( self, - vae_decoder_session: ort.InferenceSession, + scheduler: "SchedulerMixin", unet_session: ort.InferenceSession, - tokenizer: CLIPTokenizer, - config: Dict[str, Any], - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - feature_extractor: Optional[CLIPFeatureExtractor] = None, + vae_decoder_session: ort.InferenceSession, + # optional pipeline models vae_encoder_session: Optional[ort.InferenceSession] = None, text_encoder_session: Optional[ort.InferenceSession] = None, text_encoder_2_session: Optional[ort.InferenceSession] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, + # optional pipeline submodels + tokenizer: Optional["CLIPTokenizer"] = None, + tokenizer_2: Optional["CLIPTokenizer"] = None, + feature_extractor: Optional["CLIPFeatureExtractor"] = None, + # stable diffusion xl specific arguments + force_zeros_for_empty_prompt: bool = True, + requires_aesthetics_score: bool = False, + add_watermarker: Optional[bool] = None, + # onnxruntime specific arguments use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + **kwargs, ): - """ - Args: - vae_decoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the VAE decoder - unet_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the U-NET. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) - for the text encoder. - config (`Dict[str, Any]`): - A config dictionary from which the model components will be instantiated. Make sure to only load - configuration files of compatible classes. - scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): - A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. - feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): - A model extracting features from generated images to be used as inputs for the `safety_checker` - vae_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): - The ONNX Runtime inference session associated to the VAE encoder. - text_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): - The ONNX Runtime inference session associated to the text encoder. - tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) - for the second text encoder. - use_io_binding (`Optional[bool]`, defaults to `None`): - Whether to use IOBinding during inference to avoid memory copy between the host and devices. Defaults to - `True` if the device is CUDA, otherwise defaults to `False`. - model_save_dir (`Optional[str]`, defaults to `None`): - The directory under which the model exported to ONNX was saved. - """ - self.shared_attributes_init( - model=vae_decoder_session, - use_io_binding=use_io_binding, - model_save_dir=model_save_dir, - ) - self._internal_dict = config - self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) - self.vae_decoder_model_path = Path(vae_decoder_session._model_path) self.unet = ORTModelUnet(unet_session, self) - self.unet_model_path = Path(unet_session._model_path) - - if text_encoder_session is not None: - self.text_encoder_model_path = Path(text_encoder_session._model_path) - self.text_encoder = ORTModelTextEncoder(text_encoder_session, self) - else: - self.text_encoder_model_path = None - self.text_encoder = None - - if vae_encoder_session is not None: - self.vae_encoder_model_path = Path(vae_encoder_session._model_path) - self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) - else: - self.vae_encoder_model_path = None - self.vae_encoder = None + self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) + self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) if vae_encoder_session is not None else None + self.text_encoder = ( + ORTModelTextEncoder(text_encoder_session, self) if text_encoder_session is not None else None + ) + self.text_encoder_2 = ( + ORTModelTextEncoder(text_encoder_2_session, self) if text_encoder_2_session is not None else None + ) + # We wrap the VAE Decoder & Encoder in a single object to simulate diffusers API + self.vae = ORTWrapperVae(self.vae_encoder, self.vae_decoder) - if text_encoder_2_session is not None: - self.text_encoder_2_model_path = Path(text_encoder_2_session._model_path) - self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, self) - else: - self.text_encoder_2_model_path = None - self.text_encoder_2 = None + # we allow passing these as torch models for now + self.image_encoder = kwargs.pop("image_encoder", None) # TODO: maybe implement ORTModelImageEncoder + self.safety_checker = kwargs.pop("safety_checker", None) # TODO: maybe implement ORTModelSafetyChecker + self.scheduler = scheduler self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 - self.scheduler = scheduler self.feature_extractor = feature_extractor - self.safety_checker = None - - sub_models = { - DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER: self.text_encoder, - DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet, - DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder, - DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder, - DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2, - } - - # Modify config to keep the resulting model compatible with diffusers pipelines - for name in sub_models.keys(): - self._internal_dict[name] = ( - ("diffusers", "OnnxRuntimeModel") if sub_models[name] is not None else (None, None) - ) - self._internal_dict.pop("vae", None) - - if "block_out_channels" in self.vae_decoder.config: - self.vae_scale_factor = 2 ** (len(self.vae_decoder.config["block_out_channels"]) - 1) - else: - self.vae_scale_factor = 8 - - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - @staticmethod - def load_model( - vae_decoder_path: Union[str, Path], - text_encoder_path: Union[str, Path], - unet_path: Union[str, Path], - vae_encoder_path: Optional[Union[str, Path]] = None, - text_encoder_2_path: Optional[Union[str, Path]] = None, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict] = None, - ): - """ - Creates three inference sessions for respectively the VAE decoder, the text encoder and the U-NET models. - The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. - Args: - vae_decoder_path (`Union[str, Path]`): - The path to the VAE decoder ONNX model. - text_encoder_path (`Union[str, Path]`): - The path to the text encoder ONNX model. - unet_path (`Union[str, Path]`): - The path to the U-NET ONNX model. - vae_encoder_path (`Union[str, Path]`, defaults to `None`): - The path to the VAE encoder ONNX model. - text_encoder_2_path (`Union[str, Path]`, defaults to `None`): - The path to the second text decoder ONNX model. - provider (`str`, defaults to `"CPUExecutionProvider"`): - ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ - for possible providers. - session_options (`Optional[ort.SessionOptions]`, defaults to `None`): - ONNX Runtime session options to use for loading the model. Defaults to `None`. - provider_options (`Optional[Dict]`, defaults to `None`): - Provider option dictionary corresponding to the provider used. See available options - for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. - """ - vae_decoder = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options) - unet = ORTModel.load_model(unet_path, provider, session_options, provider_options) - - sessions = { - "vae_encoder": vae_encoder_path, - "text_encoder": text_encoder_path, - "text_encoder_2": text_encoder_2_path, + all_pipeline_init_args = { + "vae": self.vae, + "unet": self.unet, + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "safety_checker": self.safety_checker, + "image_encoder": self.image_encoder, + "scheduler": self.scheduler, + "tokenizer": self.tokenizer, + "tokenizer_2": self.tokenizer_2, + "feature_extractor": self.feature_extractor, + "requires_aesthetics_score": requires_aesthetics_score, + "force_zeros_for_empty_prompt": force_zeros_for_empty_prompt, + "add_watermarker": add_watermarker, } - for key, value in sessions.items(): - if value is not None and value.is_file(): - sessions[key] = ORTModel.load_model(value, provider, session_options, provider_options) - else: - sessions[key] = None + diffusers_pipeline_args = {} + for key in inspect.signature(self.auto_model_class).parameters.keys(): + if key in all_pipeline_init_args: + diffusers_pipeline_args[key] = all_pipeline_init_args[key] + # inits diffusers pipeline specific attributes (registers modules and config) + self.auto_model_class.__init__(self, **diffusers_pipeline_args) - return vae_decoder, sessions["text_encoder"], unet, sessions["vae_encoder"], sessions["text_encoder_2"] + # inits ort specific attributes + self.shared_attributes_init( + model=unet_session, use_io_binding=use_io_binding, model_save_dir=model_save_dir, **kwargs + ) def _save_pretrained(self, save_directory: Union[str, Path]): save_directory = Path(save_directory) - src_to_dst_path = { - self.vae_decoder_model_path: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME, - self.text_encoder_model_path: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, - self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME, - } - sub_models_to_save = { - self.vae_encoder_model_path: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, - self.text_encoder_2_model_path: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + models_to_save_paths = { + (self.unet, save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER), + (self.vae_decoder, save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER), + (self.vae_encoder, save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER), + (self.text_encoder, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER), + (self.text_encoder_2, save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER), } - for path, subfolder in sub_models_to_save.items(): - if path is not None: - src_to_dst_path[path] = save_directory / subfolder / ONNX_WEIGHTS_NAME - - # TODO: Modify _get_external_data_paths to give dictionnary - src_paths = list(src_to_dst_path.keys()) - dst_paths = list(src_to_dst_path.values()) - # Add external data paths in case of large models - src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths) - - for src_path, dst_path in zip(src_paths, dst_paths): - dst_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copyfile(src_path, dst_path) - config_path = src_path.parent / self.sub_component_config_name - if config_path.is_file(): - shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name) + for model, save_path in models_to_save_paths: + if model is not None: + model_path = Path(model.session._model_path) + save_path.mkdir(parents=True, exist_ok=True) + # copy onnx model + shutil.copyfile(model_path, save_path / ONNX_WEIGHTS_NAME) + # copy external onnx data + external_data_paths = _get_model_external_data_paths(model_path) + for external_data_path in external_data_paths: + shutil.copyfile(external_data_path, save_path / external_data_path.name) + # copy model config + config_path = model_path.parent / CONFIG_NAME + if config_path.is_file(): + config_save_path = save_path / CONFIG_NAME + shutil.copyfile(config_path, config_save_path) self.scheduler.save_pretrained(save_directory / "scheduler") - if self.feature_extractor is not None: - self.feature_extractor.save_pretrained(save_directory / "feature_extractor") if self.tokenizer is not None: self.tokenizer.save_pretrained(save_directory / "tokenizer") if self.tokenizer_2 is not None: self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") + if self.feature_extractor is not None: + self.feature_extractor.save_pretrained(save_directory / "feature_extractor") @classmethod def _from_pretrained( cls, model_id: Union[str, Path], config: Dict[str, Any], - use_auth_token: Optional[Union[bool, str]] = None, - token: Optional[Union[bool, str]] = None, + subfolder: str = "", + force_download: bool = False, + local_files_only: bool = False, revision: Optional[str] = None, + trust_remote_code: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, - vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, - text_encoder_file_name: str = ONNX_WEIGHTS_NAME, + token: Optional[Union[bool, str]] = None, unet_file_name: str = ONNX_WEIGHTS_NAME, + vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, vae_encoder_file_name: str = ONNX_WEIGHTS_NAME, + text_encoder_file_name: str = ONNX_WEIGHTS_NAME, text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME, - local_files_only: bool = False, + use_io_binding: Optional[bool] = None, provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: Optional[bool] = None, + session_options: Optional[ort.SessionOptions] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, + if use_io_binding: + raise ValueError( + "IOBinding is not yet available for diffusion pipelines, please set `use_io_binding` to False." ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - - if provider == "TensorrtExecutionProvider": - raise ValueError("The provider `'TensorrtExecutionProvider'` is not supported") - model_id = str(model_id) - patterns = set(config.keys()) - sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) - - if not os.path.isdir(model_id): - patterns.update({"vae_encoder", "vae_decoder"}) - allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")} + if not os.path.isdir(str(model_id)): + all_components = {key for key in config.keys() if not key.startswith("_")} | {"vae_encoder", "vae_decoder"} + allow_patterns = {os.path.join(component, "*") for component in all_components} allow_patterns.update( { - vae_decoder_file_name, - text_encoder_file_name, unet_file_name, + vae_decoder_file_name, vae_encoder_file_name, + text_encoder_file_name, text_encoder_2_file_name, SCHEDULER_CONFIG_NAME, - CONFIG_NAME, cls.config_name, + CONFIG_NAME, } ) - # Downloads all repo's files matching the allowed patterns - model_id = snapshot_download( + model_save_folder = snapshot_download( model_id, cache_dir=cache_dir, + force_download=force_download, local_files_only=local_files_only, - token=token, revision=revision, + token=token, allow_patterns=allow_patterns, ignore_patterns=["*.msgpack", "*.safetensors", "*.bin", "*.xml"], ) - new_model_save_dir = Path(model_id) + else: + model_save_folder = str(model_id) + + model_save_path = Path(model_save_folder) + + if model_save_dir is None: + model_save_dir = model_save_path + + model_paths = { + "unet": model_save_path / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, + "vae_decoder": model_save_path / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, + "vae_encoder": model_save_path / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, + "text_encoder": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, + "text_encoder_2": model_save_path / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name, + } + + sessions = {} + for model, path in model_paths.items(): + if kwargs.get(model, None) is not None: + # this allows passing a model directly to from_pretrained + sessions[f"{model}_session"] = kwargs.pop(model) + else: + sessions[f"{model}_session"] = ( + ORTModel.load_model(path, provider, session_options, provider_options) if path.is_file() else None + ) - sub_models = {} - for name in sub_models_to_load: - library_name, library_classes = config[name] - if library_classes is not None: + submodels = {} + for submodel in {"scheduler", "tokenizer", "tokenizer_2", "feature_extractor"}: + if kwargs.get(submodel, None) is not None: + submodels[submodel] = kwargs.pop(submodel) + elif config.get(submodel, (None, None))[0] is not None: + library_name, library_classes = config.get(submodel) library = importlib.import_module(library_name) class_obj = getattr(library, library_classes) load_method = getattr(class_obj, "from_pretrained") # Check if the module is in a subdirectory - if (new_model_save_dir / name).is_dir(): - sub_models[name] = load_method(new_model_save_dir / name) + if (model_save_path / submodel).is_dir(): + submodels[submodel] = load_method(model_save_path / submodel) else: - sub_models[name] = load_method(new_model_save_dir) - - vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model( - vae_decoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, - text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, - unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, - vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, - text_encoder_2_path=( - new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name - ), - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - - if model_save_dir is None: - model_save_dir = new_model_save_dir + submodels[submodel] = load_method(model_save_path) - if use_io_binding: - raise ValueError( - "IOBinding is not yet available for stable diffusion model, please set `use_io_binding` to False." - ) + # same as DiffusionPipeline.from_pretraoned, if called directly, it loads the class in the config + if cls.__name__ == "ORTDiffusionPipeline": + class_name = config["_class_name"] + ort_pipeline_class = _get_ort_class(class_name) + else: + ort_pipeline_class = cls - return cls( - vae_decoder_session=vae_decoder, - text_encoder_session=text_encoder, - unet_session=unet, - config=config, - tokenizer=sub_models.get("tokenizer", None), - scheduler=sub_models.get("scheduler"), - feature_extractor=sub_models.get("feature_extractor", None), - tokenizer_2=sub_models.get("tokenizer_2", None), - vae_encoder_session=vae_encoder, - text_encoder_2_session=text_encoder_2, + ort_pipeline = ort_pipeline_class( + **sessions, + **submodels, use_io_binding=use_io_binding, model_save_dir=model_save_dir, + **kwargs, ) + # same as in DiffusionPipeline.from_pretrained, we save where the model was instantiated from + ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", str(model_id))) + + return ort_pipeline + @classmethod - def _from_transformers( + def _export( cls, model_id: str, - config: Optional[str] = None, - use_auth_token: Optional[Union[bool, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: str = "main", - force_download: bool = True, - cache_dir: str = HUGGINGFACE_HUB_CACHE, + config: Dict[str, Any], subfolder: str = "", + force_download: bool = False, local_files_only: bool = False, + revision: Optional[str] = None, trust_remote_code: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + token: Optional[Union[bool, str]] = None, + use_io_binding: Optional[bool] = None, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: Optional[bool] = None, task: Optional[str] = None, - ) -> "ORTPipeline": - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - + **kwargs, + ) -> "ORTDiffusionPipeline": if task is None: task = cls._auto_model_to_task(cls.auto_model_class) - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) + # we continue passing the model_save_dir from here on to avoid it being cleaned up + # might be better to use a persistent temporary directory such as the one implemented in + # https://gist.github.com/twolfson/2929dc1163b0a76d2c2b66d51f9bc808 + model_save_dir = TemporaryDirectory() + model_save_path = Path(model_save_dir.name) main_export( - model_name_or_path=model_id, - output=save_dir_path, - task=task, + model_id, + output=model_save_path, do_validation=False, no_post_process=True, - subfolder=subfolder, + token=token, revision=revision, cache_dir=cache_dir, - token=token, - local_files_only=local_files_only, + subfolder=subfolder, force_download=force_download, + local_files_only=local_files_only, trust_remote_code=trust_remote_code, + library_name="diffusers", + task=task, ) return cls._from_pretrained( - save_dir_path, + model_save_path, config=config, provider=provider, - session_options=session_options, provider_options=provider_options, + session_options=session_options, use_io_binding=use_io_binding, - model_save_dir=save_dir, + model_save_dir=model_save_dir, + **kwargs, ) def to(self, device: Union[torch.device, str, int]): @@ -471,19 +374,22 @@ def to(self, device: Union[torch.device, str, int]): device, provider_options = parse_device(device) provider = get_provider_for_device(device) - validate_provider_availability(provider) # raise error if the provider is not available + validate_provider_availability(provider) if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": return self - self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) - self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) self.unet.session.set_providers([provider], provider_options=[provider_options]) + self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) if self.vae_encoder is not None: self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) + if self.text_encoder is not None: + self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) + if self.text_encoder_2 is not None: + self.text_encoder_2.session.set_providers([provider], provider_options=[provider_options]) - self.providers = self.vae_decoder.session.get_providers() + self.providers = self.unet.session.get_providers() self._device = device return self @@ -495,41 +401,142 @@ def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs): def _save_config(self, save_directory): self.save_config(save_directory) + @property + def components(self) -> Dict[str, Any]: + components = { + "vae": self.vae, + "unet": self.unet, + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "safety_checker": self.safety_checker, + "image_encoder": self.image_encoder, + } + components = {k: v for k, v in components.items() if v is not None} + return components -class ORTPipelinePart(ORTModelPart): - CONFIG_NAME = "config.json" + def __call__(self, *args, **kwargs): + # we do this to keep numpy random states support for now + # TODO: deprecate and add warnings when a random state is passed - def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline): - config_path = Path(session._model_path).parent / self.CONFIG_NAME + args = list(args) + for i in range(len(args)): + args[i] = np_to_pt_generators(args[i], self.device) - if config_path.is_file(): - # TODO: use FrozenDict - self.config = parent_model._dict_from_json_file(config_path) - else: - self.config = {} + for k, v in kwargs.items(): + kwargs[k] = np_to_pt_generators(v, self.device) + + return self.auto_model_class.__call__(self, *args, **kwargs) - super().__init__(session, parent_model) + +class ORTPipelinePart(ConfigMixin): + config_name: str = CONFIG_NAME + + def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTDiffusionPipeline): + self.session = session + self.parent_pipeline = parent_pipeline + + self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} + self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()} + self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()} + + config_file_path = Path(session._model_path).parent / self.config_name + if not config_file_path.is_file(): + # config is mandatory for the model part to be used for inference + raise ValueError(f"Configuration file for {self.__class__.__name__} not found at {config_file_path}") + config_dict = self._dict_from_json_file(config_file_path) + self.register_to_config(**config_dict) @property - def input_dtype(self): - # for backward compatibility and diffusion mixins (will be standardized in the future) - return {name: TypeHelper.ort_type_to_numpy_type(ort_type) for name, ort_type in self.input_dtypes.items()} + def device(self): + return self.parent_pipeline.device + @property + def dtype(self): + for dtype in self.input_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + for dtype in self.output_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + return None + + def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None): + for arg in args: + if isinstance(arg, torch.device): + device = arg + elif isinstance(arg, (int, str)): + device = torch.device(arg) + elif isinstance(arg, torch.dtype): + dtype = arg + + if device is not None and device != self.device: + raise ValueError( + "Cannot change the device of a pipeline part without changing the device of the parent pipeline. " + "Please use the `to` method of the parent pipeline to change the device." + ) -class ORTModelTextEncoder(ORTPipelinePart): - def forward(self, input_ids: Union[np.ndarray, torch.Tensor]): - use_torch = isinstance(input_ids, torch.Tensor) + if dtype is not None and dtype != self.dtype: + raise NotImplementedError( + f"Cannot change the dtype of the pipeline from {self.dtype} to {dtype}. " + f"Please export the pipeline with the desired dtype." + ) - model_inputs = {"input_ids": input_ids} + def prepare_onnx_inputs(self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]) -> Dict[str, np.ndarray]: + onnx_inputs = {} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) - onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + # converts pytorch inputs into numpy inputs for onnx + for input_name in self.input_names.keys(): + onnx_inputs[input_name] = inputs.pop(input_name) - return ModelOutput(**model_outputs) + if use_torch: + onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True) + + if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]: + onnx_inputs[input_name] = onnx_inputs[input_name].astype( + TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) + ) + + return onnx_inputs + + def prepare_onnx_outputs( + self, use_torch: bool, *onnx_outputs: np.ndarray + ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: + model_outputs = {} + + # converts onnxruntime outputs into tensor for standard outputs + for output_name, idx in self.output_names.items(): + model_outputs[output_name] = onnx_outputs[idx] + + if use_torch: + model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device) + + return model_outputs + + @abstractmethod + def forward(self, *args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) class ORTModelUnet(ORTPipelinePart): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # can be missing from models exported long ago + if not hasattr(self.config, "time_cond_proj_dim"): + logger.warning( + "The `time_cond_proj_dim` attribute is missing from the UNet configuration. " + "Please re-export the model with newer version of optimum and diffusers." + ) + self.register_to_config(time_cond_proj_dim=None) + def forward( self, sample: Union[np.ndarray, torch.Tensor], @@ -538,9 +545,15 @@ def forward( text_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None, time_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, ): use_torch = isinstance(sample, torch.Tensor) + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + model_inputs = { "sample": sample, "timestep": timestep, @@ -548,171 +561,323 @@ def forward( "text_embeds": text_embeds, "time_ids": time_ids, "timestep_cond": timestep_cond, + **(cross_attention_kwargs or {}), + **(added_cond_kwargs or {}), } - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) + + if return_dict: + return model_outputs return ModelOutput(**model_outputs) -class ORTModelVaeDecoder(ORTPipelinePart): - def forward(self, latent_sample: Union[np.ndarray, torch.Tensor]): - use_torch = isinstance(latent_sample, torch.Tensor) +class ORTModelTextEncoder(ORTPipelinePart): + def forward( + self, + input_ids: Union[np.ndarray, torch.Tensor], + attention_mask: Optional[Union[np.ndarray, torch.Tensor]] = None, + output_hidden_states: Optional[bool] = None, + return_dict: bool = False, + ): + use_torch = isinstance(input_ids, torch.Tensor) - model_inputs = {"latent_sample": latent_sample} + model_inputs = {"input_ids": input_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) + + if output_hidden_states: + model_outputs["hidden_states"] = [] + for i in range(self.config.num_hidden_layers): + model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}")) + model_outputs["hidden_states"].append(model_outputs.get("last_hidden_state")) + else: + for i in range(self.config.num_hidden_layers): + model_outputs.pop(f"hidden_states.{i}", None) + + if return_dict: + return model_outputs return ModelOutput(**model_outputs) class ORTModelVaeEncoder(ORTPipelinePart): - def forward(self, sample: Union[np.ndarray, torch.Tensor]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # can be missing from models exported long ago + if not hasattr(self.config, "scaling_factor"): + logger.warning( + "The `scaling_factor` attribute is missing from the VAE encoder configuration. " + "Please re-export the model with newer version of optimum and diffusers." + ) + self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1)) + + def forward( + self, + sample: Union[np.ndarray, torch.Tensor], + generator: Optional[torch.Generator] = None, + return_dict: bool = False, + ): use_torch = isinstance(sample, torch.Tensor) model_inputs = {"sample": sample} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) + + if "latent_sample" in model_outputs: + model_outputs["latents"] = model_outputs.pop("latent_sample") + + if "latent_parameters" in model_outputs: + model_outputs["latent_dist"] = DiagonalGaussianDistribution( + parameters=model_outputs.pop("latent_parameters") + ) + + if return_dict: + return model_outputs return ModelOutput(**model_outputs) +class ORTModelVaeDecoder(ORTPipelinePart): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # can be missing from models exported long ago + if not hasattr(self.config, "scaling_factor"): + logger.warning( + "The `scaling_factor` attribute is missing from the VAE decoder configuration. " + "Please re-export the model with newer version of optimum and diffusers." + ) + self.register_to_config(scaling_factor=2 ** (len(self.config.block_out_channels) - 1)) + + def forward( + self, + latent_sample: Union[np.ndarray, torch.Tensor], + generator: Optional[torch.Generator] = None, + return_dict: bool = False, + ): + use_torch = isinstance(latent_sample, torch.Tensor) + + model_inputs = {"latent_sample": latent_sample} + + onnx_inputs = self.prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs) + + if "latent_sample" in model_outputs: + model_outputs["latents"] = model_outputs.pop("latent_sample") + + if return_dict: + return model_outputs + + return ModelOutput(**model_outputs) + + +class ORTWrapperVae(ORTPipelinePart): + def __init__(self, encoder: ORTModelVaeEncoder, decoder: ORTModelVaeDecoder): + self.decoder = decoder + self.encoder = encoder + + @property + def config(self): + return self.decoder.config + + @property + def dtype(self): + return self.decoder.dtype + + @property + def device(self): + return self.decoder.device + + def decode(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + def encode(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + def to(self, *args, **kwargs): + self.decoder.to(*args, **kwargs) + if self.encoder is not None: + self.encoder.to(*args, **kwargs) + + @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionPipeline(ORTPipeline, StableDiffusionPipelineMixin): +class ORTStableDiffusionPipeline(ORTDiffusionPipeline, StableDiffusionPipeline): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline). """ main_input_name = "prompt" + export_feature = "text-to-image" auto_model_class = StableDiffusionPipeline - __call__ = StableDiffusionPipelineMixin.__call__ - @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionImg2ImgPipeline(ORTPipeline, StableDiffusionImg2ImgPipelineMixin): +class ORTStableDiffusionImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionImg2ImgPipeline): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline). """ - main_input_name = "prompt" + main_input_name = "image" + export_feature = "image-to-image" auto_model_class = StableDiffusionImg2ImgPipeline - __call__ = StableDiffusionImg2ImgPipelineMixin.__call__ - @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionInpaintPipeline(ORTPipeline, StableDiffusionInpaintPipelineMixin): +class ORTStableDiffusionInpaintPipeline(ORTDiffusionPipeline, StableDiffusionInpaintPipeline): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline). """ main_input_name = "prompt" + export_feature = "inpainting" auto_model_class = StableDiffusionInpaintPipeline - __call__ = StableDiffusionInpaintPipelineMixin.__call__ - @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTLatentConsistencyModelPipeline(ORTPipeline, LatentConsistencyPipelineMixin): +class ORTStableDiffusionXLPipeline(ORTDiffusionPipeline, StableDiffusionXLPipeline): """ - ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline). + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline). """ main_input_name = "prompt" - auto_model_class = LatentConsistencyModelPipeline - - __call__ = LatentConsistencyPipelineMixin.__call__ - + export_feature = "text-to-image" + auto_model_class = StableDiffusionXLPipeline -class ORTStableDiffusionXLPipelineBase(ORTPipeline): - def __init__( + def _get_add_time_ids( self, - vae_decoder_session: ort.InferenceSession, - text_encoder_session: ort.InferenceSession, - unet_session: ort.InferenceSession, - config: Dict[str, Any], - tokenizer: CLIPTokenizer, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - feature_extractor: Optional[CLIPFeatureExtractor] = None, - vae_encoder_session: Optional[ort.InferenceSession] = None, - text_encoder_2_session: Optional[ort.InferenceSession] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - use_io_binding: Optional[bool] = None, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - add_watermarker: Optional[bool] = None, + original_size, + crops_coords_top_left, + target_size, + dtype, + text_encoder_projection_dim=None, ): - super().__init__( - vae_decoder_session=vae_decoder_session, - text_encoder_session=text_encoder_session, - unet_session=unet_session, - config=config, - tokenizer=tokenizer, - scheduler=scheduler, - feature_extractor=feature_extractor, - vae_encoder_session=vae_encoder_session, - text_encoder_2_session=text_encoder_2_session, - tokenizer_2=tokenizer_2, - use_io_binding=use_io_binding, - model_save_dir=model_save_dir, - ) + add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids - if add_watermarker: - if not is_invisible_watermark_available(): - raise ImportError( - "`add_watermarker` requires invisible-watermark to be installed, which can be installed with `pip install invisible-watermark`." - ) - from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTStableDiffusionXLImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionXLImg2ImgPipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline). + """ - self.watermark = StableDiffusionXLWatermarker() + main_input_name = "prompt" + export_feature = "image-to-image" + auto_model_class = StableDiffusionXLImg2ImgPipeline + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) else: - self.watermark = None + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): +class ORTStableDiffusionXLInpaintPipeline(ORTDiffusionPipeline, StableDiffusionXLInpaintPipeline): """ - ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline). + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline). """ - main_input_name = "prompt" - auto_model_class = StableDiffusionXLPipeline + main_input_name = "image" + export_feature = "inpainting" + auto_model_class = StableDiffusionXLInpaintPipeline + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - __call__ = StableDiffusionXLPipelineMixin.__call__ + return add_time_ids, add_neg_time_ids @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin): +class ORTLatentConsistencyModelPipeline(ORTDiffusionPipeline, LatentConsistencyModelPipeline): """ - ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline). + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline). """ main_input_name = "prompt" - auto_model_class = StableDiffusionXLImg2ImgPipeline + export_feature = "text-to-image" + auto_model_class = LatentConsistencyModelPipeline + + +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsistencyModelImg2ImgPipeline): + """ + ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency_img2img#diffusers.LatentConsistencyModelImg2ImgPipeline). + """ - __call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__ + main_input_name = "image" + export_feature = "image-to-image" + auto_model_class = LatentConsistencyModelImg2ImgPipeline SUPPORTED_ORT_PIPELINES = [ ORTStableDiffusionPipeline, ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, - ORTLatentConsistencyModelPipeline, ORTStableDiffusionXLPipeline, ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLInpaintPipeline, + ORTLatentConsistencyModelPipeline, + ORTLatentConsistencyModelImg2ImgPipeline, ] -def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True): +def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True): for ort_pipeline_class in SUPPORTED_ORT_PIPELINES: if ( ort_pipeline_class.__name__ == pipeline_class_name @@ -724,31 +889,6 @@ def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool raise ValueError(f"ORTDiffusionPipeline can't find a pipeline linked to {pipeline_class_name}") -class ORTDiffusionPipeline(ConfigMixin): - config_name = "model_index.json" - - @classmethod - @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_or_path, **kwargs): - load_config_kwargs = { - "force_download": kwargs.get("force_download", False), - "resume_download": kwargs.get("resume_download", None), - "local_files_only": kwargs.get("local_files_only", False), - "cache_dir": kwargs.get("cache_dir", None), - "revision": kwargs.get("revision", None), - "proxies": kwargs.get("proxies", None), - "token": kwargs.get("token", None), - } - - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - config = config[0] if isinstance(config, tuple) else config - class_name = config["_class_name"] - - ort_pipeline_class = _get_pipeline_class(class_name) - - return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) - - ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", ORTStableDiffusionPipeline), @@ -761,12 +901,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): [ ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), + ("latent-consistency", ORTLatentConsistencyModelImg2ImgPipeline), ] ) ORT_INPAINT_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", ORTStableDiffusionInpaintPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLInpaintPipeline), ] ) @@ -777,7 +919,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): ] -def _get_task_class(mapping, pipeline_class_name): +def _get_task_ort_class(mapping, pipeline_class_name): def _get_model_name(pipeline_class_name): for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS: for model_name, ort_pipeline_class in ort_pipelines_mapping.items(): @@ -801,7 +943,8 @@ class ORTPipelineForTask(ConfigMixin): config_name = "model_index.json" @classmethod - def from_pretrained(cls, pretrained_model_or_path, **kwargs): + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTDiffusionPipeline: load_config_kwargs = { "force_download": kwargs.get("force_download", False), "resume_download": kwargs.get("resume_download", None), @@ -815,7 +958,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): config = config[0] if isinstance(config, tuple) else config class_name = config["_class_name"] - ort_pipeline_class = _get_task_class(cls.ort_pipelines_mapping, class_name) + ort_pipeline_class = _get_task_ort_class(cls.ort_pipelines_mapping, class_name) return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 17bd3e2a4e..9b29afa566 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -938,7 +938,7 @@ def _prepare_onnx_inputs( onnx_inputs[input_name] = inputs.pop(input_name) if use_torch: - onnx_inputs[input_name] = onnx_inputs[input_name].cpu().detach().numpy() + onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True) if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]: onnx_inputs[input_name] = onnx_inputs[input_name].astype( diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index fda3ca82bb..27e0dc01b4 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -67,7 +67,7 @@ if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: - from transformers.generation_utils import GenerationMixin + from transformers.generation_utils import GenerationMixin # type: ignore if check_if_transformers_greater("4.43.0"): diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 985980e31b..128e2406f1 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -403,3 +403,18 @@ def evaluation_loop( metrics = {} return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=len(dataset)) + + +def np_to_pt_generators(np_object, device): + if isinstance(np_object, np.random.RandomState): + return torch.Generator(device=device).manual_seed(int(np_object.get_state()[1][0])) + elif isinstance(np_object, np.random.Generator): + return torch.Generator(device=device).manual_seed(int(np_object.bit_generator.state[1][0])) + elif isinstance(np_object, list) and isinstance(np_object[0], (np.random.RandomState, np.random.Generator)): + return [np_to_pt_generators(a, device) for a in np_object] + elif isinstance(np_object, dict) and isinstance( + next(iter(np_object.values())), (np.random.RandomState, np.random.Generator) + ): + return {k: np_to_pt_generators(v, device) for k, v in np_object.items()} + else: + return np_object diff --git a/optimum/pipelines/diffusers/pipeline_latent_consistency.py b/optimum/pipelines/diffusers/pipeline_latent_consistency.py deleted file mode 100644 index 630d463de7..0000000000 --- a/optimum/pipelines/diffusers/pipeline_latent_consistency.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Callable, List, Optional, Union - -import numpy as np -import torch -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput - -from .pipeline_stable_diffusion import StableDiffusionPipelineMixin - - -logger = logging.getLogger(__name__) - - -class LatentConsistencyPipelineMixin(StableDiffusionPipelineMixin): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 4, - original_inference_steps: int = None, - guidance_scale: float = 8.5, - num_images_per_prompt: int = 1, - generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to None): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`Optional[int]`, defaults to None): - The height in pixels of the generated image. - width (`Optional[int]`, defaults to None): - The width in pixels of the generated image. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`): - A np.random.RandomState to make generation deterministic. - latents (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (Optional[Callable], defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - guidance_rescale (`float`, defaults to 0.0): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - height = height or self.unet.config["sample_size"] * self.vae_scale_factor - width = width or self.unet.config["sample_size"] * self.vae_scale_factor - - # Don't need to get negative prompts due to LCM guided distillation - negative_prompt = None - negative_prompt_embeds = None - - # check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # define call parameters - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if generator is None: - generator = np.random.RandomState() - - prompt_embeds = self._encode_prompt( - prompt, - num_images_per_prompt, - False, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, original_inference_steps=original_inference_steps) - timesteps = self.scheduler.timesteps - - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - self.unet.config["in_channels"], - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - - bs = batch_size * num_images_per_prompt - # get Guidance Scale Embedding - w = np.full(bs, guidance_scale - 1, dtype=prompt_embeds.dtype) - w_embedding = self.get_guidance_scale_embedding( - w, embedding_dim=self.unet.config["time_cond_proj_dim"], dtype=prompt_embeds.dtype - ) - - # Adapted from diffusers to extend it for other runtimes than ORT - timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - for i, t in enumerate(self.progress_bar(timesteps)): - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet( - sample=latents, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - timestep_cond=w_embedding, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents, denoised = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), return_dict=False - ) - latents, denoised = latents.numpy(), denoised.numpy() - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if output_type == "latent": - image = denoised - has_nsfw_concept = None - else: - denoised /= self.vae_decoder.config["scaling_factor"] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=denoised[i : i + 1])[0] for i in range(denoised.shape[0])] - ) - image, has_nsfw_concept = self.run_safety_checker(image) - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264 - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=None): - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - timesteps (`torch.Tensor`): - generate embedding vectors at these timesteps - embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings - - Returns: - `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` - """ - w = w * 1000 - half_dim = embedding_dim // 2 - emb = np.log(10000.0) / (half_dim - 1) - emb = np.exp(np.arange(half_dim, dtype=dtype) * -emb) - emb = w[:, None] * emb[None, :] - emb = np.concatenate([np.sin(emb), np.cos(emb)], axis=1) - - if embedding_dim % 2 == 1: # zero pad - emb = np.pad(emb, [(0, 0), (0, 1)]) - - assert emb.shape == (w.shape[0], embedding_dim) - return emb diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py deleted file mode 100644 index 6cc47fab1b..0000000000 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py +++ /dev/null @@ -1,427 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -from typing import Callable, List, Optional, Union - -import numpy as np -import torch -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput - -from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg - - -logger = logging.getLogger(__name__) - - -class StableDiffusionPipelineMixin(DiffusionPipelineMixin): - # Copied from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L114 - def _encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, list]], - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`Union[str, List[str]]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - - if do_classifier_free_guidance: - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L217 - def check_inputs( - self, - prompt: Union[str, List[str]], - height: Optional[int], - width: Optional[int], - callback_steps: int, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - if isinstance(generator, np.random.RandomState): - latents = generator.randn(*shape).astype(dtype) - elif isinstance(generator, torch.Generator): - latents = torch.randn(*shape, generator=generator).numpy().astype(dtype) - else: - raise ValueError( - f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" - f" {type(generator)}." - ) - elif latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) - - return latents - - # Adapted from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L264 - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - guidance_rescale: float = 0.0, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to None): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`Optional[int]`, defaults to None): - The height in pixels of the generated image. - width (`Optional[int]`, defaults to None): - The width in pixels of the generated image. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):: - A np.random.RandomState to make generation deterministic. - latents (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (Optional[Callable], defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - guidance_rescale (`float`, defaults to 0.0): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor - width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor - - # check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # define call parameters - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if generator is None: - generator = np.random.RandomState() - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps - - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - self.unet.config.get("in_channels", 4), - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # Adapted from diffusers to extend it for other runtimes than ORT - timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if output_type == "latent": - image = latents - has_nsfw_concept = None - else: - latents /= self.vae_decoder.config.get("scaling_factor", 0.18215) - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - image, has_nsfw_concept = self.run_safety_checker(image) - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def run_safety_checker(self, image: np.ndarray): - if self.safety_checker is None: - has_nsfw_concept = None - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="np" - ).pixel_values.astype(image.dtype) - images, has_nsfw_concept = [], [] - for i in range(image.shape[0]): - image_i, has_nsfw_concept_i = self.safety_checker( - clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] - ) - images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i[0]) - image = np.concatenate(images) - - return image, has_nsfw_concept diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py deleted file mode 100644 index a66035a789..0000000000 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput - -from .pipeline_stable_diffusion import StableDiffusionPipelineMixin - - -class StableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin): - # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.check_inputs - def check_inputs( - self, - prompt: Union[str, List[str]], - strength: float, - callback_steps: int, - negative_prompt: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - else: - init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = np.concatenate([init_latents], axis=0) - - # add noise to latents using the timesteps - if isinstance(generator, np.random.RandomState): - noise = generator.randn(*init_latents.shape).astype(dtype) - elif isinstance(generator, torch.Generator): - noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) - else: - raise ValueError( - f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" - f" {type(generator)}." - ) - - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ).numpy() - - return init_latents - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.__call__ - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image: Union[np.ndarray, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to None): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`Union[np.ndarray, PIL.Image.Image]`): - `Image`, or tensor representing an image batch which will be upscaled. - strength (`float`, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`): - A np.random.RandomState to make generation deterministic. - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (Optional[Callable], defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - # check inputs. Raise error if not correct - self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) - - # define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if generator is None: - generator = np.random.RandomState() - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - image = self.image_processor.preprocess(image) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - latents_dtype = prompt_embeds.dtype - image = image.astype(latents_dtype) - scaling_factor = self.vae_decoder.config.get("scaling_factor", 0.18215) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps.numpy()[-init_timestep] - timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) - - # 5. Prepare latent variables - latents = self.prepare_latents(image, timesteps, batch_size, num_images_per_prompt, latents_dtype, generator) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].numpy() - - # Adapted from diffusers to extend it for other runtimes than ORT - timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ - 0 - ] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if output_type == "latent": - image = latents - has_nsfw_concept = None - else: - latents /= scaling_factor - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - image, has_nsfw_concept = self.run_safety_checker(image) - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py deleted file mode 100644 index cb3c7db96e..0000000000 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import PIL_INTERPOLATION - -from .pipeline_stable_diffusion import StableDiffusionPipelineMixin - - -def prepare_mask_and_masked_image(image, mask, latents_shape, vae_scale_factor): - image = np.array( - image.convert("RGB").resize((latents_shape[1] * vae_scale_factor, latents_shape[0] * vae_scale_factor)) - ) - image = image[None].transpose(0, 3, 1, 2) - image = image.astype(np.float32) / 127.5 - 1.0 - - image_mask = np.array( - mask.convert("L").resize((latents_shape[1] * vae_scale_factor, latents_shape[0] * vae_scale_factor)) - ) - masked_image = image * (image_mask < 127.5) - - mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - return mask, masked_image - - -class StableDiffusionInpaintPipelineMixin(StableDiffusionPipelineMixin): - # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt: Union[str, List[str]], - height: Optional[int], - width: Optional[int], - callback_steps: int, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - image: PIL.Image.Image, - mask_image: PIL.Image.Image, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Union[str, List[str]]`): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be upscaled. - mask_image (`PIL.Image.Image`): - `Image`, or tensor representing a masked image batch which will be upscaled. - height (`Optional[int]`, defaults to None): - The height in pixels of the generated image. - width (`Optional[int]`, defaults to None): - The width in pixels of the generated image. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: - A np.random.RandomState to make generation deterministic. - latents (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (Optional[Callable], defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor - width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor - - # check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if generator is None: - generator = np.random.RandomState() - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - num_channels_latents = self.vae_decoder.config.get("latent_channels", 4) - num_channels_unet = self.unet.config.get("in_channels", 9) - latents_shape = ( - batch_size * num_images_per_prompt, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - latents_dtype = prompt_embeds.dtype - - if latents is None: - if isinstance(generator, np.random.RandomState): - latents = generator.randn(*latents_shape).astype(latents_dtype) - elif isinstance(generator, torch.Generator): - latents = torch.randn(*latents_shape, generator=generator).numpy().astype(latents_dtype) - else: - raise ValueError( - f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" - f" {type(generator)}." - ) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - # prepare mask and masked_image - mask, masked_image = prepare_mask_and_masked_image( - image, mask_image, latents_shape[-2:], self.vae_scale_factor - ) - mask = mask.astype(latents.dtype) - masked_image = masked_image.astype(latents.dtype) - - masked_image_latents = self.vae_encoder(sample=masked_image)[0] - - scaling_factor = self.vae_decoder.config.get("scaling_factor", 0.18215) - masked_image_latents = scaling_factor * masked_image_latents - - # duplicate mask and masked_image_latents for each generation per prompt - mask = mask.repeat(batch_size * num_images_per_prompt, 0) - masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) - - mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask - masked_image_latents = ( - np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents - ) - - # check that sizes of mask, masked image and latents match - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: expects" - f" {num_channels_unet} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - elif num_channels_unet != 4: - raise ValueError( - f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {num_channels_unet}." - ) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # Adapted from diffusers to extend it for other runtimes than ORT - timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - # concat latents, mask, masked_image_latnets in the channel dimension - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - if num_channels_unet == 9: - latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ - 0 - ] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if output_type == "latent": - image = latents - has_nsfw_concept = None - else: - latents /= scaling_factor - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - image, has_nsfw_concept = self.run_safety_checker(image) - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py deleted file mode 100644 index 0407c16a77..0000000000 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput - -from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg - - -logger = logging.getLogger(__name__) - - -class StableDiffusionXLPipelineMixin(DiffusionPipelineMixin): - # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def _encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, list]], - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - pooled_prompt_embeds: Optional[np.ndarray] = None, - negative_pooled_prompt_embeds: Optional[np.ndarray] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`Union[str, List[str]]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - # get prompt text embeddings - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder( - input_ids=text_input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) - ) - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds[-2] - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = np.concatenate(prompt_embeds_list, axis=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config["force_zeros_for_empty_prompt"] - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = np.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = np.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - negative_prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = text_encoder( - input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) - ) - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds[-2] - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = np.concatenate(negative_prompt_embeds_list, axis=-1) - - pooled_prompt_embeds = np.repeat(pooled_prompt_embeds, num_images_per_prompt, axis=0) - negative_pooled_prompt_embeds = np.repeat(negative_pooled_prompt_embeds, num_images_per_prompt, axis=0) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs - def check_inputs( - self, - prompt: Union[str, List[str]], - height: Optional[int], - width: Optional[int], - callback_steps: int, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - pooled_prompt_embeds: Optional[np.ndarray] = None, - negative_pooled_prompt_embeds: Optional[np.ndarray] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - if isinstance(generator, np.random.RandomState): - latents = generator.randn(*shape).astype(dtype) - elif isinstance(generator, torch.Generator): - latents = torch.randn(*shape, generator=generator).numpy().astype(dtype) - else: - raise ValueError( - f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" - f" {type(generator)}." - ) - elif latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) - - return latents - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - extra_step_kwargs = {} - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_eta: - extra_step_kwargs["eta"] = eta - - return extra_step_kwargs - - # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - pooled_prompt_embeds: Optional[np.ndarray] = None, - negative_pooled_prompt_embeds: Optional[np.ndarray] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to None): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`Optional[int]`, defaults to None): - The height in pixels of the generated image. - width (`Optional[int]`, defaults to None): - The width in pixels of the generated image. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):: - A np.random.RandomState to make generation deterministic. - latents (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a - plain tuple. - callback (Optional[Callable], defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - guidance_rescale (`float`, defaults to 0.7): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - # 0. Default height and width to unet - height = height or self.unet.config["sample_size"] * self.vae_scale_factor - width = width or self.unet.config["sample_size"] * self.vae_scale_factor - - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) - - # 2. Define call parameters - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if generator is None: - generator = np.random.RandomState() - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self._encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - self.unet.config.get("in_channels", 4), - height, - width, - prompt_embeds.dtype, - generator, - latents, - ) - - # 6. Prepare extra step kwargs - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - add_time_ids = (original_size + crops_coords_top_left + target_size,) - add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype) - - if do_classifier_free_guidance: - prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0) - add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0) - add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) - add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0) - - # Adapted from diffusers to extend it for other runtimes than ORT - timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - - # 8. Denoising loop - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet( - sample=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - text_embeds=add_text_embeds, - time_ids=add_time_ids, - ) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if output_type == "latent": - image = latents - else: - latents /= self.vae_decoder.config.get("scaling_factor", 0.18215) - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - image = self.image_processor.postprocess(image, output_type=output_type) - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py deleted file mode 100644 index 19988599b6..0000000000 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py +++ /dev/null @@ -1,515 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput - -from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg - - -logger = logging.getLogger(__name__) - - -class StableDiffusionXLImg2ImgPipelineMixin(DiffusionPipelineMixin): - # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def _encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, list]], - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - pooled_prompt_embeds: Optional[np.ndarray] = None, - negative_pooled_prompt_embeds: Optional[np.ndarray] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`Union[str, List[str]]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - # get prompt text embeddings - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder( - input_ids=text_input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) - ) - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds[-2] - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = np.concatenate(prompt_embeds_list, axis=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config["force_zeros_for_empty_prompt"] - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = np.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = np.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - negative_prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - - negative_prompt_embeds = text_encoder( - input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) - ) - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds[-2] - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = np.concatenate(negative_prompt_embeds_list, axis=-1) - - pooled_prompt_embeds = np.repeat(pooled_prompt_embeds, num_images_per_prompt, axis=0) - negative_pooled_prompt_embeds = np.repeat(negative_pooled_prompt_embeds, num_images_per_prompt, axis=0) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs - def check_inputs( - self, - prompt: Union[str, List[str]], - strength: float, - callback_steps: int, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def get_timesteps(self, num_inference_steps, strength): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].numpy() - - return timesteps, num_inference_steps - t_start - - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - else: - init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = np.concatenate([init_latents], axis=0) - - # add noise to latents using the timesteps - if isinstance(generator, np.random.RandomState): - noise = generator.randn(*init_latents.shape).astype(dtype) - elif isinstance(generator, torch.Generator): - noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) - else: - raise ValueError( - f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" - f" {type(generator)}." - ) - - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ) - init_latents = init_latents.numpy() - - return init_latents - - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype - ): - if self.config.get("requires_aesthetics_score"): - add_time_ids = (original_size + crops_coords_top_left + (aesthetic_score,),) - add_neg_time_ids = (original_size + crops_coords_top_left + (negative_aesthetic_score,),) - else: - add_time_ids = (original_size + crops_coords_top_left + target_size,) - add_neg_time_ids = (original_size + crops_coords_top_left + target_size,) - - add_time_ids = np.array(add_time_ids, dtype=dtype) - add_neg_time_ids = np.array(add_neg_time_ids, dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image: Union[np.ndarray, PIL.Image.Image] = None, - strength: float = 0.3, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - pooled_prompt_embeds: Optional[np.ndarray] = None, - negative_pooled_prompt_embeds: Optional[np.ndarray] = None, - output_type: str = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`Optional[Union[str, List[str]]]`, defaults to None): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`Union[np.ndarray, PIL.Image.Image]`): - `Image`, or tensor representing an image batch which will be upscaled. - strength (`float`, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to 5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`Optional[Union[str, list]]`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, defaults to 1): - The number of images to generate per prompt. - eta (`float`, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: - A np.random.RandomState to make generation deterministic. - latents (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a - plain tuple. - callback (Optional[Callable], defaults to `None`): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - guidance_rescale (`float`, defaults to 0.7): - Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - Guidance rescale factor should fix overexposure when using zero terminal SNR. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Check inputs. Raise error if not correct - self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) - - # 1. Define call parameters - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if generator is None: - generator = np.random.RandomState() - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 2. Encode input prompt - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self._encode_prompt( - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - ) - - # 3. Preprocess image - image = self.image_processor.preprocess(image) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) - latent_timestep = np.repeat(timesteps[:1], batch_size * num_images_per_prompt, axis=0) - timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - - latents_dtype = prompt_embeds.dtype - image = image.astype(latents_dtype) - - # 5. Prepare latent variables - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, latents_dtype, generator - ) - - # 6. Prepare extra step kwargs - extra_step_kwargs = {} - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_eta: - extra_step_kwargs["eta"] = eta - - height, width = latents.shape[-2:] - height = height * self.vae_scale_factor - width = width * self.vae_scale_factor - original_size = original_size or (height, width) - target_size = target_size or (height, width) - - # 8. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - add_time_ids, add_neg_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - dtype=prompt_embeds.dtype, - ) - - if do_classifier_free_guidance: - prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0) - add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0) - add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) - add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet( - sample=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - text_embeds=add_text_embeds, - time_ids=add_time_ids, - ) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if output_type == "latent": - image = latents - else: - latents /= self.vae_decoder.config.get("scaling_factor", 0.18215) - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - image = self.image_processor.postprocess(image, output_type=output_type) - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/pipelines/diffusers/pipeline_utils.py b/optimum/pipelines/diffusers/pipeline_utils.py deleted file mode 100644 index e9d5986b61..0000000000 --- a/optimum/pipelines/diffusers/pipeline_utils.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import warnings -from typing import List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from diffusers import ConfigMixin -from diffusers.image_processor import VaeImageProcessor as DiffusersVaeImageProcessor -from diffusers.utils.pil_utils import PIL_INTERPOLATION -from PIL import Image -from tqdm.auto import tqdm - - -class DiffusionPipelineMixin(ConfigMixin): - # Copied from https://github.com/huggingface/diffusers/blob/v0.12.1/src/diffusers/pipelines/pipeline_utils.py#L812 - @staticmethod - def numpy_to_pil(images): - """ - Converts a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - # Copied from https://github.com/huggingface/diffusers/blob/v0.12.1/src/diffusers/pipelines/pipeline_utils.py#L827 - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - -# Adapted from https://github.com/huggingface/diffusers/blob/v0.18.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L58 -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = np.std(noise_pred_text, axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) - std_cfg = np.std(noise_cfg, axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class VaeImageProcessor(DiffusersVaeImageProcessor): - # Adapted from diffusers.VaeImageProcessor.denormalize - @staticmethod - def denormalize(images: np.ndarray): - """ - Denormalize an image array to [0,1]. - """ - return np.clip(images / 2 + 0.5, 0, 1) - - # Adapted from diffusers.VaeImageProcessor.preprocess - def preprocess( - self, - image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], - height: Optional[int] = None, - width: Optional[int] = None, - ) -> np.ndarray: - """ - Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. - """ - supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) - - do_convert_grayscale = getattr(self.config, "do_convert_grayscale", False) - # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image - if do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: - if isinstance(image, torch.Tensor): - # if image is a pytorch tensor could have 2 possible shapes: - # 1. batch x height x width: we should insert the channel dimension at position 1 - # 2. channnel x height x width: we should insert batch dimension at position 0, - # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 - # for simplicity, we insert a dimension of size 1 at position 1 for both cases - image = image.unsqueeze(1) - else: - # if it is a numpy array, it could have 2 possible shapes: - # 1. batch x height x width: insert channel dimension on last position - # 2. height x width x channel: insert batch dimension on first position - if image.shape[-1] == 1: - image = np.expand_dims(image, axis=0) - else: - image = np.expand_dims(image, axis=-1) - - if isinstance(image, supported_formats): - image = [image] - elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): - raise ValueError( - f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" - ) - - if isinstance(image[0], PIL.Image.Image): - if self.config.do_convert_rgb: - image = [self.convert_to_rgb(i) for i in image] - elif do_convert_grayscale: - image = [self.convert_to_grayscale(i) for i in image] - if self.config.do_resize: - height, width = self.get_height_width(image[0], height, width) - image = [self.resize(i, height, width) for i in image] - image = self.reshape(self.pil_to_numpy(image)) - else: - if isinstance(image[0], torch.Tensor): - image = [self.pt_to_numpy(elem) for elem in image] - image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) - else: - image = self.reshape(np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)) - - if do_convert_grayscale and image.ndim == 3: - image = np.expand_dims(image, 1) - - # don't need any preprocess if the image is latents - if image.shape[1] == 4: - return image - - if self.config.do_resize: - height, width = self.get_height_width(image, height, width) - image = self.resize(image, height, width) - - # expected range [0,1], normalize to [-1,1] - do_normalize = self.config.do_normalize - if image.min() < 0 and do_normalize: - warnings.warn( - "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " - f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", - FutureWarning, - ) - do_normalize = False - - if do_normalize: - image = self.normalize(image) - - if getattr(self.config, "do_binarize", False): - image = self.binarize(image) - - return image - - # Adapted from diffusers.VaeImageProcessor.postprocess - def postprocess( - self, - image: np.ndarray, - output_type: str = "pil", - do_denormalize: Optional[List[bool]] = None, - ): - if not isinstance(image, np.ndarray): - raise ValueError( - f"Input for postprocessing is in incorrect format: {type(image)}. We only support np array" - ) - if output_type not in ["latent", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - warnings.warn(deprecation_message, FutureWarning) - output_type = "np" - - if output_type == "latent": - return image - - if do_denormalize is None: - do_denormalize = [self.config.do_normalize] * image.shape[0] - - image = np.stack( - [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])], axis=0 - ) - - image = image.transpose((0, 2, 3, 1)) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - return image - - def get_height_width( - self, - image: Union[PIL.Image.Image, np.ndarray], - height: Optional[int] = None, - width: Optional[int] = None, - ): - """ - This function return the height and width that are downscaled to the next integer multiple of - `vae_scale_factor`. - - Args: - image(`PIL.Image.Image`, `np.ndarray`): - The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have - shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should - have shape `[batch, channel, height, width]`. - height (`int`, *optional*, defaults to `None`): - The height in preprocessed image. If `None`, will use the height of `image` input. - width (`int`, *optional*`, defaults to `None`): - The width in preprocessed. If `None`, will use the width of the `image` input. - """ - height = height or (image.height if isinstance(image, PIL.Image.Image) else image.shape[-2]) - width = width or (image.width if isinstance(image, PIL.Image.Image) else image.shape[-1]) - # resize to integer multiple of vae_scale_factor - width, height = (x - x % self.config.vae_scale_factor for x in (width, height)) - return height, width - - # Adapted from diffusers.VaeImageProcessor.numpy_to_pt - @staticmethod - def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: - """ - Convert a NumPy image to a PyTorch tensor. - """ - if images.ndim == 3: - images = images[..., None] - - images = torch.from_numpy(images) - return images - - # Adapted from diffusers.VaeImageProcessor.pt_to_numpy - @staticmethod - def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: - """ - Convert a PyTorch tensor to a NumPy image. - """ - images = images.cpu().float().numpy() - return images - - @staticmethod - def reshape(images: np.ndarray) -> np.ndarray: - """ - Reshape inputs to expected shape. - """ - if images.ndim == 3: - images = images[..., None] - - return images.transpose(0, 3, 1, 2) - - # TODO : remove after diffusers v0.21.0 release - def resize( - self, - image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], - height: Optional[int] = None, - width: Optional[int] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: - """ - Resize image. - """ - if isinstance(image, PIL.Image.Image): - image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) - elif isinstance(image, torch.Tensor): - image = torch.nn.functional.interpolate(image, size=(height, width)) - elif isinstance(image, np.ndarray): - image = self.numpy_to_pt(image) - image = torch.nn.functional.interpolate(image, size=(height, width)) - image = self.pt_to_numpy(image) - return image diff --git a/optimum/pipelines/diffusers/watermark.py b/optimum/pipelines/diffusers/watermark.py deleted file mode 100644 index b3cd622eda..0000000000 --- a/optimum/pipelines/diffusers/watermark.py +++ /dev/null @@ -1,31 +0,0 @@ -import numpy as np -from imwatermark import WatermarkEncoder - - -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] - - -# Adapted from https://github.com/huggingface/diffusers/blob/v0.18.1/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L12 -class StableDiffusionXLWatermarker: - def __init__(self): - self.watermark = WATERMARK_BITS - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def apply_watermark(self, images: np.array): - # can't encode images that are smaller than 256 - if images.shape[-1] < 256: - return images - - # cv2 doesn't support float16 - if images.dtype == np.float16: - images = images.astype(np.float32) - - images = (255 * (images / 2 + 0.5)).transpose((0, 2, 3, 1)) - - images = np.array([self.encoder.encode(image, "dwtDct") for image in images]).transpose((0, 3, 1, 2)) - - np.clip(2 * (images / 255 - 0.5), -1.0, 1.0, out=images) - - return images diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index d1471aa218..7671d6cd2e 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -43,7 +43,7 @@ from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.onnx.model_configs import WhisperOnnxConfig from optimum.exporters.onnx.utils import get_speecht5_models_for_export -from optimum.utils import ONNX_WEIGHTS_NAME, DummyPastKeyValuesGenerator, NormalizedTextConfig +from optimum.utils import DummyPastKeyValuesGenerator, NormalizedTextConfig from optimum.utils.save_utils import maybe_load_preprocessors from optimum.utils.testing_utils import grid_parameters, require_diffusers @@ -292,27 +292,22 @@ def _onnx_export( gc.collect() - def _onnx_export_sd(self, model_type: str, model_name: str, device="cpu"): + def _onnx_export_diffusion_models(self, model_type: str, model_name: str, device="cpu"): pipeline = TasksManager.get_model_from_task(model_type, model_name, device=device) models_and_onnx_configs = get_diffusion_models_for_export(pipeline) - output_names = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] - model, _ = models_and_onnx_configs["vae_encoder"] - model.forward = lambda sample: {"latent_sample": model.encode(x=sample)["latent_dist"].parameters} with TemporaryDirectory() as tmpdirname: _, onnx_outputs = export_models( models_and_onnx_configs=models_and_onnx_configs, opset=14, output_dir=Path(tmpdirname), - output_names=output_names, device=device, ) validate_models_outputs( models_and_onnx_configs=models_and_onnx_configs, onnx_named_outputs=onnx_outputs, output_dir=Path(tmpdirname), - atol=1e-3, - onnx_files_subpaths=output_names, + atol=1e-4, use_subprocess=False, ) @@ -403,7 +398,7 @@ def test_tensorflow_export( @require_vision @require_diffusers def test_pytorch_export_for_diffusion_models(self, model_type, model_name): - self._onnx_export_sd(model_type, model_name) + self._onnx_export_diffusion_models(model_type, model_name) @parameterized.expand(PYTORCH_DIFFUSION_MODEL.items()) @require_torch @@ -414,7 +409,7 @@ def test_pytorch_export_for_diffusion_models(self, model_type, model_name): @pytest.mark.run_slow @pytest.mark.gpu_test def test_pytorch_export_for_diffusion_models_cuda(self, model_type, model_name): - self._onnx_export_sd(model_type, model_name, device="cuda") + self._onnx_export_diffusion_models(model_type, model_name, device="cuda") class CustomWhisperOnnxConfig(WhisperOnnxConfig): diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index 9f480b2d1a..956566f0e1 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np -import PIL import pytest import torch from diffusers import ( @@ -24,6 +22,7 @@ AutoPipelineForText2Image, DiffusionPipeline, ) +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import load_image from parameterized import parameterized from transformers.testing_utils import require_torch_gpu @@ -35,8 +34,7 @@ ORTPipelineForInpainting, ORTPipelineForText2Image, ) -from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor -from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_ort_rocm +from optimum.utils.testing_utils import grid_parameters, require_diffusers def get_generator(framework, seed): @@ -72,16 +70,8 @@ def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type= return [image] * batch_size -def to_np(image): - if isinstance(image[0], PIL.Image.Image): - return np.stack([np.array(i) for i in image], axis=0) - elif isinstance(image, torch.Tensor): - return image.cpu().numpy().transpose(0, 2, 3, 1) - return image - - class ORTPipelineForText2ImageTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["latent-consistency", "stable-diffusion", "stable-diffusion-xl"] + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] ORTMODEL_CLASS = ORTPipelineForText2Image AUTOMODEL_CLASS = AutoPipelineForText2Image @@ -126,17 +116,16 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): def test_num_images_per_prompt(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - height, width, batch_size = 64, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + for batch_size in [1, 3]: + for height in [64, 128]: + for width in [64, 128]: + for num_images_per_prompt in [1, 3]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -150,61 +139,13 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) - if model_arch == "latent-consistency": - # Latent Consistency Model (LCM) doesn't support deterministic outputs beyond the first inference step - # TODO: Investigate why this is the case - inputs["num_inference_steps"] = 1 - - for output_type in ["latent", "np"]: + for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type - ort_output = ort_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images - diffusers_output = diffusers_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images - - self.assertTrue( - np.allclose(ort_output, diffusers_output, atol=1e-4), - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4), - ) - self.assertEqual(ort_pipeline.device, diffusers_pipeline.device) + ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) - ) - @require_torch_gpu - @pytest.mark.cuda_ep_test - @require_diffusers - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) - ) - @require_torch_gpu - @require_ort_rocm - @pytest.mark.rocm_ep_test - @require_diffusers - def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - height, width, batch_size = 64, 32, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -220,7 +161,7 @@ def __init__(self): self.has_been_called = False self.number_of_steps = 0 - def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 @@ -243,17 +184,21 @@ def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: def test_shape(self, model_arch: str): model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) - height, width, batch_size = 128, 64, 1 + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + height, width, batch_size = 128, 64, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for output_type in ["np", "pil", "latent"]: + for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type outputs = pipeline(**inputs).images if output_type == "pil": self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) elif output_type == "np": self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + elif output_type == "pt": + self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: self.assertEqual( outputs.shape, @@ -263,9 +208,6 @@ def test_shape(self, model_arch: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_image_reproducibility(self, model_arch: str): - if model_arch in ["latent-consistency"]: - pytest.skip("Latent Consistency Model (LCM) doesn't support deterministic outputs") - model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) @@ -279,14 +221,11 @@ def test_image_reproducibility(self, model_arch: str): ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + np.testing.assert_allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_negative_prompt(self, model_arch: str): - if model_arch in ["latent-consistency"]: - pytest.skip("Latent Consistency Model (LCM) does not support negative prompts") - model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) @@ -295,9 +234,8 @@ def test_negative_prompt(self, model_arch: str): negative_prompt = ["This is a negative prompt"] pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - image_slice_1 = pipeline( - **inputs, negative_prompt=negative_prompt, generator=np.random.RandomState(SEED) - ).images[0, -3:, -3:, -1] + + images_1 = pipeline(**inputs, negative_prompt=negative_prompt, generator=get_generator("pt", SEED)).images prompt = inputs.pop("prompt") if model_arch == "stable-diffusion-xl": @@ -306,39 +244,96 @@ def test_negative_prompt(self, model_arch: str): inputs["negative_prompt_embeds"], inputs["pooled_prompt_embeds"], inputs["negative_pooled_prompt_embeds"], - ) = pipeline._encode_prompt(prompt, 1, False, negative_prompt) + ) = pipeline.encode_prompt( + prompt=prompt, + num_images_per_prompt=1, + device=torch.device("cpu"), + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) else: - text_ids = pipeline.tokenizer( - prompt, - max_length=pipeline.tokenizer.model_max_length, - padding="max_length", - return_tensors="np", - truncation=True, - ).input_ids - negative_text_ids = pipeline.tokenizer( - negative_prompt, - max_length=pipeline.tokenizer.model_max_length, - padding="max_length", - return_tensors="np", - truncation=True, - ).input_ids - inputs["prompt_embeds"] = pipeline.text_encoder(text_ids)[0] - inputs["negative_prompt_embeds"] = pipeline.text_encoder(negative_text_ids)[0] - - image_slice_2 = pipeline(**inputs, generator=np.random.RandomState(SEED)).images[0, -3:, -3:, -1] - - self.assertTrue(np.allclose(image_slice_1, image_slice_2, rtol=1e-1)) + inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt( + prompt=prompt, + num_images_per_prompt=1, + device=torch.device("cpu"), + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + + images_2 = pipeline(**inputs, generator=get_generator("pt", SEED)).images + + np.testing.assert_allclose(images_1, images_2, atol=1e-4, rtol=1e-2) + + @parameterized.expand( + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "provider": ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"], + } + ) + ) + @pytest.mark.rocm_ep_test + @pytest.mark.cuda_ep_test + @pytest.mark.trt_ep_test + @require_torch_gpu + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + + outputs = pipeline(**inputs).images + + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(["stable-diffusion", "latent-consistency"]) + @require_diffusers + def test_safety_checker(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + + pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], safety_checker=safety_checker) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained( + self.onnx_model_dirs[model_arch], safety_checker=safety_checker + ) + + self.assertIsInstance(pipeline.safety_checker, StableDiffusionSafetyChecker) + self.assertIsInstance(ort_pipeline.safety_checker, StableDiffusionSafetyChecker) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)) + diffusers_output = pipeline(**inputs, generator=get_generator("pt", SEED)) + + ort_nsfw_content_detected = ort_output.nsfw_content_detected + diffusers_nsfw_content_detected = diffusers_output.nsfw_content_detected + + self.assertTrue(ort_nsfw_content_detected is not None) + self.assertTrue(diffusers_nsfw_content_detected is not None) + self.assertEqual(ort_nsfw_content_detected, diffusers_nsfw_content_detected) + + ort_images = ort_output.images + diffusers_images = diffusers_output.images + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) class ORTPipelineForImage2ImageTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"] + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] AUTOMODEL_CLASS = AutoPipelineForImage2Image ORTMODEL_CLASS = ORTPipelineForImage2Image TASK = "image-to-image" - def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="np"): + def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"): inputs = _generate_prompts(batch_size=batch_size) inputs["image"] = _generate_images( @@ -369,11 +364,6 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) - # auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) - # ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - - # self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) - @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_num_images_per_prompt(self, model_arch: str): @@ -381,68 +371,18 @@ def test_num_images_per_prompt(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - - batch_size, height = 1, 32 - for width in [64, 32]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) - ) - @require_torch_gpu - @pytest.mark.cuda_ep_test - @require_diffusers - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) - ) - @require_torch_gpu - @require_ort_rocm - @pytest.mark.rocm_ep_test - @require_diffusers - def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + for batch_size in [1, 3]: + for height in [64, 128]: + for width in [64, 128]: + for num_images_per_prompt in [1, 3]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_callback(self, model_arch: str): - if model_arch in ["stable-diffusion"]: - pytest.skip( - "Stable Diffusion For Img2Img doesn't behave as expected with callbacks (doesn't call it every step with callback_steps=1)" - ) - model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) @@ -455,7 +395,7 @@ def __init__(self): self.has_been_called = False self.number_of_steps = 0 - def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 @@ -478,18 +418,21 @@ def test_shape(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - height, width, batch_size = 32, 64, 1 - for input_type in ["np", "pil", "pt"]: + height, width, batch_size = 128, 64, 1 + + for input_type in ["pil", "np", "pt"]: inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) - for output_type in ["np", "pil", "latent"]: + for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type outputs = pipeline(**inputs).images if output_type == "pil": self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) elif output_type == "np": self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + elif output_type == "pt": + self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: self.assertEqual( outputs.shape, @@ -499,27 +442,26 @@ def test_shape(self, model_arch: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_compare_to_diffusers_pipeline(self, model_arch: str): - pytest.skip("Img2Img models do not support support output reproducibility for some reason") - model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) height, width, batch_size = 128, 128, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - ort_output = ort_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images - diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) - diffusers_output = diffusers_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + for output_type in ["latent", "np", "pt"]: + inputs["output_type"] = output_type - self.assertTrue(np.allclose(ort_output, diffusers_output, rtol=1e-2)) + ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images + + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_image_reproducibility(self, model_arch: str): - pytest.skip("Img2Img models do not support support output reproducibility for some reason") - model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) @@ -533,12 +475,73 @@ def test_image_reproducibility(self, model_arch: str): ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + np.testing.assert_allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-4, rtol=1e-2) + + @parameterized.expand( + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "provider": ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"], + } + ) + ) + @pytest.mark.rocm_ep_test + @pytest.mark.cuda_ep_test + @pytest.mark.trt_ep_test + @require_torch_gpu + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + self.assertEqual(pipeline.device.type, "cuda") + + outputs = pipeline(**inputs).images + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(["stable-diffusion", "latent-consistency"]) + @require_diffusers + def test_safety_checker(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + + pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], safety_checker=safety_checker) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained( + self.onnx_model_dirs[model_arch], safety_checker=safety_checker + ) + + self.assertIsInstance(pipeline.safety_checker, StableDiffusionSafetyChecker) + self.assertIsInstance(ort_pipeline.safety_checker, StableDiffusionSafetyChecker) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)) + diffusers_output = pipeline(**inputs, generator=get_generator("pt", SEED)) + + ort_nsfw_content_detected = ort_output.nsfw_content_detected + diffusers_nsfw_content_detected = diffusers_output.nsfw_content_detected + + self.assertTrue(ort_nsfw_content_detected is not None) + self.assertTrue(diffusers_nsfw_content_detected is not None) + self.assertEqual(ort_nsfw_content_detected, diffusers_nsfw_content_detected) + + ort_images = ort_output.images + diffusers_images = diffusers_output.images + + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) class ORTPipelineForInpaintingTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["stable-diffusion"] + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"] AUTOMODEL_CLASS = AutoPipelineForInpainting ORTMODEL_CLASS = ORTPipelineForInpainting @@ -546,18 +549,16 @@ class ORTPipelineForInpaintingTest(ORTModelTestMixin): TASK = "inpainting" def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"): - assert batch_size == 1, "Inpainting models only support batch_size=1" - assert input_type == "pil", "Inpainting models only support input_type='pil'" - inputs = _generate_prompts(batch_size=batch_size) inputs["image"] = _generate_images( - height=height, width=width, batch_size=1, channel=channel, input_type="pil" - )[0] + height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type + ) inputs["mask_image"] = _generate_images( - height=height, width=width, batch_size=1, channel=channel, input_type="pil" - )[0] + height=height, width=width, batch_size=batch_size, channel=1, input_type=input_type + ) + inputs["strength"] = 0.75 inputs["height"] = height inputs["width"] = width @@ -583,11 +584,6 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) - # auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) - # ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - - # self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) - @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_num_images_per_prompt(self, model_arch: str): @@ -595,59 +591,14 @@ def test_num_images_per_prompt(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - batch_size, height = 1, 32 - for width in [64, 32]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) - ) - @require_torch_gpu - @pytest.mark.cuda_ep_test - @require_diffusers - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) - ) - @require_torch_gpu - @require_ort_rocm - @pytest.mark.rocm_ep_test - @require_diffusers - def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + for batch_size in [1, 3]: + for height in [64, 128]: + for width in [64, 128]: + for num_images_per_prompt in [1, 3]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -664,7 +615,7 @@ def __init__(self): self.has_been_called = False self.number_of_steps = 0 - def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + def __call__(self, *args, **kwargs) -> None: self.has_been_called = True self.number_of_steps += 1 @@ -687,18 +638,21 @@ def test_shape(self, model_arch: str): self._setup(model_args) pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - height, width, batch_size = 32, 64, 1 - for input_type in ["pil"]: + height, width, batch_size = 128, 64, 1 + + for input_type in ["pil", "np", "pt"]: inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) - for output_type in ["np", "pil", "latent"]: + for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type outputs = pipeline(**inputs).images if output_type == "pil": self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) elif output_type == "np": self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + elif output_type == "pt": + self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: self.assertEqual( outputs.shape, @@ -708,11 +662,6 @@ def test_shape(self, model_arch: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_compare_to_diffusers_pipeline(self, model_arch: str): - if model_arch in ["stable-diffusion"]: - pytest.skip( - "Stable Diffusion For Inpainting fails, it was used to be compared to StableDiffusionPipeline for some reason which is the text-to-image variant" - ) - model_args = {"test_name": model_arch, "model_arch": model_arch} self._setup(model_args) @@ -722,23 +671,13 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): height, width, batch_size = 64, 64, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - latents_shape = ( - batch_size, - ort_pipeline.vae_decoder.config["latent_channels"], - height // ort_pipeline.vae_scale_factor, - width // ort_pipeline.vae_scale_factor, - ) + for output_type in ["latent", "np", "pt"]: + inputs["output_type"] = output_type - np_latents = np.random.rand(*latents_shape).astype(np.float32) - torch_latents = torch.from_numpy(np_latents) + ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images + diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images - ort_output = ort_pipeline(**inputs, latents=np_latents).images - diffusers_output = diffusers_pipeline(**inputs, latents=torch_latents).images - - self.assertTrue( - np.allclose(ort_output, diffusers_output, atol=1e-4), - np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4), - ) + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4, rtol=1e-2) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -756,38 +695,65 @@ def test_image_reproducibility(self, model_arch: str): ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + np.testing.assert_allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-4, rtol=1e-2) + + @parameterized.expand( + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "provider": ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"], + } + ) + ) + @pytest.mark.rocm_ep_test + @pytest.mark.cuda_ep_test + @pytest.mark.trt_ep_test + @require_torch_gpu + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + self.assertEqual(pipeline.device, "cuda") + + outputs = pipeline(**inputs).images + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(["stable-diffusion"]) + @require_diffusers + def test_safety_checker(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + + pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], safety_checker=safety_checker) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained( + self.onnx_model_dirs[model_arch], safety_checker=safety_checker + ) + + self.assertIsInstance(pipeline.safety_checker, StableDiffusionSafetyChecker) + self.assertIsInstance(ort_pipeline.safety_checker, StableDiffusionSafetyChecker) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_output = ort_pipeline(**inputs, generator=get_generator("pt", SEED)) + diffusers_output = pipeline(**inputs, generator=get_generator("pt", SEED)) + + ort_nsfw_content_detected = ort_output.nsfw_content_detected + diffusers_nsfw_content_detected = diffusers_output.nsfw_content_detected + self.assertTrue(ort_nsfw_content_detected is not None) + self.assertTrue(diffusers_nsfw_content_detected is not None) + self.assertEqual(ort_nsfw_content_detected, diffusers_nsfw_content_detected) -class ImageProcessorTest(unittest.TestCase): - def test_vae_image_processor_pt(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_pt = torch.stack(_generate_images(height=8, width=8, batch_size=1, input_type="pt")) - input_np = to_np(input_pt) - - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_pt), output_type=output_type) - out_np = to_np(out) - in_np = (input_np * 255).round() if output_type == "pil" else input_np - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - def test_vae_image_processor_np(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_np = np.stack(_generate_images(height=8, width=8, input_type="np")) - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) - out_np = to_np(out) - in_np = (input_np * 255).round() if output_type == "pil" else input_np - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - def test_vae_image_processor_pil(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_pil = _generate_images(height=8, width=8, batch_size=1, input_type="pil") - - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) - for i, o in zip(input_pil, out): - in_np = np.array(i) - out_np = to_np(out) if output_type == "pil" else (to_np(out) * 255).round() - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) + ort_images = ort_output.images + diffusers_images = diffusers_output.images + np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f6771ce761..665f253c48 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -148,7 +148,7 @@ def __init__(self, *args, **kwargs): self.ONNX_SEQ2SEQ_MODEL_ID = "optimum/t5-small" self.LARGE_ONNX_SEQ2SEQ_MODEL_ID = "facebook/mbart-large-en-ro" self.TINY_ONNX_SEQ2SEQ_MODEL_ID = "fxmarty/sshleifer-tiny-mbart-onnx" - self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "optimum-internal-testing/tiny-stable-diffusion-onnx" def test_load_model_from_local_path(self): model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) @@ -222,17 +222,17 @@ def test_load_seq2seq_model_from_empty_cache(self): @require_diffusers def test_load_stable_diffusion_model_from_cache(self): _ = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) # caching - model = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, local_files_only=True ) - self.assertIsInstance(model.text_encoder, ORTModelTextEncoder) self.assertIsInstance(model.vae_decoder, ORTModelVaeDecoder) self.assertIsInstance(model.vae_encoder, ORTModelVaeEncoder) self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) + model(prompt="This is a sanity test prompt", num_inference_steps=2) + @require_diffusers def test_load_stable_diffusion_model_from_empty_cache(self): dirpath = os.path.join( @@ -325,6 +325,8 @@ def test_load_stable_diffusion_model_from_hub(self): self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) + model(prompt="This is a sanity test prompt", num_inference_steps=2) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test @@ -339,6 +341,8 @@ def test_load_stable_diffusion_model_cuda_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) + model(prompt="This is a sanity test prompt", num_inference_steps=2) + @require_diffusers @require_torch_gpu @require_ort_rocm @@ -354,6 +358,8 @@ def test_load_stable_diffusion_model_rocm_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) + model(prompt="This is a sanity test prompt", num_inference_steps=2) + @require_diffusers def test_load_stable_diffusion_model_cpu_provider(self): model = ORTStableDiffusionPipeline.from_pretrained( @@ -366,6 +372,8 @@ def test_load_stable_diffusion_model_cpu_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cpu")) + model(prompt="This is a sanity test prompt", num_inference_steps=2) + @require_diffusers def test_load_stable_diffusion_model_unknown_provider(self): with self.assertRaises(ValueError): diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 17f3b391b0..5071d0081a 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -171,6 +171,11 @@ class ORTModelTestMixin(unittest.TestCase): "np": np.ndarray, } + TASK = None + + ORTMODEL_CLASS = None + AUTOMODEL_CLASS = None + @classmethod def setUpClass(cls): cls.onnx_model_dirs = {}