Skip to content

Commit

Permalink
SD related changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
shiimizu committed Sep 5, 2024
1 parent 5a82065 commit 2ad9433
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions tiled_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,17 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict):
sigmas = self.sigmas = store.sigmas
shift_method = store.model_options.get('tiled_diffusion_shift_method', 'random')
seed = store.model_options.get('tiled_diffusion_seed', store.extra_args.get('seed', 0))
shift_height = torch.randint(0, self.tile_height, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu')
shift_width = torch.randint(0, self.tile_width, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu')
th = self.tile_height
tw = self.tile_width
cf = self.compression
if 'effnet' in c_in:
cf = x_in.shape[-1] * self.compression // c_in['effnet'].shape[-1] # compression factor
th = self.height // cf
tw = self.width // cf
shift_height = torch.randint(0, th, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu')
shift_height = (shift_height * cf / self.compression).round().to(torch.int32)
shift_width = torch.randint(0, tw, (len(sigmas)-1,), generator=torch.Generator(device='cpu').manual_seed(seed), device='cpu')
shift_width = (shift_width * cf / self.compression).round().to(torch.int32)
if shift_method == "sorted":
shift_height = shift_height.sort().values
shift_width = shift_width.sort().values
Expand Down Expand Up @@ -585,6 +594,7 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict):
if isinstance(v, torch.Tensor):
if len(v.shape) == len(x_tile.shape):
bboxes_ = bboxes
sh_h_new, sh_w_new = sh_h, sh_w
if v.shape[-2:] != x_in.shape[-2:]:
cf = x_in.shape[-1] * self.compression // v.shape[-1] # compression factor
bboxes_ = self.get_grid_bbox(
Expand All @@ -597,6 +607,8 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict):
x_in.device,
self.get_tile_weights,
)
sh_h_new, sh_w_new = round(sh_h * self.compression / cf), round(sh_w * self.compression / cf)
v = v.roll(shifts=(sh_h_new, sh_w_new), dims=(-2,-1))
v = torch.cat([v[bbox_.slicer] for bbox_ in bboxes_[batch_id]])
if v.shape[0] != x_tile.shape[0]:
v = repeat_to_batch_size(v, x_tile.shape[0])
Expand Down

0 comments on commit 2ad9433

Please sign in to comment.