diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index ed0f1bb843..72846c9007 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Optional +from typing import Any, List, Optional, Tuple import tensorrt as trt import torch @@ -157,3 +157,24 @@ def broadcastable( if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1): return False return True + + +def extend_attr_to_tuple( + val: Any, + num_elem: int, +) -> Tuple[Any, ...]: + """ + If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by + replicating `val` `num_elem` times. + + Args: + val (Any): Value that we want to process. + + Returns: + A tuple. + """ + if not isinstance(val, (tuple, list)): + val = (val,) * num_elem + if isinstance(val, list): + val = tuple(val) + return val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 93ed5b1c99..78a1276882 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -7,9 +7,9 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import aten_ops_converters +from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, - extend_attr_to_tuple, get_dyn_range, get_trt_tensor, has_dynamic_shape, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index 46e0620590..6d0d1198ce 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -18,11 +18,6 @@ def squeeze( input: TRTTensor, dim: Optional[Any] = None, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"squeeze received input {input} that is not part " - "of the TensorRT region!" - ) dims = [] if dim is not None: if isinstance(dim, int): @@ -35,6 +30,7 @@ def squeeze( # dim, which is a very rare case. For now we just claim not supporting dim=None. assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." + new_dims = [] for dim in dims: dim = get_positive_dim( dim, @@ -48,13 +44,14 @@ def squeeze( assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "Currently more than one dynamic dim for input to squeeze is not supported." + new_dims.append(dim) output_shape = [] for i, s in enumerate(input.shape): - if (i in dims) and s == 1: + if (i in new_dims) and s == 1: continue output_shape.append(s) layer = network.add_shuffle(input) layer.reshape_dims = tuple(output_shape) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0)