diff --git a/README.md b/README.md index a485e335..68a453f7 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,6 @@ unet1 = Unet( unet2 = Unet( dim = 16, image_embed_dim = 512, - lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off) cond_dim = 128, channels = 3, dim_mults = (1, 2, 4, 8, 16) @@ -349,8 +348,7 @@ unet2 = Unet( image_embed_dim = 512, cond_dim = 128, channels = 3, - dim_mults = (1, 2, 4, 8, 16), - lowres_cond = True + dim_mults = (1, 2, 4, 8, 16) ).cuda() decoder = Decoder( diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4f84123c..24b4e649 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -816,6 +816,11 @@ def __init__( attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) ): super().__init__() + # save locals to take care of some hyperparameters for cascading DDPM + + self._locals = locals() + del self._locals['self'] + del self._locals['__class__'] # for eventual cascading diffusion @@ -896,6 +901,15 @@ def __init__( nn.Conv2d(dim, out_dim, 1) ) + # if the current settings for the unet are not correct + # for cascading DDPM, then reinit the unet with the right settings + def force_lowres_cond(self, lowres_cond): + if lowres_cond == self.lowres_cond: + return self + + updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond} + return self.__class__(**updated_kwargs) + def forward_with_cond_scale( self, *args, @@ -1021,7 +1035,17 @@ def __init__( self.clip_image_size = clip.image_size self.channels = clip.image_channels - self.unets = nn.ModuleList(unet) + # automatically take care of ensuring that first unet is unconditional + # while the rest of the unets are conditioned on the low resolution image produced by previous unet + + self.unets = nn.ModuleList([]) + for ind, one_unet in enumerate(cast_tuple(unet)): + is_first = ind == 0 + one_unet = one_unet.force_lowres_cond(not is_first) + self.unets.append(one_unet) + + # unet image sizes + image_sizes = default(image_sizes, (clip.image_size,)) image_sizes = tuple(sorted(set(image_sizes))) diff --git a/setup.py b/setup.py index 8e44f4d3..c1bcccb8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.21', + version = '0.0.22', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',