Skip to content

Commit

Permalink
add support for video
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 20, 2022
1 parent aaa85ff commit 749f824
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
42 changes: 31 additions & 11 deletions flamingo_pytorch/flamingo_palm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat
from torch import einsum, nn

from flamingo_pytorch.flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
Expand Down Expand Up @@ -210,13 +210,16 @@ def __init__(
img_encoder=None,
perceiver_num_latents=64,
perceiver_depth=2,
max_video_frames = None,
only_attend_immediate_media=True
):
super().__init__()

self.token_emb = nn.Embedding(num_tokens, dim)
self.media_token_id = media_token_id # you need to reserve a special token id for media

self.video_frame_pos_emb = nn.Parameter(torch.randn(max_video_frames, dim)) if exists(max_video_frames) else None

self.img_encoder = img_encoder
freeze_model_and_make_eval_(self.img_encoder)

Expand Down Expand Up @@ -247,12 +250,14 @@ def __init__(
def forward(
self,
text,
*,
images=None,
image_embeds=None
videos=None,
embeds=None
):
batch, device = text.shape[0], text.device

flamingo_mode = exists(images) or exists(image_embeds)
flamingo_mode = any([exists(t) for t in (images, videos, embeds)])

# automatically take care of freezing or unfreezing depending on what is passed in

Expand All @@ -271,9 +276,9 @@ def forward(

text_tokens = self.token_emb(text)

assert not (exists(images) and exists(image_embeds))
assert not (exists(embeds) and (exists(images) or exists(video)))

# encode images into embeddings
# encode videos or images into embeddings
# with the img_encoder passed in at init
# it can also accept precomputed image embeddings

Expand All @@ -282,12 +287,27 @@ def forward(
images = rearrange(images, 'b t ... -> (b t) ...')

with torch.no_grad():
image_embeds = self.img_encoder(images)
embeds = self.img_encoder(images)

embeds = rearrange(embeds, '(b t) ... -> b t ...', b = batch)

if exists(videos):
assert exists(self.img_encoder), 'img_encoder must be passed in for automatic video encoding'
batch, media, num_times, *_ = videos.shape
videos = rearrange(videos, '... c h w -> (...) c h w')

with torch.no_grad():
embeds = self.img_encoder(videos)

embeds = rearrange(embeds, '(b m t) ... -> b m t ...', b = batch, m = media, t = num_times)

video_time_pos_emb = repeat(self.video_frame_pos_emb[:num_times], 't d -> b m t n d', b = batch, m = media, n = embeds.shape[-2])
embeds = embeds + video_time_pos_emb
embeds = rearrange(embeds, 'b m t n d -> b m (t n) d')

image_embeds = rearrange(image_embeds, '(b t) ... -> b t ...', b = batch)
if exists(embeds):
embeds = self.perceiver_resampler(embeds)

if exists(image_embeds):
image_embeds = self.perceiver_resampler(image_embeds)

# go through layers

Expand All @@ -296,10 +316,10 @@ def forward(

# if image embeds exist and flamingo cross attention set for the layer
# do the cross attention
if exists(flamingo_cross_attn) and exists(image_embeds):
if exists(flamingo_cross_attn) and exists(embeds):
text_tokens = flamingo_cross_attn(
text_tokens,
image_embeds,
embeds,
media_locations = media_locations
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flamingo-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Flamingo - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 749f824

Please sign in to comment.