diff --git a/tests/test_backend.py b/tests/test_backend.py index 740d6c584..428587ca8 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -73,6 +73,7 @@ matrix_diag_part = tf.compat.v1.matrix_diag_part fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args fake_quant_with_min_max_vars = tf.quantization.fake_quant_with_min_max_vars + extract_image_patches = tf.image.extract_patches elif Version(tf.__version__) >= Version("1.13"): conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input conv3d_transpose = tf.compat.v1.nn.conv3d_transpose @@ -96,6 +97,7 @@ matrix_diag_part = tf.compat.v1.matrix_diag_part fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args fake_quant_with_min_max_vars = tf.compat.v1.quantization.fake_quant_with_min_max_vars + extract_image_patches = tf.compat.v1.extract_image_patches else: conv2d_backprop_input = tf.nn.conv2d_backprop_input conv3d_transpose = tf.nn.conv3d_transpose @@ -113,6 +115,7 @@ is_inf = tf.is_inf floormod = tf.floormod matrix_diag_part = tf.matrix_diag_part + extract_image_patches = tf.extract_image_patches def make_xval(shape): @@ -6361,5 +6364,22 @@ def func(tensor, indices, updates): self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val}) self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val}) + @check_opset_min_version(9, "EyeLike and ConstantOfShape") + def test_extract_image_patches(self): + for rates in [[1, 1], [1, 4], [4, 1], [3, 3]]: + for _, padding, x_shape, sizes, strides in get_conv_getdata(): + def func(x): + return extract_image_patches( + x, + sizes=sizes, + strides=strides, + rates=[1] + rates + [1], + padding=padding, + name=_TFOUTPUT + ) + + x_val = make_xval(x_shape) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 96c658b38..a06f02cec 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -2091,3 +2091,93 @@ def version_11(cls, ctx, node, **kwargs): ctx.replace_all_inputs(node.output[3], sum_max_neg) ctx.remove_node(node.name) + + +@tf_op("ExtractImagePatches") +class ExtractImagePatches: + @classmethod + def version_9(cls, ctx, node, **kwargs): + input_shape = ctx.get_shape(node.input[0]) + output_shape = node.output_shapes[0] + + sizes = node.get_attr_value("ksizes") + strides = node.get_attr_value("strides") + rates = node.get_attr_value("rates") + padding = node.get_attr_str("padding") + + # This implementation of ExtractImagePatches does not generalize + # to outputs that are empty. For example: + # + # tf.image.extract_patches( + # np.random.rand(1, 1, 1, 1), sizes=[1, 2, 2, 1], strides=[1, 1, 1, 1], + # rates=[1, 1, 1, 1], padding="VALID" + # ) + # + # succeeds with the output of: + # + # + # + # whereas attempting the same here results in an "Invalid input shape" error for the "Conv" node. + utils.make_sure(0 not in output_shape, "Empty ExtractImagePatches output is unsupported.") + [_, size_rows, size_cols, _] = sizes + + # Transform input into [N * C, H, W, 1]. + transformed_input = ctx.make_node("Reshape", inputs=[ + ctx.make_node("Transpose", inputs=node.input, attr=dict(perm=[0, 3, 1, 2])).output[0], + ctx.make_const(utils.make_name("new_shape"), np.int64([ + input_shape[0] * input_shape[3], + input_shape[1], + input_shape[2], + 1, + ])).output[0], + ]) + + # Create identity kernel. + k = size_rows * size_cols + identity_kernel = ctx.make_node("Reshape", inputs=[ + ctx.make_node("EyeLike", inputs=[ + ctx.make_node("ConstantOfShape", inputs=[ + ctx.make_const(utils.make_name("eye_size"), np.array([k, k], dtype=np.int64)).output[0], + ]).output[0], + ]).output[0], + ctx.make_const(utils.make_name("new_shape"), np.array([ + size_rows, + size_cols, + 1, + k, + ], dtype=np.int64)).output[0], + ]) + + # Construct placeholder convolution node and transform into [N * C, K, ?H, ?W]. + convolution = ctx.make_node("Conv", inputs=[transformed_input.output[0], identity_kernel.output[0]], + shapes=[[input_shape[0] * input_shape[3], output_shape[1], output_shape[2], k]], + attr=dict(strides=strides, dilations=rates, padding=padding, data_format="NHWC"), + dtypes=node.output_dtypes) + + # Transform into [N, ?H, ?W, C * K]. + output_node = ctx.make_node("Reshape", inputs=[ + ctx.make_node("Transpose", inputs=[ + ctx.make_node("Reshape", inputs=[ + convolution.output[0], + ctx.make_const(utils.make_name("new_shape"), np.array([ + input_shape[0], + input_shape[3], + output_shape[1], + output_shape[2], + k, + ], dtype=np.int64)).output[0], + ]).output[0], + ], attr=dict(perm=[0, 2, 3, 4, 1])).output[0], + ctx.make_const(utils.make_name("new_shape"), np.array(output_shape, dtype=np.int64)).output[0], + ]) + + # Replace original node. + ctx.replace_all_inputs(node.output[0], output_node.output[0]) + ctx.remove_node(node.name) + + # Transform convolution node. + kernel_shape = conv_kernel_shape(ctx, convolution, 1) + strides = conv_dims_attr(convolution, "strides") + dilations = conv_dims_attr(convolution, "dilations") + add_padding(ctx, convolution, kernel_shape, strides, dilations) + conv_convert_inputs(ctx, convolution, with_kernel=True)