Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 3, 2022
1 parent dc54d3a commit 9e7586d
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,19 @@ def __init__(

self.to_logits = nn.Linear(dim, num_image_tokens)

def embed_text(self, text, mask = None):
batch, seq_len, device = *text.shape, text.device
assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization'

tokens = self.text_embedding(text)
pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device))
tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d')

return self.text_transformer(
tokens,
mask = mask
)

@torch.no_grad()
@eval_decorator
def generate(
Expand All @@ -580,16 +593,7 @@ def generate(
temperature = 1.
):
batch, seq_len, device = *text.shape, text.device
assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization'

tokens = self.text_embedding(text)
pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device))
tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d')

text_embeds = self.text_transformer(
tokens,
mask = text_mask
)
text_embeds = self.embed_text(text, mask = text_mask)

bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)

Expand Down Expand Up @@ -632,16 +636,7 @@ def forward(
return_loss = False
):
batch, seq_len, device = *text.shape, text.device
assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization'

tokens = self.text_embedding(text)
pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device))
tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d')

text_embeds = self.text_transformer(
tokens,
mask = text_mask
)
text_embeds = self.embed_text(text, mask = text_mask)

frame_indices = self.vae.get_video_indices(video)
frame_indices = rearrange(frame_indices, 'b ... -> b (...)')
Expand Down

0 comments on commit 9e7586d

Please sign in to comment.