Skip to content

Commit

Permalink
Fix FLUX Union ControlNet.
Browse files Browse the repository at this point in the history
For #36
  • Loading branch information
shiimizu committed Sep 5, 2024
1 parent 2ad9433 commit 311d7fc
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tiled_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,19 @@ def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List,
#
# Below can be in this if clause because self.refresh will trigger on resolution change,
# e.g. cause of ConditioningSetArea, so that particular case isn't cached atm.
cf = control.compression_ratio
if cns.shape[0] != batch_size:
cns = repeat_to_batch_size(cns, batch_size)
if shifts is not None:
control.cns = cns
cns = cns.roll(shifts=tuple(x * cf for x in shifts), dims=(-2,-1))
control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0)
control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0).to(device=cns.device)
self.control_params[tuple_key][param_id][batch_id] = control.cond_hint
else:
if hasattr(control,'cns') and shifts is not None:
cf = control.compression_ratio
cns = control.cns.roll(shifts=tuple(x * cf for x in shifts), dims=(-2,-1))
control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0)
control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0).to(device=cns.device)
else:
control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
control = control.previous_controlnet
Expand Down

0 comments on commit 311d7fc

Please sign in to comment.