Skip to content

Commit

Permalink
calculate similarities differently depending on return_loss in Clip f…
Browse files Browse the repository at this point in the history
…orward
  • Loading branch information
lucidrains authored Jan 7, 2021
1 parent 01b8bfd commit 99ca233
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def generate_images(

if exists(clipper):
scores = clipper(text_seq, img_seq, return_loss = False)
return images, scores.diag()
return images, scores

return images

Expand Down Expand Up @@ -202,11 +202,13 @@ def forward(

text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))

sim = einsum('i d, j d -> i j', text_latents, image_latents) * self.temperature.exp()
temp = self.temperature.exp()

if not return_loss:
sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
return sim

sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
labels = torch.arange(b, device = device)
loss = F.cross_entropy(sim, labels)
return loss
Expand Down

0 comments on commit 99ca233

Please sign in to comment.