Skip to content

Commit

Permalink
optimize for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 29, 2022
1 parent 3eea7f3 commit 40db958
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions flamingo_pytorch/flamingo_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,10 @@ def forward(self, x, latents):
q = self.to_q(latents)

# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
lk, lv = self.to_kv(latents).chunk(2, dim = -1)
k, v = self.to_kv(x).chunk(2, dim = -1)
kv_input = torch.cat((x, latents), dim = -2)
k, v = self.to_kv(kv_input).chunk(2, dim = -1)

k, v, lk, lv, q = rearrange_many((k, v, lk, lv, q), 'b t n (h d) -> b h t n d', h = h)

k = torch.cat((k, lk), dim = -2)
v = torch.cat((v, lv), dim = -2)
q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h)

q = q * self.scale

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flamingo-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'Flamingo - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 40db958

Please sign in to comment.