Skip to content

Commit

Permalink
Speed up PreShiftToken
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Dec 21, 2021
1 parent 1ff47c6 commit 94fda36
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import deque
from collections.abc import Iterable
from functools import partial
from itertools import islice, cycle, product
from itertools import islice, cycle

import torch
from torch import nn, einsum
Expand Down Expand Up @@ -103,18 +104,30 @@ def __init__(self, fn, image_size, seq_len):
self.fn = fn
self.image_size = image_size
self.seq_len = seq_len
self.img_seq_len = image_size ** 2
self.text_len = seq_len - self.img_seq_len + 1

def forward(self, x, cache=None, cache_key=None, **kwargs):
n0 = x.shape[1]
if exists(cache):
if cache_key in cache:
x = torch.cat([cache[cache_key], x], dim=-2)
cache[cache_key] = x
seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len

if exists(cache) and cache_key in cache:
offset = cache['offset']
assert offset >= text_len, "cached inference for text is not supported"
q = cache[cache_key]
assert isinstance(q, deque) and len(q) == image_size

x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1)

q.append((x_top, x_left))
x_top = q.popleft()[0]
x_left = q[-2][1]
if (offset - text_len) % image_size == 0:
x_left = torch.zeros_like(x_left)

x = torch.cat((x_top, x_left, *x_pass), dim=-1)
return self.fn(x[:, None], cache=cache, **kwargs)

n = x.shape[1]
seq_len, image_size = self.seq_len, self.image_size
img_seq_len = image_size ** 2
text_len = seq_len - img_seq_len + 1
padding = seq_len - n + 1

# get text and image tokens
Expand All @@ -139,8 +152,22 @@ def forward(self, x, cache=None, cache_key=None, **kwargs):
# merge text and image sequence back together

x_img = rearrange(x_img, 'b h w d -> b (h w) d')
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
return self.fn(x[:, -n0:], cache=cache, **kwargs)
x_img = x_img[:, :-padding]
x = torch.cat((x_text, x_img), dim = 1)

if exists(cache):
dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1)
dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left)

q = deque()
x_img = x_img[:, -image_size:]
for _ in range(image_size - x_img.shape[1]):
q.append((dummy_top, dummy_left))
for i in range(x_img.shape[1]):
q.append(x_img[:, i].chunk(4, dim=-1)[:2])
cache[cache_key] = q

return self.fn(x, cache=cache, **kwargs)

# main transformer class

Expand Down Expand Up @@ -277,6 +304,11 @@ def forward(self, x, **kwargs):
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)

def _get_static_mask(self, attn_type):
# In case of attn_type = "axial_{row,col}",
# the sparse implementation is most efficient for training,
# but the full attention with a static mask is most efficient for inference
# since caching is implemented in this case.

img_seq_len = self.image_fmap_size ** 2
text_len = self.seq_len + 1 - img_seq_len

Expand Down

0 comments on commit 94fda36

Please sign in to comment.