From 8ea9c30ebaeadb0a61ddf5e001753865a0bee543 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 2 Jul 2023 17:48:03 -0700 Subject: [PATCH 1/2] Update unet.py --- cm/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cm/unet.py b/cm/unet.py index e0fbbab..e6b4e53 100644 --- a/cm/unet.py +++ b/cm/unet.py @@ -366,7 +366,7 @@ def forward(self, qkv, attn_mask=None, key_padding_mask=None, need_weights=False qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads ) qkv, _ = self.inner_attn( - qkv, + qkv.contiguous(), key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal, From da90d8b7ab517be70c311bcaffc32e2deec7c745 Mon Sep 17 00:00:00 2001 From: Ren Pang Date: Sun, 2 Jul 2023 17:51:34 -0700 Subject: [PATCH 2/2] remove factory_kwargs --- cm/unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cm/unet.py b/cm/unet.py index e6b4e53..bb4d8ef 100644 --- a/cm/unet.py +++ b/cm/unet.py @@ -344,7 +344,7 @@ def __init__( from flash_attn.flash_attention import FlashAttention assert batch_first - factory_kwargs = {"device": device, "dtype": dtype} + # factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -357,7 +357,7 @@ def __init__( assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" self.inner_attn = FlashAttention( - attention_dropout=attention_dropout, **factory_kwargs + attention_dropout=attention_dropout, # **factory_kwargs ) self.rearrange = rearrange