Skip to content

Commit

Permalink
same for text encodings for decoder ddpm training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 18, 2022
1 parent 6fee4fc commit 82328f1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def sample(self, image_embed, text = None, cond_scale = 1.):

return img

def forward(self, image, text = None, image_embed = None, unet_number = None):
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
Expand All @@ -1233,7 +1233,7 @@ def forward(self, image, text = None, image_embed = None, unet_number = None):
if not exists(image_embed):
image_embed = self.get_image_embed(image)

text_encodings = self.get_text_encodings(text) if exists(text) else None
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None

lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.23',
version = '0.0.24',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 82328f1

Please sign in to comment.