diff --git a/example.py b/example.py index efc81ff..866d9f6 100644 --- a/example.py +++ b/example.py @@ -1,4 +1,4 @@ -import torch +import torch from mmca.main import MultiModalCausalAttention @@ -7,10 +7,10 @@ x = torch.randn(1, 10, 512) y = torch.randn(1, 20, 512) -#create a mask for the text +# create a mask for the text # mask = torch.ones(1, 20).bool() x, y = attn(x, y) print(x) -# print(y) \ No newline at end of file +# print(y) diff --git a/mmca/__Init__.py b/mmca/__Init__.py index b75b9ee..b4d08bf 100644 --- a/mmca/__Init__.py +++ b/mmca/__Init__.py @@ -1 +1 @@ -from mmca.main import MultiModalCausalAttention, SimpleMMCA \ No newline at end of file +from mmca.main import MultiModalCausalAttention, SimpleMMCA diff --git a/mmca/main.py b/mmca/main.py index 3b9f1e4..1be2e00 100644 --- a/mmca/main.py +++ b/mmca/main.py @@ -3,6 +3,7 @@ from einops import rearrange from torch import nn + class MultiModalCausalAttention(nn.Module): def __init__( self, @@ -12,57 +13,48 @@ def __init__( ): super().__init__() self.heads = heads - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.to_qkv = nn.Linear(dim, dim * 3, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(dim, dim), - nn.Dropout(dropout) - ) - def forward( - self, - visual_features, - textual_features, - mask=None - ): + self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout)) + + def forward(self, visual_features, textual_features, mask=None): b, n, _, h = *visual_features.shape, self.heads qkv_visual = self.to_qkv(visual_features).chunk(3, dim=-1) qkv_textual = self.to_qkv(textual_features).chunk(3, dim=-1) q_visual, k_visual, v_visual = map( - lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), - qkv_visual + lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv_visual ) - + q_textual, k_textual, v_textual = map( - lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), - qkv_textual + lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv_textual ) - dots_visual = torch.einsum( - "bhid,bhjd->bhij", - q_visual, - k_visual - ) * self.scale + dots_visual = torch.einsum("bhid,bhjd->bhij", q_visual, k_visual) * self.scale - dots_textual = torch.einsum( - "bhid,bhjd->bhij", - q_textual, - k_textual, - ) * self.scale + dots_textual = ( + torch.einsum( + "bhid,bhjd->bhij", + q_textual, + k_textual, + ) + * self.scale + ) if mask is not None: mask = F.pad(mask.flatten(1), (1, 0), value=True) - assert mask.shape[-1] == dots_textual.shape[-1], "mask has incorrect dimensions" + assert ( + mask.shape[-1] == dots_textual.shape[-1] + ), "mask has incorrect dimensions" mask = mask[:, None, :] * mask[:, :, None] dots_textual.masked_fill(~mask, float("-inf")) del mask - + attn_visual = dots_visual.softmax(dim=-1) attn_textual = dots_textual.softmax(dim=-1) @@ -71,25 +63,19 @@ def forward( attn_visual, v_visual, ) - + out_textual = torch.einsum( "bhij,bhjd->bhid", attn_textual, v_textual, ) - out_visual = rearrange( - out_visual, - "b h n d -> b n (h d)" - ) + out_visual = rearrange(out_visual, "b h n d -> b n (h d)") - out_textual = rearrange( - out_textual, - "b h n d -> b n (h d)" - ) + out_textual = rearrange(out_textual, "b h n d -> b n (h d)") return self.to_out(out_visual), self.to_out(out_textual) - + class SimpleMMCA(nn.Module): def __init__( @@ -98,22 +84,16 @@ def __init__( heads, ): super().__init__() - - self.self_attn = nn.MultiheadAttention( - embed_dim=dim, - num_heads=heads - ) - - self.cross_attn = nn.MultiheadAttention( - embed_dim=dim, - num_heads=heads - ) - + + self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads) + + self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads) + def forward(self, v, t): - #self attention for visual tokens + # self attention for visual tokens v = self.self_attn(v, v, v)[0] - #cross attention for textual tokens + # cross attention for textual tokens t = self.cross_attn(t, t, t)[0] + self.cross_attn(t, v, v)[0] return t diff --git a/simple_example.py b/simple_example.py index 5b88c0f..2b39271 100644 --- a/simple_example.py +++ b/simple_example.py @@ -1,4 +1,4 @@ -import torch +import torch from mmca.main import SimpleMMCA # Define the dimensions @@ -7,14 +7,14 @@ seq_len = 10 batch_size = 32 -#attn +# attn attn = SimpleMMCA(dim=dim, heads=head) -#random tokens +# random tokens v = torch.randn(batch_size, seq_len, dim) t = torch.randn(batch_size, seq_len, dim) -#pass the tokens throught attn +# pass the tokens throught attn tokens = attn(v, t) -print(tokens) \ No newline at end of file +print(tokens)