From ec90bb12afb6cb223165a582570915efc5a2b659 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 5 Mar 2021 08:13:26 -0800 Subject: [PATCH] more efficient --- glom_pytorch/glom_pytorch.py | 12 +++++------- setup.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/glom_pytorch/glom_pytorch.py b/glom_pytorch/glom_pytorch.py index b761dae..663b166 100644 --- a/glom_pytorch/glom_pytorch.py +++ b/glom_pytorch/glom_pytorch.py @@ -79,8 +79,8 @@ def __init__( self.init_levels = nn.Parameter(torch.randn(levels, dim)) # bottom-up and top-down - self.bottom_up = GroupedFeedForward(dim = dim, groups = levels) - self.top_down = GroupedFeedForward(dim = dim, groups = levels) + self.bottom_up = GroupedFeedForward(dim = dim, groups = levels - 1) + self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1) # consensus attention self.attention = ConsensusAttention(attend_self = consensus_self) @@ -101,10 +101,10 @@ def forward(self, img, iters = None, return_all = False): hiddens = [levels] for _ in range(iters): - levels = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level + levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level - bottom_up_out = self.bottom_up(levels[..., :-1, :]) - top_down_out = self.top_down(levels[..., 1:, :]) + bottom_up_out = self.bottom_up(levels_with_input[..., 1:-1, :]) + top_down_out = self.top_down(levels_with_input[..., 2:, :]) bottom_up_out = torch.cat((bottom_level, bottom_up_out), dim = -2) top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.) @@ -112,8 +112,6 @@ def forward(self, img, iters = None, return_all = False): consensus = self.attention(levels) levels = torch.stack((levels, bottom_up_out, top_down_out, consensus)).mean(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value - levels = levels[..., 1:, :] # excise out the bottom level - hiddens.append(levels) if return_all: diff --git a/setup.py b/setup.py index e4b874d..75a40e1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'glom-pytorch', packages = find_packages(), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'Glom - Pytorch', author = 'Phil Wang',