Skip to content

Commit

Permalink
feat: support conv dynamo converter
Browse files Browse the repository at this point in the history
fix a squeeze bug

minor fix and issue pytorch#2185

add conv validator
  • Loading branch information
zewenli98 committed Aug 29, 2023
1 parent 0e5a497 commit 0644709
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 8 deletions.
30 changes: 30 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,3 +843,33 @@ def aten_ops_isinf(
name,
args[0],
)


def conv_param_validator(conv_node: Node) -> bool:
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
)
def aten_ops_convolution(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.conv.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
stride=args[3],
padding=args[4],
dilation=args[5],
groups=args[8],
)
26 changes: 25 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
import re
from typing import List, Optional
from typing import Any, List, Optional, Tuple

import tensorrt as trt
import torch
Expand Down Expand Up @@ -164,3 +164,27 @@ def broadcastable(
get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)


def extend_attr_to_tuple(
val: Any,
num_elem: int,
) -> Tuple[Any, ...]:
"""
If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by
replicating `val` `num_elem` times.
Args:
val (Any): Value that we want to process.
Returns:
A tuple.
"""
if not isinstance(val, (tuple, list)):
val = (val,) * num_elem
elif len(val) == 1:
val = (val[0],) * num_elem

if isinstance(val, list):
val = tuple(val)
return val
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
activation,
cast,
condition,
conv,
elementwise,
embedding,
matmul,
Expand Down
129 changes: 129 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Optional, Sequence, Union

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import aten_ops_converters
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
get_dyn_range,
get_trt_tensor,
has_dynamic_shape,
mark_as_int8_layer,
set_layer_name,
to_numpy,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def convNd(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
is_conv1d: bool,
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor],
bias: Optional[Union[TRTTensor, torch.Tensor]],
stride: Optional[Union[int, Sequence[int]]],
padding: Optional[Union[int, Sequence[int]]],
dilation: Optional[Union[int, Sequence[int]]],
groups: Optional[int],
scale: Optional[Union[torch.Tensor, float]] = None,
zero_point: Optional[Union[torch.Tensor, float]] = None,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."

if is_conv1d:
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
input = aten_ops_converters.aten_ops_unsqueeze(
network, target, (input, -1), {}, name + "_unsqueeze"
)

# Process bias terms
if isinstance(bias, (torch.Tensor, np.ndarray)):
# Transform the bias constant into a Numpy array
bias = to_numpy(bias)

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")

elif bias is not None:
raise RuntimeError(
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
)

# Process weight terms
if network.has_explicit_precision or isinstance(weight, TRTTensor):
weight = get_trt_tensor(network, weight, f"{name}_weight")
# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
weight = aten_ops_converters.aten_ops_unsqueeze(
network, target, (weight, -1), {}, name + "_unsqueeze_weight"
)

elif isinstance(weight, (torch.Tensor, np.ndarray)):
# Transform the weight constant into a Numpy array
weight = to_numpy(weight)

# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
weight = np.expand_dims(weight, -1)

else:
raise RuntimeError(
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
)

# add conv layer
conv_layer = network.add_convolution_nd(
input=input,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)

# If the weight is a TRTTensor, set it as an input of the layer
if isinstance(weight, TRTTensor):
conv_layer.set_input(1, weight)

# If the bias is a TRTTensor, set it as an input of the layer
if isinstance(bias, TRTTensor):
conv_layer.set_input(2, bias)

# Expand parameters manually for Conv1D computations
if is_conv1d:
padding = tuple(padding) + (0,)
stride = extend_attr_to_tuple(stride, 2)
dilation = extend_attr_to_tuple(dilation, 2)

set_layer_name(conv_layer, target, name, source_ir)

# Set relevant attributes of convolution layer
conv_layer.padding_nd = padding
conv_layer.stride_nd = stride
conv_layer.dilation_nd = dilation

if groups is not None:
conv_layer.num_groups = groups

# Handle quantization cases
if scale is not None and zero_point is not None:
# Assume the dtype of activation is torch.quint8
mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8))

result = conv_layer.get_output(0)

if is_conv1d:
# Apply a squeeze operation to transform the conv2d problem back into conv1d
result = aten_ops_converters.aten_ops_squeeze(
network, target, (result, -1), {}, name + "_squeeze"
)

return result
11 changes: 4 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ def squeeze(
input: TRTTensor,
dim: Optional[Any] = None,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"squeeze received input {input} that is not part "
"of the TensorRT region!"
)
dims = []
if dim is not None:
if isinstance(dim, int):
Expand All @@ -35,6 +30,7 @@ def squeeze(
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."

new_dims = []
for dim in dims:
dim = get_positive_dim(
dim,
Expand All @@ -48,13 +44,14 @@ def squeeze(
assert (
len(get_dynamic_dims(input.shape)) <= 1
), "Currently more than one dynamic dim for input to squeeze is not supported."
new_dims.append(dim)

output_shape = []
for i, s in enumerate(input.shape):
if (i in dims) and s == 1:
if (i in new_dims) and s == 1:
continue
output_shape.append(s)
layer = network.add_shuffle(input)
layer.reshape_dims = tuple(output_shape)
set_layer_name(layer, target, name)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

0 comments on commit 0644709

Please sign in to comment.