From 9e7586d7ca62e9bd01c340aaaa919b2c70546c7f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 2 Jan 2022 18:25:34 -0800 Subject: [PATCH] cleanup --- nuwa_pytorch/nuwa_pytorch.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/nuwa_pytorch/nuwa_pytorch.py b/nuwa_pytorch/nuwa_pytorch.py index 8fad4e3..c0feb5b 100644 --- a/nuwa_pytorch/nuwa_pytorch.py +++ b/nuwa_pytorch/nuwa_pytorch.py @@ -569,6 +569,19 @@ def __init__( self.to_logits = nn.Linear(dim, num_image_tokens) + def embed_text(self, text, mask = None): + batch, seq_len, device = *text.shape, text.device + assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization' + + tokens = self.text_embedding(text) + pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device)) + tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d') + + return self.text_transformer( + tokens, + mask = mask + ) + @torch.no_grad() @eval_decorator def generate( @@ -580,16 +593,7 @@ def generate( temperature = 1. ): batch, seq_len, device = *text.shape, text.device - assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization' - - tokens = self.text_embedding(text) - pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device)) - tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d') - - text_embeds = self.text_transformer( - tokens, - mask = text_mask - ) + text_embeds = self.embed_text(text, mask = text_mask) bos = repeat(self.video_bos, 'd -> b 1 d', b = batch) @@ -632,16 +636,7 @@ def forward( return_loss = False ): batch, seq_len, device = *text.shape, text.device - assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization' - - tokens = self.text_embedding(text) - pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device)) - tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d') - - text_embeds = self.text_transformer( - tokens, - mask = text_mask - ) + text_embeds = self.embed_text(text, mask = text_mask) frame_indices = self.vae.get_video_indices(video) frame_indices = rearrange(frame_indices, 'b ... -> b (...)')