Skip to content

Commit

Permalink
Fix transformer-engine attention import (#795)
Browse files Browse the repository at this point in the history
Renamed upstream
  • Loading branch information
jennifgcrl authored Nov 8, 2024
1 parent 74b4108 commit e5deb47
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def _te_flash_attention(
precision: PrecisionLike = None,
block_size: Optional[int] = None,
):
from transformer_engine.jax.fused_attn import fused_attn # noqa: F401
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType # noqa: F401
from transformer_engine.jax.attention import fused_attn # noqa: F401
from transformer_engine.jax.attention import AttnBiasType, AttnMaskType, QKVLayout # noqa: F401

attention_dtype = attention_dtype or query.dtype
query = query.astype(attention_dtype)
Expand Down Expand Up @@ -358,14 +358,13 @@ def _te_flash_attention(
raise NotImplementedError("Using bias with flash attention on GPU is not currently implemented.")

attn_output = fused_attn(
q=q_,
k=k_,
v=v_,
qkv=(q_, k_, v_),
bias=fused_attn_bias,
mask=fused_attn_mask,
seed=prng,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
scaling_factor=scaling_factor,
dropout_probability=dropout,
is_training=is_training,
Expand Down Expand Up @@ -402,7 +401,7 @@ def _te_flash_attention(


def _te_materialize_mask(KPos, QPos, batch_size, mask):
from transformer_engine.jax.fused_attn import AttnMaskType
from transformer_engine.jax.attention import AttnMaskType

if isinstance(mask, NamedArray):
raise NotImplementedError(
Expand Down

0 comments on commit e5deb47

Please sign in to comment.