Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Jun 20, 2024
1 parent 066177f commit fad7ddb
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(), resultType, alphaMulX, constBeta,
/*alpha=*/constOne);

// Expression: min(1, alpha * x + beta)
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantOne);
// Expression: min(1.0, alpha * x + beta)
Value oneTensor =
createRank0Tensor(rewriter, binder.getLoc(), resultType, constOne);
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta);

// Expression: max(0, min(1, alpha * x + beta))
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantZero);
// Expression: max(0.0, min(1.0, alpha * x + beta))
Value constZero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(0.0));
Value zeroTensor =
createRank0Tensor(rewriter, binder.getLoc(), resultType, constZero);
rewriter.replaceOpWithNewOp<Torch::AtenMaximumOp>(
binder.op, resultType, zeroTensor, minExpression);
return success();
Expand Down

0 comments on commit fad7ddb

Please sign in to comment.