Skip to content

Commit

Permalink
simple mmca
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Sep 27, 2023
1 parent 5f1d3bd commit ec5f2ad
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
29 changes: 28 additions & 1 deletion mmca/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,31 @@ def forward(
"b h n d -> b n (h d)"
)

return self.to_out(out_visual), self.to_out(out_textual)
return self.to_out(out_visual), self.to_out(out_textual)


class SimpleMMCA(nn.Module):
def __init__(
self,
dim,
heads,
):
super().__init__()

self.self_attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=heads
)
self.cross_attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=heads
)

def forward(self, v, t):
#self attention for visual tokens
v = self.self_attn(v, v, v)[0]

#cross attention for textual tokens
t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0]

return t
21 changes: 21 additions & 0 deletions simple_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from mmca.main import SimpleMMCA


# Define the dimensions
dim = 512
head = 8
seq_len = 10
batch_size = 32

#attn
attn = SimpleMMCA(dim=dim, heads=head)

#random tokens
v = torch.randn(batch_size, seq_len, dim)
t = torch.randn(batch_size, seq_len, dim)

#pass the tokens throught attn
tokens = attn(v, t)

print(tokens)

0 comments on commit ec5f2ad

Please sign in to comment.