diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 633feee68..e252bb63b 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -310,8 +310,8 @@ def _te_flash_attention( precision: PrecisionLike = None, block_size: Optional[int] = None, ): - from transformer_engine.jax.fused_attn import fused_attn # noqa: F401 - from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType # noqa: F401 + from transformer_engine.jax.attention import fused_attn # noqa: F401 + from transformer_engine.jax.attention import AttnBiasType, AttnMaskType # noqa: F401 attention_dtype = attention_dtype or query.dtype query = query.astype(attention_dtype) @@ -402,7 +402,7 @@ def _te_flash_attention( def _te_materialize_mask(KPos, QPos, batch_size, mask): - from transformer_engine.jax.fused_attn import AttnMaskType + from transformer_engine.jax.attention import AttnMaskType if isinstance(mask, NamedArray): raise NotImplementedError(