diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3a237035..9dd6f35c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -63,11 +63,16 @@ def default(val, d): return val return d() if callable(d) else d -def cast_tuple(val, length = 1): +def cast_tuple(val, length = None): if isinstance(val, list): val = tuple(val) - return val if isinstance(val, tuple) else ((val,) * length) + out = val if isinstance(val, tuple) else ((val,) * default(length, 1)) + + if exists(length): + assert len(out) == length + + return out def module_device(module): return next(module.parameters()).device @@ -1341,6 +1346,7 @@ def __init__( dim_mults=(1, 2, 4, 8), channels = 3, channels_out = None, + self_attn = False, attn_dim_head = 32, attn_heads = 16, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ @@ -1387,6 +1393,8 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) + num_stages = len(in_out) + # time, image embeddings, and optional text encoding cond_dim = default(cond_dim, dim) @@ -1450,14 +1458,16 @@ def __init__( attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) + self_attn = cast_tuple(self_attn, num_stages) + + create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs))) + # resnet block klass - resnet_groups = cast_tuple(resnet_groups, len(in_out)) + resnet_groups = cast_tuple(resnet_groups, num_stages) top_level_resnet_group = first(resnet_groups) - num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out)) - - assert len(resnet_groups) == len(in_out) + num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages) # downsample klass @@ -1479,9 +1489,9 @@ def __init__( self.ups = nn.ModuleList([]) num_resolutions = len(in_out) - skip_connect_dims = [] # keeping track of skip connection dimensions + skip_connect_dims = [] # keeping track of skip connection dimensions - for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)): + for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)): is_first = ind == 0 is_last = ind >= (num_resolutions - 1) layer_cond_dim = cond_dim if not is_first else None @@ -1489,30 +1499,42 @@ def __init__( dim_layer = dim_out if memory_efficient else dim_in skip_connect_dims.append(dim_layer) + attention = nn.Identity() + if layer_self_attn: + attention = create_self_attn(dim_layer) + elif sparse_attn: + attention = Residual(LinearAttention(dim_layer, **attn_kwargs)) + self.downs.append(nn.ModuleList([ downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None, ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups), - Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(), nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), + attention, downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1) ])) mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) - self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None + self.mid_attn = create_self_attn(mid_dim) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) - for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))): + for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))): is_last = ind >= (len(in_out) - 1) layer_cond_dim = cond_dim if not is_last else None skip_connect_dim = skip_connect_dims.pop() + attention = nn.Identity() + if layer_self_attn: + attention = create_self_attn(dim_out) + elif sparse_attn: + attention = Residual(LinearAttention(dim_out, **attn_kwargs)) + self.ups.append(nn.ModuleList([ ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), - Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), + attention, upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() ])) @@ -1690,18 +1712,19 @@ def forward( hiddens = [] - for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs: + for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs: if exists(pre_downsample): x = pre_downsample(x) x = init_block(x, t, c) - x = sparse_attn(x) - hiddens.append(x) for resnet_block in resnet_blocks: x = resnet_block(x, t, c) hiddens.append(x) + x = attn(x) + hiddens.append(x) + if exists(post_downsample): x = post_downsample(x) @@ -1714,15 +1737,15 @@ def forward( connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1) - for init_block, sparse_attn, resnet_blocks, upsample in self.ups: + for init_block, resnet_blocks, attn, upsample in self.ups: x = connect_skip(x) x = init_block(x, t, c) - x = sparse_attn(x) for resnet_block in resnet_blocks: x = connect_skip(x) x = resnet_block(x, t, c) + x = attn(x) x = upsample(x) x = torch.cat((x, r), dim = 1) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index baaa2225..c1902b63 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -216,6 +216,7 @@ class UnetConfig(BaseModel): cond_on_text_encodings: bool = None cond_dim: int = None channels: int = 3 + self_attn: ListOrTuple(int) attn_dim_head: int = 32 attn_heads: int = 16 diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 00d1ab54..090480db 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.15.2' +__version__ = '0.15.3'