Skip to content

Commit

Permalink
Conditionally add the sdpa decomposition only for Torch < 2.2.0. (#396)
Browse files Browse the repository at this point in the history
This was causing an exception at startup and blocking any other work on
Torch 2.2.0 compat.
  • Loading branch information
stellaraccident authored Feb 6, 2024
1 parent 5fe6bb2 commit 9d929b0
Showing 1 changed file with 54 additions and 47 deletions.
101 changes: 54 additions & 47 deletions core/shark_turbine/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,61 @@
from typing import Dict, List, Tuple, Optional


@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)
if torch.__version__ < "2.2.0":
# Torch versions prior to 2.2.0 lacked some decompositions, which we
# add manually.
@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
dtype = query.dtype
batchSize, num_head, qSize, headSize = (
query.shape[0],
query.shape[1],
query.shape[2],
query.shape[3],
)

logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device="cpu",
requires_grad=query.requires_grad,
)
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)
logsumexp = torch.empty(
[batchSize, qSize, num_head, headSize], dtype=torch.float
)
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
max_q, max_k = 0, 0
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
[], dtype=torch.long
)
debug_attn_mask = torch.empty(
[],
dtype=query.dtype,
device="cpu",
requires_grad=query.requires_grad,
)
output, _ = torch.ops.aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
)
output = output.transpose(1, 2).contiguous(
memory_format=torch.contiguous_format
)
return (
output.transpose(1, 2),
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
)


# manually add decomposition to bypass the error that comes
Expand Down

0 comments on commit 9d929b0

Please sign in to comment.