diff --git a/tiled_diffusion.py b/tiled_diffusion.py index 816e099..e0d4f95 100644 --- a/tiled_diffusion.py +++ b/tiled_diffusion.py @@ -385,17 +385,19 @@ def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, # # Below can be in this if clause because self.refresh will trigger on resolution change, # e.g. cause of ConditioningSetArea, so that particular case isn't cached atm. + cf = control.compression_ratio if cns.shape[0] != batch_size: cns = repeat_to_batch_size(cns, batch_size) if shifts is not None: control.cns = cns cns = cns.roll(shifts=tuple(x * cf for x in shifts), dims=(-2,-1)) - control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0) + control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0).to(device=cns.device) self.control_params[tuple_key][param_id][batch_id] = control.cond_hint else: if hasattr(control,'cns') and shifts is not None: + cf = control.compression_ratio cns = control.cns.roll(shifts=tuple(x * cf for x in shifts), dims=(-2,-1)) - control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0) + control.cond_hint = torch.cat([cns[:, :, bbox[1]*cf:bbox[3]*cf, bbox[0]*cf:bbox[2]*cf] for bbox in bboxes], dim=0).to(device=cns.device) else: control.cond_hint = self.control_params[tuple_key][param_id][batch_id] control = control.previous_controlnet