Skip to content

Commit

Permalink
[Linalg] Add an avg_pool2d e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Apr 25, 2024
1 parent 2eac8a9 commit 38627d4
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,29 @@ def forward(self, x):
def AvgPool2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 10, 20, low=-1))

class AvgPool2dFloatStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[3, 3],
stride=[1, 1],
padding=[1, 1],
ceil_mode=False,
count_include_pad=False,
divisor_override=None)

@export
@annotate_args([
None,
([32, 384, 25, 25], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)


@register_test_case(module_factory=lambda: AvgPool2dFloatStaticModule())
def AvgPool2dFloatStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32, 384, 25, 25, low=-1))

class AvgPool2dDivisorOverrideModule(torch.nn.Module):

Expand Down

0 comments on commit 38627d4

Please sign in to comment.