From d16737840155581edacc56d58c372ad6f027577b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 29 Jul 2022 12:48:20 -0700 Subject: [PATCH] add cosine sim for self attention as well, as a setting --- dalle2_pytorch/dalle2_pytorch.py | 20 +++++++++++++------- dalle2_pytorch/version.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8e496040..e22a0737 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -701,11 +701,12 @@ def __init__( dropout = 0., causal = False, rotary_emb = None, - pb_relax_alpha = 128 + cosine_sim = True, + cosine_sim_scale = 16 ): super().__init__() - self.pb_relax_alpha = pb_relax_alpha - self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) + self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5) + self.cosine_sim = cosine_sim self.heads = heads inner_dim = dim_head * heads @@ -745,6 +746,13 @@ def forward(self, x, mask = None, attn_bias = None): k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) + # whether to use cosine sim + + if self.cosine_sim: + q, k = map(l2norm, (q, k)) + + q, k = map(lambda t: t * math.sqrt(self.scale), (q, k)) + # calculate query / key similarities sim = einsum('b h i d, b j d -> b h i j', q, k) @@ -770,9 +778,6 @@ def forward(self, x, mask = None, attn_bias = None): # attention - sim = sim - sim.amax(dim = -1, keepdim = True).detach() - sim = sim * self.pb_relax_alpha - attn = sim.softmax(dim = -1) attn = self.dropout(attn) @@ -1604,6 +1609,7 @@ def __init__( lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen sparse_attn = False, cosine_sim_cross_attn = False, + cosine_sim_self_attn = False, attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) cond_on_text_encodings = False, max_text_len = 256, @@ -1724,7 +1730,7 @@ def __init__( # attention related params - attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) + attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn) self_attn = cast_tuple(self_attn, num_stages) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 98d186be..4e7c72a5 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.4.2' +__version__ = '1.4.3'