diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b64e87822a0a..349b8695a36b 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2254,6 +2254,7 @@ def _impl_v1(cls, bb, inputs, attr, params): kernel_shape = attr.get("kernel_shape") pads = attr.get("pads", 0) strides = attr.get("strides", [1] * (ndim - 2)) + count_include_pad = attr.get("count_include_pad", False) assert len(kernel_shape) in [1, 2, 3], "Currently only 1D/2D/3D/ pooling is supported." @@ -2298,7 +2299,7 @@ def _impl_v1(cls, bb, inputs, attr, params): pads = tuple([val for pair in zip(*pads) for val in pair]) op = getattr(relax.op.nn, cls.name + str(len(kernel_shape)) + "d") - return op(data, kernel_shape, strides, pads, dilations, ceil_mode) + return op(data, kernel_shape, strides, pads, dilations, ceil_mode, count_include_pad) @classmethod def _get_input_spatial_shape(cls, tensor): @@ -2318,6 +2319,23 @@ class AveragePool(Pool): name = "avg_pool" +class LpPool(OnnxOpConverter): + """Converts an onnx LpPool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + dtype = inputs[0].struct_info.dtype + p = attr.get("p", 2.0) + reci_p = relax.const(1.0 / p, dtype=dtype) + # emit for get struct_info + data = bb.emit(relax.op.power(inputs[0], relax.const(p, dtype=dtype))) + attr.update({"count_include_pad": True}) + avg_pool = AveragePool._impl_v1(bb, [data], attr, params) + kernels = attr["kernel_shape"] + out = avg_pool * relax.const(_np.prod(kernels).astype(dtype)) + return relax.op.power(out, reci_p) + + class GlobalAveragePool(OnnxOpConverter): """Converts an onnx GlobalAveragePool node into an equivalent Relax expression.""" @@ -3172,7 +3190,7 @@ def _get_convert_map(): "Tile": Tile, "AveragePool": AveragePool, "MaxPool": MaxPool, - # "LpPool": LpPool, + "LpPool": LpPool, "GlobalAveragePool": GlobalAveragePool, "GlobalMaxPool": GlobalMaxPool, "GlobalLpPool": GlobalLpPool, diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 89f08e5af91f..d528d513202a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2272,194 +2272,57 @@ def test_batch_norm(): check_correctness(model, opset=15) -def test_maxpool_and_averagepool(): - for pool_name in ["MaxPool", "AveragePool"]: +@pytest.mark.parametrize("pool_name", ["MaxPool", "AveragePool", "LpPool"]) +@pytest.mark.parametrize( + "shape, auto_pad, kernel_shape, strides, pads", + [ # Pool1D - verify_unary( - pool_name, - [1, 1, 32], - dict( - auto_pad="NOTSET", - kernel_shape=[3], - pads=[1, 1], - strides=[1], - ), - ) + ([1, 1, 32], "NOTSET", [3], [1], [1, 1]), # Pool1D with stride - verify_unary( - pool_name, - [1, 1, 32], - dict( - auto_pad="NOTSET", - kernel_shape=[3], - pads=[1, 2], - strides=[2], - ), - ) + ([1, 1, 32], "NOTSET", [3], [2], [1, 1]), # Pool1D with stride and autopadding - verify_unary( - pool_name, - [1, 1, 32], - dict( - auto_pad="SAME_UPPER", - kernel_shape=[7], - pads=None, - strides=[2], - ), - ) - verify_unary( - pool_name, - [1, 1, 32], - dict( - auto_pad="SAME_LOWER", - kernel_shape=[4], - pads=None, - strides=[4], - ), - ) - verify_unary( - pool_name, - [1, 1, 32], - dict( - auto_pad="VALID", - kernel_shape=[5], - pads=None, - strides=[5], - ), - ) - verify_unary( - pool_name, - [1, 1, 32], - dict( - auto_pad="SAME_UPPER", - kernel_shape=[3], - pads=None, - ), - ) + ([1, 1, 32], "SAME_UPPER", [7], [2], None), + ([1, 1, 32], "SAME_LOWER", [4], [4], None), + ([1, 1, 32], "VALID", [5], [5], None), + ([1, 1, 32], "SAME_UPPER", [3], [1], None), # Pool2D - verify_unary( - pool_name, - [1, 1, 32, 32], - dict( - auto_pad="NOTSET", - kernel_shape=[3, 3], - pads=[1, 1, 1, 1], - strides=[1, 1], - ), - ) + ([1, 1, 32, 32], "NOTSET", [3, 3], [1, 1], [1, 1, 1, 1]), # Pool2D with stride - verify_unary( - pool_name, - [1, 1, 32, 32], - dict( - auto_pad="NOTSET", - kernel_shape=[3, 3], - pads=[1, 1, 1, 1], - strides=[2, 2], - ), - ) + ([1, 1, 32, 32], "NOTSET", [3, 3], [2, 2], [1, 1, 1, 1]), # Pool2D with stride and autopadding - verify_unary( - pool_name, - [1, 1, 32, 32], - dict( - auto_pad="SAME_UPPER", - kernel_shape=[3, 7], - pads=None, - strides=[3, 2], - ), - ) - verify_unary( - pool_name, - [1, 1, 32, 32], - dict( - auto_pad="SAME_LOWER", - kernel_shape=[3, 3], - pads=None, - strides=[2, 2], - ), - ) - verify_unary( - pool_name, - [1, 1, 32, 32], - dict( - auto_pad="VALID", - kernel_shape=[3, 3], - pads=None, - strides=[2, 2], - ), - ) - verify_unary( - pool_name, - [1, 1, 32, 32], - dict( - auto_pad="SAME_UPPER", - kernel_shape=[3, 3], - pads=None, - ), - ) + ([1, 1, 32, 32], "SAME_UPPER", [3, 7], [3, 2], None), + ([1, 1, 32, 32], "SAME_LOWER", [3, 3], [2, 2], None), + ([1, 1, 32, 32], "VALID", [3, 3], [2, 2], None), + ([1, 1, 32, 32], "SAME_UPPER", [3, 3], [1, 1], None), # Pool3D - verify_unary( - pool_name, - [1, 1, 32, 32, 32], - dict( - auto_pad="NOTSET", - kernel_shape=[3, 3, 4], - pads=[1, 2, 1, 1, 2, 2], - strides=[1, 1, 1], - ), - ) + ([1, 1, 32, 32, 32], "NOTSET", [3, 3, 4], [1, 1, 1], [1, 2, 1, 1, 2, 2]), # Pool3D with stride - verify_unary( - pool_name, - [1, 1, 32, 32, 32], - dict( - auto_pad="NOTSET", - kernel_shape=[3, 4, 3], - pads=[1, 1, 1, 1, 1, 2], - strides=[2, 2, 3], - ), - ) + ([1, 1, 32, 32, 32], "NOTSET", [3, 4, 3], [2, 2, 3], [1, 1, 1, 1, 1, 2]), # Pool3D with stride and autopadding - verify_unary( - pool_name, - [1, 1, 32, 32, 32], - dict( - auto_pad="SAME_UPPER", - kernel_shape=[4, 3, 3], - pads=None, - strides=[3, 2, 2], - ), - ) - verify_unary( - pool_name, - [1, 1, 32, 32, 32], - dict( - auto_pad="SAME_LOWER", - kernel_shape=[3, 3, 4], - pads=None, - strides=[2, 2, 2], - ), - ) - verify_unary( - pool_name, - [1, 1, 32, 32, 32], - dict( - auto_pad="VALID", - kernel_shape=[3, 3, 5], - pads=None, - strides=[2, 2, 3], - ), - ) - verify_unary( - pool_name, - [1, 1, 32, 32, 32], - dict( - auto_pad="SAME_UPPER", - kernel_shape=[3, 3, 5], - pads=None, - ), - ) + ([1, 1, 32, 32, 32], "SAME_UPPER", [4, 3, 3], [3, 2, 2], None), + ([1, 1, 32, 32, 32], "SAME_LOWER", [3, 3, 4], [2, 2, 2], None), + ([1, 1, 32, 32, 32], "VALID", [3, 3, 5], [2, 2, 3], None), + ([1, 1, 32, 32, 32], "SAME_UPPER", [3, 3, 5], [1, 1, 1], None), + ], +) +def test_pool( + pool_name: str, + shape: List[int], + auto_pad: str, + kernel_shape: List[int], + strides: List[int], + pads: List[int], +): + verify_unary( + pool_name, + shape, + attrs={ + "kernel_shape": kernel_shape, + "strides": strides, + "pads": pads, + "auto_pad": auto_pad, + }, + ) def test_global_average_pool():