Skip to content

Commit

Permalink
more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 5, 2021
1 parent c1ea075 commit ec90bb1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions glom_pytorch/glom_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
self.init_levels = nn.Parameter(torch.randn(levels, dim))

# bottom-up and top-down
self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
self.top_down = GroupedFeedForward(dim = dim, groups = levels)
self.bottom_up = GroupedFeedForward(dim = dim, groups = levels - 1)
self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)

# consensus attention
self.attention = ConsensusAttention(attend_self = consensus_self)
Expand All @@ -101,19 +101,17 @@ def forward(self, img, iters = None, return_all = False):
hiddens = [levels]

for _ in range(iters):
levels = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level
levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level

bottom_up_out = self.bottom_up(levels[..., :-1, :])
top_down_out = self.top_down(levels[..., 1:, :])
bottom_up_out = self.bottom_up(levels_with_input[..., 1:-1, :])
top_down_out = self.top_down(levels_with_input[..., 2:, :])

bottom_up_out = torch.cat((bottom_level, bottom_up_out), dim = -2)
top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)

consensus = self.attention(levels)

levels = torch.stack((levels, bottom_up_out, top_down_out, consensus)).mean(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value
levels = levels[..., 1:, :] # excise out the bottom level

hiddens.append(levels)

if return_all:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'glom-pytorch',
packages = find_packages(),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Glom - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit ec90bb1

Please sign in to comment.