diff --git a/mmca/main.py b/mmca/main.py index 3d28e45..b596cf5 100644 --- a/mmca/main.py +++ b/mmca/main.py @@ -88,4 +88,31 @@ def forward( "b h n d -> b n (h d)" ) - return self.to_out(out_visual), self.to_out(out_textual) \ No newline at end of file + return self.to_out(out_visual), self.to_out(out_textual) + + +class SimpleMMCA(nn.Module): + def __init__( + self, + dim, + heads, + ): + super().__init__() + + self.self_attn = nn.MultiheadAttention( + embed_dim=dim, + num_heads=heads + ) + self.cross_attn = nn.MultiheadAttention( + embed_dim=dim, + num_heads=heads + ) + + def forward(self, v, t): + #self attention for visual tokens + v = self.self_attn(v, v, v)[0] + + #cross attention for textual tokens + t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0] + + return t diff --git a/simple_example.py b/simple_example.py new file mode 100644 index 0000000..ba52ef7 --- /dev/null +++ b/simple_example.py @@ -0,0 +1,21 @@ +import torch +from mmca.main import SimpleMMCA + + +# Define the dimensions +dim = 512 +head = 8 +seq_len = 10 +batch_size = 32 + +#attn +attn = SimpleMMCA(dim=dim, heads=head) + +#random tokens +v = torch.randn(batch_size, seq_len, dim) +t = torch.randn(batch_size, seq_len, dim) + +#pass the tokens throught attn +tokens = attn(v, t) + +print(tokens) \ No newline at end of file