Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 committed Oct 31, 2024
1 parent 809442a commit 29e285b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def aten〇silu〡shape(self: List[int]) -> List[int]:
def aten〇exp〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇exp〡shape(self: List[int]) -> List[int]:
def aten〇exp2〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇expm1〡shape(self: List[int]) -> List[int]:
Expand Down
23 changes: 0 additions & 23 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4373,29 +4373,6 @@ def PowIntFloatModule_basic(module, tu: TestUtils):
# ==============================================================================


class Exp2StaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 2], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.exp2(x)


@register_test_case(module_factory=lambda: Exp2StaticModule())
def Exp2StaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2))


# ==============================================================================


class BaddbmmDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2881,6 +2881,29 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils):
# ==============================================================================


class Exp2StaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 2], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.exp2(x)


@register_test_case(module_factory=lambda: Exp2StaticModule())
def Exp2StaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2))


# ==============================================================================


class ElementwisePowModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 29e285b

Please sign in to comment.