Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 6, 2021
1 parent 8f614f4 commit b3b6b87
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## GLOM - Pytorch

An attempt at the implementation of <a href="https://arxiv.org/abs/2102.12627">Glom</a>, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for emergent part-whole heirarchies from data.
An implementation of <a href="https://arxiv.org/abs/2102.12627">Glom</a>, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.

## Install

Expand All @@ -29,9 +29,9 @@ img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)
```

If you were to pass the `return_all = True` keyword argument on forward, you will be returned all the column and level states per iteration, including the initial state (so number of iterations + 1). You can then use this to attach any losses to any time step you wish, to induce emergence of whatever.
Pass the `return_all = True` keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

This also gives you access to all the data for clustering for the theorized islands that should form.
It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

```python
import torch
Expand Down Expand Up @@ -78,7 +78,7 @@ patches_to_images = nn.Sequential(
Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[6, :, :, -1]
top_level = all_levels[7, :, :, -1] # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising
Expand All @@ -104,7 +104,7 @@ img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12) # show image 1 for 12 iterations
levels1 = model(img1, iters = 12) # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations
```
Expand Down
1 change: 0 additions & 1 deletion glom_pytorch/glom_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def forward(self, levels):
sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE)

if self.local_consensus_radius > 0:
h = w = int(sqrt(n))
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(self.non_local_mask, max_neg_value)

Expand Down

0 comments on commit b3b6b87

Please sign in to comment.