diff --git a/stable_audio_tools/models/transformer.py b/stable_audio_tools/models/transformer.py index b68118a..65965b4 100644 --- a/stable_audio_tools/models/transformer.py +++ b/stable_audio_tools/models/transformer.py @@ -276,7 +276,7 @@ def __init__( dim_context = None, causal = False, zero_init_output=True, - qk_norm = False, + qk_norm: Literal['l2', 'ln', 'none'] = 'none', natten_kernel_size = None ): super().__init__() @@ -302,6 +302,10 @@ def __init__( self.qk_norm = qk_norm + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + # Using 1d neighborhood attention self.natten_kernel_size = natten_kernel_size if natten_kernel_size is not None: @@ -416,9 +420,12 @@ def forward( q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # Normalize q and k for cosine sim attention - if self.qk_norm: + if self.qk_norm == "l2": q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) + elif self.qk_norm == "ln": + q = self.q_norm(q) + k = self.k_norm(k) if rotary_pos_emb is not None and not has_context: freqs, _ = rotary_pos_emb