From 29e285b6149e59e0eb4d9575b83ed2d8f2d06f8b Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 31 Oct 2024 17:34:35 +0800 Subject: [PATCH] update --- .../build_tools/abstract_interp_lib_gen.py | 2 +- .../torch_mlir_e2e_test/test_suite/basic.py | 23 ------------------- .../test_suite/elementwise.py | 23 +++++++++++++++++++ 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2e5f8480a042..1bb4266d518a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -216,7 +216,7 @@ def aten〇silu〡shape(self: List[int]) -> List[int]: def aten〇exp〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇exp〡shape(self: List[int]) -> List[int]: +def aten〇exp2〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) def aten〇expm1〡shape(self: List[int]) -> List[int]: diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 323773cc2a99..bef16f3efcd7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4373,29 +4373,6 @@ def PowIntFloatModule_basic(module, tu: TestUtils): # ============================================================================== -class Exp2StaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([3, 2], torch.float32, True), - ] - ) - def forward(self, x): - return torch.ops.aten.exp2(x) - - -@register_test_case(module_factory=lambda: Exp2StaticModule()) -def Exp2StaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2)) - - -# ============================================================================== - - class BaddbmmDynamicModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a62b901a91ec..e9098698f38f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2881,6 +2881,29 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils): # ============================================================================== +class Exp2StaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 2], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.exp2(x) + + +@register_test_case(module_factory=lambda: Exp2StaticModule()) +def Exp2StaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2)) + + +# ============================================================================== + + class ElementwisePowModule(torch.nn.Module): def __init__(self): super().__init__()