diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 7ace41ffead3..1d4a74349b8d 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -20,7 +20,9 @@ "RepeatInterleaveStaticModule_basic", "RepeatInterleaveFillModule_basic", # tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 - "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic" + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + # Unimplemented operator 'aten.eye.m' + "EyeStaticModule_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -1235,6 +1237,7 @@ ### Tests additionally passing in make_fx_tosa "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "EyeStaticModule_basic", "NativeGroupNormBackwardModule_basic", "SliceWholeTensorModule_basic", "TensorFloatModule_basic", diff --git a/python/torch_mlir/dynamo.py b/python/torch_mlir/dynamo.py index 9ae51e3b7ca4..023af1faa7df 100644 --- a/python/torch_mlir/dynamo.py +++ b/python/torch_mlir/dynamo.py @@ -69,6 +69,7 @@ def _get_decomposition_table(): aten.index_select, aten.linalg_vector_norm, aten.index_select, + aten.eye, ] # TODO: enable test once 2.1.0 is stable if torch_version_for_comparison() >= version.parse("2.1.0.dev"): diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index b50a2a1f02cd..1b92c8f17135 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1527,3 +1527,19 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype()) def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + + +class EyeStaticModule(torch.nn.Module): + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ops.aten.eye(3, 5) + + +@register_test_case(module_factory=lambda: EyeStaticModule()) +def EyeStaticModule_basic(module, tu: TestUtils): + module.forward()