Skip to content

Commit

Permalink
Fix transformer-engine attention import
Browse files Browse the repository at this point in the history
  • Loading branch information
jennifgcrl committed Nov 8, 2024
1 parent be80580 commit 9f2530b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def _te_flash_attention(
precision: PrecisionLike = None,
block_size: Optional[int] = None,
):
from transformer_engine.jax.fused_attn import fused_attn # noqa: F401
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType # noqa: F401
from transformer_engine.jax.attention import fused_attn # noqa: F401
from transformer_engine.jax.attention import AttnBiasType, AttnMaskType # noqa: F401

attention_dtype = attention_dtype or query.dtype
query = query.astype(attention_dtype)
Expand Down Expand Up @@ -402,7 +402,7 @@ def _te_flash_attention(


def _te_materialize_mask(KPos, QPos, batch_size, mask):
from transformer_engine.jax.fused_attn import AttnMaskType
from transformer_engine.jax.attention import AttnMaskType

if isinstance(mask, NamedArray):
raise NotImplementedError(
Expand Down

0 comments on commit 9f2530b

Please sign in to comment.