From 3c252cdd44f411ef67e3a759319be53b46396d44 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 22 Apr 2024 22:28:07 +0530 Subject: [PATCH] [onnx] Add `onnx-to-torch` lowering for random ops (#3193) This commit adds the OnnxToTorch lowering for Onnx's RandomNormal, RandomNormalLike, RandomUniform, and RandomUniformLike op. --- .../Conversion/TorchOnnxToTorch/Utils.h | 2 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 92 ++------ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 214 ++++++++++++++++++ lib/Conversion/TorchOnnxToTorch/Utils.cpp | 38 ++++ projects/pt1/e2e_testing/xfail_sets.py | 14 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 62 +++++ 6 files changed, 339 insertions(+), 83 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 8e9de1ff5940..d4ace352a9bd 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -87,6 +87,8 @@ m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { return detail::onnx_list_of_constant_ints_op_binder(bind_values); } +std::optional onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 7032ddcd208e..b16e76e3afe5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/FormatVariadic.h" @@ -17,56 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; -class Endian { -private: - static constexpr uint32_t uint32_ = 0x01020304; - static constexpr uint8_t magic_ = (const uint8_t &)uint32_; - -public: - static constexpr bool little = magic_ == 0x04; - static constexpr bool big = magic_ == 0x01; - static_assert(little || big, "Cannot determine endianness!"); - -private: - Endian() = delete; -}; - -static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { - // TODO: Add complete mapping. - // Where are the ONNX and PyTorch dtype enums defined? - // ONNX: - // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto - // PyTorch: - // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 - - int64_t dtypeIntTorch = [dtypeIntOnnx]() { - switch (dtypeIntOnnx) { - case 1: - return 6; // float - case 2: - return 0; // uint8 - case 3: - return 1; // int8 - case 6: - return 3; // int32 - case 7: - return 4; // int64 - case 9: - return 11; // bool - case 10: - return 5; // half - case 11: - return 7; // double - case 16: - return 15; // bfloat16 - default: - return -1; // No dtype - } - }(); - - return dtypeIntTorch; -} - static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t dimA, int64_t dimB, @@ -428,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; - int64_t dtypeIntOnnx, dtypeIntTorch; + int64_t dtypeIntOnnx; if (binder.tensorOperand(input) || binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) || binder.tensorResultType(resultType)) @@ -452,16 +403,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( rewriter.replaceOp(binder.op, bernoulli); return success(); } - dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); - if (dtypeIntTorch == -1) { + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - dtypeIntTorch)); + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value cstFalse = rewriter.create(binder.getLoc(), false); rewriter.replaceOpWithNewOp( @@ -539,25 +489,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; - int64_t dtypeIntOnnx, dtypeIntTorch; + int64_t dtypeIntOnnx; if (binder.tensorOperand(operand) || binder.s64IntegerAttr(dtypeIntOnnx, "to") || binder.tensorResultType(resultType)) return failure(); - dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); - if (dtypeIntTorch == -1) { - auto message = llvm::formatv("unimplemented support for the given " - "dtype conversion (onnx 'type' = {0})", - dtypeIntOnnx); - auto y = rewriter.notifyMatchFailure(binder.op, message); - - return y; + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); } Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - dtypeIntTorch)); + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value none = rewriter.create(binder.getLoc()); Value cstFalse = rewriter.create(binder.getLoc(), false); @@ -1768,9 +1714,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value mVal = rewriter.create(binder.getLoc(), operand, cst1); Value noneVal = rewriter.create(binder.getLoc()); - int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } Value dtypeVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch)); + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); // diagonalIndex = 0 populates the main diagonal // diagonalIndex > 0 populates an upper diagonal diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 7630fcfa1108..6c86ecb92789 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2274,4 +2274,218 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, input, cstAlpha, value); return success(); }); + patterns.onOp( + "RandomNormal", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float mean, scale; + SmallVector shape; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(mean, "mean", 0.0) || + binder.f32FloatAttr(scale, "scale", 1.0) || + binder.s64IntegerArrayAttr(shape, "shape", {}) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, shape); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType, shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + Value cstMean = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), mean)); + Value cstStd = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), scale)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, cstMean, cstStd, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomNormalLike", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float mean, scale; + SmallVector shape; + Value input; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(mean, "mean", 0.0) || + binder.f32FloatAttr(scale, "scale", 1.0) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + input = rewriter.create( + binder.op->getLoc(), resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstNone); + + Value cstMean = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), mean)); + Value cstStd = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), scale)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstMean, cstStd, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomUniform", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float high, low; + SmallVector shape; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(high, "high", 1.0) || + binder.f32FloatAttr(low, "low", 0.0) || + binder.s64IntegerArrayAttr(shape, "shape", {}) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, shape); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType, shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + Value cstHigh = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), high)); + Value cstLow = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), low)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, cstLow, cstHigh, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomUniformLike", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float high, low; + SmallVector shape; + Value input; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(high, "high", 1.0) || + binder.f32FloatAttr(low, "low", 0.0) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + input = rewriter.create( + binder.op->getLoc(), resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstNone); + + Value cstHigh = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), high)); + Value cstLow = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), low)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstLow, cstHigh, + /*generator=*/cstNone); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index 2d24303394dd..dec13490666e 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -59,3 +59,41 @@ bool mlir::torch::onnx_c::areAllElementsDistinct(SmallVector array) { // as array's size. return (set.size() == array.size()); } + +std::optional +mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { + // TODO: Add complete mapping. + // Where are the ONNX and PyTorch dtype enums defined? + // ONNX: + // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto + // PyTorch: + // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 + + std::optional dtypeIntTorch = + [dtypeIntOnnx]() -> std::optional { + switch (dtypeIntOnnx) { + case 1: + return 6; // float + case 2: + return 0; // uint8 + case 3: + return 1; // int8 + case 6: + return 3; // int32 + case 7: + return 4; // int64 + case 9: + return 11; // bool + case 10: + return 5; // half + case 11: + return 7; // double + case 16: + return 15; // bfloat16 + default: + return std::nullopt; // No dtype + } + }(); + + return dtypeIntTorch; +} diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 64a9d3bb6169..323a39bf33cb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2605,27 +2605,15 @@ # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", - # Failure - onnx_lowering: onnx.RandomNormal + # ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64) "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", "RandnGeneratorModule_basic", "RandnModule_basic", - - # Failure - onnx_lowering: onnx.RandomNormalLike - "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", - - # Failure - onnx_lowering: onnx.RandomUniform - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - - # Failure - onnx_lowering: onnx.RandomUniformLike "BernoulliFloatModule_basic", "BernoulliPModule_basic", "BernoulliTensorModule_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", - "RandModule_basic", # Failure - onnx_lowering: onnx.ReduceL2 "LinalgNormKeepDimModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index de3e796f4e5d..43849fbbd06e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1679,3 +1679,65 @@ func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtens %0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64> return %0 : !torch.vtensor<[0,5],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_random_normal +func.func @test_random_normal() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[I10:.+]] = torch.constant.int 10 + // CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: torch.aten.normal_functional %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomNormal"() {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_random_normal_like +func.func @test_random_normal_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: torch.aten.normal_functional %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomNormalLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.mean = 0.000000e+00 : f32, torch.onnx.scale = 1.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_random_uniform +func.func @test_random_uniform() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[I10:.+]] = torch.constant.int 10 + // CHECK: %[[SHAPE:.+]] = torch.prim.ListConstruct %[[I10]] : (!torch.int) -> !torch.list + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.+]] = torch.aten.empty.memory_format %[[SHAPE]], %[[I6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK: torch.aten.uniform %[[EMPTY_TENSOR]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomUniform"() {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32, torch.onnx.shape = [10 : si64]} : () -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_random_uniform_like +func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[I6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: torch.aten.uniform %[[CAST]], %[[F0]], %[[F1]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 1 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +}