Skip to content

Commit

Permalink
bet on the new self-conditioning technique out of geoffrey hintons group
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 12, 2022
1 parent be3bb86 commit 7c5477b
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 26 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1253,4 +1253,15 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```

```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
101 changes: 76 additions & 25 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def forward(self, x, mask = None, attn_bias = None):

# attention

attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)

# aggregate values
Expand Down Expand Up @@ -1157,17 +1157,17 @@ def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)

if self.predict_x_start:
x_recon = pred
x_start = pred
else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)

if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
x_start.clamp_(-1., 1.)

if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
x_start = l2norm(x_start) * self.image_embed_scale

model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance

@torch.no_grad()
Expand Down Expand Up @@ -1571,7 +1571,7 @@ def forward(self, x, context, mask = None):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)

attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down Expand Up @@ -1700,6 +1700,7 @@ def __init__(
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False,
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
Expand Down Expand Up @@ -1735,12 +1736,21 @@ def __init__(

self.lowres_cond = lowres_cond

# whether to do self conditioning

self.self_cond = self_cond

# determine dimensions

self.channels = channels
self.channels_out = default(channels_out, channels)

init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
# initial number of channels depends on
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
# (2) self conditioning (bit diffusion paper)

init_channels = channels * (1 + int(lowres_cond) + int(self_cond))

init_dim = default(init_dim, dim)

self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
Expand Down Expand Up @@ -1994,14 +2004,23 @@ def forward(
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None,
disable_checkpoint = False
disable_checkpoint = False,
self_cond = None
):
batch_size, device = x.shape[0], x.device

# add low resolution conditioning, if present

assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'

# concat self conditioning, if needed

if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)

# concat low resolution conditioning

if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)

Expand Down Expand Up @@ -2571,23 +2590,23 @@ def dynamic_threshold(self, x):
x = x.clamp(-s, s) / s
return x

def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))

if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)

if predict_x_start:
x_recon = pred
x_start = pred
else:
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)

if clip_denoised:
x_recon = self.dynamic_threshold(x_recon)
x_start = self.dynamic_threshold(x_start)

model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)

if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
Expand All @@ -2603,16 +2622,17 @@ def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodin
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()

return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance, x_start

@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start

@torch.no_grad()
def p_sample_loop_ddpm(
Expand All @@ -2638,6 +2658,8 @@ def p_sample_loop_ddpm(
b = shape[0]
img = torch.randn(shape, device = device)

x_start = None # for self-conditioning

is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1

Expand Down Expand Up @@ -2665,13 +2687,16 @@ def p_sample_loop_ddpm(
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)

img = self.p_sample(
self_cond = x_start if unet.self_cond else None

img, x_start = self.p_sample(
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
Expand Down Expand Up @@ -2730,6 +2755,8 @@ def p_sample_loop_ddim(

img = torch.randn(shape, device = device)

x_start = None # for self-conditioning

if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

Expand All @@ -2750,7 +2777,9 @@ def p_sample_loop_ddim(
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)

pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
self_cond = x_start if unet.self_cond else None

pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)

if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
Expand Down Expand Up @@ -2810,13 +2839,35 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres

x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)

model_output = unet(
x_noisy,
times,
# unet kwargs

unet_kwargs = dict(
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
)

# self conditioning

self_cond = None

if unet.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = unet(x_noisy, times, **unet_kwargs)

if learned_variance:
self_cond, _ = self_cond.chunk(2, dim = 1)

self_cond = self_cond.detach()

# forward to get model prediction

model_output = unet(
x_noisy,
times,
**unet_kwargs,
self_cond = self_cond,
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)
Expand Down Expand Up @@ -2847,7 +2898,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
# if learning the variance, also include the extra weight kl loss

true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)

# kl loss with detached model predicted mean, for stability reasons as in paper

Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.0'
__version__ = '1.6.0'

0 comments on commit 7c5477b

Please sign in to comment.