Skip to content

Commit

Permalink
[reland] SDPA decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Dec 22, 2024
1 parent 25075e2 commit d996bf4
Show file tree
Hide file tree
Showing 9 changed files with 533 additions and 731 deletions.
34 changes: 2 additions & 32 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,40 +2730,10 @@ def aten_ops_max_pool(
)


def attention_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
torch.ops.aten._reshape_copy.default, supports_dynamic_shapes=True
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
activation,
addmm,
arange,
attention,
cast,
cat,
condition,
Expand Down
165 changes: 0 additions & 165 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py

This file was deleted.

133 changes: 132 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -435,6 +435,137 @@ def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])


@register_torch_trt_decomposition(aten.view.default, registry=TORCH_TRT_DECOMPOSITIONS)
def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tensor:
return aten._reshape_copy.default(x, size)


@register_torch_trt_decomposition(
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value


@register_torch_trt_decomposition(
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_flash_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, None, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_efficient_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_cudnn_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand All @@ -23,8 +21,6 @@
repair_input_as_output,
fuse_prims_broadcast,
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
]
Expand Down
Loading

0 comments on commit d996bf4

Please sign in to comment.