Skip to content

Commit

Permalink
fix bug with bottom up
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 6, 2021
1 parent 3ee326e commit a972c16
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
11 changes: 5 additions & 6 deletions glom_pytorch/glom_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(self, *, dim, groups, mult = 4):
total_dim = dim * groups # levels * dim
self.net = nn.Sequential(
Rearrange('b n l d -> b (l d) n'),
nn.Conv1d(total_dim, total_dim * 4, 1, groups = groups),
nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),
nn.GELU(),
nn.Conv1d(total_dim * 4, total_dim, 1, groups = groups),
nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),
Rearrange('b (l d) n -> b n l d', l = groups)
)

Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
self.init_levels = nn.Parameter(torch.randn(levels, dim))

# bottom-up and top-down
self.bottom_up = GroupedFeedForward(dim = dim, groups = levels - 1)
self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)

# consensus attention
Expand Down Expand Up @@ -131,10 +131,9 @@ def forward(self, img, iters = None, levels = None, return_all = False):
for _ in range(iters):
levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level

bottom_up_out = self.bottom_up(levels_with_input[..., 1:-1, :])
top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks
bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])

bottom_up_out = torch.cat((bottom_level, bottom_up_out), dim = -2)
top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks
top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)

consensus = self.attention(levels)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'glom-pytorch',
packages = find_packages(),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'Glom - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a972c16

Please sign in to comment.