From a972c16e5471e9ae98672f2df01fce8d1c98e4d7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 6 Mar 2021 02:03:25 -0800 Subject: [PATCH] fix bug with bottom up --- glom_pytorch/glom_pytorch.py | 11 +++++------ setup.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/glom_pytorch/glom_pytorch.py b/glom_pytorch/glom_pytorch.py index c6d8125..c391e49 100644 --- a/glom_pytorch/glom_pytorch.py +++ b/glom_pytorch/glom_pytorch.py @@ -26,9 +26,9 @@ def __init__(self, *, dim, groups, mult = 4): total_dim = dim * groups # levels * dim self.net = nn.Sequential( Rearrange('b n l d -> b (l d) n'), - nn.Conv1d(total_dim, total_dim * 4, 1, groups = groups), + nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups), nn.GELU(), - nn.Conv1d(total_dim * 4, total_dim, 1, groups = groups), + nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups), Rearrange('b (l d) n -> b n l d', l = groups) ) @@ -101,7 +101,7 @@ def __init__( self.init_levels = nn.Parameter(torch.randn(levels, dim)) # bottom-up and top-down - self.bottom_up = GroupedFeedForward(dim = dim, groups = levels - 1) + self.bottom_up = GroupedFeedForward(dim = dim, groups = levels) self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1) # consensus attention @@ -131,10 +131,9 @@ def forward(self, img, iters = None, levels = None, return_all = False): for _ in range(iters): levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level - bottom_up_out = self.bottom_up(levels_with_input[..., 1:-1, :]) - top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks + bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :]) - bottom_up_out = torch.cat((bottom_level, bottom_up_out), dim = -2) + top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.) consensus = self.attention(levels) diff --git a/setup.py b/setup.py index 8c34139..7d9a71c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'glom-pytorch', packages = find_packages(), - version = '0.0.11', + version = '0.0.12', license='MIT', description = 'Glom - Pytorch', author = 'Phil Wang',