Skip to content

Commit

Permalink
add cosine sim for self attention as well, as a setting
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 29, 2022
1 parent 2d67d58 commit d167378
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
20 changes: 13 additions & 7 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,12 @@ def __init__(
dropout = 0.,
causal = False,
rotary_emb = None,
pb_relax_alpha = 128
cosine_sim = True,
cosine_sim_scale = 16
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.cosine_sim = cosine_sim

self.heads = heads
inner_dim = dim_head * heads
Expand Down Expand Up @@ -745,6 +746,13 @@ def forward(self, x, mask = None, attn_bias = None):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)

# whether to use cosine sim

if self.cosine_sim:
q, k = map(l2norm, (q, k))

q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))

# calculate query / key similarities

sim = einsum('b h i d, b j d -> b h i j', q, k)
Expand All @@ -770,9 +778,6 @@ def forward(self, x, mask = None, attn_bias = None):

# attention

sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

Expand Down Expand Up @@ -1604,6 +1609,7 @@ def __init__(
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
Expand Down Expand Up @@ -1724,7 +1730,7 @@ def __init__(

# attention related params

attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)

self_attn = cast_tuple(self_attn, num_stages)

Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.4.2'
__version__ = '1.4.3'

0 comments on commit d167378

Please sign in to comment.