Skip to content

Commit

Permalink
Revert excess changes
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Nov 9, 2021
1 parent 14eb932 commit c9f462a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

q = q * self.scale
q *= self.scale

((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

Expand Down Expand Up @@ -252,7 +252,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

q = q * self.scale
q *= self.scale

((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

Expand Down
1 change: 1 addition & 0 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def forward(
return loss

# main DALL-E class

class DALLE(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit c9f462a

Please sign in to comment.