From 718f53ff8a94baf3b7d0c4f307484171e6546d2a Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 17 Jul 2023 09:52:04 -0700 Subject: [PATCH] Fix handling of `!torch.number` in abstract interpretation library (#2309) In PyTorch, the `NumberType` is equal to `Union[int, float, complex]`. However, the abstract interpretation library was treating the `NumberType` as `Union[int, float]`, resulting in type mismatches when reifying certain dtype functions. This commit fixes the type inconsistency by having the abstract interpretation functions take as an input a `Union[int, float, complex]` for the ops that take `!torch.number` inputs. --- .../Transforms/AbstractInterpLibrary.cpp | 158 +++++++++--------- .../ReifyAbstractInterpCalculationsUtils.cpp | 15 +- .../build_tools/abstract_interp_lib_gen.py | 104 ++++++------ .../jit_ir/build_tools/library_generator.py | 2 +- .../importer/jit_ir/build_tools/registry.py | 2 +- .../Torch/reify-dtype-calculations.mlir | 15 ++ 6 files changed, 159 insertions(+), 137 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index d3209bb115c2..5891e88a1c68 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7287,7 +7287,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.number, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" " %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -8026,7 +8026,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" @@ -8166,7 +8166,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8178,7 +8178,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8190,7 +8190,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8206,7 +8206,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8263,7 +8263,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8315,7 +8315,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" @@ -8326,7 +8326,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int11 = torch.constant.int 11\n" @@ -8381,7 +8381,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" @@ -8396,11 +8396,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8544,7 +8544,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8591,7 +8591,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -8647,12 +8647,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.number) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.number) -> !torch.int {\n" +" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.number -> !torch.tensor\n" " %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" " return %1 : !torch.int\n" " }\n" @@ -8710,7 +8710,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8718,11 +8718,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8734,7 +8734,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8758,7 +8758,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -8778,15 +8778,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" @@ -8835,11 +8835,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" @@ -8852,7 +8852,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" @@ -9176,7 +9176,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" @@ -9184,7 +9184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" " %int11 = torch.constant.int 11\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" @@ -9358,7 +9358,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9376,7 +9376,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int11 = torch.constant.int 11\n" @@ -9409,7 +9409,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" " return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9425,39 +9425,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %7 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" @@ -9468,16 +9468,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9490,21 +9490,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" " %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int11 = torch.constant.int 11\n" @@ -9517,7 +9517,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" @@ -9536,16 +9536,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" @@ -9590,14 +9590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int4 = torch.constant.int 4\n" " %false = torch.constant.bool false\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" @@ -9610,20 +9610,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" @@ -9701,7 +9701,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.number, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int6 = torch.constant.int 6\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -9719,7 +9719,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" @@ -9730,7 +9730,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" @@ -9749,12 +9749,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %7 : !torch.bool\n" " }\n" @@ -9767,7 +9767,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" @@ -9786,19 +9786,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %8 : !torch.bool\n" " }\n" " %5 = torch.prim.If %4 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" " %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" " torch.prim.If.yield %8 : !torch.bool\n" " }\n" @@ -9910,7 +9910,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" @@ -9925,7 +9925,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" @@ -9935,7 +9935,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10061,7 +10061,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" " %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -10069,7 +10069,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" " torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" " %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int6 : !torch.int\n" @@ -10116,7 +10116,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -10313,7 +10313,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.tuple {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.tuple {\n" " %int7 = torch.constant.int 7\n" " %int10 = torch.constant.int 10\n" " %int6 = torch.constant.int 6\n" @@ -10485,8 +10485,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.number) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 8e6b5888bb02..290beb1da7c9 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -176,10 +176,17 @@ FailureOr Torch::adjustFunctionArg( return b.create(loc, desiredType, operand).getResult(); } - // !torch.union or !torch.union is the type used - // for (optional) `Scalar` inputs. At compile time, such inputs will usually - // be resolved to an `int` or a `float` so we need to derefine to match the - // library function signature. + // The type `!torch.number` can be an `int`, `float`, or `complex`. + // TODO: Add a new type `Torch::ComplexType` to handle the complex case. + if (desiredType.isa() && + operandType.isa()) { + return b.create(loc, desiredType, operand).getResult(); + } + + // !torch.union is the type used for optional + // `Scalar` inputs. At compile time, such inputs will usually be + // resolved to an `int`, `float`, or `None` so we need to derefine + // to match the library function signature. if (auto unionType = desiredType.dyn_cast()) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { return containedType diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 232ef262d268..e7794e0a33be 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -772,7 +772,7 @@ def aten〇scalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout: return [] @check_dtype_function([Invocation(-1), Invocation(-1.0)]) -def aten〇scalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇scalar_tensor〡dtype(s: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: return dtype else: @@ -1314,7 +1314,7 @@ def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: +def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, threshold: Union[int, float, complex] = 20) -> int: self_rank, self_dtype = self_rank_dtype if is_integer_dtype(self_dtype): return self_dtype @@ -1395,21 +1395,21 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) -def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int: +def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype if self_dtype == torch.bool: return torch.int64 return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0)) -def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int: +def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype if self_dtype == torch.bool: return torch.int64 return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) -def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int: +def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float, complex]] = None, max: Optional[Union[int, float, complex]] = None) -> int: self_rank, self_dtype = self_rank_dtype if self_dtype == torch.bool: return torch.int64 @@ -1421,7 +1421,7 @@ def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Option return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) -def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int: +def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float, complex] = 0) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1478,7 +1478,7 @@ def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], imp return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0)) -def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1537,14 +1537,14 @@ def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype @check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5)) -def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int: +def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex], max_val: Union[int, float, complex]) -> int: grad_output_rank, grad_output_dtype = grad_output_rank_dtype if is_integer_dtype(grad_output_dtype): return torch.float32 return grad_output_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool})) -def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int: +def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex] = -1, max_val: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.uint8, torch.bool] return self_dtype @@ -1597,7 +1597,7 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap return input_dtype @check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) -def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int: +def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int: grad_output_rank, grad_output_dtype = grad_output_rank_dtype self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [grad_output_rank, self_rank] @@ -1617,12 +1617,12 @@ def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, return input_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0)) -def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0)) -def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1766,7 +1766,7 @@ def aten〇scatter〇src〡dtype(self_rank_dtype: Tuple[int, int], dim: int, ind @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES]) -def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: +def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1820,7 +1820,7 @@ def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output return promoted_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0)) -def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int: +def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex], value: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -1890,7 +1890,7 @@ def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype @check_dtype_function([Invocation(-1), Invocation(-1.0)]) -def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int: +def prim〇abs〇Scalar〡dtype(a: Union[int, float, complex]) -> int: return get_dtype_of_scalar(a) @check_dtype_function(_check_tensors_with_the_same_dtype( @@ -1931,7 +1931,7 @@ def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -1941,13 +1941,13 @@ def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -1961,7 +1961,7 @@ def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -1988,7 +1988,7 @@ def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) @@ -2010,7 +2010,7 @@ def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: return torch.bool @check_dtype_function([ @@ -2019,7 +2019,7 @@ def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[in Invocation(0, 0.0), # int, float Invocation(0, 0), # int, int ]) -def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: +def aten〇add〡dtype(a: Union[int, float, complex], b: Union[int, float, complex]) -> int: ranks: List[Optional[int]] = [None, None] dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] return promote_dtypes(ranks, dtypes) @@ -2044,7 +2044,7 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) @@ -2057,7 +2057,7 @@ def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op()) -def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: other_rank, other_dtype = other_rank_dtype self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] @@ -2238,7 +2238,7 @@ def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[in return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op()) -def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: other_rank, other_dtype = other_rank_dtype self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] @@ -2249,7 +2249,7 @@ def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty # https://github.com/pytorch/pytorch/issues/100921 # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function(_check_two_tensor_op(tensor_device="cpu", input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}, threshold=0)) -def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype grad_output_rank, grad_output_dtype = grad_output_rank_dtype assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" @@ -2433,7 +2433,7 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(4, 3, dtype=torch.float32))]) -def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: +def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype mat1_rank, mat1_dtype = mat1_rank_dtype mat2_rank, mat2_dtype = mat2_rank_dtype @@ -2477,7 +2477,7 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float32))]) -def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: +def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype tensor1_rank, tensor1_dtype = tensor1_rank_dtype tensor2_rank, tensor2_dtype = tensor2_rank_dtype @@ -2503,7 +2503,7 @@ def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float32))]) -def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: +def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype tensor1_rank, tensor1_dtype = tensor1_rank_dtype tensor2_rank, tensor2_dtype = tensor2_rank_dtype @@ -2517,7 +2517,7 @@ def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2526,7 +2526,7 @@ def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2534,7 +2534,7 @@ def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2542,7 +2542,7 @@ def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2554,7 +2554,7 @@ def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2563,7 +2563,7 @@ def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) -def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype) ranks: List[Optional[int]] = [self_rank, None] @@ -2572,7 +2572,7 @@ def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) -def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: +def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(exponent)] @@ -2581,7 +2581,7 @@ def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponen @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0)) -def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int: +def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex] = 0.01) -> int: self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.bool ranks: List[Optional[int]] = [self_rank, None] @@ -2594,7 +2594,7 @@ def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) -def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2611,7 +2611,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int64, device="cpu")), ErrorInvocation( TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.bfloat16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"))]) -def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: +def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: batch1_rank, batch1_dtype = batch1_rank_dtype batch2_rank, batch2_dtype = batch2_rank_dtype assert batch1_dtype not in [torch.bool, torch.float16] @@ -2637,7 +2637,7 @@ def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0), Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0), Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)]) -def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int: +def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other: Union[int, float, complex]) -> int: if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)): return torch.int64 return torch.float32 @@ -2646,7 +2646,7 @@ def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: U Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0), Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0), Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)]) -def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, None] dtypes = [self_dtype, get_dtype_of_scalar(other)] @@ -2656,7 +2656,7 @@ def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], se Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)), Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)), Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))]) -def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int: +def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype ranks: List[Optional[int]] = [None, other_rank] dtypes = [get_dtype_of_scalar(self), other_dtype] @@ -2755,7 +2755,7 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified Invocation(end=0, dtype=torch.float16), # Dtype specified Invocation(end=0, dtype=torch.int16)]) # Dtype specified -def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇arange〡dtype(end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: assert not is_complex_dtype(dtype) return dtype @@ -2769,7 +2769,7 @@ def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, l ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified -def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇arange〇start〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: assert not is_complex_dtype(dtype) return dtype @@ -2785,7 +2785,7 @@ def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, floa ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified -def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇arange〇start_step〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], step: Union[int, float, complex] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: assert not is_complex_dtype(dtype) return dtype @@ -2876,7 +2876,7 @@ def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: +def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int: return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @@ -2888,7 +2888,7 @@ def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: +def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int: return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) @@ -2906,7 +2906,7 @@ def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) -def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: +def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float, complex] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype assert not is_integer_dtype(self_dtype) if dtype is not None: @@ -2971,7 +2971,7 @@ def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] = Invocation([1], 0.0, dtype=torch.int32), Invocation([1], 0.0, dtype=torch.float16), Invocation([1], 0.0, dtype=torch.complex64)]) -def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: +def aten〇full〡dtype(size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: if dtype is not None: return dtype fill_value_dtype = get_dtype_of_scalar(fill_value) @@ -3009,7 +3009,7 @@ def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[ _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64)) -def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: +def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype @@ -3143,7 +3143,7 @@ def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Opt return dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) -def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]: +def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype assert not is_integer_dtype(self_dtype) if self_dtype == torch.complex64: @@ -3220,7 +3220,7 @@ def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int: assert False, "Unexpected dtype!" @check_dtype_function([Invocation(0), Invocation(0.0)]) -def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int: +def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float, complex]) -> int: return get_dtype_of_scalar(a) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 3cfc4a24aa74..99b2bc058c2c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -63,7 +63,7 @@ def get_priority_of_dtype(dtype: int) -> int: return 11 assert False, "Cannot determine priority of dtype" -def get_dtype_of_scalar(scalar: Union[int, float]) -> int: +def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars # that when `jit.script`ed converts a float scalar to a tensor # with dtype that corresponds to Python's `float`. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 0aa5a1960eb0..54f552290fcb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -49,7 +49,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str: def _pytype_to_fn_pytype_common(pytype: str) -> str: if "number" in pytype: - return pytype.replace("number", "Union[int, float]") + return pytype.replace("number", "Union[int, float, complex]") # `torch.device` is lowercase. if pytype == "Device": return "device" diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 265497ddf324..9aec26662b69 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -72,3 +72,18 @@ func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: ! %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor } + +// ----- + +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.arange( + +// CHECK-LABEL: func.func @derefine_int_to_number() -> !torch.vtensor { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[NUMBER:.*]] = torch.derefine %[[INT1]] : !torch.int to !torch.number +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.arange(%[[NUMBER]], {{.*}}) : (!torch.number, {{.*}}) -> !torch.int +func.func @derefine_int_to_number() -> !torch.vtensor { + %int1 = torch.constant.int 1 + %none = torch.constant.none + %0 = torch.aten.arange %int1, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %0 : !torch.vtensor +}