Skip to content

Commit

Permalink
Several fixes to Flux ControlNet pipelines (#9472)
Browse files Browse the repository at this point in the history
* fix flux controlnet pipelines

---------

Co-authored-by: yiyixuxu <[email protected]>
  • Loading branch information
vladmandic and yiyixuxu authored Sep 20, 2024
1 parent 2b443a5 commit 14a1b86
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
11 changes: 10 additions & 1 deletion src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .flux import (
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
Expand Down Expand Up @@ -128,6 +135,7 @@
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
]
)

Expand All @@ -143,6 +151,7 @@
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
]
)

Expand Down
19 changes: 11 additions & 8 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]

Expand Down Expand Up @@ -763,7 +763,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]

Expand Down Expand Up @@ -840,12 +840,10 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
Expand All @@ -863,6 +861,11 @@ def __call__(
return_dict=False,
)

guidance = (
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]

Expand Down Expand Up @@ -798,7 +798,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]

Expand Down Expand Up @@ -933,7 +933,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]

Expand Down

0 comments on commit 14a1b86

Please sign in to comment.