Skip to content

Commit

Permalink
allow for specifying for whether text should attend to only the most …
Browse files Browse the repository at this point in the history
…recent media item, or all previous media items
  • Loading branch information
lucidrains committed Jun 8, 2022
1 parent 44920f4 commit 2f3606a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
5 changes: 3 additions & 2 deletions flamingo_pytorch/flamingo_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def __init__(
cross_attn_every=3,
img_encoder=None,
perceiver_num_latents=64,
perceiver_depth=2
perceiver_depth=2,
only_attend_immediate_media=False
):
super().__init__()

Expand All @@ -231,7 +232,7 @@ def __init__(
for ind in range(depth):
self.layers.append(nn.ModuleList([
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads) if not (ind % cross_attn_every) else None
GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media) if not (ind % cross_attn_every) else None
]))

self.to_logits = nn.Sequential(
Expand Down
19 changes: 15 additions & 4 deletions flamingo_pytorch/flamingo_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def __init__(
*,
dim,
dim_head = 64,
heads = 8
heads = 8,
only_attend_immediate_media = False
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -132,6 +133,10 @@ def __init__(
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# whether for text to only attend to immediate preceding image, or all images

self.only_attend_immediate_media = only_attend_immediate_media

def forward(
self,
x,
Expand All @@ -156,7 +161,12 @@ def forward(
if exists(media_locations):
text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time)
media_time = torch.arange(t, device = x.device) + 1
text_to_media_mask = rearrange(text_time, 'b i -> b 1 i 1') >= repeat(media_time, 'j -> 1 1 1 (j m)', m = m)

# text time must equal media time if only attending to most immediate image
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge

text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)

sim = sim - sim.amax(dim = -1, keepdim = True).detach()
Expand All @@ -173,10 +183,11 @@ def __init__(
dim,
dim_head = 64,
heads = 8,
ff_mult = 4
ff_mult = 4,
only_attend_immediate_media = False
):
super().__init__()
self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads)
self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
self.attn_gate = nn.Parameter(torch.tensor([0.]))

self.ff = FeedForward(dim, mult = ff_mult)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/flamingo-pytorch',
long_description_content_type = 'text/markdown',
keywords = [
'artificial intelligence',
'deep learning',
Expand Down

0 comments on commit 2f3606a

Please sign in to comment.