From e4196a4deb17f3fa783fb1df1819112aebf121c5 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 26 Sep 2023 22:06:31 -0400 Subject: [PATCH] zeta mmca --- mmca/main.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/mmca/main.py b/mmca/main.py index 7e9f923..f5ef4fe 100644 --- a/mmca/main.py +++ b/mmca/main.py @@ -11,7 +11,6 @@ def __init__( dim, heads=8, dropout=0.1, - ): super().__init__() self.heads = heads @@ -118,3 +117,36 @@ def forward(self, v, t): t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0] return t + + +from zeta.nn import FlashAttention + +class ZetaMMCA(nn.Module): + def __init__( + self, + flash=True, + causal=True, + dropout=0.1, + ): + super().__init__() + + self.self_attn = FlashAttention( + flash=flash, + causal=causal, + dropout=dropout, + ) + + self.cross_attn = FlashAttention( + flash=flash, + causal=causal, + dropout=dropout, + ) + + def forward(self, v, t): + #self attention for visual tokens + v = self.self_attn(v, v, v)[0] + + #cross attention for textual tokens + t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0] + + return t \ No newline at end of file