Skip to content

Commit

Permalink
from my vision transformer experience, dimension of attention head of…
Browse files Browse the repository at this point in the history
… 32 is sufficient for image feature maps
  • Loading branch information
lucidrains committed Apr 20, 2022
1 parent b8e8d3c commit faebf4c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
22 changes: 14 additions & 8 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,11 @@ def __init__(
net,
*,
clip,
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x0 = True,
beta_schedule = "cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
Expand Down Expand Up @@ -825,6 +825,8 @@ def __init__(
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
Expand Down Expand Up @@ -888,6 +890,10 @@ def __init__(
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))

# attention related params

attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)

# layers

self.downs = nn.ModuleList([])
Expand All @@ -901,15 +907,15 @@ def __init__(

self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))

mid_dim = dims[-1]

self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
Expand All @@ -918,7 +924,7 @@ def __init__(

self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.30',
version = '0.0.31',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit faebf4c

Please sign in to comment.