diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index ed55994e..395a3ce4 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -1,6 +1,7 @@ +from collections import deque from collections.abc import Iterable from functools import partial -from itertools import islice, cycle, product +from itertools import islice, cycle import torch from torch import nn, einsum @@ -103,18 +104,30 @@ def __init__(self, fn, image_size, seq_len): self.fn = fn self.image_size = image_size self.seq_len = seq_len + self.img_seq_len = image_size ** 2 + self.text_len = seq_len - self.img_seq_len + 1 def forward(self, x, cache=None, cache_key=None, **kwargs): - n0 = x.shape[1] - if exists(cache): - if cache_key in cache: - x = torch.cat([cache[cache_key], x], dim=-2) - cache[cache_key] = x + seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len + + if exists(cache) and cache_key in cache: + offset = cache['offset'] + assert offset >= text_len, "cached inference for text is not supported" + q = cache[cache_key] + assert isinstance(q, deque) and len(q) == image_size + + x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1) + + q.append((x_top, x_left)) + x_top = q.popleft()[0] + x_left = q[-2][1] + if (offset - text_len) % image_size == 0: + x_left = torch.zeros_like(x_left) + + x = torch.cat((x_top, x_left, *x_pass), dim=-1) + return self.fn(x[:, None], cache=cache, **kwargs) n = x.shape[1] - seq_len, image_size = self.seq_len, self.image_size - img_seq_len = image_size ** 2 - text_len = seq_len - img_seq_len + 1 padding = seq_len - n + 1 # get text and image tokens @@ -139,8 +152,22 @@ def forward(self, x, cache=None, cache_key=None, **kwargs): # merge text and image sequence back together x_img = rearrange(x_img, 'b h w d -> b (h w) d') - x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) - return self.fn(x[:, -n0:], cache=cache, **kwargs) + x_img = x_img[:, :-padding] + x = torch.cat((x_text, x_img), dim = 1) + + if exists(cache): + dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1) + dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left) + + q = deque() + x_img = x_img[:, -image_size:] + for _ in range(image_size - x_img.shape[1]): + q.append((dummy_top, dummy_left)) + for i in range(x_img.shape[1]): + q.append(x_img[:, i].chunk(4, dim=-1)[:2]) + cache[cache_key] = q + + return self.fn(x, cache=cache, **kwargs) # main transformer class @@ -277,6 +304,11 @@ def forward(self, x, **kwargs): return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs) def _get_static_mask(self, attn_type): + # In case of attn_type = "axial_{row,col}", + # the sparse implementation is most efficient for training, + # but the full attention with a static mask is most efficient for inference + # since caching is implemented in this case. + img_seq_len = self.image_fmap_size ** 2 text_len = self.seq_len + 1 - img_seq_len