Skip to content

Commit

Permalink
Decompositions aten.eye
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jul 10, 2023
1 parent a5670e8 commit 4d9fd29
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
5 changes: 4 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
"RepeatInterleaveStaticModule_basic",
"RepeatInterleaveFillModule_basic",
# tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic"
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic",
# Unimplemented operator 'aten.eye.m'
"EyeStaticModule_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -1235,6 +1237,7 @@
### Tests additionally passing in make_fx_tosa
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"EyeStaticModule_basic",
"NativeGroupNormBackwardModule_basic",
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _get_decomposition_table():
aten.index_select,
aten.linalg_vector_norm,
aten.index_select,
aten.eye,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
Expand Down
16 changes: 16 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/constant_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,3 +1527,19 @@ def forward(self, a):
@register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype())
def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))

# ==============================================================================


class EyeStaticModule(torch.nn.Module):
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.eye(3, 5)


@register_test_case(module_factory=lambda: EyeStaticModule())
def EyeStaticModule_basic(module, tu: TestUtils):
module.forward()

0 comments on commit 4d9fd29

Please sign in to comment.