diff --git a/tiled_diffusion.py b/tiled_diffusion.py index 5292eac..816e099 100644 --- a/tiled_diffusion.py +++ b/tiled_diffusion.py @@ -552,8 +552,17 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict): sigmas = self.sigmas = store.sigmas shift_method = store.model_options.get('tiled_diffusion_shift_method', 'random') seed = store.model_options.get('tiled_diffusion_seed', store.extra_args.get('seed', 0)) - shift_height = torch.randint(0, self.tile_height, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu') - shift_width = torch.randint(0, self.tile_width, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu') + th = self.tile_height + tw = self.tile_width + cf = self.compression + if 'effnet' in c_in: + cf = x_in.shape[-1] * self.compression // c_in['effnet'].shape[-1] # compression factor + th = self.height // cf + tw = self.width // cf + shift_height = torch.randint(0, th, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu') + shift_height = (shift_height * cf / self.compression).round().to(torch.int32) + shift_width = torch.randint(0, tw, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu') + shift_width = (shift_width * cf / self.compression).round().to(torch.int32) if shift_method == "sorted": shift_height = shift_height.sort().values shift_width = shift_width.sort().values @@ -585,6 +594,7 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict): if isinstance(v, torch.Tensor): if len(v.shape) == len(x_tile.shape): bboxes_ = bboxes + sh_h_new, sh_w_new = sh_h, sh_w if v.shape[-2:] != x_in.shape[-2:]: cf = x_in.shape[-1] * self.compression // v.shape[-1] # compression factor bboxes_ = self.get_grid_bbox( @@ -597,6 +607,8 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict): x_in.device, self.get_tile_weights, ) + sh_h_new, sh_w_new = round(sh_h * self.compression / cf), round(sh_w * self.compression / cf) + v = v.roll(shifts=(sh_h_new, sh_w_new), dims=(-2,-1)) v = torch.cat([v[bbox_.slicer] for bbox_ in bboxes_[batch_id]]) if v.shape[0] != x_tile.shape[0]: v = repeat_to_batch_size(v, x_tile.shape[0])