diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index c2e1094e..f636a85f 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -65,7 +65,7 @@ def generate_images( if exists(clipper): scores = clipper(text_seq, img_seq, return_loss = False) - return images, scores.diag() + return images, scores return images @@ -202,11 +202,13 @@ def forward( text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents)) - sim = einsum('i d, j d -> i j', text_latents, image_latents) * self.temperature.exp() + temp = self.temperature.exp() if not return_loss: + sim = einsum('n d, n d -> n', text_latents, image_latents) * temp return sim + sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp labels = torch.arange(b, device = device) loss = F.cross_entropy(sim, labels) return loss