From faebf4c8b8f1ac0888df6b1ce9c40eb319cdb108 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 20 Apr 2022 11:40:24 -0700 Subject: [PATCH] from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps --- dalle2_pytorch/dalle2_pytorch.py | 22 ++++++++++++++-------- setup.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7c2423b3..d8c41e84 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -464,11 +464,11 @@ def __init__( net, *, clip, - timesteps=1000, - cond_drop_prob=0.2, - loss_type="l1", - predict_x0=True, - beta_schedule="cosine", + timesteps = 1000, + cond_drop_prob = 0.2, + loss_type = "l1", + predict_x0 = True, + beta_schedule = "cosine", ): super().__init__() assert isinstance(clip, CLIP) @@ -825,6 +825,8 @@ def __init__( out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, + attn_dim_head = 32, + attn_heads = 8, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond_upsample_mode = 'bilinear', blur_sigma = 0.1, @@ -888,6 +890,10 @@ def __init__( self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim)) + # attention related params + + attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) + # layers self.downs = nn.ModuleList([]) @@ -901,7 +907,7 @@ def __init__( self.downs.append(nn.ModuleList([ ConvNextBlock(dim_in, dim_out, norm = ind != 0), - Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), + Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) @@ -909,7 +915,7 @@ def __init__( mid_dim = dims[-1] self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) - self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None + self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): @@ -918,7 +924,7 @@ def __init__( self.ups.append(nn.ModuleList([ ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), - Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), + Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), Upsample(dim_in) ])) diff --git a/setup.py b/setup.py index e9db6d01..54959685 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.30', + version = '0.0.31', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',