From 02340408b7bb909dce71269a031c699c4eb187f5 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 25 Jun 2024 19:00:45 +0530 Subject: [PATCH] [torch] Add OnnxToTorch lowering for Onnx.STFT op (#3492) Adds OnnxToTorch lowering for `Onnx.STFT` op. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 30 ++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 166 ++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 256 ++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 60 ++++ .../build_tools/torch_ods_gen.py | 3 + .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 49 ++++ 6 files changed, 564 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b836b6bab5b6..c351d845c2f8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12533,6 +12533,36 @@ def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [ let hasVerifier = 1; } +def Torch_AtenStftOp : Torch_Op<"aten.stft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$n_fft, + AnyTorchOptionalIntType:$hop_length, + AnyTorchOptionalIntType:$win_length, + AnyTorchOptionalTensorType:$window, + Torch_BoolType:$normalized, + AnyTorchOptionalBoolType:$onesided, + AnyTorchOptionalBoolType:$return_complex + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenStftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 63eac34270db..a6d05d7cc8b8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3300,4 +3300,170 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*Torch_BoolType:$antialias*/ cstFalse); return success(); }); + patterns.onOp( + "STFT", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // operands in order ->(signal, frameStep, window, frameLength*) + SmallVector operands; + int64_t onesided; + Torch::ValueTensorType resultType; + + if (binder.tensorOperandsList(operands) || + binder.s64IntegerAttr(onesided, "onesided", 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value signal = operands[0]; + Value frameStep = operands[1]; + auto signalTy = cast(signal.getType()); + auto signalShape = signalTy.getSizes(); + auto resultShape = resultType.getSizes(); + + // There are two possible cases for optional inputs frameLength and + // window, which are that either 4 operands will be passed with window + // being !torch.none, or three operands will be passed, with window + // present and frameLength absent. In the former case, we simply create + // a rectangular window consisting of ones, and in the latter, we set + // frameLength equal to the the inputShape[-2] or windowShape[0] + // depending upon whether window was present or not. Note that it is + // possible that both window and frameLength can be none, which would + // mean that either only two operands were passed, or, in case of three + // operands, window was passed in as none, and frameLength was absent. + Value window = nullptr, frameLength = nullptr; + bool windowIsNone = true, frameLengthIsNone = true; + if (operands.size() == 3) { + window = operands[2]; + windowIsNone = isa(window.getType()); + } + if (operands.size() == 4) { + window = operands[2]; + frameLength = operands[3]; + windowIsNone = isa(window.getType()); + frameLengthIsNone = isa(frameLength.getType()); + } + + ArrayRef windowShape; + if (frameLengthIsNone) { + if (windowIsNone) { + frameLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + signalShape[signalShape.size() - 2])); + } else { + windowShape = + cast(window.getType()).getSizes(); + frameLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + } + } + + Value frameLengthItem; + if (!frameLengthIsNone || windowIsNone) { + frameLengthItem = + getItemOp(binder, rewriter, frameLength); + } else { + frameLengthItem = frameLength; + } + Value frameStepItem = + getItemOp(binder, rewriter, frameStep); + + if (windowIsNone) { + auto onesResultTy = rewriter.getType( + ArrayRef({-1}), signalTy.getDtype()); + + Value none = rewriter.create(binder.getLoc()); + Value sizes = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + SmallVector{frameLengthItem}); + window = rewriter.create( + binder.getLoc(), onesResultTy, sizes, none, none, none, none); + } + + FailureOr complexDtype; + if (signalTy.getDtype().isBF16()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support for bfloat16 type is unimplemented."); + } + if (signalTy.getDtype().isF16()) { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexHalf); + } else if (signalTy.getDtype().isF32()) { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexFloat); + } else { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexDouble); + } + + auto complexSignalTy = rewriter.getType( + ArrayRef({signalShape[0], signalShape[1]}), + complexDtype.value()); + + // The onnx STFT op always passes in a float input, and if the input + // is intended to be complex, its shape will be [batch][length][2], + // where [...][0] is the real component, and [...][1] is the complex + // component. This complex input has to be made torch compatible before + // being passed into torch.stft, so it is necessary to call + // AtenViewAsComplexOp. In case of real input, the shape of the signal + // will be [batch][length][1], and therefore it will have to be squeezed + // at dim=2, before being passed into torch.stft. + if (signalShape[2] == 2) { + signal = rewriter.create( + binder.getLoc(), complexSignalTy, signal); + } else { + Value two = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + auto newSignalTy = signalTy.getWithSizesAndDtype( + ArrayRef({signalShape[0], signalShape[1]}), + signalTy.getDtype()); + signal = rewriter.create( + binder.getLoc(), newSignalTy, signal, two); + } + + // In case the window is not given, we use frameLength + // as the length of the window. + Value windowLen; + if (!windowIsNone) { + windowLen = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + } else { + windowLen = frameLengthItem; + } + + Value falseVal = + rewriter.create(binder.getLoc(), false); + Value trueVal = + rewriter.create(binder.getLoc(), true); + auto stftTy = complexSignalTy.getWithSizesAndDtype( + ArrayRef({resultShape[0], resultShape[2], resultShape[1]}), + complexSignalTy.getDtype()); + + // After torch.stft is called and the result is stored into the value + // stft, there is one thing to note: The resultType for the onnx op + // will have shape [batch][num_frames][length][2], while the shape of + // stft will be [batch][length][num_frames]. Before the value is + // converted to real through torch.view_as_real, we must permute the + // shape of stft to match the shape of resultType. Also, it is + // immaterial whether torch.view_as_real is called after or before the + // permutation; both outputs will be equivalent. + Value stft = rewriter.create( + binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem, + windowLen, window, falseVal, onesided ? trueVal : falseVal, + trueVal); + + auto permuteStftTy = complexSignalTy.getWithSizesAndDtype( + ArrayRef({resultShape[0], resultShape[1], resultShape[2]}), + complexSignalTy.getDtype()); + Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1}); + Value permutedStft = rewriter.create( + binder.getLoc(), permuteStftTy, stft, permuteDims); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, permutedStft); + return success(); + }); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e9147d5853ec..537d3b6198a4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10143,6 +10143,125 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.optional) {\n" +" %24 = torch.derefine %none : !torch.none to !torch.optional\n" +" torch.prim.If.yield %24 : !torch.optional\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.derefine %24 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %25 : !torch.optional\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %9 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %24 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" %24 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %24 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.prim.ListConstruct : () -> !torch.list\n" +" %15 = torch.aten.__isnot__ %5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" %24 = torch.prim.unchecked_cast %5 : !torch.optional -> !torch.int\n" +" %25 = torch.aten.append.t %14, %24 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.__is__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.bool\n" +" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %17 -> () {\n" +" %24 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %14, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %24 = torch.aten.append.t %14, %arg1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.floordiv.int %18, %10 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %int1, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %14, %20 : !torch.list, !torch.int -> !torch.list\n" +" %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %24 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" %24 = torch.aten.append.t %14, %int2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %14 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" @@ -11607,6 +11726,143 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" +" %none = torch.constant.none\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %7 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %7 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %11 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %11 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n" +" %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n" +" } else {\n" +" %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %15 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n" +" %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n" +" } else {\n" +" %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %19 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %19 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 018377e45c16..e77a1978b101 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1976,6 +1976,35 @@ def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self +@check_shape_function([ + Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window. +]) +def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Optional[List[int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> List[int]: + assert len(self) == 1 or len(self) == 2, "Expected input tensor to be of shape (B?,L), where B is an optional batch dimension" + + batch = None if len(self) == 1 else self[0] + length = self[0] if len(self) == 1 else self[1] + hop_length = (n_fft // 4) if hop_length is None else hop_length + assert n_fft > 0 and n_fft <= length, "Expected that 0 < n_fft <= len" + assert hop_length > 0, "Expected hop_length to be greater than 0" + + out: List[int] = [] + if batch is not None: + out.append(batch) # (B?,) + + if onesided is None or onesided == True: + out.append(n_fft//2 + 1) + else: + out.append(n_fft) # (B?,N,) + + # For this operator, center=False by default + out.append(1 + (length - n_fft)//hop_length) #(B?,N,T,) + + if return_complex is not None and bool(return_complex) == False: + out.append(2) # a length-2 dimension of real and imaginary components. This gives output shape (B?,N,T,C?). + + return out + class DummyClassType: def __init__(self): pass @@ -3307,6 +3336,37 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = else: assert False, "Unsupported dtype" +@check_dtype_function([ + Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32 + Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64 + Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=True), # output dtype = torch.complex64 + Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=False), # output dtype = torch.float32 +]) +def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window_rank_dtype: Optional[Tuple[int, int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype) and return_complex is not None and return_complex: + return self_dtype + elif is_complex_dtype(self_dtype) and return_complex is not None and return_complex != True: + if self_dtype == torch.complex32: + return torch.float16 + elif self_dtype == torch.complex64: + return torch.float32 + elif self_dtype == torch.complex128: + return torch.float64 + elif is_float_dtype(self_dtype) and return_complex is not None and return_complex: + if self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_float_dtype(self_dtype) and return_complex is not None and return_complex != True: + return self_dtype + elif is_integer_dtype(self_dtype): + return torch.complex64 + + assert False, "Unsupported dtype" + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 9cf8b2602964..b21362f7c8ef 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -921,6 +921,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)", has_verifier=True, ) + emit( + "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)" + ) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8e37e1d83202..445d54c8697f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2583,3 +2583,52 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: ! %0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> return %0 : !torch.vtensor<[1,1,4,6],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_stft +func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list + // CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32> + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_stft_with_window +func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16 + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16 + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +}