Skip to content

Commit

Permalink
[Relax] Add support for ONNX LPPool
Browse files Browse the repository at this point in the history
adding support for ONNX LPPool and refactoring frontend tests
  • Loading branch information
Hzfengsy committed Nov 21, 2024
1 parent 2123b8c commit 44a3e11
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 181 deletions.
22 changes: 20 additions & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,7 @@ def _impl_v1(cls, bb, inputs, attr, params):
kernel_shape = attr.get("kernel_shape")
pads = attr.get("pads", 0)
strides = attr.get("strides", [1] * (ndim - 2))
count_include_pad = attr.get("count_include_pad", False)

assert len(kernel_shape) in [1, 2, 3], "Currently only 1D/2D/3D/ pooling is supported."

Expand Down Expand Up @@ -2298,7 +2299,7 @@ def _impl_v1(cls, bb, inputs, attr, params):
pads = tuple([val for pair in zip(*pads) for val in pair])

op = getattr(relax.op.nn, cls.name + str(len(kernel_shape)) + "d")
return op(data, kernel_shape, strides, pads, dilations, ceil_mode)
return op(data, kernel_shape, strides, pads, dilations, ceil_mode, count_include_pad)

@classmethod
def _get_input_spatial_shape(cls, tensor):
Expand All @@ -2318,6 +2319,23 @@ class AveragePool(Pool):
name = "avg_pool"


class LpPool(OnnxOpConverter):
"""Converts an onnx LpPool node into an equivalent Relax expression."""

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
dtype = inputs[0].struct_info.dtype
p = attr.get("p", 2.0)
reci_p = relax.const(1.0 / p, dtype=dtype)
# emit for get struct_info
data = bb.emit(relax.op.power(inputs[0], relax.const(p, dtype=dtype)))
attr.update({"count_include_pad": True})
avg_pool = AveragePool._impl_v1(bb, [data], attr, params)
kernels = attr["kernel_shape"]
out = avg_pool * relax.const(_np.prod(kernels).astype(dtype))
return relax.op.power(out, reci_p)


class GlobalAveragePool(OnnxOpConverter):
"""Converts an onnx GlobalAveragePool node into an equivalent Relax expression."""

Expand Down Expand Up @@ -3172,7 +3190,7 @@ def _get_convert_map():
"Tile": Tile,
"AveragePool": AveragePool,
"MaxPool": MaxPool,
# "LpPool": LpPool,
"LpPool": LpPool,
"GlobalAveragePool": GlobalAveragePool,
"GlobalMaxPool": GlobalMaxPool,
"GlobalLpPool": GlobalLpPool,
Expand Down
221 changes: 42 additions & 179 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,194 +2272,57 @@ def test_batch_norm():
check_correctness(model, opset=15)


def test_maxpool_and_averagepool():
for pool_name in ["MaxPool", "AveragePool"]:
@pytest.mark.parametrize("pool_name", ["MaxPool", "AveragePool", "LpPool"])
@pytest.mark.parametrize(
"shape, auto_pad, kernel_shape, strides, pads",
[
# Pool1D
verify_unary(
pool_name,
[1, 1, 32],
dict(
auto_pad="NOTSET",
kernel_shape=[3],
pads=[1, 1],
strides=[1],
),
)
([1, 1, 32], "NOTSET", [3], [1], [1, 1]),
# Pool1D with stride
verify_unary(
pool_name,
[1, 1, 32],
dict(
auto_pad="NOTSET",
kernel_shape=[3],
pads=[1, 2],
strides=[2],
),
)
([1, 1, 32], "NOTSET", [3], [2], [1, 1]),
# Pool1D with stride and autopadding
verify_unary(
pool_name,
[1, 1, 32],
dict(
auto_pad="SAME_UPPER",
kernel_shape=[7],
pads=None,
strides=[2],
),
)
verify_unary(
pool_name,
[1, 1, 32],
dict(
auto_pad="SAME_LOWER",
kernel_shape=[4],
pads=None,
strides=[4],
),
)
verify_unary(
pool_name,
[1, 1, 32],
dict(
auto_pad="VALID",
kernel_shape=[5],
pads=None,
strides=[5],
),
)
verify_unary(
pool_name,
[1, 1, 32],
dict(
auto_pad="SAME_UPPER",
kernel_shape=[3],
pads=None,
),
)
([1, 1, 32], "SAME_UPPER", [7], [2], None),
([1, 1, 32], "SAME_LOWER", [4], [4], None),
([1, 1, 32], "VALID", [5], [5], None),
([1, 1, 32], "SAME_UPPER", [3], [1], None),
# Pool2D
verify_unary(
pool_name,
[1, 1, 32, 32],
dict(
auto_pad="NOTSET",
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
)
([1, 1, 32, 32], "NOTSET", [3, 3], [1, 1], [1, 1, 1, 1]),
# Pool2D with stride
verify_unary(
pool_name,
[1, 1, 32, 32],
dict(
auto_pad="NOTSET",
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[2, 2],
),
)
([1, 1, 32, 32], "NOTSET", [3, 3], [2, 2], [1, 1, 1, 1]),
# Pool2D with stride and autopadding
verify_unary(
pool_name,
[1, 1, 32, 32],
dict(
auto_pad="SAME_UPPER",
kernel_shape=[3, 7],
pads=None,
strides=[3, 2],
),
)
verify_unary(
pool_name,
[1, 1, 32, 32],
dict(
auto_pad="SAME_LOWER",
kernel_shape=[3, 3],
pads=None,
strides=[2, 2],
),
)
verify_unary(
pool_name,
[1, 1, 32, 32],
dict(
auto_pad="VALID",
kernel_shape=[3, 3],
pads=None,
strides=[2, 2],
),
)
verify_unary(
pool_name,
[1, 1, 32, 32],
dict(
auto_pad="SAME_UPPER",
kernel_shape=[3, 3],
pads=None,
),
)
([1, 1, 32, 32], "SAME_UPPER", [3, 7], [3, 2], None),
([1, 1, 32, 32], "SAME_LOWER", [3, 3], [2, 2], None),
([1, 1, 32, 32], "VALID", [3, 3], [2, 2], None),
([1, 1, 32, 32], "SAME_UPPER", [3, 3], [1, 1], None),
# Pool3D
verify_unary(
pool_name,
[1, 1, 32, 32, 32],
dict(
auto_pad="NOTSET",
kernel_shape=[3, 3, 4],
pads=[1, 2, 1, 1, 2, 2],
strides=[1, 1, 1],
),
)
([1, 1, 32, 32, 32], "NOTSET", [3, 3, 4], [1, 1, 1], [1, 2, 1, 1, 2, 2]),
# Pool3D with stride
verify_unary(
pool_name,
[1, 1, 32, 32, 32],
dict(
auto_pad="NOTSET",
kernel_shape=[3, 4, 3],
pads=[1, 1, 1, 1, 1, 2],
strides=[2, 2, 3],
),
)
([1, 1, 32, 32, 32], "NOTSET", [3, 4, 3], [2, 2, 3], [1, 1, 1, 1, 1, 2]),
# Pool3D with stride and autopadding
verify_unary(
pool_name,
[1, 1, 32, 32, 32],
dict(
auto_pad="SAME_UPPER",
kernel_shape=[4, 3, 3],
pads=None,
strides=[3, 2, 2],
),
)
verify_unary(
pool_name,
[1, 1, 32, 32, 32],
dict(
auto_pad="SAME_LOWER",
kernel_shape=[3, 3, 4],
pads=None,
strides=[2, 2, 2],
),
)
verify_unary(
pool_name,
[1, 1, 32, 32, 32],
dict(
auto_pad="VALID",
kernel_shape=[3, 3, 5],
pads=None,
strides=[2, 2, 3],
),
)
verify_unary(
pool_name,
[1, 1, 32, 32, 32],
dict(
auto_pad="SAME_UPPER",
kernel_shape=[3, 3, 5],
pads=None,
),
)
([1, 1, 32, 32, 32], "SAME_UPPER", [4, 3, 3], [3, 2, 2], None),
([1, 1, 32, 32, 32], "SAME_LOWER", [3, 3, 4], [2, 2, 2], None),
([1, 1, 32, 32, 32], "VALID", [3, 3, 5], [2, 2, 3], None),
([1, 1, 32, 32, 32], "SAME_UPPER", [3, 3, 5], [1, 1, 1], None),
],
)
def test_pool(
pool_name: str,
shape: List[int],
auto_pad: str,
kernel_shape: List[int],
strides: List[int],
pads: List[int],
):
verify_unary(
pool_name,
shape,
attrs={
"kernel_shape": kernel_shape,
"strides": strides,
"pads": pads,
"auto_pad": auto_pad,
},
)


def test_global_average_pool():
Expand Down

0 comments on commit 44a3e11

Please sign in to comment.