diff --git a/docs/development.md b/docs/development.md index 56ae3dbf0728..154b398f1ca1 100644 --- a/docs/development.md +++ b/docs/development.md @@ -71,10 +71,10 @@ cmake -GNinja -Bbuild \ `# use ccache to cache build results` \ -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ `# use LLD to link in seconds, rather than minutes` \ - `# if using clang <= 13, replace --ld-path=lld with -fuse-ld=lld` \ - -DCMAKE_EXE_LINKER_FLAGS_INIT="--ld-path=lld" \ - -DCMAKE_MODULE_LINKER_FLAGS_INIT="--ld-path=lld" \ - -DCMAKE_SHARED_LINKER_FLAGS_INIT="--ld-path=lld" \ + `# if using clang <= 13, replace --ld-path=ld.lld with -fuse-ld=lld` \ + -DCMAKE_EXE_LINKER_FLAGS_INIT="--ld-path=ld.lld" \ + -DCMAKE_MODULE_LINKER_FLAGS_INIT="--ld-path=ld.lld" \ + -DCMAKE_SHARED_LINKER_FLAGS_INIT="--ld-path=ld.lld" \ `# Enabling libtorch binary cache instead of downloading the latest libtorch everytime.` \ `# Testing against a mismatched version of libtorch may cause failures` \ -DLIBTORCH_CACHE=ON \ diff --git a/externals/llvm-project b/externals/llvm-project index 18808c7be688..f320c79aae1f 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 18808c7be688436de2bedcae13d27250f29d49a8 +Subproject commit f320c79aae1f06fbeb2908ce1ac1b8dad119b5ad diff --git a/externals/stablehlo b/externals/stablehlo index c44d9af8d487..dd48ec58d3bb 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit c44d9af8d4879adccf1054cb61a53377ae5898cb +Subproject commit dd48ec58d3bb8d674adf56715d4394102538fa84 diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index b214e147d5d9..dd7cfb5c428f 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); /// Gets the !torch.quint8 typeid. MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void); +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.qint16 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t); + +/// Gets the !torch.qint16 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context); + +/// Gets the !torch.qint16 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 0de85f4eebe5..f296b6dfee5c 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -110,6 +110,18 @@ struct OpBinder { return success(); } + ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) { + if (idx >= op->getNumOperands()) + return failure(); + valueIdx = op->getOperand(idx); + auto tt = dyn_cast(valueIdx.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorListResultType(Torch::ListType &type0) { if (op->getNumResults() != 1) return failure(); diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 97d004e367ba..4bf6c845c68a 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, SmallVector cstInput); -Type getQTorchTypeFromTorchIntType(Type ty); +Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty); template Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, @@ -96,6 +96,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { std::optional onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx); +LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, + Location loc, Value input, int64_t dimA, + int64_t dimB, Value &transposed); + +LogicalResult createTorchPermuteOp(OpBinder binder, + ConversionPatternRewriter &rewriter, + Location loc, Value input, + SmallVector permuteDims, + Value &permuted); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 802dad692e31..6f385be3cb4b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -256,6 +256,106 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [ }]; } +def Torch_AtenRreluOp : Torch_Op<"aten.rrelu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRreluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_ : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRrelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRrelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenSeluOp : Torch_Op<"aten.selu", [ AllowsTypeRefinement, HasValueSemantics, @@ -4810,53 +4910,6 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ }]; } -def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenCeluOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchOptionalNonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenCelu_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - def Torch_AtenRealOp : Torch_Op<"aten.real", [ AllowsTypeRefinement, ReadOnly @@ -6766,6 +6819,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ }]; } +def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, @@ -6854,6 +6932,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [ }]; } +def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, @@ -16197,11 +16302,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`"; + let summary = "Generated op for `prims::var : (Tensor, int[]?, float?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$inp, AnyTorchOptionalListOfTorchIntType:$dims, - Torch_FloatType:$correction, + AnyTorchOptionalFloatType:$correction, AnyTorchOptionalIntType:$output_dtype ); let results = (outs diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 65f514c2ede9..03563287883c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -11,6 +11,7 @@ #define TORCH_OPS include "torch-mlir/Dialect/Torch/IR/TorchTypes.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -1337,4 +1338,67 @@ def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Symbolic shape modeling ops for TorchDynamo frontend. +//===----------------------------------------------------------------------===// + +def Torch_SymbolicIntOp : Torch_Op<"symbolic_int", [Pure]> { + let summary = "Symbolic int representing a dynamic dimension"; + let description = [{ + The `torch.symbolic_int` operation captures a dynamic dimension on the + global function arguments as exported by TorchDynamo (torch.export). + It associates the shape symbols (i.e. "s0", "s1") with the + global SSA values (i.e. `%0`, `%1`) that is then referenced + to bind shapes on op results. + + Additionally, the operation annotates `min_val` and `max_val` attributes + denoting the range constraints for the dynamic dimension. This may be + useful for modeling runtime shape guards, or compile-time optimizations + based on the shape bounds (min, opt, max) on results of ops / regions. + + Example: + ``` + %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 2, max_val = 20} : !torch.int + ``` + }]; + let arguments = (ins + StrAttr:$symbol_name, + I64Attr:$min_val, + I64Attr:$max_val + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = [{ + $symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result) + }]; +} + +def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> { + let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols"; + let description = [{ + The `torch.bind_symbolic_shape` operation binds shape expressions + useful to compute the dynamic dimensions of a tensor. It takes a + variadic of SSA symbols that map 1:1 to the local symbols declared + in the affine map. The affine map contains a list of affine shape + expressions for each dim where the terminals are from the declared + symbols. + + Example: + ``` + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> + ``` + }]; + let arguments = (ins + Torch_ValueTensorType:$operand, + Variadic:$shape_symbols, + Builtin_AffineMapAttr:$shape_expressions + ); + let results = (outs); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + #endif // TORCH_OPS diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 279e694540f9..367b08610cd8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> { }]; } +def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> { + let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist"; + let description = [{ + Pytorch does not have 16-bit integer quantization support. + + This torch type is added to provide a target for 16-bit quantization + schemes coming from imported onnx models. + }]; +} + def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { let summary = "Type modeling `ScalarType::QUInt8`"; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 043dd92549b2..e2b57538d7e6 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -86,24 +86,34 @@ enum class TypeKind { // at:: and c10:: parts of the macro are never used within the compiler -- we // only use this for the enum values. #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(c10::complex, ComplexHalf) /* 8 */ \ - _(c10::complex, ComplexFloat) /* 9 */ \ - _(c10::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(c10::qint8, QInt8) /* 12 */ \ - _(c10::quint8, QUInt8) /* 13 */ \ - _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ - _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(c10::qint16, QInt16) /* 27 */ enum class ScalarType : int8_t { #define DEFINE_ENUM(_1, n) n, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index f31795bd9233..79b14f478459 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -161,6 +161,10 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, // Torch flags, user options, etc). Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); +LogicalResult getPermutedType(BaseTensorType inType, + SmallVector permuteDims, + Type &permutedType); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td index bbc176feb4d4..f7bb2775385b 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -25,9 +25,7 @@ class TorchConversion_Op traits = []> // Conversions to backend types. //===----------------------------------------------------------------------===// -def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [ - DeclareOpInterfaceMethods - ]> { +def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor"> { let summary = "Convert a `!torch.vtensor` to a `tensor`"; let description = [{ This op only operates on ValueTensorType, to avoid conflating conversions diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h index de188b4f4e8f..b0a085eab7f0 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h @@ -25,6 +25,11 @@ void getBackendTypeConversionDependentDialects(DialectRegistry ®istry); /// boundary (which currently consist only of builtin types). void setupBackendTypeConversion(ConversionTarget &target, TypeConverter &typeConverter); + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +void setupBackendTypeConversionForStablehlo(ConversionTarget &target, + TypeConverter &typeConverter); +#endif } // namespace TorchConversion } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index 5aa1f0688c7c..4613d518fe53 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -63,6 +63,13 @@ struct StablehloBackendPipelineOptions void createTorchBackendToStablehloBackendPipeline( OpPassManager &pm, const StablehloBackendPipelineOptions &options); + +std::unique_ptr> +createFuncBackendTypeConversionForStablehloPass(); + +std::unique_ptr> +createFinalizingBackendTypeConversionForStablehloPass(); + std::unique_ptr> createVerifyStablehloBackendContractPass(); #endif diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 73654c6f8034..690c53879075 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -21,6 +21,17 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu }]; } +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +def FuncBackendTypeConversionForStablehlo : Pass<"torch-func-backend-type-conversion-for-stablehlo", "ModuleOp"> { + let summary = "Convert functions to operate on builtin tensors for stablehlo backend"; + let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForStablehloPass()"; + let description = [{ + Partial type conversion pass analogous in scope to the upstream + `func-bufferize` pass. See details there. + }]; +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO + def FinalizingBackendTypeConversion : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { let summary = "Finalizes a partial conversion to builtin tensors"; @@ -32,6 +43,19 @@ def FinalizingBackendTypeConversion }]; } +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +def FinalizingBackendTypeConversionForStablehlo + : InterfacePass<"torch-finalizing-backend-type-conversion-for-stablehlo", "mlir::FunctionOpInterface"> { + let summary = "Finalizes a partial conversion to builtin tensors for stablehlo"; + let constructor = + "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()"; + let description = [{ + Analogous in scope to the upstream `finalizing-bufferize` pass. + See details there. + }]; +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO + def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { let summary = "Verifies conformity to the linalg-on-tensors backend contract"; let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; diff --git a/include/torch-mlir/InitAll.h b/include/torch-mlir/InitAll.h index 42eb3c6a1ffb..19b2c474d787 100644 --- a/include/torch-mlir/InitAll.h +++ b/include/torch-mlir/InitAll.h @@ -18,6 +18,9 @@ namespace torch { // Registers all dialects that this project produces and any dependencies. void registerAllDialects(mlir::DialectRegistry ®istry); +// Registers all necessary dialect extensions for this project +void registerAllExtensions(mlir::DialectRegistry ®istry); + // Registers dialects that may be needed to parse torch-mlir inputs and // test cases. void registerOptionalInputDialects(mlir::DialectRegistry ®istry); diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 399915459e40..6402e44a3701 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -269,6 +269,22 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { return wrap(Torch::QUInt8Type::getTypeID()); } +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +bool torchMlirTypeIsATorchQInt16(MlirType t) { + return isa(unwrap(t)); +} + +MlirType torchMlirTorchQInt16TypeGet(MlirContext context) { + return wrap(Torch::QInt16Type::get(unwrap(context))); +} + +MlirTypeID torchMlirTorchQInt16TypeGetTypeID() { + return wrap(Torch::QInt16Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c0b622005900..249a8ad4f104 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -13,6 +13,7 @@ set(LinkedLibs MLIRMemRefDialect MLIRSCFDialect MLIRTensorDialect + MLIRTensorInferTypeOpInterfaceImpl MLIRTosaDialect MLIRSupport diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index cb5affbbba27..dcb28129ae95 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -18,23 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; -static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, - Location loc, Value input, - int64_t dimA, int64_t dimB, - Value &transposed) { - Type transposedType; - if (failed(getTransposedType(cast(input.getType()), - dimA, dimB, transposedType))) - return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); - return success(); -} - namespace { LogicalResult windowFunctionImpl(OpBinder binder, ConversionPatternRewriter &rewriter, @@ -458,9 +441,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - for (int64_t i : padding) { + // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] + // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all + // axes x. + int64_t paddingSizeHalf = padding.size() / 2; + for (int64_t i = 0; i < paddingSizeHalf; ++i) { + // Check if onnx padding attribute is symmetric. + if (padding[i] != padding[i + paddingSizeHalf]) + return rewriter.notifyMatchFailure( + binder.op, "onnx padding attribute is not symmetric"); cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( @@ -754,7 +745,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::lowest())) return failure(); auto minSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); min = rewriter.create( binder.getLoc(), resultType, minSplatAttr); @@ -765,7 +756,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::max())) return failure(); auto maxSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); max = rewriter.create( binder.getLoc(), resultType, maxSplatAttr); @@ -846,7 +837,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector tensors; int64_t dim; @@ -878,7 +869,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_float") && !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(dtype, floatValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -889,7 +880,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_int") && !binder.s64IntegerAttr(intValue, "value_int", 0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getIntegerAttr(dtype, intValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -949,8 +940,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( for (auto intVal : intValues) { apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); } - auto attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(dtype), apValues); + auto attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues); rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -968,7 +959,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } - Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1051,23 +1041,94 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; + Value paddedInput = input; + Value paddingList; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } + paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstPadding); } else { + // ONNX offers pads in the format listing all starting dims, then all + // ending dims, e.g. {t, l, b, r} for conv2d. Torch by default accepts + // only starting dims, e.g. {t, l}. However, we can support padding at + // the beginning and end of each dimension by first performing + // torch.nn.functional.pad on the input. But this requires the pad + // values to be rearranged since torch pad() takes pads in the order + // rightmost dim start and end, then next to last, and so on, e.g. {l, + // r, t, b}. + bool matchedPads = true; for (unsigned i = 0; i < padding.size() / 2; i++) { if (padding[i] != padding[i + (padding.size() / 2)]) { - // TODO: Add support for different padding values for the - // beginning and ending along each spatial axis - return rewriter.notifyMatchFailure( - binder.op, - "unsupported conversion: padding values for the beginning " - "and ending along each spatial axis must be equal"); + matchedPads = false; + break; } - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + if (matchedPads) { + for (unsigned i = 0; i < padding.size() / 2; i++) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstPadding); + } else { + SmallVector padsRearrange; + SmallVector inputPaddingList; + for (uint32_t i = 0; i < padding.size() / 2; i++) { + padsRearrange.emplace_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + padsRearrange.emplace_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + padding[(padding.size() / 2) + i]))); + inputPaddingList.emplace_back( + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + } + // The conv op itself will have no padding since the actual padding + // is performed using the torch.pad preceding it. + paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + inputPaddingList); + Value padsSizeList = + rewriter + .create( + binder.getLoc(), + Torch::ListType::get( + rewriter.getType()), + padsRearrange) + .getResult(); + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("constant")); + Value constantValue; + auto inputTensorType = + cast(input.getType()); + if (isa(inputTensorType.getDtype())) + constantValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + if (isa(inputTensorType.getDtype())) + constantValue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + // Pad output shape must be computed explicitly from the pad values + SmallVector newInputShape(inputTensorType.getSizes()); + for (uint32_t i = 0; i < padding.size() / 2; i++) { + newInputShape[2 + i] += + padding[i] + padding[(padding.size() / 2) + i]; + } + auto padTy = rewriter.getType( + newInputShape, inputTensorType.getDtype()); + paddedInput = rewriter.create( + binder.getLoc(), padTy, input, padsSizeList, modeVal, + constantValue); } } for (int64_t i : dilations) { @@ -1082,10 +1143,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getI64IntegerAttr(0)); cstOutputPadding = {cstZero, cstZero}; - Value paddingList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstPadding); Value dilationsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), @@ -1112,7 +1169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( - binder.op, resultType, input, weight, bias, stridesList, + binder.op, resultType, paddedInput, weight, bias, stridesList, paddingList, dilationsList, transposed, outputPaddingList, cstGroup); return success(); @@ -1666,21 +1723,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "requires known result dtype"); if (scaleTy.getSizes().size() == 0 || (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { - Type qTy = operandTy.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { + auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); + if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); scale = rewriter.create( binder.getLoc(), rewriter.getType(), scale); zeropoint = rewriter.create( @@ -2223,9 +2271,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Extract the fill value and dtype // ONNX requires value attr to be a tensor if (!attr) { - attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDType), - rewriter.getFloatAttr(resultDType, 0.0)); + attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), + rewriter.getFloatAttr(resultDType, 0.0)); } // If its a dense resource attr we need to convert to a dense type: diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index e55756eb4305..3c6d82e103b5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -408,20 +408,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); - auto q = [&](Type qty) -> Type { - if (qty.isSignedInteger(8)) - return rewriter.getType(); - if (qty.isUnsignedInteger(8)) - return rewriter.getType(); - if (qty.isSignedInteger(32)) - return rewriter.getType(); - return {}; - }; + auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy); + auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy); - Type lhsQTy = rewriter.getType( - lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); - Type rhsQTy = rewriter.getType( - rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); + if (!lhsQTy || !rhsQTy) + return rewriter.notifyMatchFailure(binder.op, "failed to get qtype"); lhs = rewriter.create( binder.getLoc(), lhsQTy, lhs, scale, lhsZp); @@ -1918,4 +1909,82 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); + patterns.onOp( + "MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // TODO: Add support for `output_shape` arg. + if (binder.op->getNumOperands() == 3) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: output_shape arg is not supported"); + + Torch::ValueTensorType resultType; + Value data, indices; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "data/indices/resultType bind failure"); + std::optional maybeRank = Torch::getTensorRank(data); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + int64_t rank = *maybeRank; + int64_t spatial = rank - 2; + + if (rank <= 3 || rank > 5) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: MaxUnpool support " + "only present for rank 4/5 input"); + + if (!(resultType.hasSizes() && resultType.areAllSizesKnown())) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected result to have all shapes " + "statically known"); + + SmallVector resultShape(resultType.getSizes()); + Value resultShapeList = + createConstantIntList(binder, rewriter, resultShape); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList); + return success(); + } + + SmallVector padding, strides; + if (binder.s64IntegerArrayAttr(padding, "pads", {})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (!padding.empty() && + padding.size() != static_cast(2 * spatial)) + return rewriter.notifyMatchFailure( + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); + if (binder.s64IntegerArrayAttr(strides, "strides", {})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (!strides.empty() && strides.size() != static_cast(spatial)) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + + if (padding.empty()) + padding.resize(spatial, 0); + if (strides.empty()) + strides.resize(spatial, 1); + + // If the padding is symmetric we can push the padding + // operation to the torch operator. + if (padding.size() == static_cast(2 * spatial)) { + bool equal = true; + for (int i = 0; i < spatial; ++i) { + equal = equal && (padding[i] == padding[i + spatial]); + } + if (equal) + padding.resize(spatial); + } + + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList, stridesList, + paddingList); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4b06c185bc39..edcbaa7d5173 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -177,22 +177,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "requires known result dtype"); if (scaleTy.getSizes().size() == 0) { - Type qTy = resultType.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { + auto qTensorTy = getQTorchTypeFromTorchIntType(resultType); + if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); - auto torchqTy = Torch::getScalarTypeForType(qTy); + auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype()); Value tyConst = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -311,8 +302,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter.create(binder.getLoc(), cTy, c); - cTy = dyn_cast( - getQTorchTypeFromTorchIntType(resultType)); + cTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( @@ -2963,28 +2953,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Torch::ValueTensorType resultType; llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; + int64_t antialias, exclude_outside; + float extrapolation_value; Value noneVal = rewriter.create(binder.getLoc()); - if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: support not present for antialias attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for axes attribute"); } - if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "exclude_outside attribute"); - } - if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "extrapolation_value attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { return rewriter.notifyMatchFailure( @@ -2997,17 +2974,40 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || - binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "")) + binder.s64IntegerAttr(antialias, "antialias", 0) || + binder.s64IntegerAttr(exclude_outside, "exclude_outside", 0) || + binder.f32FloatAttr(extrapolation_value, "extrapolation_value", + 0.0) || + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", + "round_prefer_floor")) return failure(); + if (antialias != 0) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (exclude_outside != 0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (extrapolation_value != 0.0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } if (coordTfMode == "tf_crop_and_resize") return rewriter.notifyMatchFailure( binder.op, "unimplemented: coordinate transformation mode: " "tf_crop_and_resize"); - if (mode == "nearest" && nearest_mode != "floor") { + + if (mode == "nearest" && coordTfMode != "asymmetric" && + coordTfMode != "half_pixel") { return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for nearest_mode " - "except floor"); + binder.op, "unimplemented: support not present for coord tf mode " + "except asymmetric and half_pixel"); } + unsigned rank = dyn_cast(operands[0].getType()) .getSizes() .size(); @@ -3109,6 +3109,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // apparently asymmetric if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") modeStr = (modeStr + "_") + coordTfMode; + if (nearest_mode != "floor" && nearest_mode != "") + modeStr = modeStr + "," + nearest_mode; modeStrValue = rewriter.create(binder.getLoc(), modeStr); } @@ -3134,4 +3136,308 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*Torch_BoolType:$antialias*/ cstFalse); return success(); }); + patterns.onOp( + "SpaceToDepth", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t blockSize; + std::string mode; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(blockSize, "blocksize") || + binder.customOpNameStringAttr(mode, "mode", "DCR") || + binder.tensorResultType(resultType)) + return failure(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + SmallVector inputSizes{inputTy.getSizes()}; + if (inputSizes.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input rank to be 4"); + } + + Value b = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + Value c = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1))); + Value h = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2))); + Value w = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value hDivBlockSize = rewriter.create( + binder.getLoc(), h, cstBlockSize); + Value wDivBlockSize = rewriter.create( + binder.getLoc(), w, cstBlockSize); + hDivBlockSize = rewriter.create(binder.getLoc(), + hDivBlockSize); + wDivBlockSize = rewriter.create(binder.getLoc(), + wDivBlockSize); + + // The implementation is as follows: + // tmp = np.reshape( + // x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize] + // ) + // tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) + // y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // + // blocksize]) + Value reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, c, hDivBlockSize, cstBlockSize, + wDivBlockSize, cstBlockSize}); + int64_t hDivBlockSizeInt = inputSizes[2] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[2] / blockSize; + int64_t wDivBlockSizeInt = inputSizes[3] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[3] / blockSize; + SmallVector reshapeSizesInt{inputSizes[0], inputSizes[1], + hDivBlockSizeInt, blockSize, + wDivBlockSizeInt, blockSize}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); + + SmallVector permuteDimsInt{0, 3, 5, 1, 2, 4}; + Value permutedInput; + if (failed(createTorchPermuteOp(binder, rewriter, binder.getLoc(), + reshapedInput, permuteDimsInt, + permutedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create Torch Permute op"); + + Value cMulBlockSizeSquare = rewriter.create( + binder.getLoc(), c, cstBlockSizeSquare); + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cMulBlockSizeSquare, hDivBlockSize, + wDivBlockSize}); + rewriter.replaceOpWithNewOp( + binder.op, resultType, permutedInput, reshapeSizesList); + return success(); + }); + patterns.onOp( + "Shrink", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value input; + float bias, lambd; + if (binder.tensorOperand(input) || + binder.f32FloatAttr(bias, "bias", 0.0) || + binder.f32FloatAttr(lambd, "lambd", 0.5) || + binder.tensorResultType(resultType)) { + return failure(); + } + + Torch::ValueTensorType inputType = + cast(input.getType()); + if (!isa(inputType.getDtype())) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: non-floating point dtype"); + + // The formula of this operator is: If x < -lambd, y = x + bias; If x > + // lambd, y = x - bias; Otherwise, y = 0. + // The implementation is based on the following algorithm: + // Shrink (input) => (output) + // { + // Lambd = Constant () + // LambdCast = CastLike (Lambd, input) + // Bias = Constant () + // BiasCast = CastLike (Bias, input) + // Zero = Constant () + // ZeroCast = CastLike (Zero, input) + // NegLmbda = Neg (LambdCast) + // InputLessThanNegLambda = Less (input, NegLmbda) + // InputAddBias = Add (input, BiasCast) + // InputSubBias = Sub (input, BiasCast) + // LambdaLessThanInput = Less (LambdCast, input) + // InputSubBiasOrZero = Where (LambdaLessThanInput, InputSubBias, + // ZeroCast) output = Where (InputLessThanNegLambda, InputAddBias, + // InputSubBiasOrZero) + // } + Value constLambd = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), lambd)); + Value constBias = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), bias)); + Value constZero = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + Value constOne = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); + Value constNegLambd = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), -lambd)); + + Value inputLTNegLambd = rewriter.create( + loc, inputType, input, constNegLambd); + Value inputPlusBias = rewriter.create( + loc, inputType, input, constBias, /*alpha=*/constOne); + Value inputSubBias = rewriter.create( + loc, inputType, input, constBias, /*alpha=*/constOne); + Value inputGTLambd = rewriter.create( + loc, inputType, input, constLambd); + + Value inputSubBiasOrZero = + rewriter.create( + loc, resultType, inputGTLambd, inputSubBias, constZero); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputLTNegLambd, inputPlusBias, + inputSubBiasOrZero); + return success(); + }); + patterns.onOp("SequenceAt", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(position, 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value index = rewriter.create( + binder.getLoc(), rewriter.getType(), + position); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, index); + return success(); + }); + patterns.onOp( + "SequenceEmpty", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + int64_t dtypeIntOnnx; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.tensorListResultType(resultType)) + return failure(); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, {}); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType.getContainedType(), shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, llvm::SmallVector{self}); + return success(); + }); + patterns.onOp( + "SequenceErase", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorListResultType(resultType)) + return failure(); + + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), inputSequence); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the last tensor from the list has to be erased. + Value lengthMinusOne = rewriter.create( + binder.getLoc(), length, cstOne); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, /*start=*/cstNone, + /*end=*/lengthMinusOne, /*step=*/cstOne); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 1)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + // Handling negative position value. + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value isPositionNegative = rewriter.create( + binder.getLoc(), positionInt, cstZero); + isPositionNegative = rewriter.create( + binder.getLoc(), isPositionNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isPositionNegative, length); + positionInt = rewriter.create( + binder.getLoc(), positionInt, finalOffset); + + Value listBeforePosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, /*start=*/cstNone, + /*end=*/positionInt, /*step=*/cstOne); + Value positionPlusOne = rewriter.create( + binder.getLoc(), positionInt, cstOne); + Value listAfterPosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, + /*start=*/positionPlusOne, + /*end=*/length, /*step=*/cstOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, listBeforePosition, listAfterPosition); + return success(); + }); + patterns.onOp( + "SequenceInsert", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position, insertValue; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(insertValue, 1) || + binder.tensorListResultType(resultType)) + return failure(); + + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the tensor has to be inserted at the end of the list. + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputSequence); + rewriter.replaceOpWithNewOp( + binder.op, inputSequence, /*idx=*/length, + /*el=*/insertValue); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 2)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + rewriter.create(binder.getLoc(), inputSequence, + /*idx=*/positionInt, + /*el=*/insertValue); + rewriter.replaceOp(binder.op, inputSequence); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index dec13490666e..bec6ade4270c 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -28,7 +28,8 @@ Value mlir::torch::onnx_c::createConstantIntList( cstValue); } -Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { +Torch::ValueTensorType +mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { Torch::ValueTensorType tty = dyn_cast(ty); if (!tty) return nullptr; @@ -40,6 +41,8 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { dty = Torch::QUInt8Type::get(ctx); if (dty.isSignedInteger(8)) dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(16)) + dty = Torch::QInt16Type::get(ctx); if (dty.isSignedInteger(32)) dty = Torch::QInt32Type::get(ctx); @@ -97,3 +100,33 @@ mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { return dtypeIntTorch; } + +LogicalResult mlir::torch::onnx_c::createTorchTransposeOp( + ConversionPatternRewriter &rewriter, Location loc, Value input, + int64_t dimA, int64_t dimB, Value &transposed) { + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} + +LogicalResult mlir::torch::onnx_c::createTorchPermuteOp( + OpBinder binder, ConversionPatternRewriter &rewriter, Location loc, + Value input, SmallVector permuteDims, Value &permuted) { + Type permutedType; + if (failed( + Torch::getPermutedType(cast(input.getType()), + permuteDims, permutedType))) + return failure(); + Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims); + permuted = rewriter.create(loc, permutedType, input, + permuteDimsList); + return success(); +} diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b9b0fb0ae5d7..dc8b5d431002 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2579,6 +2579,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { SmallVector ConvertSparseOperatorOp::legalizedNames = { "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc", + "torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr", + "torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc", }; } // namespace diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 373ed076551b..c2e89e078eca 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -146,12 +146,11 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } - auto resultTy = cast(op.getType()); - auto resultDTy = resultTy.toBuiltinTensor().getElementType(); - Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = cast(newResultType).getElementType(); - auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); - if (accumulatorDType != resultDTy) { + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + Type elementType = resultType.getElementType(); + auto accumulatorDType = getDefaultAccType(rewriter, elementType); + if (accumulatorDType != resultType.getElementType()) { elementType = accumulatorDType; } Value zeroFill = createZeroInitTensor( @@ -197,18 +196,16 @@ class ConvertAtenMmOp : public OpConversionPattern { .getResult(0); } - if (accumulatorDType != resultDTy) { - Type resultElementType = - cast(newResultType).getElementType(); + if (accumulatorDType != resultType.getElementType()) { matmul = torch_to_linalg::convertTensorToElementType( - rewriter, loc, matmul, resultElementType); + rewriter, loc, matmul, resultType.getElementType()); } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. - rewriter.replaceOpWithNewOp(op, newResultType, matmul); + rewriter.replaceOpWithNewOp(op, resultType, matmul); return success(); } @@ -829,7 +826,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "lhs and rhs of convolution must either be both int or fp"); } - if (inputZp && weightZp && !isa(bias.getType())) { + if (inputZp && !isa(bias.getType())) { auto biasDTy = cast(bias.getType()).getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( @@ -1123,7 +1120,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // - grouped 1d-3d // - grouped 1d-3d (quantized) // - ungrouped 1d-3d - if (groupSize == 1 && !inputZp && !weightZp) { + if (groupSize == 1 && !inputZp) { switch (numSpatialDims) { case 1: conv = rewriter @@ -1164,7 +1161,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (groupSize == 1 && inputZp && weightZp) { + if (groupSize == 1 && inputZp) { // The quantized version uses a different channel ordering so we need to // permute the tensors in order to use the existing path. We should // eventually directly support this channel ordering. @@ -1224,10 +1221,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (inputZp || weightZp) - return rewriter.notifyMatchFailure( - op, "unimplemented: quantized grouped convolutions"); - if (numSpatialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); @@ -1238,7 +1231,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto weightShape = makeShapeTorchCompatible( cast(weight.getType()).getShape()); if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && - weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { + weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) { // Collapse weight shape SmallVector collapsedDims = {{0, 1}, {2}, {3}}; SmallVector collapsedShape{ @@ -1325,13 +1318,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto expandOutputTensor = expandGroups(outputTensor, 1); // TODO: add 1D and 3D case - conv = rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weightExpanded}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) - .getResult(0); - + if (!inputZp) { + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + } else { + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded, inputZp, + weightZp}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + } conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 36fa9dc56f82..d80f3d4272e4 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -619,13 +619,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); - // If the padding is zero then there is no padding to include. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true"); - } - // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; SmallVector outTensorShape; @@ -635,9 +628,142 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); - // } - Value divisor = kernelSizeIntValues[0]; + // Compute the average of sumPool. + Value outputTensor = rewriter.create( + loc, getAsOpFoldResult(outTensorShape), resultElementType); + SmallVector indexingMapsAvg( + 2, rewriter.getMultiDimIdentityMap(Dim + 2)); + SmallVector iteratorTypesAvg( + Dim + 2, utils::IteratorType::parallel); + Value avgPool; + Value divisor; + // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. + if constexpr (std::is_same()) { + auto selfType = cast(self.getType()); + const int64_t selfRank = selfType.getRank(); + int64_t wDim = toPositiveDim(-1, selfRank); + int64_t hDim = toPositiveDim(-2, selfRank); + Value inputHeight = getDimOp(rewriter, loc, self, hDim); + Value inputWidth = getDimOp(rewriter, loc, self, wDim); + RankedTensorType sumPoolType = cast(sumPool.getType()); + const int64_t rank = sumPoolType.getRank(); + int dimH = toPositiveDim(-2, rank); + int dimW = toPositiveDim(-1, rank); + avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + // The algorithm for computing the divisor with + // count_include_pad is manily based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + Value indexOh = + b.create(loc, /*value=*/dimH); + Value oh = castIndexToInt64(b, loc, indexOh); + Value indexOw = + b.create(loc, /*value=*/dimW); + Value ow = castIndexToInt64(b, loc, indexOw); + + // int64_t ih0 = oh * dH - padH; + Value dH = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[0])); + Value padH = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[0])); + Value ohDH = b.create(loc, oh, dH); + Value ih0 = b.create(loc, ohDH, padH); + // int64_t iw0 = ow * dW - padW; + Value dW = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[1])); + Value padW = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[1])); + Value owDW = b.create(loc, ow, dW); + Value iw0 = b.create(loc, owDW, padW); + // int64_t ih1 = std::min(ih0 + kH, input_height + padH); + Value ih = castIndexToInt64(b, loc, inputHeight); + Value ih0KH = b.create( + loc, ih0, kernelSizeIntValues[0]); + Value ihPadH = b.create(loc, ih, padH); + Value ih1 = b.create(loc, ih0KH, ihPadH); + // int64_t iw1 = std::min(iw0 + kW, input_width + padW); + Value iw = castIndexToInt64(b, loc, inputWidth); + Value iw0KW = b.create( + loc, iw0, kernelSizeIntValues[1]); + Value iwPadW = b.create(loc, iw, padW); + Value iw1 = b.create(loc, iw0KW, iwPadW); + // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); + Value ih1Ih0 = b.create(loc, ih1, ih0); + Value iw1Iw0 = b.create(loc, iw1, iw0); + Value poolSize = + b.create(loc, ih1Ih0, iw1Iw0); + // ih0 = std::max(ih0, 0); + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value ih0Clamped = + b.create(loc, ih0, cstZero); + // iw0 = std::max(iw0, 0); + Value iw0Clamped = + b.create(loc, iw0, cstZero); + // ih1 = std::min(ih1, input_height); + Value ih1Clamped = b.create(loc, ih1, ih); + // iw1 = std::min(iw1, input_width); + Value iw1Clamped = b.create(loc, iw1, iw); + // if (divisor_override.has_value()) { + // divisor = divisor_override.value(); + // } else { + // if(count_include_pad) { + // divisor = pool_size; + // } else { + // divisor = (ih1 - ih0) * (iw1 - iw0); + // } + // } + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = + b.create(loc, ih1Clamped, ih0Clamped); + Value iw1_iw0 = + b.create(loc, iw1Clamped, iw0Clamped); + divisor = b.create(loc, ih1_ih0, iw1_iw0); + } + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + if constexpr (std::is_same()) { + if (!isa( + op.getDivisorOverride().getType())) + divisor = adaptor.getDivisorOverride(); + } + + divisor = convertScalarToDtype(b, loc, divisor, + resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); + } + + // TODO: Add support for count_include_pad equal to `False` in + // AtenAvgPool1/3dOp. + if (!countIncludePad && + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { + return rewriter.notifyMatchFailure( + op, "unimplemented: count_include_pad is expected to be true for " + "AtenAvgPool3dOp"); + } + + // Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`. + divisor = kernelSizeIntValues[0]; for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { divisor = rewriter.create(loc, divisor, kernelSizeIntValues[i]); @@ -648,29 +774,20 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { : adaptor.getDivisorOverride(); } divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); - - Value outputTensor = rewriter.create( - loc, getAsOpFoldResult(outTensorShape), resultElementType); - SmallVector indexingMapsAvg( - 2, rewriter.getMultiDimIdentityMap(Dim + 2)); - SmallVector iteratorTypesAvg( - Dim + 2, utils::IteratorType::parallel); - Value avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - + avgPool = rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index f7c40c147262..1330174699a5 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } -template -static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); - - Type lhsDtype = lhs.getType(); - Type rhsDtype = rhs.getType(); - - // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs - // to be handled. - if (lhsDtype != rhsDtype) { - op.emitError("unimplemented: lhs and rhs dtype must be same"); - return nullptr; - } - - Type elementalType = cast(op.getSelf().getType()).getDtype(); - if constexpr (std::is_same()) { - return createLessThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createNotEqual(b, loc, elementalType, lhs, rhs); - } - llvm_unreachable("unimplemented: op type not supported"); -} +template +struct is_any_same : std::disjunction...> {}; template -static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); +static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs, + Value rhs) { + static_assert( + is_any_same(), + "unimplemented: op type not supported"); Type lhsDtype = lhs.getType(); Type rhsDtype = rhs.getType(); @@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, return nullptr; } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createGreaterThan(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createNotEqual(b, loc, elementalType, lhs, rhs); } llvm_unreachable("unimplemented: op type not supported"); @@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, rhs); } if (auto ltTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]); } if (auto leTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, leTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]); } if (auto gtTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]); } if (auto geTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]); } if (auto neTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, neTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); @@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto gtScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]); } if (auto geScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]); } if (auto eqScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]); } if (auto neScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]); } if (auto ltScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]); } if (auto leScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]); } if (auto whereSelf = dyn_cast(op)) { @@ -1197,6 +1150,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; + Type inputElementType = + cast(atenToDtype.getSelf().getType()).getDtype(); Type dtype = cast(converter->convertType(atenToDtype.getType())) .getElementType(); @@ -1215,7 +1170,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } resultElementType = *maybeResultElementType; Value result = convertScalarToDtype(b, loc, input, dtype, - /*srcOriginalDtype=*/std::nullopt, + /*srcOriginalDtype=*/inputElementType, /*dstOriginalDtype=*/resultElementType); return result; } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 0aa919fe04a6..46b51558f13d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -559,6 +559,8 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) { return false; if (isa(type)) return true; + if (isa(type)) + return false; if (isa(type)) return false; if (auto intTy = dyn_cast(type)) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 715f89ff9063..4d75979027cf 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -277,8 +277,8 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto inputType = dyn_cast(adaptor.getA().getType()); if (!inputType) - op.emitError("only Tensor types supported in StableHLO"); + Location loc = op.getLoc(); Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); @@ -290,14 +290,24 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); + // handle unsigned interger + if (inputType.getElementType().isUnsignedInteger()) { + input = rewriter.create( + loc, input, + rewriter.getIntegerType( + inputType.getElementType().getIntOrFloatBitWidth())); + } + Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); Value result = rewriter.create(loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, - resultType, inputDtype)); + rewriter.replaceOp( + op, + convertScalarToDtype(rewriter, loc, result, resultType, inputDtype, + /*srcOriginalDtype=*/inputType.getElementType())); return success(); } }; diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index a551e0521852..05c52483c254 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -13,7 +13,9 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -900,7 +902,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (int64_t i = maxIndexRank; i < inputRank; ++i) { updateWindowDims.push_back(i); } - llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n"; auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), /*updateWindowDims=*/updateWindowDims, @@ -941,6 +942,412 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenGridSamplerOp +// See +// https://github.com/pytorch/pytorch/blob/ec58f1f74ebcec744d2ab90ad34abd09c1018e92/torch/_decomp/decompositions.py#L3923-L4086 +namespace { +template +static Value getConstantLike(OpBuilder &b, Location loc, T constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + auto getAttr = [&]() -> Attribute { + if (isa(ty)) + return b.getIntegerAttr(ty, constant); + if (isa(ty)) + return b.getFloatAttr(ty, constant); + if (auto complexTy = dyn_cast(ty)) + return complex::NumberAttr::get(complexTy, constant, 0); + llvm_unreachable("unhandled element type"); + }; + return b.create(loc, cast(getAttr()), + val); +} + +template +static Value getConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + ArrayRef values, ArrayRef shape, + Type ty) { + Location loc = op->getLoc(); + RankedTensorType valueType = RankedTensorType::get(shape, ty); + auto valueAttr = DenseElementsAttr::get(valueType, values); + return rewriter.create(loc, valueType, valueAttr); +} + +template +static Value getConstScalarTensor(ConversionPatternRewriter &rewriter, + Operation *op, T value, Type ty) { + return getConstTensor(rewriter, op, ArrayRef{value}, {}, ty); +} + +// Helper function to lower AtenGridSamplerOp. +static Value unnormalize(ConversionPatternRewriter &rewriter, Operation *op, + Value coords, int64_t size, Type elemTy, + bool alignCorners) { + Location loc = op->getLoc(); + APFloat pointFive(cast(elemTy).getFloatSemantics(), "0.5"); + APFloat sizeFloat = + APFloat(cast(elemTy).getFloatSemantics(), size); + APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); + APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); + + // double mul = alignCorners ? (size * 0.5 - 0.5) : (size * 0.5); + // double ofs = size * 0.5 - 0.5; + APFloat mul = + alignCorners ? sizeFloat * pointFive - pointFive : sizeFloat * pointFive; + APFloat ofs = sizeFloat * pointFive - pointFive; + Value constMul = getConstScalarTensor(rewriter, op, mul, elemTy); + Value constOfs = getConstScalarTensor(rewriter, op, ofs, elemTy); + + // use chlo::BroadcastMulOp to multiply constMul with coords. + DenseI64ArrayAttr bcastDimensions; + Value mulResult = rewriter.create(loc, coords, constMul, + bcastDimensions); + // use chlo::BroadcastAddOp to add constOfs to mulResult. + Value result = rewriter.create(loc, mulResult, constOfs, + bcastDimensions); + return result; +} + +static Value computeCoordinates(ConversionPatternRewriter &rewriter, + Operation *op, Value coords, int64_t size, + Type elemTy, int64_t padding_mode) { + // TODO: add support for padding_mode 1 and 2. + return coords; +} + +static Value computeSourceIndex(ConversionPatternRewriter &rewriter, + Operation *op, Value coords, int64_t size, + Type elemTy, int64_t padding_mode, + bool alignCorners) { + Value coordsUn = + unnormalize(rewriter, op, coords, size, elemTy, alignCorners); + return computeCoordinates(rewriter, op, coordsUn, size, elemTy, padding_mode); +} + +// def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: +// return torch.logical_and( +// 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys +// < iH)) +// ) +static Value inBoundsCond(ConversionPatternRewriter &rewriter, Operation *op, + Value xs, Value ys, int64_t ih, int64_t iw, + Type elemTy) { + Location loc = op->getLoc(); + APFloat zeroFloat = + APFloat(cast(elemTy).getFloatSemantics(), 0); + Value zero = getConstScalarTensor(rewriter, op, zeroFloat, elemTy); + APFloat iwFloat = + APFloat(cast(elemTy).getFloatSemantics(), iw); + APFloat ihFloat = + APFloat(cast(elemTy).getFloatSemantics(), ih); + + Value iwFloatValue = getConstScalarTensor(rewriter, op, iwFloat, elemTy); + Value ihFloatValue = getConstScalarTensor(rewriter, op, ihFloat, elemTy); + + chlo::ComparisonTypeAttr compareTypeAttr = chlo::ComparisonTypeAttr::get( + rewriter.getContext(), chlo::ComparisonType::FLOAT); + chlo::ComparisonDirectionAttr compareLTAttr = + chlo::ComparisonDirectionAttr::get(rewriter.getContext(), + chlo::ComparisonDirection::LT); + chlo::ComparisonDirectionAttr compareGEAttr = + chlo::ComparisonDirectionAttr::get(rewriter.getContext(), + chlo::ComparisonDirection::GE); + DenseI64ArrayAttr bcastDimensions; + Value cond1 = rewriter.create( + loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond2 = rewriter.create( + loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); + Value cond3 = rewriter.create( + loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond4 = rewriter.create( + loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); + Value cond5 = + rewriter.create(loc, cond1, cond2, bcastDimensions); + Value cond6 = + rewriter.create(loc, cond3, cond4, bcastDimensions); + return rewriter.create(loc, cond5, cond6, + bcastDimensions); +} +// def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: +// cond = in_bounds_cond(xs, ys) +// # To clip to inside valid coordinates, we map the coordinates +// # to (x, y) = (0, 0) and also set the weight to 0 +// # We also change the shape of the tensor to the appropriate one for +// # broadcasting with N_idx, C_idx for the purposes of advanced +// indexing c = C if _expand_grid else 1 +// return tuple( +// torch.where(cond, t, 0).view(N, c, oH, oW) +// for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) +// ) +SmallVector clip(ConversionPatternRewriter &rewriter, Operation *op, + Value xs, Value ys, Value ws, int64_t N, int64_t oH, + int64_t oW, int64_t iH, int64_t iW, Type elemTy) { + Location loc = op->getLoc(); + auto indexElemTy = rewriter.getI64Type(); + auto indexTy = RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); + + Value zeroIntValue = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, ArrayRef{0})); + + APFloat zeroAPFloat = + APFloat(cast(elemTy).getFloatSemantics(), 0); + Value zeroFloatValue = + getConstScalarTensor(rewriter, op, zeroAPFloat, elemTy); + Value cond = inBoundsCond(rewriter, op, xs, ys, iH, iW, elemTy); + Value xsInt = rewriter.create(loc, xs, indexElemTy); + Value ysInt = rewriter.create(loc, ys, indexElemTy); + + Value selectXs = rewriter.create( + loc, ArrayRef{cond, xsInt, zeroIntValue}); + Value selectYs = rewriter.create( + loc, ArrayRef{cond, ysInt, zeroIntValue}); + Value selectWs = rewriter.create( + loc, ArrayRef{cond, ws, zeroFloatValue}); + + SmallVector sizes = {N, 1, oH, oW}; + Value reshapedXs = rewriter.create( + loc, RankedTensorType::get(sizes, indexElemTy), selectXs); + Value reshapedYs = rewriter.create( + loc, RankedTensorType::get(sizes, indexElemTy), selectYs); + Value reshapedWs = rewriter.create( + loc, RankedTensorType::get(sizes, elemTy), selectWs); + return SmallVector{reshapedXs, reshapedYs, reshapedWs}; +} + +Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, + Value input, Value ix, Value iy, Value w, int64_t N, + int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx, + Value CIdx, RankedTensorType outType, Type elemTy) { + Location loc = op->getLoc(); + auto inputTensorType = cast(input.getType()); + SmallVector clipValues = + clip(rewriter, op, ix, iy, w, N, oH, oW, iH, iW, elemTy); + Value idxX = clipValues[0]; + Value idxY = clipValues[1]; + Value idxW = clipValues[2]; + SmallVector indexTensors{Nidx, CIdx, idxY, idxX}; + + int maxIndexRank = -1; + auto gatherIndicesInfo = + broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors, + outType.getShape(), maxIndexRank); + auto gatherIndices = *gatherIndicesInfo; + int64_t numIndicesDim = indexTensors.size(); + int64_t indexVecDim = maxIndexRank; + + SmallVector offsetDims; + SmallVector collapsedDims; + SmallVector startIndexMap; + for (int64_t i = 0; i < numIndicesDim; ++i) { + collapsedDims.push_back(i); + startIndexMap.push_back(i); + } + for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) { + offsetDims.push_back(i + maxIndexRank - numIndicesDim); + } + auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/offsetDims, + /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + SmallVector sliceSizes; + auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape()); + for (int64_t i = 0; i < inputTensorType.getRank(); ++i) { + if (i < numIndicesDim) { + sliceSizes.push_back(1); + } else { + sliceSizes.push_back(inputShape[i]); + } + } + + Value gather = rewriter.create( + loc, input, gatherIndices, dimsAttr, + rewriter.getDenseI64ArrayAttr(sliceSizes)); + // use chlo::BroadcastMulOp to multiply idxW with gather. + DenseI64ArrayAttr bcastDimensions; + return rewriter.create(loc, gather, idxW, + bcastDimensions); +} + +} // namespace +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGridSamplerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value grid = adaptor.getGrid(); + + int64_t interpolationMode; + if (!matchPattern(op.getInterpolationMode(), + m_TorchConstantInt(&interpolationMode))) + return rewriter.notifyMatchFailure( + op, "interpolation_mode must be an integer constant"); + int64_t paddingMode; + if (!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingMode))) + return rewriter.notifyMatchFailure( + op, "padding_mode must be an integer constant"); + + if (interpolationMode != 0 && interpolationMode != 1) + return rewriter.notifyMatchFailure( + op, "only support interpolation_mode = 0 (bilinear) or 1(nearest)"); + + if (paddingMode != 0) + return rewriter.notifyMatchFailure(op, + "only support paddingMode = 0 (Zero)"); + + bool alignCorners = false; + if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners))) + return rewriter.notifyMatchFailure( + op, "alignCorners must be a boolean constant"); + + RankedTensorType inputTy = cast(input.getType()); + RankedTensorType gridTy = cast(grid.getType()); + RankedTensorType outTy = + cast(getTypeConverter()->convertType(op.getType())); + Type elemTy = inputTy.getElementType(); + if (inputTy.getRank() != 4) + return rewriter.notifyMatchFailure(op, "input must be a 4D tensor"); + if (gridTy.getRank() != 4) + return rewriter.notifyMatchFailure(op, "grid must be a 4D tensor"); + + auto inputSize = inputTy.getShape(); + auto gridSize = gridTy.getShape(); + int64_t N = inputSize[0]; + int64_t C = inputSize[1]; + int64_t iH = inputSize[2]; + int64_t iW = inputSize[3]; + int64_t oH = gridSize[1]; + int64_t oW = gridSize[2]; + // grid is a 4D tensor with shape (N, oH, oW, 2) + + Type indexElemTy = rewriter.getI64Type(); + RankedTensorType indexTy = + RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); + Value constN = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, {N})); + Value constC = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, {C})); + APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); + APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); + + Value constOneFloat = getConstScalarTensor(rewriter, op, one, elemTy); + + auto NidxFlatten = rewriter.create( + loc, RankedTensorType::get(mlir::ArrayRef{N}, indexElemTy), + constN, 0); + auto CidxFlatten = rewriter.create( + loc, RankedTensorType::get(mlir::ArrayRef{C}, indexElemTy), + constC, 0); + + // Reshape NidxFlatten to 4D tensor (N, 1, 1, 1) + auto NidxSizes = mlir::SmallVector{N, 1, 1, 1}; + auto Nidx = rewriter.create( + loc, RankedTensorType::get(NidxSizes, indexElemTy), NidxFlatten); + + // Reshape CidxFlatten to 4D tensor (1, C, 1, 1) + auto CidxSizes = mlir::SmallVector{1, C, 1, 1}; + auto Cidx = rewriter.create( + loc, RankedTensorType::get(CidxSizes, indexElemTy), CidxFlatten); + + llvm::SmallVector stride(4, 1); + auto gridX = rewriter.create( + loc, + RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, + gridTy.getElementType()), + grid, mlir::SmallVector{0, 0, 0, 0}, + mlir::SmallVector{N, oH, oW, 1}, stride); + auto gridY = rewriter.create( + loc, + RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, + gridTy.getElementType()), + grid, mlir::SmallVector{0, 0, 0, 1}, + mlir::SmallVector{N, oH, oW, 2}, stride); + // squeeze last dimension + auto gridXshape = mlir::SmallVector{N, oH, oW}; + + auto gridXReshape = rewriter.create( + loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridX); + auto gridYReshape = rewriter.create( + loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridY); + + if (interpolationMode == 0) { + Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, + paddingMode, alignCorners); + Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, + paddingMode, alignCorners); + Value ix_nw = rewriter.create(loc, ix); + Value iy_nw = rewriter.create(loc, iy); + + DenseI64ArrayAttr bcastDimensions; + Value ix_ne = rewriter.create( + loc, ix_nw, constOneFloat, bcastDimensions); + Value iy_ne = iy_nw; + Value ix_sw = ix_nw; + Value iy_sw = rewriter.create( + loc, iy_nw, constOneFloat, bcastDimensions); + Value ix_se = ix_ne; + Value iy_se = iy_sw; + + // w_nw = (ix_se - ix) * (iy_se - iy) + // w_ne = (ix - ix_sw) * (iy_sw - iy) + // w_sw = (ix_ne - ix) * (iy - iy_ne) + // w_se = (ix - ix_nw) * (iy - iy_nw) + Value w_nw = rewriter.create( + loc, + rewriter.create(loc, ix_se, ix, bcastDimensions), + rewriter.create(loc, iy_se, iy, bcastDimensions), + bcastDimensions); + Value w_ne = rewriter.create( + loc, + rewriter.create(loc, ix, ix_sw, bcastDimensions), + rewriter.create(loc, iy_sw, iy, bcastDimensions), + bcastDimensions); + Value w_sw = rewriter.create( + loc, + rewriter.create(loc, ix_ne, ix, bcastDimensions), + rewriter.create(loc, iy, iy_ne, bcastDimensions), + bcastDimensions); + Value w_se = rewriter.create( + loc, + rewriter.create(loc, ix, ix_nw, bcastDimensions), + rewriter.create(loc, iy, iy_nw, bcastDimensions), + bcastDimensions); + + Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + + // summand_nw + summand_ne + summand_sw + summand_se + Value sum = rewriter.create(loc, summand_nw, summand_ne); + sum = rewriter.create(loc, sum, summand_sw); + sum = rewriter.create(loc, sum, summand_se); + rewriter.replaceOp(op, sum); + } else if (interpolationMode == 1) { + Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, + paddingMode, alignCorners); + Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, + paddingMode, alignCorners); + Value ix_round = rewriter.create(loc, ix); + Value iy_round = rewriter.create(loc, iy); + Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round); + Value summand = + getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH, + oW, iH, iW, Nidx, Cidx, outTy, elemTy); + rewriter.replaceOp(op, summand); + } + return success(); +} + void mlir::torch::torch_to_stablehlo:: populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -957,6 +1364,7 @@ void mlir::torch::torch_to_stablehlo:: INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenGridSamplerOp); #undef INSERT_ATENOP_PATTERN #define INSERT_ATEN_SCATTER_PATTERN(AtenOp, reduceType) \ diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b6e9d9ba90a8..82002292ec4a 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -591,25 +591,32 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto weightShape = weightTy.getShape(); auto nDims = inputTy.getRank(); + auto weightDims = weightTy.getRank(); + auto kernelDims = weightDims - 2; + auto nSpatialDims = nDims - 2; auto convOutTy = outType; // Transpose weight SmallVector perm(nDims); SmallVector transposeShape(nDims); - for (int i = 0; i < nDims; i++) { - if (i < 2) - perm[i] = nDims - 2 + i; + // 1d: kernelDims = 1, [0, 1, 2] => [2, 1, 0] + // 2d: kernelDims = 2, [0, 1, 2, 3] => [2, 3, 1, 0] + // 3d: kernelDims = 3, [0, 1, 2, 3, 4] => [2, 3, 4, 1, 0] + for (int i = 0; i < weightDims; i++) { + if (i < kernelDims) + perm[i] = 2 + i; else - perm[i] = nDims - i - 1; + perm[i] = kernelDims + 1 - i; transposeShape[i] = weightShape[perm[i]]; } + auto reverseDim = llvm::to_vector<4>(llvm::seq(0, kernelDims)); auto transposeTy = RankedTensorType::get(transposeShape, weightTy.getElementType()); auto transposeOp = rewriter.create( op->getLoc(), transposeTy, weight, perm); auto reverseOp = rewriter.create( - op->getLoc(), transposeOp, ArrayRef{0, 1}); + op->getLoc(), transposeOp, reverseDim); // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 6830e13f810a..ec9aa7a45493 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -51,7 +51,8 @@ class ConvertTorchToStablehlo TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversionForStablehlo(target, + typeConverter); RewritePatternSet patterns(context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 1669cb43fbc0..4b01d88223b7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, return nullptr; auto dty = resultTy.getDtype(); - auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + auto resultBTy = resultTy.toBuiltinTensor(); auto fpTy = dyn_cast(dty); auto intTy = dyn_cast(dty); @@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { if (!ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, return nullptr; auto ctx = lhs.getContext(); - auto resultETy = resultTy.getDtype(); auto tensorETy = cast(lhs.getType()).getElementType(); if (lhs.isSplat()) { if (auto intAttr = dyn_cast(rhs)) { @@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } return nullptr; } @@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = intFolder(tensorAP, scalarAP, unsign); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } return nullptr; @@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand, if (!fpTy && !intTy) return nullptr; - auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); + auto resultBTy = resultTy.toBuiltinTensor(); bool splat = operand.isSplat(); bool withinMaxFold = resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; @@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { return nullptr; auto selfTy = cast(self.getType()); - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -2671,8 +2666,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, if (!indicesTensorType.hasDtype()) return failure(); - auto indicesType = - indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + auto indicesType = indicesTensorType.toBuiltinTensor(); if (!indicesType || !indicesType.hasStaticShape()) return failure(); @@ -3627,9 +3621,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { return nullptr; if (input && input.isSplat()) - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), - input.getSplatValue()); + return DenseElementsAttr::get(outType.toBuiltinTensor(), + input.getSplatValue()); int count = 1; for (auto dim : outType.getSizes()) @@ -3667,8 +3660,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { for (int i = begin; i < limit; i += stride) values.push_back(input.getValues()[i]); - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), values); + return DenseElementsAttr::get(outType.toBuiltinTensor(), values); } // If the input and output shapes are the same we can just fold: @@ -4007,7 +3999,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); SmallVector data; if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && @@ -4028,7 +4020,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); int64_t data; if (matchPattern(getT(), m_TorchConstantInt(&data))) { @@ -4048,7 +4040,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); double data; if (matchPattern(getT(), m_TorchConstantFloat(&data))) { @@ -4221,7 +4213,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); - auto attrTy = resultTy.toBuiltinTensor().clone(dty); + auto attrTy = resultTy.toBuiltinTensor(); if (auto floatAttr = dyn_cast(splattr)) return DenseElementsAttr::get( attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); @@ -4414,7 +4406,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!valueDense.isSplat()) return nullptr; auto splattr = valueDense.getSplatValue(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, splattr); } @@ -4422,7 +4414,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; int64_t intval = intAttr.getInt(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); } @@ -4430,7 +4422,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; double dblval = fpAttr.getValueAsDouble(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); } @@ -5118,3 +5110,65 @@ LogicalResult InitializeGlobalSlotsOp::verify() { return emitOpError("expected number of operands to match number of slots"); return success(); } + +//===----------------------------------------------------------------------===// +// BindSymbolicShapeOp +//===----------------------------------------------------------------------===// + +// +// torch.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] -> +// (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> +// + +ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + SmallVector shapeSymbols; + AffineMapAttr shapeExpressions; + Type operandType; + + if (parser.parseOperand(operand) || parser.parseComma() || + parser.parseLSquare() || parser.parseOperandList(shapeSymbols) || + parser.parseRSquare() || parser.parseComma() || + parser.parseAttribute(shapeExpressions, "shape_expressions", + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(operandType)) { + return failure(); + } + + if (parser.resolveOperand(operand, operandType, result.operands) || + parser.resolveOperands(shapeSymbols, + parser.getBuilder().getType(), + result.operands)) { + return failure(); + } + + return success(); +} + +// Use a custom printer here to avoid the AffineMap from getting hoisted +// when printed. This makes it so the AffineMap is printed inline with the op. +void BindSymbolicShapeOp::print(OpAsmPrinter &p) { + p << " " << getOperand() << ", ["; + llvm::interleaveComma(getShapeSymbols(), p); + p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"shape_expressions"}); + p << " : " << getOperand().getType(); +} + +LogicalResult BindSymbolicShapeOp::verify() { + if (getShapeSymbols().empty()) + return emitOpError() << "requires non-empty shapeSymbols"; + + for (auto symbol : getShapeSymbols()) { + Operation *definingOp = symbol.getDefiningOp(); + if (!isa(definingOp)) { + return emitOpError() + << "shape symbol must be produced by a SymbolicIntOp"; + } + } + + return success(); +} diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 6735bb37e48b..c46865ee5fed 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -185,7 +185,8 @@ static bool isValidTorchDtype(Type dtype) { dtype = cast(dtype).getElementType(); } // Torch quantized types. - if (isa(dtype)) + if (isa(dtype)) return true; // Builtin floating point types. if (isa(dtype)) @@ -453,12 +454,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { } static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { - if (auto floatType = dyn_cast(dtype)) { - return dtype; - } else if (auto integerType = dyn_cast(dtype)) { - return IntegerType::get(context, integerType.getWidth(), - IntegerType::Signless); - } else if (isa(dtype)) { + if (isa(dtype)) { return dtype; } @@ -468,6 +464,9 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (isa(dtype)) return IntegerType::get(context, 8, IntegerType::Signless); + if (isa(dtype)) + return IntegerType::get(context, 16, IntegerType::Signless); + if (isa(dtype)) return IntegerType::get(context, 32, IntegerType::Signless); @@ -480,11 +479,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { TensorType ValueTensorType::toBuiltinTensor() const { if (!hasDtype()) return nullptr; - if (!hasSizes()) - return UnrankedTensorType::get(getDtype()); Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; + if (!hasSizes()) + return UnrankedTensorType::get(elementType); return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, getOptionalSparsity()); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 945c439d6000..92ed767adb26 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7074,6 +7074,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7130,7 +7134,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" " %none = torch.constant.none\n" " %false = torch.constant.bool false\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" @@ -9312,6 +9316,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose3d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list {\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10815,6 +10829,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -12016,10 +12050,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose1d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose3d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12800,7 +12842,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 2cbfe2642045..af937ac10b0e 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -62,7 +62,7 @@ class AdjustCallingConventionForFunc // TODO: add tuple type. conversion.addInputs(type.index(), type.value()); } - rewriter.applySignatureConversion(&func.getBody(), conversion, + rewriter.applySignatureConversion(&func.getBody().front(), conversion, typeConverter); SmallVector newResultTypes; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f2d54d8db7c4..19a04bcc1336 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2520,6 +2520,77 @@ class DecomposeAtenPreluOp : public OpRewritePattern { } // namespace +// rrelu = max(0, x) + min(0, alpha * x) +// if in training mode, the alpha is sampled from uniform distribution (lower, +// upper) if in testing mode, the alpha is (lower + upper) / 2 +namespace { +class DecomposeAtenRreluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + + Value alpha; + if (training) { + // Create a uniform random op with low and high set to `lower` and + // `upper`, respectively. + Value none = rewriter.create(loc); + Value emptyTensor = rewriter.create( + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); + } + + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + + Value scaledSelf; + if (training) { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } else { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } + + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, rreluOutput); + return success(); + } +}; +} // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -2585,7 +2656,36 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { auto weightedDelta = rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, inputType, start, + auto lerp = rewriter.create(loc, resType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenLerpTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = cast(start.getType()); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, resType, start, weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); @@ -3633,6 +3733,25 @@ class DecomposeAtenConv3dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv_transpose1d to aten.convolution +namespace { +class DecomposeAtenConvTranspose1dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose1dOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); + return success(); + } +}; +} // namespace + // Decompose aten.conv_transpose2d to aten.convolution namespace { class DecomposeAtenConvTranspose2dOp @@ -3652,6 +3771,25 @@ class DecomposeAtenConvTranspose2dOp }; } // namespace +// Decompose aten.conv_transpose3d to aten.convolution +namespace { +class DecomposeAtenConvTranspose3dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose3dInputOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); + return success(); + } +}; +} // namespace + // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ @@ -8085,7 +8223,9 @@ class DecomposeComplexOpsPass DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8118,6 +8258,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8196,6 +8337,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 38bc4d275bf1..5925dd07e185 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -378,7 +378,7 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, - QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index b73044c9bd40..2f748285ddd6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -428,7 +428,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -481,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -505,6 +508,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index c237ede12479..b571003940cb 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -21,10 +21,12 @@ using namespace mlir::torch::Torch; namespace { Type getQuantizedType(MLIRContext *context, Type t) { - if (t.isSignlessInteger(8)) + if (t.isSignlessInteger(8) || t.isUnsignedInteger(8)) return Torch::QUInt8Type::get(context); if (t.isInteger(8) || t.isSignedInteger(8)) return Torch::QInt8Type::get(context); + if (t.isInteger(16)) + return Torch::QInt16Type::get(context); if (t.isInteger(32)) return Torch::QInt32Type::get(context); return {}; diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef964c..c4c42f7fe0e3 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -21,7 +21,7 @@ static inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || - t == ScalarType::QUInt2x4; + t == ScalarType::QUInt2x4 || t == ScalarType::QInt16; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 883d7555bb3e..7ba3157b8986 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -112,6 +112,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::QUInt8; if (isa(type)) return torch_upstream::ScalarType::QInt8; + if (isa(type)) + return torch_upstream::ScalarType::QInt16; if (isa(type)) return torch_upstream::ScalarType::QInt32; if (isa(type)) { @@ -123,6 +125,14 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (complexElemType.isF64()) return torch_upstream::ScalarType::ComplexDouble; } + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fn; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2fnuz; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fnuz; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } Type Torch::getTypeForTorchType( @@ -163,6 +173,8 @@ Torch::getTypeForScalarType(MLIRContext *context, return QUInt8Type::get(context); case torch_upstream::ScalarType::QInt8: return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt16: + return QInt16Type::get(context); case torch_upstream::ScalarType::QInt32: return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: @@ -171,6 +183,14 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::ComplexType::get(Float32Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float64Type::get(context)); + case torch_upstream::ScalarType::Float8_e5m2: + return Float8E5M2Type::get(context); + case torch_upstream::ScalarType::Float8_e4m3fn: + return Float8E4M3FNType::get(context); + case torch_upstream::ScalarType::Float8_e5m2fnuz: + return Float8E5M2FNUZType::get(context); + case torch_upstream::ScalarType::Float8_e4m3fnuz: + return Float8E4M3FNUZType::get(context); case torch_upstream::ScalarType::Undefined: return failure(); default: @@ -613,6 +633,24 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, return success(); } +LogicalResult Torch::getPermutedType(BaseTensorType inType, + SmallVector permuteDims, + Type &permutedType) { + if (!inType.hasSizes()) + return failure(); + + SmallVector shape(inType.getSizes()); + if (shape.size() != permuteDims.size()) + return failure(); + + SmallVector permutedShape; + for (unsigned i = 0; i < shape.size(); i++) + permutedShape.push_back(shape[permuteDims[i]]); + permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape), + inType.getOptionalDtype()); + return success(); +} + Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { if (inputType.isF16()) return rewriter.getF32Type(); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index bd66bbe55330..3a667b81d942 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -23,7 +23,18 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) { if (lhs.hasRank() != rhs.hasRank()) return false; bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true; - bool sameElementType = lhs.getElementType() == rhs.getElementType(); + bool sameElementType = false; + // Namely, it is worth mentioning that the backends can have different + // expectations for signedness when converting from and to the builtin MLIR + // types. Therefore, the verifier cannot expect the input and output types to + // match in their signedness. + if (isa(lhs.getElementType()) && + isa(rhs.getElementType())) { + sameElementType = lhs.getElementType().getIntOrFloatBitWidth() == + rhs.getElementType().getIntOrFloatBitWidth(); + } else { + sameElementType = lhs.getElementType() == rhs.getElementType(); + } return sameElementType && sameSize; } @@ -42,18 +53,6 @@ LogicalResult ToBuiltinTensorOp::verify() { return success(); } -LogicalResult ToBuiltinTensorOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto resultType = - cast(operands[0].getType()).toBuiltinTensor(); - if (!resultType) - return failure(); - inferredReturnTypes.push_back(resultType); - return success(); -} - //===----------------------------------------------------------------------===// // FromBuiltinTensorOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 3c0ad51fb520..0f2533e063f0 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -23,22 +23,22 @@ void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects( // Type conversion setup. //===----------------------------------------------------------------------===// -static void -setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, - TypeConverter &typeConverter) { +using ValueTensorTypeConversionFn = + std::function(Torch::ValueTensorType)>; + +static void setupValueTensorToBuiltinTensorConversion( + ConversionTarget &target, TypeConverter &typeConverter, + const ValueTensorTypeConversionFn &conversionFn) { target.addLegalOp(); - typeConverter.addConversion( - [](Torch::ValueTensorType type) -> std::optional { - return type.toBuiltinTensor(); - }); + typeConverter.addConversion(conversionFn); typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); if (!isa(inputs[0].getType())) return {}; - return builder.create(loc, inputs[0]); + return builder.create(loc, type, inputs[0]); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::ValueTensorType type, @@ -162,9 +162,54 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, void mlir::torch::TorchConversion::setupBackendTypeConversion( ConversionTarget &target, TypeConverter &typeConverter) { - setupValueTensorToBuiltinTensorConversion(target, typeConverter); + auto valueTensorTypeConversion = + [](Torch::ValueTensorType type) -> std::optional { + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert any integer type to signless + if (type.getDtype().isInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; + }; + setupValueTensorToBuiltinTensorConversion(target, typeConverter, + valueTensorTypeConversion); + setupTorchBoolToI1Conversion(target, typeConverter); + setupTorchIntToI64Conversion(target, typeConverter); + setupTorchFloatToF64Conversion(target, typeConverter); + setupTorchGeneratorToI64Conversion(target, typeConverter); +} + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo( + ConversionTarget &target, TypeConverter &typeConverter) { + auto valueTensorTypeConversion = + [](Torch::ValueTensorType type) -> std::optional { + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert signed integer type to signless, keep unsigned as unsigned + if (type.getDtype().isUnsignedInteger()) { + return builtinType.clone(type.getDtype()); + } else if (type.getDtype().isSignedInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; + }; + setupValueTensorToBuiltinTensorConversion(target, typeConverter, + valueTensorTypeConversion); setupTorchBoolToI1Conversion(target, typeConverter); setupTorchIntToI64Conversion(target, typeConverter); setupTorchFloatToF64Conversion(target, typeConverter); setupTorchGeneratorToI64Conversion(target, typeConverter); } +#endif diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index b99ece8946dc..90767fb2ccb5 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -26,6 +26,32 @@ using namespace mlir::torch::TorchConversion; //===----------------------------------------------------------------------===// namespace { + +void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target) { + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addLegalOp(); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); +} + struct FuncBackendTypeConversionPass : public FuncBackendTypeConversionBase { using FuncBackendTypeConversionBase< @@ -43,31 +69,41 @@ struct FuncBackendTypeConversionPass typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - populateCallOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - target.addLegalOp(); - - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return isNotBranchOpInterfaceOrReturnLikeOp(op) || - isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter) || - isLegalForReturnOpTypeConversionPattern(op, typeConverter); - }); + populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target); if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } }; + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +struct FuncBackendTypeConversionForStablehloPass + : public FuncBackendTypeConversionForStablehloBase< + FuncBackendTypeConversionForStablehloPass> { + using FuncBackendTypeConversionForStablehloBase< + FuncBackendTypeConversionForStablehloPass>:: + FuncBackendTypeConversionForStablehloBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversionForStablehlo(target, + typeConverter); + + populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; +#endif // TORCH_MLIR_ENABLE_STABLEHLO } // namespace std::unique_ptr> @@ -75,6 +111,13 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { return std::make_unique(); } +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +std::unique_ptr> mlir::torch::TorchConversion:: + createFuncBackendTypeConversionForStablehloPass() { + return std::make_unique(); +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO + //===----------------------------------------------------------------------===// // FinalizingBackendTypeConversionPass //===----------------------------------------------------------------------===// @@ -170,9 +213,61 @@ struct FinalizingBackendTypeConversionPass stripTorchAttrs(func); } }; + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +struct FinalizingBackendTypeConversionForStablehloPass + : public FinalizingBackendTypeConversionForStablehloBase< + FinalizingBackendTypeConversionForStablehloPass> { + using FinalizingBackendTypeConversionForStablehloBase< + FinalizingBackendTypeConversionForStablehloPass>:: + FinalizingBackendTypeConversionForStablehloBase; + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversionForStablehlo(target, + typeConverter); + + // Mark materializations as illegal in this pass (since we are finalizing) + // and add patterns that eliminate them. + setupFinalization(target, patterns, typeConverter); + + // If all result types are legal, and all block arguments are legal, then + // all types in the program are legal. + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents the patterns from updating + // the types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + + // Drop attributes that are no longer used after conversion out of Torch. + stripTorchAttrs(func); + } +}; +#endif // TORCH_MLIR_ENABLE_STABLEHLO } // namespace std::unique_ptr> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +std::unique_ptr> mlir::torch:: + TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() { + return std::make_unique(); +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 3a829dae4e15..4b878c007006 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -153,10 +153,11 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline( // Finish the type conversion from `torch` types to the types of the // StableHLO backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addPass( + TorchConversion::createFuncBackendTypeConversionForStablehloPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); + TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()); // Verify that we have lowered to Stablehlo ops. pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 93ca5ef20372..f2013fdcb2f2 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Dialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" @@ -40,7 +41,11 @@ void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); +} + +void mlir::torch::registerAllExtensions(mlir::DialectRegistry ®istry) { mlir::func::registerInlinerExtension(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); } // TODO: Break this up when backends are separated. @@ -63,6 +68,7 @@ void mlir::torch::registerAllPasses() { mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); mlir::stablehlo::registerStablehloAggressiveSimplificationPass(); mlir::stablehlo::registerStablehloRefineShapesPass(); + mlir::stablehlo::registerStablehloConvertToSignlessPass(); #endif #ifdef TORCH_MLIR_ENABLE_REFBACKEND diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 184de282d85e..cc85d6491d6c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -30,6 +30,7 @@ "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateDynamicModule_scales_recompute_bilinear", + "ElementwiseFloatTensorGtIntTensorModule_basic", } LINALG_CRASHING_SET = { @@ -38,6 +39,7 @@ # Out of bounds access "ConvolutionModule2DTranspose_basic", "Conv_Transpose2dModule_basic", + "Conv_Transpose2dStaticModule_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTransposeStridedStatic_basic", } @@ -298,6 +300,7 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -398,6 +401,7 @@ "ContainsIntList_True", "Conv1dNoPaddingGroupModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", @@ -412,6 +416,10 @@ "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", @@ -569,6 +577,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", @@ -849,6 +858,8 @@ "SplitWithSizes_Module_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", + "EmptyModule_uint8", + "TypeConversionUint8ToF32Module_basic", "AtenLinear1D_basic", "AtenLinear2D_basic", "AtenLinear3DBias_basic", @@ -908,6 +919,7 @@ "Aten_CastLongModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", @@ -949,6 +961,9 @@ "Convolution2DGroupsStatic_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dStaticModule_basic", "ConstantPad2dStaticModule_basic", "ConstantPadNdModule_basic", "ConstantPadNdPartialStaticModule_basic", @@ -1050,12 +1065,15 @@ "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", @@ -1105,6 +1123,10 @@ "GeIntModule_basic", "GeluBackwardModule_basic", "GluStaticModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", "GtFloatIntModule_basic", "GtIntModule_basic", "IndexTensorModule3dInputStatic_basic", @@ -1511,6 +1533,7 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AvgPool2dCountIncludePadFalseStaticModule_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", "AtenLinear2D_basic", @@ -1525,6 +1548,7 @@ "AtenDotModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseLogSigmoidModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseTruncModule_basic", "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", @@ -1791,6 +1815,8 @@ "ElementwiseRemainderScalarModule_Float_basic", "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", @@ -2101,6 +2127,7 @@ "CumsumStaticNegativeDimModule_basic", "CumsumInputDtypeInt32Module_basic", "EyeStaticModule_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", "AtenLinear1D_basic", "AtenLinearMatVec_basic", "AtenLinearVecMatBias_basic", @@ -2155,6 +2182,10 @@ "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", # It appears that you're trying to get value out of a tracing tensor + # failed to legalize operation 'torch.aten.rrelu_with_noise' + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + # It appears that you're trying to get value out of a tracing tensor "PrimListUnpackNumMismatchModule_basic", # RuntimeError: shape '[2, -1, 6]' is invalid for input of size 210 "ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic", @@ -2321,6 +2352,7 @@ "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "ConvTranspose2DQInt8_basic", } @@ -2472,6 +2504,7 @@ "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", @@ -2505,9 +2538,6 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseLeftShiftInt32Module_basic", - "ElementwiseBitwiseLeftShiftInt64Module_basic", - "ElementwiseBitwiseLeftShiftInt8Module_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrModule_basic", @@ -2855,6 +2885,8 @@ "PrimsIotaModule_basic", # Failure - unknown "BernoulliModule_basic", + "Conv_Transpose1dModule_basic", + "Conv_Transpose3dModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", @@ -2874,6 +2906,7 @@ "ElementwiseTanIntModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", @@ -2904,6 +2937,14 @@ "RepeatInterleaveSelfIntNoDimModule_basic", } +if torch_version_for_comparison() < version.parse("2.4.0.dev"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::bitwise_left_shift' to ONNX opset version 17 is not supported. + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + } + ONNX_CRASHING_SET = { "FakeQuantizePerTensorAffineModule_basic", @@ -2924,6 +2965,7 @@ "ScatterReduceIntSumModuleIncludeSelf", # Nondeterministically passes or fails with mismatching numerics "ConvolutionModule2DTransposeStridedStatic_basic", + "Conv_Transpose2dStaticModule_basic", # The following test sporadically stopped producing correct numerics for the golden value in the CI. # For now, we are removing the test until this issue has been debugged. "QuantizedMLP_basic", @@ -3052,6 +3094,7 @@ "ContainsIntList_True", "Conv1dModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv3dModule_basic", @@ -3838,6 +3881,7 @@ "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", @@ -3977,6 +4021,7 @@ "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "ElementwiseFlattenBroadcastModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFmodTensor_Float_basic", "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d4f162d64398..b7376bc669c2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -555,6 +555,9 @@ def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -597,7 +600,7 @@ def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[in def aten〇var〡shape(self: List[int], unbiased: bool = True) -> List[int]: return [] -def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> List[int]: +def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: Optional[float] = 1, output_dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(inp, dims, False, None) def aten〇var〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: @@ -1627,6 +1630,12 @@ def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Option def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) +def aten〇conv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]: + return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) + +def aten〇conv_transpose3d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> List[int]: + return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) + def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) @@ -2810,6 +2819,12 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()})) +def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3635,6 +3650,10 @@ def aten〇conv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv_transpose1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), @@ -3646,6 +3665,10 @@ def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], w input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv_transpose3d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + convolution_kwargs = { "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} @check_dtype_function( @@ -4310,7 +4333,7 @@ def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optio return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) -def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> int: +def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: Optional[float] = 1, output_dtype: Optional[int] = None) -> int: return aten〇std〡dtype(inp_rank_dtype) @check_dtype_function( diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index d23ac9ac45d4..1f86a60cc0e6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -301,6 +301,8 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", + "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sinh : (Tensor) -> (Tensor)", @@ -472,7 +474,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") @@ -596,6 +597,7 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", has_canonicalizer=True, @@ -604,6 +606,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) @@ -1128,7 +1131,7 @@ def emit_with_mutating_variants(key, **kwargs): # ========================================================================== emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True) - emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") + emit("prims::var : (Tensor, int[]?, float?, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::collapse : (Tensor, int, int) -> (Tensor)") emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 4cda217a14eb..11a6ef6ffd6f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -49,6 +49,9 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: prog, output_type=self._output_type, func_name=artifact.__class__.__name__, + # While the current e2e tests don't exercise symbolic shapes, + # enabling this here ensures they don't regress either. + import_symbolic_shape_expressions=True, ) module = self._backend.compile(module) backend_module = self._backend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 61050de8fd6c..25c6405b7436 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -23,6 +23,7 @@ [ "func.func(stablehlo-aggressive-simplification)", "stablehlo-legalize-to-linalg", + "stablehlo-convert-to-signless", "canonicalize", ] ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index a410b5ba5ddb..4ba757f2f655 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -877,6 +877,66 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3)) +class Conv_Transpose1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d( + inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dModule()) +def Conv_Transpose1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2)) + + +class Conv_Transpose1dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 6], torch.float32, True), + ([2, 5, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d( + inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dStaticModule()) +def Conv_Transpose1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2)) + + class Conv_Transpose2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -907,6 +967,96 @@ def Conv_Transpose2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) +class Conv_Transpose2dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose2d( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose2dStaticModule()) +def Conv_Transpose2dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + + +class Conv_Transpose3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d( + inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2)) + + +class Conv_Transpose3dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6, 7], torch.float32, True), + ([2, 5, 2, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d( + inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dStaticModule()) +def Conv_Transpose3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2)) + + class UpSampleNearest2d(torch.nn.Module): def __init__(self): super().__init__() @@ -1124,7 +1274,8 @@ def ConvTbcModule_basic(module, tu: TestUtils): class Conv2dQInt8Module(torch.nn.Module): - def __init__(self): + def __init__(self, groups=1): + self.groups = groups super().__init__() @export @@ -1153,7 +1304,7 @@ def forward(self, inputVec, weight, bias): stride=[1, 1], padding=[0, 0], dilation=[1, 1], - groups=1, + groups=self.groups, ) @@ -1165,13 +1316,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) -N = 10 -Cin = 5 -Cout = 7 -Hin = 10 -Win = 8 -Hker = 3 -Wker = 2 +@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2)) +def Conv2dQInt8Module_grouped(module, tu: TestUtils): + inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8) + bias = torch.rand(6) + module.forward(inputVec, weight, bias) class ConvTranspose2DQInt8Module(torch.nn.Module): @@ -1211,6 +1361,13 @@ def forward(self, input, weight, bias): @register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + N = 10 + Cin = 5 + Cout = 7 + Hin = 10 + Win = 8 + Hker = 3 + Wker = 2 module.forward( tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 83a6e397a1e6..cbcfb0f40c6d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -414,6 +414,31 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTernaryStaticShapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 4, 3], torch.float32, True), + ([4, 3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.lerp(a, b, c) + + +@register_test_case(module_factory=lambda: ElementwiseTernaryStaticShapeModule()) +def ElementwiseTernaryStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.rand(4, 3), tu.rand(3)) + + +# ============================================================================== + + class ElementwiseAtenWhereSelfModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1037,6 +1062,100 @@ def ElementwiseCeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRreluTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + res = torch.ops.aten.rrelu(x, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluTrainModule()) +def ElementwiseRreluTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + + +# ============================================================================== + + +class ElementwiseRreluTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1024, 1536], torch.float32, True), + ] + ) + def forward(self, x): + res = torch.ops.aten.rrelu(x, 0.1, 0.9, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule()) +def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + + +# ============================================================================== + + +class ElementwiseRreluEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.4, 0.6, False) + + +@register_test_case(module_factory=lambda: ElementwiseRreluEvalModule()) +def ElementwiseRreluEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseRreluEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.1, 0.9, False) + + +@register_test_case(module_factory=lambda: ElementwiseRreluEvalStaticModule()) +def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 7fdfb454d362..304bc422e4d2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) +class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.float64, True), + ] + ) + def forward(self, x, y): + return torch.lt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule()) +def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64)) + + +class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.gt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule()) +def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(3, 5, high=10).to(torch.float32), + tu.randint(5, high=10, dtype=torch.int32), + ) + + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index bbcfd15d9712..1de40096c006 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1017,6 +1017,35 @@ def AvgPool2dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 10, 20, low=-1)) +class AvgPool2dCountIncludePadFalseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([32, 384, 25, 25], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCountIncludePadFalseStaticModule()) +def AvgPool2dCountIncludePadFalseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(32, 384, 25, 25, low=-1)) + + class AvgPool2dDivisorOverrideModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 5d3d085d5e2b..df78262fff96 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -136,6 +136,26 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils): module.forward(tensor) +class TypeConversionUint8ToF32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.uint8, True), + ] + ) + def forward(self, x): + return x.to(torch.float) + + +@register_test_case(module_factory=lambda: TypeConversionUint8ToF32Module()) +def TypeConversionUint8ToF32Module_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 1, 255]).to(torch.uint8)) + + # ============================================================================== diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9981ed30e607..2a73325c7d76 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -14,6 +14,8 @@ import logging import operator import re +import sympy +import math from dataclasses import dataclass from types import BuiltinMethodType, BuiltinFunctionType from typing import ( @@ -81,6 +83,14 @@ ) from ..ir import ( + AffineAddExpr, + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineMapAttr, + AffineModExpr, + AffineMulExpr, + AffineSymbolExpr, Attribute, Block, Context, @@ -89,6 +99,10 @@ FloatAttr, BF16Type, ComplexType, + Float8E5M2Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E4M3FNUZType, F16Type, F32Type, F64Type, @@ -137,6 +151,10 @@ torch.complex32: "complex", torch.complex64: "complex", torch.complex128: "complex", + torch.float8_e5m2: "f8E5M2", + torch.float8_e4m3fn: "f8E4M3FN", + torch.float8_e5m2fnuz: "f8E5M2FNUZ", + torch.float8_e4m3fnuz: "f8E4M3FNUZ", } TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { @@ -155,6 +173,10 @@ torch.complex32: lambda: ComplexType.get(F16Type.get()), torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), + torch.float8_e5m2: lambda: Float8E5M2Type.get(), + torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(), + torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(), + torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(), } TORCH_DTYPE_TO_NPY_TYPE = { @@ -193,6 +215,10 @@ # torch.quint8: 13, # torch.qint32 14 torch.bfloat16: 15, + torch.float8_e5m2: 23, + torch.float8_e4m3fn: 24, + torch.float8_e5m2fnuz: 25, + torch.float8_e4m3fnuz: 26, } TORCH_MEMORY_FORMAT_TO_INT = { @@ -258,6 +284,71 @@ SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} +@dataclass +class RangeConstraint: + min_val: int + max_val: int + + +def sympy_expr_to_semi_affine_expr( + expr: sympy.Expr, symbols_map: Dict[str, AffineSymbolExpr] +) -> AffineExpr: + """Translate sympy expressions to MLIR (semi-)affine expressions. + + Recursively traverse the sympy expr AST and build the affine expr. + This is not a perfect translation. Sympy expressions are much more + expressive and not as constrained as affine (linear) expressions are. + However, for the most part, we don't need to support all of sympy. + PyTorch only uses a subset of sympy for capturing and expressing + symbolic shapes, and among what's supported, we expect the semi-affine + expressions (https://mlir.llvm.org/docs/Dialects/Affine/#semi-affine-maps) + to be sufficient. + """ + if isinstance(expr, sympy.Symbol): + return symbols_map[str(expr)] + elif isinstance(expr, (int, sympy.Integer)): + return AffineConstantExpr.get(expr) + # This handles both add (`s0 + c`) and subtract (`s0 - c`). + # The expression is `sympy.Add` in both cases but with args + # (s0, c) in first case and (s0, -c) in the second case. + elif isinstance(expr, sympy.Add): + affine_expr = AffineConstantExpr.get(0) + for arg in expr.args: + affine_expr = AffineAddExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mul): + affine_expr = AffineConstantExpr.get(1) + for arg in expr.args: + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Pow): + base, exp = expr.args + # Only integer exponent is supported + # So, s1 ** s0 isn't allowed. + assert isinstance(exp, (int, sympy.Integer)) + assert exp > 0, "Only positive exponents supported in sympy.Pow" + affine_expr = AffineConstantExpr.get(1) + for _ in range(exp): + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(base, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mod): + dividend, divisor = expr.args + return AffineModExpr.get( + sympy_expr_to_semi_affine_expr(dividend, symbols_map), + sympy_expr_to_semi_affine_expr(divisor, symbols_map), + ) + else: + raise NotImplementedError( + f"Translation of sympy.Expr of type {type(expr)} not implemented yet." + ) + + @dataclass(frozen=True) class SparsityMeta: """ @@ -478,6 +569,7 @@ def import_program( *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Imports an ExportedProgram according to our chosen canonical representation. @@ -527,6 +619,10 @@ def import_program( sig = prog.graph_signature + # Populate symbolic guards for dynamic shapes (if any) + if import_symbolic_shape_expressions: + self._cc.set_symbolic_guards(prog) + # Invert the (producer, node_name) maps for mutated user inputs and mutated # buffers. This is because we hit-detect based on the input node name. mutated_user_inputs = { @@ -682,7 +778,9 @@ def import_program( # Import all nodes and return. node_importer.import_nodes( - all_producer_nodes.values(), skip_placeholders_outputs=True + all_producer_nodes.values(), + skip_placeholders_outputs=True, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) @@ -694,6 +792,7 @@ def import_frozen_program( *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Imports a consolidated torch.export.ExportedProgram instance. @@ -728,6 +827,10 @@ def import_frozen_program( state_dict = prog.state_dict arg_replacements: Dict[str, Any] = {} + # Populate symbolic guards for dynamic shapes (if any) + if import_symbolic_shape_expressions: + self._cc.set_symbolic_guards(prog) + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 if hasattr(prog, "constants"): @@ -774,7 +877,10 @@ def import_frozen_program( g.erase_node(node) return self.import_stateless_graph( - g, func_name=func_name, func_visibility=func_visibility + g, + func_name=func_name, + func_visibility=func_visibility, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) def import_graph_module(self, gm: GraphModule) -> Operation: @@ -791,6 +897,7 @@ def import_stateless_graph( *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Low-level import of a functionalized, assumed stateless Graph as a func. @@ -815,7 +922,9 @@ def import_stateless_graph( self._cc, entry_block, ) - node_importer.import_nodes(g.nodes) + node_importer.import_nodes( + g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions + ) self.symbol_table.insert(func) return func @@ -870,6 +979,7 @@ class ContextCache: "_c", "_dtype_to_type", "_tensor_metadata_cache", + "_symbolic_guards", "_py_attr_tracker", # Types. "torch_bool_type", @@ -888,6 +998,7 @@ def __init__( self._tensor_metadata_cache: Dict[ Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType ] = {} + self._symbolic_guards: Dict = {} self._py_attr_tracker = py_attr_tracker or RefTracker() # Common types. @@ -1037,6 +1148,52 @@ def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: return Location.file(filename, line, col=0, context=self._c) return Location.unknown(context=self._c) + def set_symbolic_guards( + self, prog: torch.export.ExportedProgram + ) -> Dict[str, RangeConstraint]: + + def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + # TODO: Remove this adjustment when fractional ranges are removed + return adjust_func(val) + + contains_symbolic_ints = False + for val in prog.range_constraints.values(): + if ( + isinstance(val.lower, sympy.Integer) + and isinstance(val.upper, sympy.Integer) + and not val.is_bool + ): + contains_symbolic_ints = True + break + if contains_symbolic_ints: + # Build a map from shape symbol name to `RangeConstraint` object + # capturing `min_val`` and `max_val`` constraints for that + # symbol. Translate sympy integers to regular integers. + # + # Example: + # { + # 's0': RangeConstraint(min_val=5, max_val=10), + # 's1': RangeConstraint(min_val=0, max_val=100), + # 's3': RangeConstraint(min_val=0, max_val=9223372036854775806), + # } + self._symbolic_guards = { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, math.ceil), + _sympy_int_to_int(v.upper, math.floor), + ) + for k, v in prog.range_constraints.items() + } + + def get_symbolic_guards(self) -> Dict[str, RangeConstraint]: + return self._symbolic_guards + class GraphNodeImporter: """Imports graph nodes into an MLIR function. @@ -1050,6 +1207,7 @@ class GraphNodeImporter: "_cc", "_on_node_produced", "_v", + "_symbol_to_value", "_multi_result_nodes", "fx_importer", ] @@ -1068,6 +1226,8 @@ def __init__( # Map of (Node, result_index) to MLIR Value or a callback that lazily # constructs and returns a value. self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} + # Map of Shape Symbol to MLIR Value + self._symbol_to_value: Dict[str, Value] = {} # Map of node name to hook that should be called when it is produced. self._on_node_produced: Dict[str, Callable[[Value], None]] = {} # Statically multi-result nodes which we have de-tupled are noted here. @@ -1108,6 +1268,28 @@ def resolve_node_value(self, node: Node, result_index: int = 0) -> Value: self._v[key] = value return value + def bind_symbol_value( + self, + shape_symbol: str, + value: Value, + ): + """Binds a shape symbol to a global SSA value (and asserts if already bound).""" + assert ( + shape_symbol not in self._symbol_to_value + ), f"Symbol already has a value: {shape_symbol}" + self._symbol_to_value[shape_symbol] = value + + def resolve_symbol_value(self, shape_symbol: str) -> Value: + """Resolves a shape symbol to a value.""" + try: + binding = self._symbol_to_value[shape_symbol] + except KeyError: + raise KeyError( + f"Shape symbol {shape_symbol} has not been bound to an MLIR value" + ) + if isinstance(binding, Value): + return binding + def import_mutable_to_vtensor( self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str ) -> Value: @@ -1190,10 +1372,20 @@ def return_node_values(self, loc, nodes: List[Node]): func_dialect.ReturnOp(operands, loc=loc) def import_nodes( - self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False + self, + nodes: Iterable[Node], + *, + skip_placeholders_outputs: bool = False, + import_symbolic_shape_expressions: bool = False, ): with InsertionPoint(self._b): loc = Location.unknown() + + # Import dynamic shape symbols and guards (if any) + if import_symbolic_shape_expressions: + symbolic_guards = self._cc.get_symbolic_guards() + self._import_shape_symbols_with_guards(loc, symbolic_guards) + num_placeholders = 0 for node in nodes: op = node.op @@ -1253,6 +1445,9 @@ def import_nodes( operands = [self._import_argument(loc, arg) for arg in node.args[0]] func_dialect.ReturnOp(operands, loc=loc) + if import_symbolic_shape_expressions: + self._create_bind_symbolic_shape_ops(loc, node) + def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar temp_node = Node( @@ -1516,6 +1711,69 @@ def _import_torch_op_overload( for i, value in enumerate(operation.results): self.bind_node_value(node, value, i) + def _import_shape_symbols_with_guards( + self, loc: Location, symbolic_guards: Dict[str, RangeConstraint] + ): + for symbol, constraints in symbolic_guards.items(): + # Create torch.sym_int ops + operation = Operation.create( + name="torch.symbolic_int", + attributes={ + "symbol_name": StringAttr.get(symbol), + "min_val": self._cc.integer_attr(constraints.min_val, 64), + "max_val": self._cc.integer_attr(constraints.max_val, 64), + }, + results=[self._cc.torch_int_type], + loc=loc, + ) + self.bind_symbol_value(symbol, operation.result) + + def _create_bind_symbolic_shape_ops(self, loc: Location, node: torch_fx.Node): + node_val = node.meta.get("val") + if (node_val is not None) and isinstance(node_val, TorchFakeTensor): + # Only create bind ops if the shapes contain symbolic sizes. + # Query the bool attribute `_has_symbolic_sizes_strides` on node.meta["val"]. + if node_val._has_symbolic_sizes_strides: + # Read node metadata to obtain shape symbols and expressions + symbols_set = set() + shape_exprs = [] + for s in node_val.size(): + if isinstance(s, torch.SymInt): + symbols_set.update(s.node.expr.free_symbols) + shape_exprs.append(s.node.expr) + else: + assert isinstance(s, int) + shape_exprs.append(s) + + # Map from sympy shape symbols to local symbols in the affine map + symbols_set = sorted(symbols_set, key=lambda x: x.name) + symbols_map = { + str(symbol): AffineSymbolExpr.get(i) + for i, symbol in enumerate(symbols_set) + } + + # Convert symbolic shape expressions into affine expressions + affine_exprs = [ + sympy_expr_to_semi_affine_expr(expr, symbols_map) + for expr in shape_exprs + ] + + affine_map = AffineMap.get(0, len(symbols_set), affine_exprs) + + # Build operand list + operand_list = [] + operand_list.append(self.resolve_node_value(node)) + for symbol in symbols_map.keys(): + operand_list.append(self.resolve_symbol_value(symbol)) + + # Create torch.bind_symbolic_shape ops + Operation.create( + name="torch.bind_symbolic_shape", + attributes={"shape_expressions": AffineMapAttr.get(affine_map)}, + operands=operand_list, + loc=loc, + ) + def _import_argument( self, loc: Location, arg: NodeArgument, expected_jit_type=None ) -> Value: @@ -1600,6 +1858,10 @@ def _import_literal(self, py_value: Any) -> Value: user_value = self.fx_importer._hooks.resolve_literal(self, py_value) if user_value is not None: assert isinstance(user_value, Value) + if orig_value is not None: + user_value = self._convert_type( + user_value, torch.Tensor, orig_value.dtype, orig_value.size() + ) return user_value # Default conversion path. diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index e0d3529d942e..4c1e0b9e9aed 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -34,6 +34,7 @@ ) from e from typing import Optional, List, Dict, Tuple +import warnings from dataclasses import dataclass @@ -579,6 +580,10 @@ def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: if tp == "": + warnings.warn( + "Found a node without a valid type proto. Consider updating the opset_version of" + " the model and/or running the importer with the flag '--clear-domain'." + ) return self.get_none_type() tt = tp.tensor_type diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index b8765b65984a..5cd7d2d6e1f1 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -54,6 +54,7 @@ def export_and_import( fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, + import_symbolic_shape_expressions: bool = False, hooks: Optional[FxImporterHooks] = None, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", @@ -79,9 +80,17 @@ def export_and_import( if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") - fx_importer.import_program(prog, func_name=func_name) + fx_importer.import_program( + prog, + func_name=func_name, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) else: - fx_importer.import_frozen_program(prog, func_name=func_name) + fx_importer.import_frozen_program( + prog, + func_name=func_name, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) return _module_lowering( enable_ir_printing, OutputType.get(output_type), fx_importer.module diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 92ae3c7eb356..bca87cee7f59 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -20,6 +20,7 @@ import sys import onnx +import onnx.version from ...extras import onnx_importer @@ -81,6 +82,16 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file, load_external_data=False) onnx.load_external_data_for_model(raw_model, args.data_dir) + if args.opset_version: + raw_model = onnx.version_converter.convert_version( + raw_model, args.opset_version + ) + + if args.clear_domain: + graph = raw_model.graph + for n in graph.node: + n.ClearField("domain") + # Run the checker to test whether the file is above the threshold for # in-memory shape inference. If not, go ahead and do the shape inference. try: @@ -149,6 +160,14 @@ def parse_arguments(argv=None) -> argparse.Namespace: action=argparse.BooleanOptionalAction, help="Toggle data propogation for onnx shape inference", ) + parser.add_argument( + "--clear-domain", + dest="clear_domain", + default=False, + action=argparse.BooleanOptionalAction, + help="If enabled, this will clear the domain attribute from each node" + " in the onnx graph before performing shape inference.", + ) parser.add_argument( "--keep-temps", action="store_true", help="Keep intermediate files" ) @@ -170,6 +189,12 @@ def parse_arguments(argv=None) -> argparse.Namespace: " Defaults to the directory of the input file.", type=Path, ) + parser.add_argument( + "--opset-version", + help="Allows specification of a newer opset_version to update the model" + " to before importing to MLIR. This can sometime assist with shape inference.", + type=int, + ) args = parser.parse_args(argv) return args diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 3424cb46aad1..ef6ddf92e034 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -1b7523fbe9d0a0c81930673f4374c6e69fa293b6 +b94ddab65bbb15cca98bca857b173bfc4abdb7b5 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 0ec9c7c7c856..c106bc5a30d8 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -4,4 +4,4 @@ # release page, and we use this page as an additional source for the wheels. -f https://xilinx.github.io/torch-mlir/package-index/ --pre -torch==2.4.0.dev20240505 +torch==2.4.0.dev20240604 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a87ec4f8f43f..5b33fd17471b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -748,6 +748,19 @@ func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !tor // ----- +// CHECK-LABEL: @test_dequantizelinear_si16 +func.func @test_dequantizelinear_si16(%arg0: !torch.vtensor<[6],si16>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si16>, !torch.vtensor<[],f32>, !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_ui8 func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> @@ -946,12 +959,12 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32> func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0_1:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list @@ -969,12 +982,12 @@ func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32 func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1_1:.*]] = torch.constant.int 1 // CHECK: %[[C1_2:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list @@ -992,12 +1005,12 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C3:.*]] = torch.constant.int 3 // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list @@ -1649,8 +1662,8 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m // ----- -// CHECK-LABEL: @dense_constant -func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { +// CHECK-LABEL: @dense_resource_constant +func.func @dense_resource_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { // CHECK: torch.vtensor.literal(dense<[0, 10, 128, 17000]> : tensor<4xsi32>) : !torch.vtensor<[4],si32> %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int32> : tensor<4xsi32>} : () -> !torch.vtensor<[4],si32> // CHECK: torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+01, 1.280000e+02, 1.700000e+04]> : tensor<4xf32>) : !torch.vtensor<[4],f32> diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 865648c40d4f..227eac7d9665 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1087,3 +1087,42 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc %0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> return %0 : !torch.vtensor<[3,4,1,6,7],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape +func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 + // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> + %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> + return %0 : !torch.vtensor<[1,1,4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape +func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_1:.*]] = torch.constant.int 4 + // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]], %[[INT4_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]], %[[INT0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_2:.*]] = torch.constant.int 2 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> + // return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32> + %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> + return %0 : !torch.vtensor<[1,1,4,4,4],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index b9fda48f9d17..79958da59c77 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -60,12 +60,12 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] @@ -99,12 +99,12 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] @@ -1085,15 +1085,17 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor // ----- // CHECK-LABEL: func.func @test_reduce_sum_keepdims_example -func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_1:.*]] = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %[[VAL_1]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[DIM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %arg1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[3,1,2],f32> } @@ -2215,6 +2217,19 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -2223,3 +2238,340 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK-LABEL: @test_spacetodepth_example +func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,1,4,6],f32>, !torch.list -> !torch.vtensor<[1,1,2,2,3,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[1,1,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,2,1,2,3],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,1,2,3],f32>, !torch.list -> !torch.vtensor<[1,4,2,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,4,2,3],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> + return %0 : !torch.vtensor<[1,4,2,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_spacetodepth +func.func @test_spacetodepth(%arg0: !torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[2,2,6,6],f32>, !torch.list -> !torch.vtensor<[2,2,3,2,3,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[2,2,3,2,3,2],f32>, !torch.list -> !torch.vtensor<[2,2,2,2,3,3],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[2,2,2,2,3,3],f32>, !torch.list -> !torch.vtensor<[2,8,3,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,8,3,3],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> + return %0 : !torch.vtensor<[2,8,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_spacetodepth +func.func @test_spacetodepth_dynamic_dims(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,2,?,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list -> !torch.vtensor<[?,2,2,?,?,?],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[?,2,2,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @Shrink +func.func @Shrink(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %float1.500000e00 = torch.constant.float 1.500000e+00 + // CHECK: %float1.500000e00_0 = torch.constant.float 1.500000e+00 + // CHECK: %float0.000000e00 = torch.constant.float 0.000000e+00 + // CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 + // CHECK: %float-1.500000e00 = torch.constant.float -1.500000e+00 + // CHECK: %0 = torch.aten.lt.Scalar %arg0, %float-1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %1 = torch.aten.add.Scalar %arg0, %float1.500000e00_0, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %2 = torch.aten.sub.Scalar %arg0, %float1.500000e00_0, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %3 = torch.aten.gt.Scalar %arg0, %float1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %4 = torch.aten.where.ScalarOther %3, %2, %float0.000000e00 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %5 = torch.aten.where.self %0, %1, %4 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[5],f32> + // CHECK: return %5 : !torch.vtensor<[5],f32> + %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.bias = 1.500000e+00 : f32, torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> + return %0 : !torch.vtensor<[5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_shrink_hard +func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %float1.500000e00 = torch.constant.float 1.500000e+00 + // CHECK: %float0.000000e00 = torch.constant.float 0.000000e+00 + // CHECK: %float0.000000e00_0 = torch.constant.float 0.000000e+00 + // CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 + // CHECK: %float-1.500000e00 = torch.constant.float -1.500000e+00 + // CHECK: %0 = torch.aten.lt.Scalar %arg0, %float-1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %1 = torch.aten.add.Scalar %arg0, %float0.000000e00, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %2 = torch.aten.sub.Scalar %arg0, %float0.000000e00, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %3 = torch.aten.gt.Scalar %arg0, %float1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %4 = torch.aten.where.ScalarOther %3, %2, %float0.000000e00_0 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %5 = torch.aten.where.self %0, %1, %4 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[5],f32> + // CHECK: return %5 : !torch.vtensor<[5],f32> + %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> + return %0 : !torch.vtensor<[5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_at +func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %4 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_insert +func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list>, !torch.int, !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list> + %6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %6 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_beginning +func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_end +func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_negative_idx +func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_empty +func.func @test_sequence_erase_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_empty +func.func @test_sequence_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + return %0 : !torch.list> +} diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index d9860aaa9258..64198d03f2a1 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -49,8 +49,8 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 - // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 - // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x26:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x26]] : f32 to i64 // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 @@ -58,8 +58,8 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 // CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64 // CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32 // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32 - // CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32 - // CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 + // CHECK: %[[x29:.*]] = math.floor %[[x28]] : f32 + // CHECK: %[[x33:.*]] = arith.fptosi %[[x29]] : f32 to i64 // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> // CHECK: linalg.yield %[[extracted]] : f32 @@ -129,8 +129,8 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 - // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 - // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[floor:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[floor]] : f32 to i64 // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index // CHECK: %[[x34:.*]] = arith.index_cast %[[Wfptosi:.*]] : i64 to index // CHECK: %[[x35:.*]] = arith.index_cast %[[Dfptosi:.*]] : i64 to index diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 180b6aac5dd3..250f11cf67a1 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3026,3 +3026,35 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (! %1 = torch.copy.to_tensor %0 : !torch.tensor return %1 : !torch.tensor } + + +// ----- + +// CHECK-LABEL: @torch.symbolic_int$canonicalize( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int +// CHECK-NOT: %[[S1:.*]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int +// CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> +// CHECK: %[[V1:.*]] = torch.aten.slice.Tensor %[[ARG1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[V1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: %[[V2:.*]] = torch.aten.add.Tensor %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[V2]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: return %[[V2]] : !torch.vtensor<[?],f32> +func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int + %1 = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int + torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int1_0 = torch.constant.int 1 + %2 = torch.aten.slice.Tensor %arg1, %int0, %int1, %int9223372036854775807, %int1_0 : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + %int1_1 = torch.constant.int 1 + %3 = torch.aten.add.Tensor %arg0, %2, %int1_1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %3 : !torch.vtensor<[?],f32> +} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 63aa1e3755a9..5b732788faef 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -375,3 +375,22 @@ func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64] // expected-error @+1 {{invalid sparsity encoding attribute}} func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> + + +// ----- + +func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int + // expected-error @+1 {{op requires non-empty shapeSymbols}} + torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} + +// ----- + +func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %int0 = torch.constant.int 0 + // expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}} + torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 8242321c3303..29ab52f9dab0 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -174,6 +174,7 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, % func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8> func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8> +func.func private @tensor_legal_dtype$torch.qint16() -> !torch.tensor<*,!torch.qint16> func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> { %arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index fde318630077..fbc8fdff32f3 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -89,6 +89,11 @@ def forward(self, x): @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> +# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> +# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> +# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32> def test_import_frozen_exported_program_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -100,7 +105,11 @@ def forward(self, x): batch = Dim("batch") dynamic_shapes = {"x": {0: batch}} m = fx.export_and_import( - Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net" + Basic(), + torch.randn(3, 4), + dynamic_shapes=dynamic_shapes, + func_name="test_net", + import_symbolic_shape_expressions=True, ) print(m) @@ -108,6 +117,12 @@ def forward(self, x): @run # CHECK-LABEL: test_broadcast_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: torch.aten.size.int +# CHECK: torch.prim.ListConstruct +# CHECK: %[[EXPAND:.*]] = torch.aten.expand +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> def test_broadcast_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -127,7 +142,12 @@ def forward(self, x, y): } m = fx.export_and_import( - Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + Basic(), + x, + y, + dynamic_shapes=dynamic_shapes, + func_name="test_net", + import_symbolic_shape_expressions=True, ) print(m) diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py new file mode 100644 index 000000000000..dbbc5ba057af --- /dev/null +++ b/test/python/fx_importer/custom_op_test.py @@ -0,0 +1,86 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.nn as nn +from torch.export import Dim +from torch.library import Library, impl, impl_abstract + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat_custom_op(): + + m = Library("my_custom_library", "DEF") + m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor") + + @impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd") + def custom_op(x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + @impl_abstract("my_custom_library::tanh_sigmoid_cat_op") + def custom_op_meta(x, y, z): + result = custom_op(x, y, z) + return torch.empty_like(result) + + class TanhSigmoidCatCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCatCustomOp(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 0a1a91193750..7c7198ef6f61 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -125,7 +125,7 @@ def sparse_export( # Zero preserving elt-wise unary op. if opname in {"abs", "neg", "relu", "sin"}: node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "_to_sparse": + elif opname == "_to_sparse" or opname == "to_sparse": dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 @@ -339,15 +339,14 @@ def forward(self, x, v): @run # -# CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, # CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { -# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> # CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> # CHECK: } -# +## # CHECK: torch.sparse # CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], # CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], @@ -516,7 +515,7 @@ def forward(self, x): # CHECK: %[[N1:.*]] = torch.constant.none # CHECK: %[[N2:.*]] = torch.constant.none # CHECK: %[[N3:.*]] = torch.constant.none -# CHECK: %[[R:.*]] = torch.operator "torch.aten._to_sparse"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> # CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]> # CHECK: } # @@ -648,8 +647,8 @@ def forward(self, X): # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { # ... more IR ... -# CHECK: %[[D:.*]] = torch.operator "torch.aten._to_sparse" -# CHECK: %[[R:.*]] = torch.aten.mm %[[D]], %[[A]] +# CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" +# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] # CHECK return %[[R]] : !torch.vtensor<[4,4],f32> # CHECK: } # diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py new file mode 100644 index 000000000000..3215e0f8213d --- /dev/null +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -0,0 +1,463 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests of various op special forms that the fx_importer +# handles. + +import torch +import torch.export +import torch.nn as nn +from torch.export import Dim + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list +# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list, !torch.int -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat(): + class TanhSigmoidCat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCat(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_symbolic_dim_differ_by_one +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int +# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> +# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[ARG0]], %[[SLICE]], {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[ADD]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[ADD]] : !torch.vtensor<[?],f32> +def test_symbolic_dim_differ_by_one(): + class SymbolicDimDifferByOne(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y[1:] + + # Sample inputs + x = torch.randn(5) + y = torch.randn(6) + + # Dynamic dim constraints + dimx = Dim("dimx", min=3, max=6) + dimy = dimx + 1 + dynamic_shapes = { + "x": {0: dimx}, + "y": {0: dimy}, + } + + m = fx.export_and_import( + SymbolicDimDifferByOne(), + x, + y, + dynamic_shapes=dynamic_shapes, + experimental_support_mutation=True, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_outer_with_squared_shape +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[VIEW1:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW1]], [%[[S0]]], affine_map<()[s0] -> (s0, 1)> : !torch.vtensor<[?,1],f32> +# CHECK: %[[MUL:.+]] = torch.aten.mul.Tensor %[[VIEW1]], %[[ARG0]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %[[MUL]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: %[[VIEW2:.+]] = torch.aten.view %[[MUL]], {{.*}} : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW2]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[VIEW2]] : !torch.vtensor<[?],f32> +def test_outer_with_squared_shape(): + class OuterWithSquaredShape(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.outer(x, x).flatten() + + # Sample inputs + x = torch.rand(10) + + # Dynamic dim constraints + batch = Dim("batch") + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + OuterWithSquaredShape(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_slice_tensor_static_output +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[2,1],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[SLICE1:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +# CHECK: %[[SLICE2:.+]] = torch.aten.slice.Tensor %[[SLICE1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> +# CHECK: return %[[SLICE2]] : !torch.vtensor<[2,1],f32> +def test_slice_tensor_static_output(): + class SliceTensorStaticOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[0:2, :1] + + # Sample inputs + x = torch.randn(4, 3) + + # Dynamic dim constraints + batch = Dim("batch", min=3) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + SliceTensorStaticOutput(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_slice_tensor_dynamic_output +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 9223372036854775806} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0 - 5)> : !torch.vtensor<[?],f32> +# CHECK: return %[[SLICE]] : !torch.vtensor<[?],f32> +def test_slice_tensor_dynamic_output(): + class SliceTensorDynamicOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[5:] + + # Sample inputs + x = torch.randn(10) + + # Dynamic dim constraints + dimx = Dim("dimx", min=5) + dynamic_shapes = {"x": {0: dimx}} + + m = fx.export_and_import( + SliceTensorDynamicOutput(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_div_tensor_mixed_ranks +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[DIV:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[DIV]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[DIV]] : !torch.vtensor<[?,3],f32> +def test_div_tensor_mixed_ranks(): + class DivTensorMixedRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + div = torch.div(x, y) + return div + + # Sample inputs + x = torch.tensor(10.0) + y = torch.randn(2, 3) + + # Dynamic dim constraints + batch = Dim("batch") + dynamic_shapes = {"x": None, "y": {0: batch}} + + m = fx.export_and_import( + DivTensorMixedRanks(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_shape_div +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> { +# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32> +# CHECK: %[[VIEW:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?,7],f32>, !torch.list -> !torch.vtensor<[?,5],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S1]]], affine_map<()[s0] -> (s0 * 7, 5)> : !torch.vtensor<[?,5],f32> +# CHECK: return %[[VIEW]] : !torch.vtensor<[?,5],f32> +def test_shape_div(): + class ShapeDiv(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 5) + + # Sample inputs + x = torch.rand(10, 7) + + # Dynamic dim constraints + batch = Dim("batch", max=1000) * 5 + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + ShapeDiv(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>) -> !torch.vtensor<[3,?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,?],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (3, s0)> : !torch.vtensor<[3,?],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[3,?],f32> +def test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic(): + class BroadcastUnitDimToStaticWithUnchangedDimDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (3, -1)) + + # Sample inputs + x = torch.randn(1, 2) + + # Dynamic dim constraints + dim_1 = Dim("dim_1") + dynamic_shapes = {"x": {1: dim_1}} + + m = fx.export_and_import( + BroadcastUnitDimToStaticWithUnchangedDimDynamic(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,2],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> +# CHECK: return %3 : !torch.vtensor<[?,2],f32> +def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static(): + class BroadcastUnitDimToDynamicWithUnchangedDimStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + # Dynamic dim constraints + dim_0 = Dim("dim_0") + dynamic_shapes = {"x": {}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithUnchangedDimStatic(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S1]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s1, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,?],f32> +def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic(): + class BroadcastUnitDimToDynamicWithUnchangedDimDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + # Dynamic dim constraints + dim_0 = Dim("dim_0") + dim_1 = Dim("dim_1") + dynamic_shapes = {"x": {1: dim_1}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithUnchangedDimDynamic(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_rank_increase +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3,2],f32>) -> !torch.vtensor<[?,3,2],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,3,2],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,3,2],f32> +def test_broadcast_unit_dim_to_dynamic_with_rank_increase(): + class BroadcastUnitDimToDynamicWithRankIncrease(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, y.size()) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(4, 3, 2) + + # Dynamic dim constraints + dim_0 = Dim("dim_0") + dynamic_shapes = {"x": {}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithRankIncrease(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_gather_elements +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> { +# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32> +# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32> +def test_gather_elements(): + class GatherElements(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.gather(x, 0, y) + + # Sample inputs + x = torch.randn(4, 3) + y = torch.tensor([[0, 0, 0], [1, 1, 1]]) + + # Dynamic dim constraints + batch = Dim("batch", min=3) + dynamic_shapes = {"x": {0: batch}, "y": {}} + + m = fx.export_and_import( + GatherElements(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/test/python/fx_importer/sympy_to_affine_expr_test.py b/test/python/fx_importer/sympy_to_affine_expr_test.py new file mode 100644 index 000000000000..0c366040d216 --- /dev/null +++ b/test/python/fx_importer/sympy_to_affine_expr_test.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests checking translating sympy expressions to (semi-)affine expressions. + +from sympy import Symbol +from torch_mlir.extras.fx_importer import sympy_expr_to_semi_affine_expr + +from torch_mlir.ir import ( + AffineSymbolExpr, + Context, +) + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_sympy_to_semi_affine_expr_translation +def test_sympy_to_semi_affine_expr_translation(): + with Context(): + s0 = Symbol("s0", positive=True, integer=True) + s1 = Symbol("s1", positive=True, integer=True) + + symbols_set = sorted({s0, s1}, key=lambda x: x.name) + symbols_map = { + str(symbol): AffineSymbolExpr.get(i) for i, symbol in enumerate(symbols_set) + } + + SYMPY_EXPRS = [ + # CHECK: 10 + (10), + # CHECK: s0 + (s0), + # CHECK: s0 + (s0 + 0), + # CHECK: s0 + 1 + (s0 + 1), + # CHECK: s0 + (s0 * 1), + # CHECK: s0 * 2 + (s0 * 2), + # CHECK: s0 * s0 + (s0 * s0), + # CHECK: s0 * s1 + (s0 * s1), + # CHECK: s0 * s0 + (s0**2), + # CHECK: (s0 * s0) * s0 + (s0**3), + # CHECK: ((((s0 * s0) * s0) * s0) * s0) * s0 + ((s0**2) ** 3), + # CHECK: ((((((s0 * s0) * s0) * s0) * s0) * s0) * s0) * s0 + (s0 ** (2**3)), + # CHECK: s0 mod 10 + (s0 % 10), + # CHECK: s0 - s1 * 2 + 5 + (s0 + 5 - 2 * s1), + ] + + for expr in SYMPY_EXPRS: + print(sympy_expr_to_semi_affine_expr(expr, symbols_map)) diff --git a/test/python/fx_importer/v2.3/types_test.py b/test/python/fx_importer/v2.3/types_test.py index 19dee8b7b2cb..eccea125cea1 100644 --- a/test/python/fx_importer/v2.3/types_test.py +++ b/test/python/fx_importer/v2.3/types_test.py @@ -36,8 +36,13 @@ def forward(self, x): x = x + 1.0 return x.shape[0] + # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int + # CHECK: torch.bind_symbolic_shape %arg0, [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int m = fx.export_and_import( - Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}} + Basic(), + torch.randn(3, 4), + dynamic_shapes={"x": {0: torch.export.Dim("b")}}, + import_symbolic_shape_expressions=True, ) print(m) diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index 2750ee2b7145..0fa392de43b3 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -33,8 +34,13 @@ int main(int argc, char **argv) { registerStripDebugInfoPass(); registerSymbolDCEPass(); + // memref passes used in torch-backend-to-linalg-on-tensors-backend-pipeline + memref::registerExpandOpsPass(); + memref::registerResolveShapedTypeResultDimsPass(); + DialectRegistry registry; mlir::torch::registerAllDialects(registry); + mlir::torch::registerAllExtensions(registry); mlir::torch::registerOptionalInputDialects(registry); #ifdef TORCH_MLIR_ENABLE_STABLEHLO diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index fc005dedeedb..4d96adc70fe5 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -4,4 +4,4 @@ # release page, and we use this page as an additional source for the wheels. -f https://xilinx.github.io/torch-mlir/package-index/ --pre -torchvision==0.19.0.dev20240505 +torchvision==0.19.0.dev20240604 diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index d21d1acad337..e7ac2ca1cab2 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -64,6 +64,7 @@ gentbl_cc_library( td_file = "include/torch-mlir/Dialect/Torch/IR/TorchOps.td", deps = [ ":MLIRTorchOpsIncGenTdFiles", + "@llvm-project//mlir:BuiltinDialectTdFiles", ], ) @@ -329,7 +330,10 @@ gentbl_cc_library( strip_include_prefix = "include", tbl_outs = [ ( - ["-gen-pass-decls"], + [ + "-gen-pass-decls", + "-DTORCH_MLIR_ENABLE_STABLEHLO", + ], "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc", ), ], @@ -496,6 +500,9 @@ cc_library( "lib/Conversion/TorchToStablehlo/*.cpp", ]), hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), + defines = [ + "TORCH_MLIR_ENABLE_STABLEHLO", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen", @@ -556,6 +563,9 @@ cc_library( "lib/Dialect/TorchConversion/Transforms/*.h", ]), hdrs = glob(["include/torch-mlir/Dialect/TorchConversion/Transforms/*.h"]), + defines = [ + "TORCH_MLIR_ENABLE_STABLEHLO", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRTorchBackendTypeConversion", @@ -891,6 +901,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", "@stablehlo//:linalg_passes", "@stablehlo//:stablehlo_passes", ],