Skip to content

Commit

Permalink
fix imports in community pipeline for semantic guidance for flux
Browse files Browse the repository at this point in the history
  • Loading branch information
Marlon154 committed Jan 20, 2025
1 parent 7829f3d commit a5892d7
Showing 1 changed file with 97 additions and 85 deletions.
182 changes: 97 additions & 85 deletions examples/community/pipeline_flux_semantic_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,25 @@
# limitations under the License.

import inspect
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
CLIPVisionModelWithProjection,
CLIPImageProcessor
)
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
Expand All @@ -39,7 +42,7 @@
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline


if is_torch_xla_available():
import torch_xla.core.xla_model as xm
Expand All @@ -57,7 +60,7 @@
>>> from diffusers import DiffusionPipeline
>>> pipe = DiffusionPipeline.from_pretrained(
>>> "black-forest-labs/FLUX.1-dev",
>>> "black-forest-labs/FLUX.1-dev",
>>> custom_pipeline="pipeline_flux_semantic_guidance",
>>> torch_dtype=torch.bfloat16
>>> )
Expand Down Expand Up @@ -319,7 +322,6 @@ def _get_clip_prompt_embeds(

return prompt_embeds


def encode_prompt(
self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -400,18 +402,18 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids

def encode_text_with_editing(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
editing_prompt: Optional[List[str]] = None,
editing_prompt_2: Optional[List[str]] = None,
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
editing_prompt: Optional[List[str]] = None,
editing_prompt_2: Optional[List[str]] = None,
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
"""
Encode text prompts with editing prompts and negative prompts for semantic guidance.
Expand Down Expand Up @@ -500,8 +502,15 @@ def encode_text_with_editing(
editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0)
pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0)

return (prompt_embeds, pooled_prompt_embeds, editing_prompt_embeds,
pooled_editing_prompt_embeds, text_ids, edit_text_ids, enabled_editing_prompts)
return (
prompt_embeds,
pooled_prompt_embeds,
editing_prompt_embeds,
pooled_editing_prompt_embeds,
text_ids,
edit_text_ids,
enabled_editing_prompts,
)

def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
Expand Down Expand Up @@ -546,27 +555,27 @@ def prepare_ip_adapter_image_embeds(
return ip_adapter_image_embeds

def check_inputs(
self,
prompt,
prompt_2,
height,
width,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
self,
prompt,
prompt_2,
height,
width,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
Expand Down Expand Up @@ -743,47 +752,47 @@ def interrupt(self):
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
editing_prompt: Optional[Union[str, List[str]]] = None,
editing_prompt_2: Optional[Union[str, List[str]]] = None,
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
edit_warmup_steps: Optional[Union[int, List[int]]] = 8,
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
edit_momentum_scale: Optional[float] = 0.1,
edit_mom_beta: Optional[float] = 0.4,
edit_weights: Optional[List[float]] = None,
sem_guidance: Optional[List[torch.Tensor]] = None,
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
editing_prompt: Optional[Union[str, List[str]]] = None,
editing_prompt_2: Optional[Union[str, List[str]]] = None,
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
edit_warmup_steps: Optional[Union[int, List[int]]] = 8,
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
edit_momentum_scale: Optional[float] = 0.1,
edit_mom_beta: Optional[float] = 0.4,
edit_weights: Optional[List[float]] = None,
sem_guidance: Optional[List[torch.Tensor]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -1037,7 +1046,9 @@ def __call__(
min_edit_warmup_steps = 0

if edit_cooldown_steps:
tmp_e_cooldown_steps = edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps]
tmp_e_cooldown_steps = (
edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps]
)
max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps)
else:
max_edit_cooldown_steps = num_inference_steps
Expand Down Expand Up @@ -1110,7 +1121,9 @@ def __call__(

if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
noise_pred_edit_concepts = []
for e_embed, pooled_e_embed, e_text_id in zip(editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids):
for e_embed, pooled_e_embed, e_text_id in zip(
editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids
):
noise_pred_edit = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
Expand Down Expand Up @@ -1160,7 +1173,6 @@ def __call__(
# noise_guidance_edit = torch.zeros_like(noise_guidance)
warmup_inds = []
for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):

if isinstance(edit_guidance_scale, list):
edit_guidance_scale_c = edit_guidance_scale[c]
else:
Expand Down Expand Up @@ -1247,9 +1259,7 @@ def __call__(
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)

noise_guidance_edit_tmp = torch.index_select(
noise_guidance_edit.to(device), 0, warmup_inds
)
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
noise_guidance_edit_tmp = torch.einsum(
"cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp
)
Expand Down Expand Up @@ -1325,4 +1335,6 @@ def __call__(
if not return_dict:
return (image,)

return FluxPipelineOutput(image, )
return FluxPipelineOutput(
image,
)

0 comments on commit a5892d7

Please sign in to comment.