Skip to content

Commit

Permalink
updated all mixins, enabled all tests ; all are passing except some r…
Browse files Browse the repository at this point in the history
…eproducibility and comparaison tests (7 failed, 35 passed)
  • Loading branch information
IlyasMoutawwakil committed Sep 11, 2024
1 parent 2cd616e commit dceccca
Show file tree
Hide file tree
Showing 9 changed files with 2,263 additions and 1,349 deletions.
20 changes: 15 additions & 5 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,11 @@ def __init__(
)
self._internal_dict.pop("vae", None)

if "block_out_channels" in self.vae_decoder.config:
self.vae_scale_factor = 2 ** (len(self.vae_decoder.config["block_out_channels"]) - 1)
else:
self.vae_scale_factor = 8

self.vae_scale_factor = 2 ** (len(self.vae_decoder.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)

@staticmethod
def load_model(
Expand Down Expand Up @@ -526,6 +525,11 @@ def forward(self, input_ids: Union[np.ndarray, torch.Tensor]):
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

if any("hidden_states" in model_output for model_output in model_outputs):
model_outputs["hidden_states"] = []
for i in range(self.config.num_hidden_layers):
model_outputs["hidden_states"].append(model_outputs.pop(f"hidden_states.{i}"))

return ModelOutput(**model_outputs)


Expand Down Expand Up @@ -567,6 +571,9 @@ def forward(self, latent_sample: Union[np.ndarray, torch.Tensor]):
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")

return ModelOutput(**model_outputs)


Expand All @@ -580,6 +587,9 @@ def forward(self, sample: Union[np.ndarray, torch.Tensor]):
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")

return ModelOutput(**model_outputs)


Expand Down
18 changes: 16 additions & 2 deletions optimum/pipelines/diffusers/pipeline_latent_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from diffusers.utils.deprecation_utils import deprecate

from .pipeline_stable_diffusion import StableDiffusionPipelineMixin
from .pipeline_utils import patch_randn_tensor


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -239,11 +238,19 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# must have a compatible torch device
device = self.device

# convert numpy arrays to torch tensors
prompt_embeds = self.np_to_pt(prompt_embeds, device) if isinstance(prompt_embeds, np.ndarray) else prompt_embeds
latents = self.np_to_pt(latents, device) if isinstance(latents, np.ndarray) else latents
prompt_embeds = (
self.np_to_pt(prompt_embeds, device) if isinstance(prompt_embeds, np.ndarray) else prompt_embeds
)
ip_adapter_image_embeds = (
[self.np_to_pt(i, device) if isinstance(i, np.ndarray) else i for i in ip_adapter_image_embeds]
if ip_adapter_image_embeds is not None
else ip_adapter_image_embeds
)

for k, v in kwargs.items():
if isinstance(v, np.ndarray):
Expand All @@ -253,6 +260,13 @@ def __call__(
elif isinstance(v, dict) and all(isinstance(i, np.ndarray) for i in v.values()):
kwargs[k] = {k: self.np_to_pt(v, device) for k, v in v.items()}

generator = (
self.np_to_pt(generator, device)
if (isinstance(generator, list) and isinstance(generator[0], np.random.RandomState))
or isinstance(generator, np.random.RandomState)
else generator
)

callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

Expand Down
98 changes: 87 additions & 11 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps
from diffusers.utils.deprecation_utils import deprecate
from diffusers.utils.torch_utils import randn_tensor

from .pipeline_utils import DiffusionPipelineMixin, randn_tensor
from .pipeline_utils import DiffusionPipelineMixin


logger = logging.getLogger(__name__)
Expand All @@ -34,11 +36,11 @@ class StableDiffusionPipelineMixin(DiffusionPipelineMixin):

def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt: str,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
Expand Down Expand Up @@ -74,6 +76,8 @@ def encode_prompt(
the output of the pre-final layer will be used for computing the prompt embeddings.
"""

device = device or self.device

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -234,7 +238,10 @@ def run_safety_checker(self, image, device, dtype):

def decode_latents(self, latents):
latents = 1 / self.vae_decoder.config.scaling_factor * latents
image = self.vae_decoder(latents, return_dict=False)[0]
image = self.vae_decoder(
latents,
# return_dict=False,
)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
Expand Down Expand Up @@ -461,8 +468,10 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# must have a compatible torch device
device = self.device

# convert numpy arrays to torch tensors
latents = self.np_to_pt(latents, device) if isinstance(latents, np.ndarray) else latents
prompt_embeds = (
self.np_to_pt(prompt_embeds, device) if isinstance(prompt_embeds, np.ndarray) else prompt_embeds
Expand All @@ -472,17 +481,41 @@ def __call__(
if isinstance(negative_prompt_embeds, np.ndarray)
else negative_prompt_embeds
)
ip_adapter_image_embeds = (
[self.np_to_pt(i, device) if isinstance(i, np.ndarray) else i for i in ip_adapter_image_embeds]
if ip_adapter_image_embeds is not None
else ip_adapter_image_embeds
)

for k, v in kwargs.items():
if isinstance(v, np.ndarray):
kwargs[k] = self.np_to_pt(v, device)
elif isinstance(v, list) and all(isinstance(i, np.ndarray) for i in v):
kwargs[k] = [self.np_to_pt(i, device) for i in v]
elif isinstance(v, dict) and all(isinstance(i, np.ndarray) for i in v.values()):
kwargs[k] = {k: self.np_to_pt(v, device) for k, v in v.items()}

generator = (
self.np_to_pt(generator, device)
if (isinstance(generator, list) and isinstance(generator[0], np.random.RandomState))
or isinstance(generator, np.random.RandomState)
else generator
)

callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

if callback is not None:
logger.warning(
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
logger.warning(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
Expand Down Expand Up @@ -665,3 +698,46 @@ def __call__(
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0

half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
Loading

0 comments on commit dceccca

Please sign in to comment.