Skip to content

Commit

Permalink
add continuation feature
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 5, 2021
1 parent d72cf2b commit 9032564
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ recon_img = patches_to_images(top_level)
loss = F.mse_loss(img, recon_img)
loss.backward()
```

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

```python
import torch
from glom_pytorch import Glom

model = Glom(
dim = 512,
levels = 6,
image_size = 224,
patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12) # show image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations
```

### Appreciation

Thanks goes out to <a href="https://github.com/cfoster0">Cfoster0</a> for reviewing the code
Expand Down
5 changes: 3 additions & 2 deletions glom_pytorch/glom_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
# consensus attention
self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)

def forward(self, img, iters = None, return_all = False):
def forward(self, img, iters = None, levels = None, return_all = False):
b, h, w, _, device = *img.shape, img.device
iters = default(iters, self.levels * 2) # need to have twice the number of levels of iterations in order for information to propagate up and back down. can be overridden

Expand All @@ -121,7 +121,8 @@ def forward(self, img, iters = None, return_all = False):
bottom_level = tokens
bottom_level = rearrange(bottom_level, 'b n d -> b n () d')

levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)
if not exists(levels):
levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)

hiddens = [levels]

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'glom-pytorch',
packages = find_packages(),
version = '0.0.10',
version = '0.0.11',
license='MIT',
description = 'Glom - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 9032564

Please sign in to comment.