Skip to content

Commit

Permalink
do not noise for the last step in ddim
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 10, 2022
1 parent 4878762 commit a598820
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,10 +1059,10 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal

c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
new_noise = torch.randn_like(image_embed)
noise = torch.randn_like(image_embed) if time_next > 0 else 0.

img = x_start * alpha_next.sqrt() + \
c1 * new_noise + \
c1 * noise + \
c2 * pred_noise

return image_embed
Expand Down Expand Up @@ -2275,9 +2275,10 @@ def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timestep

c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.

img = x_start * alpha_next.sqrt() + \
c1 * torch.randn_like(img) + \
c1 * noise + \
c2 * pred_noise

img = self.unnormalize_img(img)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.19.3'
__version__ = '0.19.4'

0 comments on commit a598820

Please sign in to comment.