From 9d929b006d9383afd5aeb6b367fbe5ea597da6ff Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 5 Feb 2024 21:13:49 -0800 Subject: [PATCH] Conditionally add the sdpa decomposition only for Torch < 2.2.0. (#396) This was causing an exception at startup and blocking any other work on Torch 2.2.0 compat. --- core/shark_turbine/dynamo/utils.py | 101 +++++++++++++++-------------- 1 file changed, 54 insertions(+), 47 deletions(-) diff --git a/core/shark_turbine/dynamo/utils.py b/core/shark_turbine/dynamo/utils.py index 6429c2444..14f3a07a2 100644 --- a/core/shark_turbine/dynamo/utils.py +++ b/core/shark_turbine/dynamo/utils.py @@ -10,54 +10,61 @@ from typing import Dict, List, Tuple, Optional -@register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) -def scaled_dot_product_flash_attention( - query, - key, - value, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: float = None, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: - dtype = query.dtype - batchSize, num_head, qSize, headSize = ( - query.shape[0], - query.shape[1], - query.shape[2], - query.shape[3], - ) +if torch.__version__ < "2.2.0": + # Torch versions prior to 2.2.0 lacked some decompositions, which we + # add manually. + @register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default) + def scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: float = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]: + dtype = query.dtype + batchSize, num_head, qSize, headSize = ( + query.shape[0], + query.shape[1], + query.shape[2], + query.shape[3], + ) - logsumexp = torch.empty([batchSize, qSize, num_head, headSize], dtype=torch.float) - cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - max_q, max_k = 0, 0 - philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( - [], dtype=torch.long - ) - debug_attn_mask = torch.empty( - [], - dtype=query.dtype, - device="cpu", - requires_grad=query.requires_grad, - ) - output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( - query, key, value, None, dropout_p, is_causal, None, scale=scale - ) - output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) - return ( - output.transpose(1, 2), - logsumexp, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - debug_attn_mask, - ) + logsumexp = torch.empty( + [batchSize, qSize, num_head, headSize], dtype=torch.float + ) + cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty( + [], dtype=torch.long + ) + max_q, max_k = 0, 0 + philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty( + [], dtype=torch.long + ) + debug_attn_mask = torch.empty( + [], + dtype=query.dtype, + device="cpu", + requires_grad=query.requires_grad, + ) + output, _ = torch.ops.aten._scaled_dot_product_attention_math.default( + query, key, value, None, dropout_p, is_causal, None, scale=scale + ) + output = output.transpose(1, 2).contiguous( + memory_format=torch.contiguous_format + ) + return ( + output.transpose(1, 2), + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + debug_attn_mask, + ) # manually add decomposition to bypass the error that comes