From 0644709e8abec4a9abe35cd0da60e45b8defaa76 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 21 Aug 2023 16:14:41 -0700 Subject: [PATCH 1/2] feat: support conv dynamo converter fix a squeeze bug minor fix and issue #2185 add conv validator --- .../dynamo/conversion/aten_ops_converters.py | 30 ++++ .../dynamo/conversion/converter_utils.py | 26 +++- .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/conv.py | 129 ++++++++++++++++++ .../dynamo/conversion/impl/squeeze.py | 11 +- 5 files changed, 189 insertions(+), 8 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/conv.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 451d218ee7..6cd44f4855 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -843,3 +843,33 @@ def aten_ops_isinf( name, args[0], ) + + +def conv_param_validator(conv_node: Node) -> bool: + return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0])) + + +@dynamo_tensorrt_converter( + torch.ops.aten.convolution.default, capability_validator=conv_param_validator +) +def aten_ops_convolution( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.conv.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=len(args[3]) == 1, + input=args[0], + weight=args[1], + bias=args[2], + stride=args[3], + padding=args[4], + dilation=args[5], + groups=args[8], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index e33bf09903..6f9d7b6f1d 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,7 +1,7 @@ import functools import logging import re -from typing import List, Optional +from typing import Any, List, Optional, Tuple import tensorrt as trt import torch @@ -164,3 +164,27 @@ def broadcastable( get_axes_for_reduce_op = functools.partial( get_axes_for_reduce_op, has_implicit_batch_dimension=False ) + + +def extend_attr_to_tuple( + val: Any, + num_elem: int, +) -> Tuple[Any, ...]: + """ + If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by + replicating `val` `num_elem` times. + + Args: + val (Any): Value that we want to process. + + Returns: + A tuple. + """ + if not isinstance(val, (tuple, list)): + val = (val,) * num_elem + elif len(val) == 1: + val = (val[0],) * num_elem + + if isinstance(val, list): + val = tuple(val) + return val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 6bd315871c..4ee7fd2bed 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -4,6 +4,7 @@ activation, cast, condition, + conv, elementwise, embedding, matmul, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py new file mode 100644 index 0000000000..285da2a04c --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -0,0 +1,129 @@ +from typing import Optional, Sequence, Union + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion import aten_ops_converters +from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + get_dyn_range, + get_trt_tensor, + has_dynamic_shape, + mark_as_int8_layer, + set_layer_name, + to_numpy, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def convNd( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + is_conv1d: bool, + input: TRTTensor, + weight: Union[TRTTensor, torch.Tensor], + bias: Optional[Union[TRTTensor, torch.Tensor]], + stride: Optional[Union[int, Sequence[int]]], + padding: Optional[Union[int, Sequence[int]]], + dilation: Optional[Union[int, Sequence[int]]], + groups: Optional[int], + scale: Optional[Union[torch.Tensor, float]] = None, + zero_point: Optional[Union[torch.Tensor, float]] = None, +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution." + + if is_conv1d: + # Apply an unsqueeze operation to transform the conv1d problem into conv2d + input = aten_ops_converters.aten_ops_unsqueeze( + network, target, (input, -1), {}, name + "_unsqueeze" + ) + + # Process bias terms + if isinstance(bias, (torch.Tensor, np.ndarray)): + # Transform the bias constant into a Numpy array + bias = to_numpy(bias) + + elif isinstance(bias, TRTTensor): + bias = get_trt_tensor(network, bias, f"{name}_bias") + + elif bias is not None: + raise RuntimeError( + f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor" + ) + + # Process weight terms + if network.has_explicit_precision or isinstance(weight, TRTTensor): + weight = get_trt_tensor(network, weight, f"{name}_weight") + # Append new dimension (unsqueeze) if the convolution is 1d + if is_conv1d: + weight = aten_ops_converters.aten_ops_unsqueeze( + network, target, (weight, -1), {}, name + "_unsqueeze_weight" + ) + + elif isinstance(weight, (torch.Tensor, np.ndarray)): + # Transform the weight constant into a Numpy array + weight = to_numpy(weight) + + # Append new dimension (unsqueeze) if the convolution is 1d + if is_conv1d: + weight = np.expand_dims(weight, -1) + + else: + raise RuntimeError( + f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]" + ) + + # add conv layer + conv_layer = network.add_convolution_nd( + input=input, + num_output_maps=weight.shape[0], + kernel_shape=weight.shape[2:], + kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, + bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, + ) + + # If the weight is a TRTTensor, set it as an input of the layer + if isinstance(weight, TRTTensor): + conv_layer.set_input(1, weight) + + # If the bias is a TRTTensor, set it as an input of the layer + if isinstance(bias, TRTTensor): + conv_layer.set_input(2, bias) + + # Expand parameters manually for Conv1D computations + if is_conv1d: + padding = tuple(padding) + (0,) + stride = extend_attr_to_tuple(stride, 2) + dilation = extend_attr_to_tuple(dilation, 2) + + set_layer_name(conv_layer, target, name, source_ir) + + # Set relevant attributes of convolution layer + conv_layer.padding_nd = padding + conv_layer.stride_nd = stride + conv_layer.dilation_nd = dilation + + if groups is not None: + conv_layer.num_groups = groups + + # Handle quantization cases + if scale is not None and zero_point is not None: + # Assume the dtype of activation is torch.quint8 + mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8)) + + result = conv_layer.get_output(0) + + if is_conv1d: + # Apply a squeeze operation to transform the conv2d problem back into conv1d + result = aten_ops_converters.aten_ops_squeeze( + network, target, (result, -1), {}, name + "_squeeze" + ) + + return result diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index 46e0620590..6d0d1198ce 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -18,11 +18,6 @@ def squeeze( input: TRTTensor, dim: Optional[Any] = None, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"squeeze received input {input} that is not part " - "of the TensorRT region!" - ) dims = [] if dim is not None: if isinstance(dim, int): @@ -35,6 +30,7 @@ def squeeze( # dim, which is a very rare case. For now we just claim not supporting dim=None. assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." + new_dims = [] for dim in dims: dim = get_positive_dim( dim, @@ -48,13 +44,14 @@ def squeeze( assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "Currently more than one dynamic dim for input to squeeze is not supported." + new_dims.append(dim) output_shape = [] for i, s in enumerate(input.shape): - if (i in dims) and s == 1: + if (i in new_dims) and s == 1: continue output_shape.append(s) layer = network.add_shuffle(input) layer.reshape_dims = tuple(output_shape) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) From e3a7f2fe2f82315c4d76ace6b2d707157c7ce791 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 29 Aug 2023 16:10:09 -0700 Subject: [PATCH 2/2] change to invoke the implementations directly --- py/torch_tensorrt/dynamo/conversion/impl/conv.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 285da2a04c..ff7deb0962 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -6,7 +6,7 @@ import tensorrt as trt import torch from torch.fx.node import Target -from torch_tensorrt.dynamo.conversion import aten_ops_converters +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, @@ -41,8 +41,8 @@ def convNd( if is_conv1d: # Apply an unsqueeze operation to transform the conv1d problem into conv2d - input = aten_ops_converters.aten_ops_unsqueeze( - network, target, (input, -1), {}, name + "_unsqueeze" + input = impl.unsqueeze.unsqueeze( + network, target, source_ir, name + "_unsqueeze_conv1d", input, -1 ) # Process bias terms @@ -63,8 +63,8 @@ def convNd( weight = get_trt_tensor(network, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: - weight = aten_ops_converters.aten_ops_unsqueeze( - network, target, (weight, -1), {}, name + "_unsqueeze_weight" + input = impl.unsqueeze.unsqueeze( + network, target, source_ir, name + "_unsqueeze_weight", weight, -1 ) elif isinstance(weight, (torch.Tensor, np.ndarray)): @@ -122,8 +122,8 @@ def convNd( if is_conv1d: # Apply a squeeze operation to transform the conv2d problem back into conv1d - result = aten_ops_converters.aten_ops_squeeze( - network, target, (result, -1), {}, name + "_squeeze" + result = impl.squeeze.squeeze( + network, target, source_ir, name + "_squeeze_conv1d", result, -1 ) return result