Skip to content

Commit

Permalink
change to invoke the implementations directly
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 29, 2023
1 parent 0644709 commit e3a7f2f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import aten_ops_converters
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
Expand Down Expand Up @@ -41,8 +41,8 @@ def convNd(

if is_conv1d:
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
input = aten_ops_converters.aten_ops_unsqueeze(
network, target, (input, -1), {}, name + "_unsqueeze"
input = impl.unsqueeze.unsqueeze(
network, target, source_ir, name + "_unsqueeze_conv1d", input, -1
)

# Process bias terms
Expand All @@ -63,8 +63,8 @@ def convNd(
weight = get_trt_tensor(network, weight, f"{name}_weight")
# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
weight = aten_ops_converters.aten_ops_unsqueeze(
network, target, (weight, -1), {}, name + "_unsqueeze_weight"
input = impl.unsqueeze.unsqueeze(
network, target, source_ir, name + "_unsqueeze_weight", weight, -1
)

elif isinstance(weight, (torch.Tensor, np.ndarray)):
Expand Down Expand Up @@ -122,8 +122,8 @@ def convNd(

if is_conv1d:
# Apply a squeeze operation to transform the conv2d problem back into conv1d
result = aten_ops_converters.aten_ops_squeeze(
network, target, (result, -1), {}, name + "_squeeze"
result = impl.squeeze.squeeze(
network, target, source_ir, name + "_squeeze_conv1d", result, -1
)

return result

0 comments on commit e3a7f2f

Please sign in to comment.