Skip to content

Commit

Permalink
[Relay][Bugfix] Fix conv transpose with default strides in ONNX front…
Browse files Browse the repository at this point in the history
…end (#15868)

* fix strides in ConvTranspose converter

* Update test_forward.py
  • Loading branch information
jikechao authored Oct 4, 2023
1 parent 7a1f7d0 commit e754bc2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ def _impl_v11(cls, inputs, attr, params):
data = inputs[0]
input_shape = infer_shape(data)
ndim = len(input_shape)
num_spatial_dims = ndim - 2
if "auto_pad" in attr or "output_shape" in attr:
if "auto_pad" in attr:
attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
Expand All @@ -941,7 +942,8 @@ def _impl_v11(cls, inputs, attr, params):
kndim = len(kernel_shape)
dilations = attr.get("dilations", [1] * kndim)
output_padding = attr.get("output_padding", [0] * kndim)
strides = attr["strides"]
# this is meant to handle the field 'strides' being optional for opsets 11+
strides = attr.get("strides", [1] * num_spatial_dims)
total_pad = [0] * kndim
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
if "output_shape" in attr:
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3340,6 +3340,15 @@ def repeat(num, dims):
repeat(1, dims),
auto_pad="SAME_UPPER",
)
# Convolution with default stride
verify_convtranspose_with_padding(
(1, 1) + repeat(5, dims),
(1, 1) + repeat(3, dims),
2 * repeat(1, dims),
repeat(3, dims),
None,
repeat(1, dims),
)
# Convolution with dilation
# TODO(mbrookhart): Relay doesn't currently support convtranspose with dilation
# verify_convtranspose_with_padding(
Expand Down

0 comments on commit e754bc2

Please sign in to comment.