From 8c48135a426b84fa412b031fc92e12826ff60b31 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 1 May 2024 12:06:53 +0530 Subject: [PATCH] [linalg] Fix bug for conversion of complex dtype (#3269) The conversion of complex type wasn't supported or checked; the support and required tests were added. Fixes: https://github.com/iree-org/iree/issues/17226#issuecomment-2087779158 --- lib/Conversion/Utils/Utils.cpp | 21 ++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 2 ++ .../test_suite/elementwise.py | 28 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index bae25cc7ac60..e014fbeaa9d4 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtype, scalar); } + if (auto dtypeComplex = dyn_cast(dtype)) { + if (auto scalarComplex = dyn_cast(scalarType)) { + auto dtypeElemType = dtypeComplex.getElementType(); + + // Extract the real and imaginary parts of the scalar. + // Cast them to the target element type, and create a new complex + // value with the target complex type. + Value realVal = b.create(loc, scalar); + Value imgVal = b.create(loc, scalar); + + realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType); + imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType); + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype + << "(dtype)"; + } + llvm_unreachable("convertScalarToDtype should handle all the types"); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 25d8fa9be5a2..33f1ed702273 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -575,6 +575,7 @@ "ElementwiseErfIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -2314,6 +2315,7 @@ "ElementwiseExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 8e287584295b..a26fd9809f13 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1839,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils): # ============================================================================== +# torch.complex32 is not supported by the refbackend. +class ElementwiseMulTensorComplexDiffModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.complex64, True), + ([-1], torch.complex128, True), + ] + ) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule()) +def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.complex64), + tu.randint(4, high=10).type(torch.complex128), + ) + + +# ============================================================================== + + class ElementwiseMishModule(torch.nn.Module): def __init__(self): super().__init__()