From 560bd91f795a056e2188bd2b9850696a7800a908 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 26 Apr 2024 04:12:10 +0000 Subject: [PATCH] Fix failing CI --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index cf33338ae426..7a150794cb4b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -52,14 +52,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta - Value alphaMulX = rewriter.create( - binder.getLoc(), resultType, tensorOperand, constAlpha); - Value constOne = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getF64FloatAttr(1.0)); - Value alphaMulXPlusBeta = rewriter.create( - binder.getLoc(), resultType, alphaMulX, constBeta, - /*alpha=*/constOne); + Value alpha_x_plus_beta = rewriter.create( + binder.getLoc(), resultType, tensorOperand, constBeta, + /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( @@ -67,7 +62,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantOne); Value minExpression = rewriter.create( - binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta); + binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); // Expression: max(0, min(1, alpha * x + beta)) Value constantZero = rewriter.create(