From 594e72039e60e521c93d90106dae6172dfa23d7e Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 26 Sep 2023 21:37:31 -0400 Subject: [PATCH] clean up --- mmca/main.py | 8 +++++--- simple_example.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mmca/main.py b/mmca/main.py index b596cf5..7e9f923 100644 --- a/mmca/main.py +++ b/mmca/main.py @@ -1,7 +1,8 @@ -import torch -from torch import nn -from einops import rearrange +import torch import torch.nn.functional as F +from einops import rearrange +from torch import nn + # from zeta.nn import FlashAttention class MultiModalCausalAttention(nn.Module): @@ -103,6 +104,7 @@ def __init__( embed_dim=dim, num_heads=heads ) + self.cross_attn = nn.MultiheadAttention( embed_dim=dim, num_heads=heads diff --git a/simple_example.py b/simple_example.py index ba52ef7..5b88c0f 100644 --- a/simple_example.py +++ b/simple_example.py @@ -1,7 +1,6 @@ import torch from mmca.main import SimpleMMCA - # Define the dimensions dim = 512 head = 8