From 2b46bcb98c8e8fdb250cb8ff2e20874f3ccdd768 Mon Sep 17 00:00:00 2001 From: Robin Rombach <38811725+rromb@users.noreply.github.com> Date: Mon, 17 Jan 2022 21:24:19 +0100 Subject: [PATCH] Update ddpm.py clean up no.1 --- ldm/models/diffusion/ddpm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 983512969..bbedd04cf 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -461,7 +461,7 @@ def __init__(self, self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None # # TODO: special class? + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -598,7 +598,7 @@ def get_weighting(self, h, w, Ly, Lx, device): weighting = weighting * L_weighting return weighting - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code ! + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) @@ -793,7 +793,7 @@ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_qua z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): # todo ask what this is + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -901,7 +901,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): if hasattr(self, "split_input_params"): assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids # todo dont know what this is -> I exclude --> Good + assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64)