Skip to content

Commit

Permalink
fix a squeeze bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 23, 2023
1 parent b2ac5f0 commit 9d52c6a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
23 changes: 22 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import List, Optional
from typing import Any, List, Optional, Tuple

import tensorrt as trt
import torch
Expand Down Expand Up @@ -157,3 +157,24 @@ def broadcastable(
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
return False
return True


def extend_attr_to_tuple(
val: Any,
num_elem: int,
) -> Tuple[Any, ...]:
"""
If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by
replicating `val` `num_elem` times.
Args:
val (Any): Value that we want to process.
Returns:
A tuple.
"""
if not isinstance(val, (tuple, list)):
val = (val,) * num_elem
if isinstance(val, list):
val = tuple(val)
return val
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import aten_ops_converters
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
extend_attr_to_tuple,
get_dyn_range,
get_trt_tensor,
has_dynamic_shape,
Expand Down
11 changes: 4 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ def squeeze(
input: TRTTensor,
dim: Optional[Any] = None,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"squeeze received input {input} that is not part "
"of the TensorRT region!"
)
dims = []
if dim is not None:
if isinstance(dim, int):
Expand All @@ -35,6 +30,7 @@ def squeeze(
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."

new_dims = []
for dim in dims:
dim = get_positive_dim(
dim,
Expand All @@ -48,13 +44,14 @@ def squeeze(
assert (
len(get_dynamic_dims(input.shape)) <= 1
), "Currently more than one dynamic dim for input to squeeze is not supported."
new_dims.append(dim)

output_shape = []
for i, s in enumerate(input.shape):
if (i in dims) and s == 1:
if (i in new_dims) and s == 1:
continue
output_shape.append(s)
layer = network.add_shuffle(input)
layer.reshape_dims = tuple(output_shape)
set_layer_name(layer, target, name)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

0 comments on commit 9d52c6a

Please sign in to comment.