diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index bbd5005c44e3..327a02a9af6d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -61,19 +61,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), resultType, alphaMulX, constBeta, /*alpha=*/constOne); - // Expression: min(1, alpha * x + beta) - Value constantOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), - resultType, constantOne); + // Expression: min(1.0, alpha * x + beta) + Value oneTensor = + createRank0Tensor(rewriter, binder.getLoc(), resultType, constOne); Value minExpression = rewriter.create( binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta); - // Expression: max(0, min(1, alpha * x + beta)) - Value constantZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), - resultType, constantZero); + // Expression: max(0.0, min(1.0, alpha * x + beta)) + Value constZero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(0.0)); + Value zeroTensor = + createRank0Tensor(rewriter, binder.getLoc(), resultType, constZero); rewriter.replaceOpWithNewOp( binder.op, resultType, zeroTensor, minExpression); return success();