Skip to content

Commit

Permalink
feat: support many elementwise dynamo converters (#2263)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Sep 8, 2023
1 parent ba2a300 commit 40f8064
Show file tree
Hide file tree
Showing 18 changed files with 1,394 additions and 66 deletions.
373 changes: 315 additions & 58 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_int_int_div_trt_tensor,
cast_trt_tensor,
)
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import dynamo_tensorrt_converter
Expand Down Expand Up @@ -48,58 +42,6 @@ def aten_ops_batch_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
def aten_ops_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"other": args[1],
}
# If both are TRTTensor, both are cast to float32
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
network,
kwargs_new["input"],
kwargs_new["other"],
name,
)
# If one is TRTTensor, it is cast to float32
elif isinstance(args[0], TRTTensor) and (
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
):
kwargs_new["input"] = cast_trt_tensor(
network, kwargs_new["input"], trt.float32, name, target
)
elif isinstance(args[1], TRTTensor) and (
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
):
kwargs_new["other"] = cast_trt_tensor(
network, kwargs_new["other"], trt.float32, name, target
)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
elif rounding_mode == "floor":
return acc_ops_converters.acc_ops_floor_div(
network, target, None, kwargs_new, name
)
elif rounding_mode == "trunc":
return impl.elementwise.trunc_div(
network, target, SourceIR.ATEN, name, args[0], args[1]
)
else:
raise RuntimeError(
f"Target {target} does not support rounding mode {rounding_mode}"
)


def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)
Expand Down Expand Up @@ -1004,6 +946,321 @@ def aten_ops_isinf(
)


@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
def aten_ops_add(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
other = args[1]
alpha = kwargs.get("alpha", 1)

if alpha != 1:
other = impl.elementwise.mul(
network,
target,
SourceIR.ATEN,
name,
other,
alpha,
)

return impl.elementwise.add(
network,
target,
SourceIR.ATEN,
name,
args[0],
other,
)


@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
def aten_ops_mul(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.mul(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
def aten_ops_max(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.max(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
def aten_ops_min(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.min(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
def aten_ops_sub(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
other = args[1]
alpha = kwargs.get("alpha", 1)

if alpha != 1:
other = impl.elementwise.mul(
network,
target,
SourceIR.ATEN,
name,
other,
alpha,
)

return impl.elementwise.sub(
network,
target,
SourceIR.ATEN,
name,
args[0],
other,
)


@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
def aten_ops_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
rounding_mode = kwargs.get("rounding_mode")

if rounding_mode is None:
return impl.elementwise.div(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
elif rounding_mode == "floor":
return impl.elementwise.floor_divide(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
elif rounding_mode == "trunc":
return impl.elementwise.trunc_div(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
else:
raise RuntimeError(
f"Target {target} does not support rounding mode {rounding_mode}"
)


@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
def aten_ops_pow(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.pow(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
def aten_ops_floor_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.floor_divide(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
def aten_ops_logical_and(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.logical_and(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
def aten_ops_logical_or(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.logical_or(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
def aten_ops_logical_xor(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.logical_xor(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
def aten_ops_equal(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.eq(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
def aten_ops_greater(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.gt(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
def aten_ops_less(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.lt(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


def conv_param_validator(conv_node: Node) -> bool:
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))

Expand Down
Loading

0 comments on commit 40f8064

Please sign in to comment.