Skip to content

Commit

Permalink
bring in two tricks from the cogview paper for reducing the chances o…
Browse files Browse the repository at this point in the history
…f overflow, for attention and layernorm
  • Loading branch information
lucidrains committed Jul 5, 2022
1 parent e1fe308 commit 3bdf85a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 15 additions & 4 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,13 @@ def __init__(
heads = 8,
dropout = 0.,
causal = False,
rotary_emb = None
rotary_emb = None,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)

self.heads = heads
inner_dim = dim_head * heads

Expand Down Expand Up @@ -696,6 +699,9 @@ def forward(self, x, mask = None, attn_bias = None):

# attention

sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha

attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)

Expand Down Expand Up @@ -1210,10 +1216,12 @@ def __init__(
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
norm_context = False,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads

Expand Down Expand Up @@ -1259,6 +1267,9 @@ def forward(self, x, context, mask = None):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)

sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha

attn = sim.softmax(dim = -1, dtype = torch.float32)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.16.0'
__version__ = '0.16.1'

0 comments on commit 3bdf85a

Please sign in to comment.