Skip to content

Commit

Permalink
offer way to turn off initial cross embed convolutional module, for d…
Browse files Browse the repository at this point in the history
…ebugging upsampler artifacts
  • Loading branch information
lucidrains committed Jul 16, 2022
1 parent a58a370 commit a2ee3fa
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ def __init__(
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 2,
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
Expand Down Expand Up @@ -1578,7 +1579,7 @@ def __init__(
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim)

self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
Expand Down
1 change: 1 addition & 0 deletions dalle2_pytorch/train_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class UnetConfig(BaseModel):
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True

class Config:
extra = "allow"
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.24.2'
__version__ = '0.24.3'

0 comments on commit a2ee3fa

Please sign in to comment.