Skip to content

Commit

Permalink
Fixed Conv3DTranspose with strides for data format channels_first (fixes
Browse files Browse the repository at this point in the history
 #1714)

While shape calculations for the input correctly distinguished between channels_first and channels_last, shape calculations for the inputs of the final Slice and Pad nodes always assumed channels_last format.

Signed-off-by: fthielke <[email protected]>
  • Loading branch information
fthielke committed Nov 22, 2021
1 parent 4245d8d commit a4327e1
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,15 @@ def version_1(cls, ctx, node, **kwargs):
use_strides_workaround = False
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
output_shape = ctx.make_node("Shape", [node.output[0]])
sp_index_start = 1 if is_channels_last(node) else 2
output_h = GraphBuilder(ctx).make_slice(
{"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
{"data": output_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]})
output_w = GraphBuilder(ctx).make_slice(
{"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
{"data": output_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]})
expect_h = GraphBuilder(ctx).make_slice(
{"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
{"data": input_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]})
expect_w = GraphBuilder(ctx).make_slice(
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
{"data": input_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]})
diff_h = ctx.make_node("Sub", [output_h, expect_h])
diff_w = ctx.make_node("Sub", [output_w, expect_w])
nonneg_diff_h = diff_h
Expand All @@ -528,10 +529,12 @@ def version_1(cls, ctx, node, **kwargs):
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
if spatial == 3:
output_d = GraphBuilder(ctx).make_slice(
{"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
expect_d = GraphBuilder(ctx).make_slice(
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
output_d = GraphBuilder(ctx).make_slice({
"data": output_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0]
})
expect_d = GraphBuilder(ctx).make_slice({
"data": input_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0]
})
diff_d = ctx.make_node("Sub", [output_d, expect_d])
nonneg_diff_d = diff_d
if use_strides_workaround:
Expand All @@ -543,12 +546,12 @@ def version_1(cls, ctx, node, **kwargs):
attr={"axis": 0})
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0})
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
np.array([1, 2, 3], dtype=np.int64))
np.arange(sp_index_start, sp_index_start + 3, dtype=np.int64))
else:
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
np.array([1, 2], dtype=np.int64))
np.arange(sp_index_start, sp_index_start + 2, dtype=np.int64))

slice_node = ctx.make_node("Slice",
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
Expand All @@ -571,10 +574,16 @@ def version_1(cls, ctx, node, **kwargs):
neg_diff_d = ctx.make_node("Neg", [diff_d.output[0]])
shrink_d_by = ctx.make_node("Max", [neg_diff_d.output[0], const_zero.output[0]])
sdb = shrink_d_by.output[0]
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
if is_channels_last(node):
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
else:
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, cz, shb, swb, sdb], attr={"axis": 0})
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
else:
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
if is_channels_last(node):
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
else:
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb], attr={"axis": 0})
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])

final_node = padded_node
Expand Down

0 comments on commit a4327e1

Please sign in to comment.