Skip to content

Commit

Permalink
use some magic just this once to remove the need for researchers to t…
Browse files Browse the repository at this point in the history
…hink
  • Loading branch information
lucidrains committed Apr 18, 2022
1 parent 7214df4 commit 960a798
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ unet1 = Unet(
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
Expand Down Expand Up @@ -349,8 +348,7 @@ unet2 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
lowres_cond = True
dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
Expand Down
26 changes: 25 additions & 1 deletion dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,11 @@ def __init__(
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM

self._locals = locals()
del self._locals['self']
del self._locals['__class__']

# for eventual cascading diffusion

Expand Down Expand Up @@ -896,6 +901,15 @@ def __init__(
nn.Conv2d(dim, out_dim, 1)
)

# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self

updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)

def forward_with_cond_scale(
self,
*args,
Expand Down Expand Up @@ -1021,7 +1035,17 @@ def __init__(
self.clip_image_size = clip.image_size
self.channels = clip.image_channels

self.unets = nn.ModuleList(unet)
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet

self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)

# unet image sizes

image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.21',
version = '0.0.22',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 960a798

Please sign in to comment.