diff --git a/nuwa_pytorch/nuwa_pytorch.py b/nuwa_pytorch/nuwa_pytorch.py index c0feb5b..cba6a93 100644 --- a/nuwa_pytorch/nuwa_pytorch.py +++ b/nuwa_pytorch/nuwa_pytorch.py @@ -6,7 +6,6 @@ from einops.layers.torch import Rearrange, Reduce from vector_quantize_pytorch import VectorQuantize as VQ -from axial_positional_embedding import AxialPositionalEmbedding import torchvision @@ -489,6 +488,30 @@ def forward( return self.norm(x) +# positional embedding + +class AxialPositionalEmbedding(nn.Module): + def __init__( + self, + dim, + *, + shape + ): + super().__init__() + self.dim = dim + frames, height, width = shape + + self.pos_frames = nn.Parameter(torch.randn(frames, dim)) + self.pos_height = nn.Parameter(torch.randn(height, dim)) + self.pos_width = nn.Parameter(torch.randn(width, dim)) + + def forward(self): + pos_frames = rearrange(self.pos_frames, 'f d -> f 1 1 d') + pos_height = rearrange(self.pos_height, 'h d -> 1 h 1 d') + pos_width = rearrange(self.pos_width, 'w d -> 1 1 w d') + positions = pos_frames + pos_height + pos_width + return rearrange(positions, 'f h w d -> 1 (f h w) d') + # sampling helpers def top_k(logits, thres = 0.5): @@ -548,10 +571,7 @@ def __init__( self.max_video_frames = max_video_frames video_shape = (max_video_frames, fmap_size, fmap_size) - self.video_pos_emb = AxialPositionalEmbedding( - dim = dim, - axial_shape = video_shape - ) + self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape) self.video_transformer = Transformer( dim = dim, @@ -600,9 +620,11 @@ def generate( video_indices = torch.empty((batch, 0), device = device, dtype = torch.long) total_video_tokens = self.video_fmap_size * self.video_fmap_size * self.max_video_frames - for _ in range(total_video_tokens): + pos_emb = self.video_pos_emb() + + for ind in range(total_video_tokens): frame_embeddings = self.image_embedding(video_indices) - frame_embeddings = self.video_pos_emb(frame_embeddings) + frame_embeddings + frame_embeddings = pos_emb[:, :ind] + frame_embeddings frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1) frame_embeddings = self.video_transformer( @@ -643,7 +665,7 @@ def forward( frame_indices_input = frame_indices[:, :-1] if return_loss else frame_indices frame_embeddings = self.image_embedding(frame_indices_input) - frame_embeddings = self.video_pos_emb(frame_embeddings) + frame_embeddings + frame_embeddings = self.video_pos_emb()[:, :-1] + frame_embeddings bos = repeat(self.video_bos, 'd -> b 1 d', b = batch) frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1) diff --git a/setup.py b/setup.py index 280a67a..6fa34b7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'nuwa-pytorch', packages = find_packages(exclude=[]), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'NÜWA - Pytorch', author = 'Phil Wang', @@ -15,7 +15,6 @@ 'transformers' ], install_requires=[ - 'axial_positional_embedding', 'einops>=0.3', 'torch>=1.6', 'torchvision',