Skip to content

Commit

Permalink
add ability to specify full self attention on specific stages in the …
Browse files Browse the repository at this point in the history
…unet
  • Loading branch information
lucidrains committed Jul 1, 2022
1 parent 282c359 commit 3d23ba4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
57 changes: 40 additions & 17 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,16 @@ def default(val, d):
return val
return d() if callable(d) else d

def cast_tuple(val, length = 1):
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)

return val if isinstance(val, tuple) else ((val,) * length)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))

if exists(length):
assert len(out) == length

return out

def module_device(module):
return next(module.parameters()).device
Expand Down Expand Up @@ -1341,6 +1346,7 @@ def __init__(
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
self_attn = False,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
Expand Down Expand Up @@ -1387,6 +1393,8 @@ def __init__(
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

num_stages = len(in_out)

# time, image embeddings, and optional text encoding

cond_dim = default(cond_dim, dim)
Expand Down Expand Up @@ -1450,14 +1458,16 @@ def __init__(

attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)

self_attn = cast_tuple(self_attn, num_stages)

create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))

# resnet block klass

resnet_groups = cast_tuple(resnet_groups, len(in_out))
resnet_groups = cast_tuple(resnet_groups, num_stages)
top_level_resnet_group = first(resnet_groups)

num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))

assert len(resnet_groups) == len(in_out)
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)

# downsample klass

Expand All @@ -1479,40 +1489,52 @@ def __init__(
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

skip_connect_dims = [] # keeping track of skip connection dimensions
skip_connect_dims = [] # keeping track of skip connection dimensions

for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None

dim_layer = dim_out if memory_efficient else dim_in
skip_connect_dims.append(dim_layer)

attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_layer)
elif sparse_attn:
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))

self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))

mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])

for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None

skip_connect_dim = skip_connect_dims.pop()

attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_out)
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))

self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))

Expand Down Expand Up @@ -1690,18 +1712,19 @@ def forward(

hiddens = []

for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)

x = init_block(x, t, c)
x = sparse_attn(x)
hiddens.append(x)

for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)

x = attn(x)
hiddens.append(x)

if exists(post_downsample):
x = post_downsample(x)

Expand All @@ -1714,15 +1737,15 @@ def forward(

connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)

for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
x = init_block(x, t, c)
x = sparse_attn(x)

for resnet_block in resnet_blocks:
x = connect_skip(x)
x = resnet_block(x, t, c)

x = attn(x)
x = upsample(x)

x = torch.cat((x, r), dim = 1)
Expand Down
1 change: 1 addition & 0 deletions dalle2_pytorch/train_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class UnetConfig(BaseModel):
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16

Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.15.2'
__version__ = '0.15.3'

0 comments on commit 3d23ba4

Please sign in to comment.