From 311d7fcf5c898d2a8d515cc2927fcf461cec1c0f Mon Sep 17 00:00:00 2001 From: shiimizu Date: Wed, 4 Sep 2024 18:21:27 -0700 Subject: [PATCH] Fix FLUX Union ControlNet. For #36 --- tiled_diffusion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tiled_diffusion.py b/tiled_diffusion.py index 816e099..e0d4f95 100644 --- a/tiled_diffusion.py +++ b/tiled_diffusion.py @@ -385,17 +385,19 @@ def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, # # Below can be in this if clause because self.refresh will trigger on resolution change, # e.g. cause of ConditioningSetArea, so that particular case isn't cached atm. + cf = control.compression_ratio if cns.shape[0] != batch_size: cns = repeat_to_batch_size(cns, batch_size) if shifts is not None: control.cns = cns cns = cns.roll(shifts=tuple(x * cf for x in shifts), dims=(-2,-1)) - control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0) + control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0).to(device=cns.device) self.control_params[tuple_key][param_id][batch_id] = control.cond_hint else: if hasattr(control,'cns') and shifts is not None: + cf = control.compression_ratio cns = control.cns.roll(shifts=tuple(x * cf for x in shifts), dims=(-2,-1)) - control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0) + control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0).to(device=cns.device) else: control.cond_hint = self.control_params[tuple_key][param_id][batch_id] control = control.previous_controlnet