diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d26601c0de8d..6a4175c78985 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -847,15 +847,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0 Torch::ValueTensorType resultType; float alpha, gamma; Value operand; if (binder.tensorOperand(operand) || - binder.f32FloatAttr(alpha, "alpha") || - binder.f32FloatAttr(gamma, "gamma") || + binder.f32FloatAttr(alpha, "alpha", 1.67326) || + binder.f32FloatAttr(gamma, "gamma", 1.0507) || binder.tensorResultType(resultType)) return failure(); + Torch::ValueTensorType inputType = + operand.getType().cast(); + Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -864,12 +868,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value vInputScale = rewriter.create( + Value cstOne = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, vAlpha, vScale, vInputScale); + Value cstNone = rewriter.create(binder.getLoc()); + Value zeroTensor = rewriter.create( + binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, + cstNone, cstNone); + Value exp = rewriter.create(binder.getLoc(), + resultType, operand); + Value expMulAlpha = rewriter.create( + binder.getLoc(), resultType, exp, vAlpha); + Value expMulAlphaSubAlpha = rewriter.create( + binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); + Value neg = rewriter.create( + binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); + Value pos = rewriter.create( + binder.getLoc(), resultType, operand, vScale); + Type compareType = inputType.getWithSizesAndDtype( + inputType.getOptionalSizes(), rewriter.getI1Type()); + Value xLessThanZero = rewriter.create( + binder.getLoc(), compareType, operand, zeroTensor); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, xLessThanZero, neg, pos); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e768033ac87f..2cea874f341f 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); - if (!inputDtype || !inputDtype.isInteger(64)) + if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1))) return nullptr; std::optional inputRank = getTensorRank(input); @@ -148,11 +148,19 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() - .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + if (inputDtype.isInteger(64)) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } else { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e426e998ebe0..f34b5a5e3b2c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2154,7 +2154,6 @@ "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", - "ElementwiseSeluModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", @@ -2669,8 +2668,6 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index c8d513a31d21..54b22e982779 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,10 +582,18 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 - // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 - // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 - // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] + // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 + // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 + // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> }