From 7c5477b26d4cd0cbc4d5afdbface8e6e5f46b368 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 12 Aug 2022 11:36:08 -0700 Subject: [PATCH] bet on the new self-conditioning technique out of geoffrey hintons group --- README.md | 11 ++++ dalle2_pytorch/dalle2_pytorch.py | 101 +++++++++++++++++++++++-------- dalle2_pytorch/version.py | 2 +- 3 files changed, 88 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index c88b700a..cd3fb3d9 100644 --- a/README.md +++ b/README.md @@ -1253,4 +1253,15 @@ For detailed information on training the diffusion prior, please refer to the [d } ``` +```bibtex +@misc{chen2022analog, + title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, + author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton}, + year = {2022}, + eprint = {2208.04202}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 388d8208..31ce1577 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -870,7 +870,7 @@ def forward(self, x, mask = None, attn_bias = None): # attention - attn = sim.softmax(dim = -1) + attn = sim.softmax(dim = -1, dtype = torch.float32) attn = self.dropout(attn) # aggregate values @@ -1157,17 +1157,17 @@ def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1 pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond) if self.predict_x_start: - x_recon = pred + x_start = pred else: - x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) + x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if clip_denoised and not self.predict_x_start: - x_recon.clamp_(-1., 1.) + x_start.clamp_(-1., 1.) if self.predict_x_start and self.sampling_clamp_l2norm: - x_recon = l2norm(x_recon) * self.image_embed_scale + x_start = l2norm(x_start) * self.image_embed_scale - model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) + model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() @@ -1571,7 +1571,7 @@ def forward(self, x, context, mask = None): mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) - attn = sim.softmax(dim = -1) + attn = sim.softmax(dim = -1, dtype = torch.float32) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -1700,6 +1700,7 @@ def __init__( attn_heads = 16, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen + self_cond = False, sparse_attn = False, cosine_sim_cross_attn = False, cosine_sim_self_attn = False, @@ -1735,12 +1736,21 @@ def __init__( self.lowres_cond = lowres_cond + # whether to do self conditioning + + self.self_cond = self_cond + # determine dimensions self.channels = channels self.channels_out = default(channels_out, channels) - init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis + # initial number of channels depends on + # (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade + # (2) self conditioning (bit diffusion paper) + + init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) + init_dim = default(init_dim, dim) self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) @@ -1994,7 +2004,8 @@ def forward( text_cond_drop_prob = 0., blur_sigma = None, blur_kernel_size = None, - disable_checkpoint = False + disable_checkpoint = False, + self_cond = None ): batch_size, device = x.shape[0], x.device @@ -2002,6 +2013,14 @@ def forward( assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' + # concat self conditioning, if needed + + if self.self_cond: + self_cond = default(self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x, self_cond), dim = 1) + + # concat low resolution conditioning + if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) @@ -2571,23 +2590,23 @@ def dynamic_threshold(self, x): x = x.clamp(-s, s) / s return x - def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): + def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' - pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)) + pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) if learned_variance: pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1) if predict_x_start: - x_recon = pred + x_start = pred else: - x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) + x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if clip_denoised: - x_recon = self.dynamic_threshold(x_recon) + x_start = self.dynamic_threshold(x_start) - model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) + model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t) if learned_variance: # if learned variance, posterio variance and posterior log variance are predicted by the network @@ -2603,16 +2622,17 @@ def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodin posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_variance = posterior_log_variance.exp() - return model_mean, posterior_variance, posterior_log_variance + return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): + def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + return pred, x_start @torch.no_grad() def p_sample_loop_ddpm( @@ -2638,6 +2658,8 @@ def p_sample_loop_ddpm( b = shape[0] img = torch.randn(shape, device = device) + x_start = None # for self-conditioning + is_inpaint = exists(inpaint_image) resample_times = inpaint_resample_times if is_inpaint else 1 @@ -2665,13 +2687,16 @@ def p_sample_loop_ddpm( noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) - img = self.p_sample( + self_cond = x_start if unet.self_cond else None + + img, x_start = self.p_sample( unet, img, times, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, + self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level, predict_x_start = predict_x_start, @@ -2730,6 +2755,8 @@ def p_sample_loop_ddim( img = torch.randn(shape, device = device) + x_start = None # for self-conditioning + if not is_latent_diffusion: lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) @@ -2750,7 +2777,9 @@ def p_sample_loop_ddim( noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) - pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) + self_cond = x_start if unet.self_cond else None + + pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) if learned_variance: pred, _ = pred.chunk(2, dim = 1) @@ -2810,13 +2839,35 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise) - model_output = unet( - x_noisy, - times, + # unet kwargs + + unet_kwargs = dict( image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level, + ) + + # self conditioning + + self_cond = None + + if unet.self_cond and random.random() < 0.5: + with torch.no_grad(): + self_cond = unet(x_noisy, times, **unet_kwargs) + + if learned_variance: + self_cond, _ = self_cond.chunk(2, dim = 1) + + self_cond = self_cond.detach() + + # forward to get model prediction + + model_output = unet( + x_noisy, + times, + **unet_kwargs, + self_cond = self_cond, image_cond_drop_prob = self.image_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob, ) @@ -2847,7 +2898,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres # if learning the variance, also include the extra weight kl loss true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times) - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) + model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) # kl loss with detached model predicted mean, for stability reasons as in paper diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 77f1c8e6..bcd8d54e 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.5.0' +__version__ = '1.6.0'