diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e0e14c3c..47a612ff 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -629,10 +629,13 @@ def __init__( heads = 8, dropout = 0., causal = False, - rotary_emb = None + rotary_emb = None, + pb_relax_alpha = 32 ** 2 ): super().__init__() - self.scale = dim_head ** -0.5 + self.pb_relax_alpha = pb_relax_alpha + self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) + self.heads = heads inner_dim = dim_head * heads @@ -696,6 +699,9 @@ def forward(self, x, mask = None, attn_bias = None): # attention + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + sim = sim * self.pb_relax_alpha + attn = sim.softmax(dim = -1, dtype = torch.float32) attn = self.dropout(attn) @@ -1210,10 +1216,12 @@ def __init__( dim_head = 64, heads = 8, dropout = 0., - norm_context = False + norm_context = False, + pb_relax_alpha = 32 ** 2 ): super().__init__() - self.scale = dim_head ** -0.5 + self.pb_relax_alpha = pb_relax_alpha + self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.heads = heads inner_dim = dim_head * heads @@ -1259,6 +1267,9 @@ def forward(self, x, context, mask = None): mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + sim = sim * self.pb_relax_alpha + attn = sim.softmax(dim = -1, dtype = torch.float32) out = einsum('b h i j, b h j d -> b h i d', attn, v) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8911e95c..9513287c 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.0' +__version__ = '0.16.1'