Skip to content

Commit

Permalink
do axial positional embedding manually
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 3, 2022
1 parent 9e7586d commit 888f357
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
38 changes: 30 additions & 8 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from einops.layers.torch import Rearrange, Reduce

from vector_quantize_pytorch import VectorQuantize as VQ
from axial_positional_embedding import AxialPositionalEmbedding

import torchvision

Expand Down Expand Up @@ -489,6 +488,30 @@ def forward(

return self.norm(x)

# positional embedding

class AxialPositionalEmbedding(nn.Module):
def __init__(
self,
dim,
*,
shape
):
super().__init__()
self.dim = dim
frames, height, width = shape

self.pos_frames = nn.Parameter(torch.randn(frames, dim))
self.pos_height = nn.Parameter(torch.randn(height, dim))
self.pos_width = nn.Parameter(torch.randn(width, dim))

def forward(self):
pos_frames = rearrange(self.pos_frames, 'f d -> f 1 1 d')
pos_height = rearrange(self.pos_height, 'h d -> 1 h 1 d')
pos_width = rearrange(self.pos_width, 'w d -> 1 1 w d')
positions = pos_frames + pos_height + pos_width
return rearrange(positions, 'f h w d -> 1 (f h w) d')

# sampling helpers

def top_k(logits, thres = 0.5):
Expand Down Expand Up @@ -548,10 +571,7 @@ def __init__(
self.max_video_frames = max_video_frames
video_shape = (max_video_frames, fmap_size, fmap_size)

self.video_pos_emb = AxialPositionalEmbedding(
dim = dim,
axial_shape = video_shape
)
self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)

self.video_transformer = Transformer(
dim = dim,
Expand Down Expand Up @@ -600,9 +620,11 @@ def generate(
video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)
total_video_tokens = self.video_fmap_size * self.video_fmap_size * self.max_video_frames

for _ in range(total_video_tokens):
pos_emb = self.video_pos_emb()

for ind in range(total_video_tokens):
frame_embeddings = self.image_embedding(video_indices)
frame_embeddings = self.video_pos_emb(frame_embeddings) + frame_embeddings
frame_embeddings = pos_emb[:, :ind] + frame_embeddings
frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)

frame_embeddings = self.video_transformer(
Expand Down Expand Up @@ -643,7 +665,7 @@ def forward(
frame_indices_input = frame_indices[:, :-1] if return_loss else frame_indices

frame_embeddings = self.image_embedding(frame_indices_input)
frame_embeddings = self.video_pos_emb(frame_embeddings) + frame_embeddings
frame_embeddings = self.video_pos_emb()[:, :-1] + frame_embeddings

bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand All @@ -15,7 +15,6 @@
'transformers'
],
install_requires=[
'axial_positional_embedding',
'einops>=0.3',
'torch>=1.6',
'torchvision',
Expand Down

0 comments on commit 888f357

Please sign in to comment.