Skip to content

Commit

Permalink
always work in the l2normed space for image and text embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 14, 2022
1 parent a1a8a78 commit 5e06cde
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,13 @@ def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return image_embed
return l2norm(image_embed)

def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)

def q_mean_variance(self, x_start, t):
Expand Down Expand Up @@ -750,7 +751,7 @@ def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return image_embed
return l2norm(image_embed)

def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 5e06cde

Please sign in to comment.