Skip to content

Commit

Permalink
Revert excess changes in attentions
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Dec 21, 2021
1 parent 2b77018 commit bff3114
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 56 deletions.
77 changes: 35 additions & 42 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import torch.nn.functional as F
from einops import rearrange, repeat

from dalle_pytorch.cache import Cached

# helpers

def exists(val):
Expand Down Expand Up @@ -41,6 +39,8 @@ def apply_rotary_emb(freqs, t):
return torch.cat((t, t_right), dim = -1)

def apply_pos_emb(pos_emb, qkv):
n = qkv[0].shape[-2]
pos_emb = pos_emb[..., :n, :]
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))

# classes
Expand All @@ -65,30 +65,24 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
using_cache = exists(cache) and cache_key in cache

qkv_key = f'{cache_key}_qkv'
if exists(cache) and qkv_key in cache:
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb[..., n - 1:n, :], (q, k, v)) # FIXME: Fix rotary index here
if exists(rotary_pos_emb):
if using_cache:
rotary_pos_emb = rotary_pos_emb[..., n - 1:, :] # FIXME: Fix rotary index here
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

q *= self.scale
q = q * self.scale

k_top, v_top = cache[qkv_key]
if using_cache:
k_top, v_top = cache[cache_key]
k = torch.cat([k_top, k], dim=-2)
v = torch.cat([v_top, v], dim=-2)
else:
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb[..., :n, :], (q, k, v))

q *= self.scale
if exists(cache):
cache[qkv_key] = (k, v)
cache[cache_key] = k, v

dots = q @ k.swapaxes(-1, -2)
mask_value = max_neg_value(dots)
Expand All @@ -98,17 +92,16 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
dots.masked_fill_(~mask, mask_value)
del mask

# if self.causal: # TODO:
# i, j = dots.shape[-2:]
# mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
# dots.masked_fill_(mask, mask_value)
if self.causal and not using_cache: # causality is naturally enforced if we run the cached inference
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)

attn = softmax(dots, dim=-1)

out = attn @ v
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)

return out

# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
Expand All @@ -128,14 +121,14 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,

self.stable = stable

self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = Cached(nn.Sequential(
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
))
)

def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax

Expand All @@ -152,7 +145,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key

# derive query / keys / values

qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

if exists(rotary_pos_emb):
Expand Down Expand Up @@ -229,7 +222,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
out = torch.cat((out_text, out_image), dim = 1)

out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
out = self.to_out(out)
return out[:, :n]

# sparse axial causal attention
Expand All @@ -248,14 +241,14 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head

self.stable = stable

self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = Cached(nn.Sequential(
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
))
)

def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax

Expand All @@ -272,7 +265,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key

# derive queries / keys / values

qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

if exists(rotary_pos_emb):
Expand All @@ -284,15 +277,15 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key

# text attention

dots_text = q_text @ k_text.swapaxes(-1, -2)
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)

i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)

attn_text = softmax(dots_text, dim = -1)
out_text = attn_text @ v_text
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

# image attention

Expand All @@ -305,8 +298,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key

# similarity

dots_image_to_image = q_img @ k_img.swapaxes(-1, -2)
dots_image_to_text = q_img @ k_text[:, None].swapaxes(-1, -2)
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)

dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)

Expand All @@ -329,8 +322,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key

attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]

out_image_to_image = attn_image_to_image @ v_img
out_image_to_text = attn_image_to_text @ v_text[:, None]
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)

out_image = out_image_to_image + out_image_to_text

Expand All @@ -343,7 +336,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
out = torch.cat((out_text, out_image), dim = 1)

out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
out = self.to_out(out)
return out[:, :n]

# microsoft sparse attention CUDA kernel
Expand Down
14 changes: 0 additions & 14 deletions dalle_pytorch/cache.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,5 @@
import torch
import torch.nn as nn

# helpers

def exists(val):
return val is not None

class Cached(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, *, cache=None, cache_key=None, **kwargs):
return self.fn(x, **kwargs)

class FixCacheKey(nn.Module):
def __init__(self, cache_key, fn):
super().__init__()
Expand Down

0 comments on commit bff3114

Please sign in to comment.