Skip to content

Commit

Permalink
Decompositions aten.eye
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jul 10, 2023
1 parent a5670e8 commit 7a48e90
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 1 deletion.
8 changes: 7 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
"RepeatInterleaveStaticModule_basic",
"RepeatInterleaveFillModule_basic",
# tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic"
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic",
# Unimplemented operator 'aten.eye.m'
"EyeStaticModule_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -248,6 +250,7 @@
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"RsubIntStaticModule_noalpha_basic",

# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseDivRoundingModeFloorModule_basic",
Expand Down Expand Up @@ -646,6 +649,7 @@
"RsubFloatModule_noalpha_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubIntStaticModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
Expand Down Expand Up @@ -1235,6 +1239,7 @@
### Tests additionally passing in make_fx_tosa
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"EyeStaticModule_basic",
"NativeGroupNormBackwardModule_basic",
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
Expand All @@ -1252,6 +1257,7 @@
"NormalizeModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
"RsubIntStaticModule_noalpha_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa

Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _get_decomposition_table():
aten.index_select,
aten.linalg_vector_norm,
aten.index_select,
aten.eye,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
Expand Down
16 changes: 16 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/constant_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,3 +1527,19 @@ def forward(self, a):
@register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype())
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))

# ==============================================================================


class EyeStaticModule(torch.nn.Module):
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.eye(3, 5)


@register_test_case(module_factory=lambda: EyeStaticModule())
def EyeStaticModule_basic(module, tu: TestUtils):
module.forward()
23 changes: 23 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,29 @@ def forward(self, x):
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=100))


# ==============================================================================


class RsubIntStaticModule_noalpha(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.rsub(x, 2.)


@register_test_case(module_factory=lambda: RsubIntStaticModule_noalpha())
def RsubIntStaticModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=100))


# ==============================================================================


Expand Down

0 comments on commit 7a48e90

Please sign in to comment.