Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support conv dynamo converter #1

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,3 +843,33 @@ def aten_ops_isinf(
name,
args[0],
)


def conv_param_validator(conv_node: Node) -> bool:
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
)
def aten_ops_convolution(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.conv.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
stride=args[3],
padding=args[4],
dilation=args[5],
groups=args[8],
)
26 changes: 25 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
import re
from typing import List, Optional
from typing import Any, List, Optional, Tuple

import tensorrt as trt
import torch
Expand Down Expand Up @@ -164,3 +164,27 @@ def broadcastable(
get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)


def extend_attr_to_tuple(
val: Any,
num_elem: int,
) -> Tuple[Any, ...]:
"""
If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by
replicating `val` `num_elem` times.

Args:
val (Any): Value that we want to process.

Returns:
A tuple.
"""
if not isinstance(val, (tuple, list)):
val = (val,) * num_elem
elif len(val) == 1:
val = (val[0],) * num_elem

if isinstance(val, list):
val = tuple(val)
return val
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
activation,
cast,
condition,
conv,
elementwise,
embedding,
matmul,
Expand Down
129 changes: 129 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Optional, Sequence, Union

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
get_dyn_range,
get_trt_tensor,
has_dynamic_shape,
mark_as_int8_layer,
set_layer_name,
to_numpy,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def convNd(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
is_conv1d: bool,
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor],
bias: Optional[Union[TRTTensor, torch.Tensor]],
stride: Optional[Union[int, Sequence[int]]],
padding: Optional[Union[int, Sequence[int]]],
dilation: Optional[Union[int, Sequence[int]]],
groups: Optional[int],
scale: Optional[Union[torch.Tensor, float]] = None,
zero_point: Optional[Union[torch.Tensor, float]] = None,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."

if is_conv1d:
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
input = impl.unsqueeze.unsqueeze(
network, target, source_ir, name + "_unsqueeze_conv1d", input, -1
)

# Process bias terms
if isinstance(bias, (torch.Tensor, np.ndarray)):
# Transform the bias constant into a Numpy array
bias = to_numpy(bias)

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")

elif bias is not None:
raise RuntimeError(
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
)

# Process weight terms
if network.has_explicit_precision or isinstance(weight, TRTTensor):
weight = get_trt_tensor(network, weight, f"{name}_weight")
# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
input = impl.unsqueeze.unsqueeze(
network, target, source_ir, name + "_unsqueeze_weight", weight, -1
)

elif isinstance(weight, (torch.Tensor, np.ndarray)):
# Transform the weight constant into a Numpy array
weight = to_numpy(weight)

# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
weight = np.expand_dims(weight, -1)

else:
raise RuntimeError(
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
)

# add conv layer
conv_layer = network.add_convolution_nd(
input=input,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)

# If the weight is a TRTTensor, set it as an input of the layer
if isinstance(weight, TRTTensor):
conv_layer.set_input(1, weight)

# If the bias is a TRTTensor, set it as an input of the layer
if isinstance(bias, TRTTensor):
conv_layer.set_input(2, bias)

# Expand parameters manually for Conv1D computations
if is_conv1d:
padding = tuple(padding) + (0,)
stride = extend_attr_to_tuple(stride, 2)
dilation = extend_attr_to_tuple(dilation, 2)

set_layer_name(conv_layer, target, name, source_ir)

# Set relevant attributes of convolution layer
conv_layer.padding_nd = padding
conv_layer.stride_nd = stride
conv_layer.dilation_nd = dilation

if groups is not None:
conv_layer.num_groups = groups

# Handle quantization cases
if scale is not None and zero_point is not None:
# Assume the dtype of activation is torch.quint8
mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8))

result = conv_layer.get_output(0)

if is_conv1d:
# Apply a squeeze operation to transform the conv2d problem back into conv1d
result = impl.squeeze.squeeze(
network, target, source_ir, name + "_squeeze_conv1d", result, -1
)

return result
11 changes: 4 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ def squeeze(
input: TRTTensor,
dim: Optional[Any] = None,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"squeeze received input {input} that is not part "
"of the TensorRT region!"
)
dims = []
if dim is not None:
if isinstance(dim, int):
Expand All @@ -35,6 +30,7 @@ def squeeze(
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."

new_dims = []
for dim in dims:
dim = get_positive_dim(
dim,
Expand All @@ -48,13 +44,14 @@ def squeeze(
assert (
len(get_dynamic_dims(input.shape)) <= 1
), "Currently more than one dynamic dim for input to squeeze is not supported."
new_dims.append(dim)

output_shape = []
for i, s in enumerate(input.shape):
if (i in dims) and s == 1:
if (i in new_dims) and s == 1:
continue
output_shape.append(s)
layer = network.add_shuffle(input)
layer.reshape_dims = tuple(output_shape)
set_layer_name(layer, target, name)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
Loading