From 40f8064a9a300753320f22e0c527615db02e7f4b Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Fri, 8 Sep 2023 12:46:59 -0700 Subject: [PATCH] feat: support many elementwise dynamo converters (#2263) --- .../dynamo/conversion/aten_ops_converters.py | 373 +++++++++++++++--- .../dynamo/conversion/converter_utils.py | 11 +- .../conversion/impl/elementwise/base.py | 27 +- .../dynamo/conversion/impl/elementwise/ops.py | 242 +++++++++++- tests/py/dynamo/conversion/test_add_aten.py | 85 ++++ tests/py/dynamo/conversion/test_div_aten.py | 87 ++++ tests/py/dynamo/conversion/test_equal_aten.py | 70 ++++ .../dynamo/conversion/test_floor_div_aten.py | 67 ++++ .../py/dynamo/conversion/test_greater_aten.py | 70 ++++ tests/py/dynamo/conversion/test_less_aten.py | 70 ++++ .../conversion/test_logical_and_aten.py | 31 ++ .../dynamo/conversion/test_logical_or_aten.py | 31 ++ .../conversion/test_logical_xor_aten.py | 31 ++ tests/py/dynamo/conversion/test_max_aten.py | 31 ++ tests/py/dynamo/conversion/test_min_aten.py | 31 ++ tests/py/dynamo/conversion/test_mul_aten.py | 51 +++ tests/py/dynamo/conversion/test_pow_aten.py | 67 ++++ tests/py/dynamo/conversion/test_sub_aten.py | 85 ++++ 18 files changed, 1394 insertions(+), 66 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_add_aten.py create mode 100644 tests/py/dynamo/conversion/test_div_aten.py create mode 100644 tests/py/dynamo/conversion/test_equal_aten.py create mode 100644 tests/py/dynamo/conversion/test_floor_div_aten.py create mode 100644 tests/py/dynamo/conversion/test_greater_aten.py create mode 100644 tests/py/dynamo/conversion/test_less_aten.py create mode 100644 tests/py/dynamo/conversion/test_logical_and_aten.py create mode 100644 tests/py/dynamo/conversion/test_logical_or_aten.py create mode 100644 tests/py/dynamo/conversion/test_logical_xor_aten.py create mode 100644 tests/py/dynamo/conversion/test_max_aten.py create mode 100644 tests/py/dynamo/conversion/test_min_aten.py create mode 100644 tests/py/dynamo/conversion/test_mul_aten.py create mode 100644 tests/py/dynamo/conversion/test_pow_aten.py create mode 100644 tests/py/dynamo/conversion/test_sub_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ca650b09f6..792f58955b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,16 +1,10 @@ import logging from typing import Any, Dict, Optional, Sequence, Tuple, Union -import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion.converter_utils import ( - cast_int_int_div_trt_tensor, - cast_trt_tensor, -) -from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from .converter_registry import dynamo_tensorrt_converter @@ -48,58 +42,6 @@ def aten_ops_batch_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc] -def aten_ops_div( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - "other": args[1], - } - # If both are TRTTensor, both are cast to float32 - if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor): - kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor( - network, - kwargs_new["input"], - kwargs_new["other"], - name, - ) - # If one is TRTTensor, it is cast to float32 - elif isinstance(args[0], TRTTensor) and ( - kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 - ): - kwargs_new["input"] = cast_trt_tensor( - network, kwargs_new["input"], trt.float32, name, target - ) - elif isinstance(args[1], TRTTensor) and ( - kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 - ): - kwargs_new["other"] = cast_trt_tensor( - network, kwargs_new["other"], trt.float32, name, target - ) - rounding_mode = kwargs.get("rounding_mode") - if rounding_mode is None: - return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) - elif rounding_mode == "floor": - return acc_ops_converters.acc_ops_floor_div( - network, target, None, kwargs_new, name - ) - elif rounding_mode == "trunc": - return impl.elementwise.trunc_div( - network, target, SourceIR.ATEN, name, args[0], args[1] - ) - else: - raise RuntimeError( - f"Target {target} does not support rounding mode {rounding_mode}" - ) - - def embedding_param_validator(embedding_node: Node) -> bool: scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) sparse = args_bounds_check(embedding_node.args, 4) @@ -1004,6 +946,321 @@ def aten_ops_isinf( ) +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) +def aten_ops_add( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + other = args[1] + alpha = kwargs.get("alpha", 1) + + if alpha != 1: + other = impl.elementwise.mul( + network, + target, + SourceIR.ATEN, + name, + other, + alpha, + ) + + return impl.elementwise.add( + network, + target, + SourceIR.ATEN, + name, + args[0], + other, + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) +def aten_ops_mul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.mul( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) +def aten_ops_max( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.max( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) +def aten_ops_min( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.min( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) +def aten_ops_sub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + other = args[1] + alpha = kwargs.get("alpha", 1) + + if alpha != 1: + other = impl.elementwise.mul( + network, + target, + SourceIR.ATEN, + name, + other, + alpha, + ) + + return impl.elementwise.sub( + network, + target, + SourceIR.ATEN, + name, + args[0], + other, + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) +def aten_ops_div( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + rounding_mode = kwargs.get("rounding_mode") + + if rounding_mode is None: + return impl.elementwise.div( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + elif rounding_mode == "floor": + return impl.elementwise.floor_divide( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + elif rounding_mode == "trunc": + return impl.elementwise.trunc_div( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + else: + raise RuntimeError( + f"Target {target} does not support rounding mode {rounding_mode}" + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) +def aten_ops_pow( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.pow( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) +def aten_ops_floor_div( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.floor_divide( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) +def aten_ops_logical_and( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.logical_and( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) +def aten_ops_logical_or( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.logical_or( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) +def aten_ops_logical_xor( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.logical_xor( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) +def aten_ops_equal( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.eq( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) +def aten_ops_greater( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.gt( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) +def aten_ops_less( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.lt( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + def conv_param_validator(conv_node: Node) -> bool: return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0])) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 6f9d7b6f1d..c5df3f9752 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -124,9 +124,7 @@ def cast_int_int_div_trt_tensor( Returns: A list of lhs_val and rhs_val casted to the approriate datatype """ - if (lhs_val.dtype == trt.int8 or lhs_val.dtype == trt.int32) and ( - rhs_val.dtype == trt.int8 or rhs_val.dtype == trt.int32 - ): + if lhs_val.dtype == trt.int32 and rhs_val.dtype == trt.int32: lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name) rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name) return [lhs_val, rhs_val] @@ -188,3 +186,10 @@ def extend_attr_to_tuple( if isinstance(val, list): val = tuple(val) return val + + +def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor): + if tensor.dtype != trt.bool: + return cast_trt_tensor(network, tensor, trt.bool, name) + + return tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 9ae7859fdc..46380cbec7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -2,6 +2,7 @@ import warnings from typing import Any, Callable, Optional, Union +import numpy as np import tensorrt as trt import torch from torch.fx.node import Target @@ -24,12 +25,30 @@ def get_python_op_from_trt_elementwise_op( return operator.add elif trt_op == trt.ElementWiseOperation.PROD: return operator.mul + elif trt_op == trt.ElementWiseOperation.MAX: + return lambda a, b: max(a, b) + elif trt_op == trt.ElementWiseOperation.MIN: + return lambda a, b: min(a, b) elif trt_op == trt.ElementWiseOperation.SUB: return operator.sub elif trt_op == trt.ElementWiseOperation.DIV: return operator.truediv + elif trt_op == trt.ElementWiseOperation.POW: + return operator.pow elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: return operator.floordiv + elif trt_op == trt.ElementWiseOperation.AND: + return lambda a, b: a and b + elif trt_op == trt.ElementWiseOperation.OR: + return lambda a, b: a or b + elif trt_op == trt.ElementWiseOperation.XOR: + return lambda a, b: (a or b) and not (a and b) + elif trt_op == trt.ElementWiseOperation.EQUAL: + return operator.eq + elif trt_op == trt.ElementWiseOperation.GREATER: + return operator.gt + elif trt_op == trt.ElementWiseOperation.LESS: + return operator.lt else: raise RuntimeError(f"{trt_op} is not supported yet!") @@ -75,10 +94,10 @@ def convert_binary_elementwise( is_rhs_trt_tensor = False if isinstance(lhs_val, TRTTensor): - lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) + lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY) is_lhs_trt_tensor = True if isinstance(rhs_val, TRTTensor): - rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) + rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY) is_rhs_trt_tensor = True if not is_lhs_trt_tensor and not is_rhs_trt_tensor: @@ -103,9 +122,9 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + rhs_val = np.array([rhs_val], dtype=lhs_dtype) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + lhs_val = np.array([lhs_val], dtype=rhs_dtype) # When lhs is scalar, and rhs has shape [1,], then currently the assert # will fail because lhs shape has fewer dimensions than rhs shape. This diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 3470328e44..f5d46efc17 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,9 +1,13 @@ -from typing import Optional +from typing import Optional, Union import numpy as np import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_int_int_div_trt_tensor, + cast_int_or_float_to_bool, +) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) @@ -239,3 +243,239 @@ def _add_layer( input_val = clamp_max_layer.get_output(0) return input_val + + +def add( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.SUM, lhs_val, rhs_val + ) + + +def mul( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.PROD, + lhs_val, + rhs_val, + ) + + +def max( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.MAX, lhs_val, rhs_val + ) + + +def min( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.MIN, lhs_val, rhs_val + ) + + +def sub( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.SUB, lhs_val, rhs_val + ) + + +def div( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor): + lhs_val, rhs_val = cast_int_int_div_trt_tensor(network, lhs_val, rhs_val, name) + + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.DIV, lhs_val, rhs_val + ) + + +def pow( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor): + lhs_val, rhs_val = cast_int_int_div_trt_tensor(network, lhs_val, rhs_val, name) + + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val + ) + + +def floor_divide( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.FLOOR_DIV, + lhs_val, + rhs_val, + ) + + +def logical_and( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + if isinstance(lhs_val, TRTTensor): + lhs_val = cast_int_or_float_to_bool(network, name, lhs_val) + + if isinstance(rhs_val, TRTTensor): + rhs_val = cast_int_or_float_to_bool(network, name, rhs_val) + + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.AND, lhs_val, rhs_val + ) + + +def logical_or( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + if isinstance(lhs_val, TRTTensor): + lhs_val = cast_int_or_float_to_bool(network, name, lhs_val) + + if isinstance(rhs_val, TRTTensor): + rhs_val = cast_int_or_float_to_bool(network, name, rhs_val) + + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.OR, lhs_val, rhs_val + ) + + +def logical_xor( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + if isinstance(lhs_val, TRTTensor): + lhs_val = cast_int_or_float_to_bool(network, name, lhs_val) + + if isinstance(rhs_val, TRTTensor): + rhs_val = cast_int_or_float_to_bool(network, name, rhs_val) + + return convert_binary_elementwise( + network, target, source_ir, name, trt.ElementWiseOperation.XOR, lhs_val, rhs_val + ) + + +def eq( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.EQUAL, + lhs_val, + rhs_val, + ) + + +def gt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.GREATER, + lhs_val, + rhs_val, + ) + + +def lt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float], + rhs_val: Union[TRTTensor, int, float], +) -> TRTTensor: + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.LESS, + lhs_val, + rhs_val, + ) diff --git a/tests/py/dynamo/conversion/test_add_aten.py b/tests/py/dynamo/conversion/test_add_aten.py new file mode 100644 index 0000000000..b9fec820c6 --- /dev/null +++ b/tests/py/dynamo/conversion/test_add_aten.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestAddConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_add_tensor(self, _, shape): + class add(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.add(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + add(), + inputs, + expected_ops={torch.ops.aten.add.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_add_tensor_alpha(self, _, shape, alpha): + class add(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.add(lhs_val, rhs_val, alpha=alpha) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + add(), + inputs, + expected_ops={torch.ops.aten.add.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1.0), + ("3d", (2, 1, 2), 2), + ] + ) + def test_add_scalar(self, _, shape, scalar): + class add(nn.Module): + def forward(self, lhs_val): + return torch.add(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + add(), + inputs, + expected_ops={torch.ops.aten.add.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1.0, 1.0), + ("3d", (2, 1, 2), 2, 2), + ] + ) + def test_add_scalar_alpha(self, _, shape, scalar, alpha): + class add(nn.Module): + def forward(self, lhs_val): + return torch.add(lhs_val, scalar, alpha=alpha) + + inputs = [torch.randn(shape)] + self.run_test( + add(), + inputs, + expected_ops={torch.ops.aten.add.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_div_aten.py b/tests/py/dynamo/conversion/test_div_aten.py new file mode 100644 index 0000000000..49a13ea3a6 --- /dev/null +++ b/tests/py/dynamo/conversion/test_div_aten.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestDivConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_div_tensor(self, _, shape): + class div(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.div(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + div(), + inputs, + expected_ops={torch.ops.aten.div.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), None), + ("3d", (2, 1, 2), "trunc"), + ("3d", (2, 3, 2), "floor"), + ] + ) + def test_div_tensor_rounding_mode(self, _, shape, rounding_mode): + class div(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.div(lhs_val, rhs_val, rounding_mode=rounding_mode) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + div(), + inputs, + expected_ops={torch.ops.aten.div.Tensor_mode}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), -1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_div_tensor(self, _, shape, scalar): + class div(nn.Module): + def forward(self, lhs_val): + return torch.div(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + div(), + inputs, + expected_ops={torch.ops.aten.div.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1.0, None), + ("3d", (2, 1, 2), 2.0, "trunc"), + ("3d", (2, 3, 2), -3.0, "floor"), + ] + ) + def test_div_tensor_rounding_mode(self, _, shape, scalar, rounding_mode): + class div(nn.Module): + def forward(self, lhs_val): + return torch.div(lhs_val, scalar, rounding_mode=rounding_mode) + + inputs = [torch.randn(shape)] + self.run_test( + div(), + inputs, + expected_ops={torch.ops.aten.div.Tensor_mode}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_equal_aten.py b/tests/py/dynamo/conversion/test_equal_aten.py new file mode 100644 index 0000000000..edc2259487 --- /dev/null +++ b/tests/py/dynamo/conversion/test_equal_aten.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestEqualConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_equal_tensor(self, _, shape): + class equal(nn.Module): + def forward(self, lhs_val, rhs_val): + return lhs_val == rhs_val + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + equal(), + inputs, + expected_ops={torch.ops.aten.eq.Tensor}, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_equal_tensor_scalar(self, _, shape, scalar): + class equal(nn.Module): + def forward(self, lhs_val): + return lhs_val == torch.tensor(scalar) + + inputs = [torch.randn(shape)] + self.run_test( + equal(), + inputs, + expected_ops={torch.ops.aten.eq.Tensor}, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_equal_scalar(self, _, shape, scalar): + class equal(nn.Module): + def forward(self, lhs_val): + return lhs_val == scalar + + inputs = [torch.randn(shape)] + self.run_test( + equal(), + inputs, + expected_ops={torch.ops.aten.eq.Scalar}, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_floor_div_aten.py b/tests/py/dynamo/conversion/test_floor_div_aten.py new file mode 100644 index 0000000000..329e8bca8a --- /dev/null +++ b/tests/py/dynamo/conversion/test_floor_div_aten.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestFloorDivConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_floor_div_default(self, _, shape): + class floor_div(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.floor_divide(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + floor_div(), + inputs, + expected_ops={torch.ops.aten.floor_divide.default}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_floor_div_tensor_scalar(self, _, shape, scalar): + class floor_div(nn.Module): + def forward(self, lhs_val): + return torch.floor_divide(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randn(shape)] + self.run_test( + floor_div(), + inputs, + expected_ops={torch.ops.aten.floor_divide.default}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_floor_div_scalar(self, _, shape, scalar): + class floor_div(nn.Module): + def forward(self, lhs_val): + return torch.floor_divide(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + floor_div(), + inputs, + expected_ops={torch.ops.aten.floor_divide.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_greater_aten.py b/tests/py/dynamo/conversion/test_greater_aten.py new file mode 100644 index 0000000000..d677c1583f --- /dev/null +++ b/tests/py/dynamo/conversion/test_greater_aten.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestGreaterConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_greater_tensor(self, _, shape): + class greater(nn.Module): + def forward(self, lhs_val, rhs_val): + return lhs_val > rhs_val + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + greater(), + inputs, + expected_ops={torch.ops.aten.gt.Tensor}, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_greater_tensor_scalar(self, _, shape, scalar): + class greater(nn.Module): + def forward(self, lhs_val): + return lhs_val > torch.tensor(scalar) + + inputs = [torch.randn(shape)] + self.run_test( + greater(), + inputs, + expected_ops={torch.ops.aten.gt.Tensor}, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_greater_scalar(self, _, shape, scalar): + class greater(nn.Module): + def forward(self, lhs_val): + return lhs_val > scalar + + inputs = [torch.randn(shape)] + self.run_test( + greater(), + inputs, + expected_ops={torch.ops.aten.gt.Scalar}, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_less_aten.py b/tests/py/dynamo/conversion/test_less_aten.py new file mode 100644 index 0000000000..35efb38791 --- /dev/null +++ b/tests/py/dynamo/conversion/test_less_aten.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLessConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_less_tensor(self, _, shape): + class less(nn.Module): + def forward(self, lhs_val, rhs_val): + return lhs_val < rhs_val + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + less(), + inputs, + expected_ops={torch.ops.aten.lt.Tensor}, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_less_tensor_scalar(self, _, shape, scalar): + class less(nn.Module): + def forward(self, lhs_val): + return lhs_val < torch.tensor(scalar) + + inputs = [torch.randn(shape)] + self.run_test( + less(), + inputs, + expected_ops={torch.ops.aten.lt.Tensor}, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_less_scalar(self, _, shape, scalar): + class less(nn.Module): + def forward(self, lhs_val): + return lhs_val < scalar + + inputs = [torch.randn(shape)] + self.run_test( + less(), + inputs, + expected_ops={torch.ops.aten.lt.Scalar}, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_logical_and_aten.py b/tests/py/dynamo/conversion/test_logical_and_aten.py new file mode 100644 index 0000000000..b9c1f383ba --- /dev/null +++ b/tests/py/dynamo/conversion/test_logical_and_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLogicalAndConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_logical_and(self, _, shape): + class logical_and(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.logical_and(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + logical_and(), + inputs, + expected_ops={torch.ops.aten.logical_and.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_logical_or_aten.py b/tests/py/dynamo/conversion/test_logical_or_aten.py new file mode 100644 index 0000000000..df8e577932 --- /dev/null +++ b/tests/py/dynamo/conversion/test_logical_or_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLogicalOrConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_logical_or(self, _, shape): + class logical_or(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.logical_or(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + logical_or(), + inputs, + expected_ops={torch.ops.aten.logical_or.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_logical_xor_aten.py b/tests/py/dynamo/conversion/test_logical_xor_aten.py new file mode 100644 index 0000000000..c31a31541d --- /dev/null +++ b/tests/py/dynamo/conversion/test_logical_xor_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLogicalXorConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_logical_xor(self, _, shape): + class logical_xor(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.logical_xor(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + logical_xor(), + inputs, + expected_ops={torch.ops.aten.logical_xor.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_max_aten.py b/tests/py/dynamo/conversion/test_max_aten.py new file mode 100644 index 0000000000..2be1d9c74b --- /dev/null +++ b/tests/py/dynamo/conversion/test_max_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestMaxConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_max(self, _, shape): + class max(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.max(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + max(), + inputs, + expected_ops={torch.ops.aten.maximum.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_min_aten.py b/tests/py/dynamo/conversion/test_min_aten.py new file mode 100644 index 0000000000..35d0d7163f --- /dev/null +++ b/tests/py/dynamo/conversion/test_min_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestMinConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_min(self, _, shape): + class min(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.min(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + min(), + inputs, + expected_ops={torch.ops.aten.minimum.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_mul_aten.py b/tests/py/dynamo/conversion/test_mul_aten.py new file mode 100644 index 0000000000..fecd1e06f4 --- /dev/null +++ b/tests/py/dynamo/conversion/test_mul_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestMulConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_mul_tensor(self, _, shape): + class mul(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.mul(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + mul(), + inputs, + expected_ops={torch.ops.aten.mul.Tensor}, + ) + + @parameterized.expand( + [ + ("2d_int", (2, 1), 1), + ("3d_int", (2, 1, 2), 2), + ("2d_float", (2, 1), 1.0), + ("3d_float", (2, 1, 2), 2.0), + ] + ) + def test_mul_scalar(self, _, shape, scalar): + class mul(nn.Module): + def forward(self, lhs_val): + return torch.mul(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + mul(), + inputs, + expected_ops={torch.ops.aten.mul.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_pow_aten.py b/tests/py/dynamo/conversion/test_pow_aten.py new file mode 100644 index 0000000000..29dd74eb07 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pow_aten.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestPowConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_pow_tensor_tensor(self, _, shape): + class pow(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.pow(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + pow(), + inputs, + expected_ops={torch.ops.aten.pow.Tensor_Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_pow_scalar(self, _, shape, scalar): + class pow(nn.Module): + def forward(self, rhs_val): + return torch.pow(scalar, rhs_val) + + inputs = [torch.randn(shape)] + self.run_test( + pow(), + inputs, + expected_ops={torch.ops.aten.pow.Scalar}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_pow_tensor_scalar(self, _, shape, scalar): + class pow(nn.Module): + def forward(self, lhs_val): + return torch.pow(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + pow(), + inputs, + expected_ops={torch.ops.aten.pow.Tensor_Scalar}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_sub_aten.py b/tests/py/dynamo/conversion/test_sub_aten.py new file mode 100644 index 0000000000..1ad7e340e3 --- /dev/null +++ b/tests/py/dynamo/conversion/test_sub_aten.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_sub_tensor(self, _, shape): + class sub(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.sub(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + sub(), + inputs, + expected_ops={torch.ops.aten.sub.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_sub_tensor_alpha(self, _, shape, alpha): + class sub(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.sub(lhs_val, rhs_val, alpha=alpha) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + sub(), + inputs, + expected_ops={torch.ops.aten.sub.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1.0), + ("3d", (2, 1, 2), 2), + ] + ) + def test_sub_scalar(self, _, shape, scalar): + class sub(nn.Module): + def forward(self, lhs_val): + return torch.sub(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + sub(), + inputs, + expected_ops={torch.ops.aten.sub.Tensor}, + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1.0, 1.0), + ("3d", (2, 1, 2), 2, 2), + ] + ) + def test_sub_scalar_alpha(self, _, shape, scalar, alpha): + class sub(nn.Module): + def forward(self, lhs_val): + return torch.sub(lhs_val, scalar, alpha=alpha) + + inputs = [torch.randn(shape)] + self.run_test( + sub(), + inputs, + expected_ops={torch.ops.aten.sub.Tensor}, + ) + + +if __name__ == "__main__": + run_tests()