Skip to content

Commit

Permalink
[ONNX] Fix Onnx.Selu lowering and canonicalizer for IntImplicit op
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Apr 24, 2024
1 parent cff2f08 commit f1b16d8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
32 changes: 27 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,11 +851,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
float alpha, gamma;
Value operand;
if (binder.tensorOperand(operand) ||
binder.f32FloatAttr(alpha, "alpha") ||
binder.f32FloatAttr(gamma, "gamma") ||
binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
binder.tensorResultType(resultType))
return failure();

Torch::ValueTensorType inputType =
operand.getType().cast<Torch::ValueTensorType>();

Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
Expand All @@ -864,12 +867,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));

Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));

rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zeroTensor = rewriter.create<Torch::AtenZerosLikeOp>(
binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone,
cstNone, cstNone);
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
resultType, operand);
Value expMulAlpha = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, exp, vAlpha);
Value expMulAlphaSubAlpha = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne);
Value neg = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale);
Value pos = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, operand, vScale);
Type compareType = inputType.getWithSizesAndDtype(
inputType.getOptionalSizes(), rewriter.getI1Type());
Value xLessThanZero = rewriter.create<Torch::AtenLtTensorOp>(
binder.getLoc(), compareType, operand, zeroTensor);

rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
binder.op, resultType, xLessThanZero, neg, pos);
return success();
});
patterns.onOp("ReduceL1", 1,
Expand Down
20 changes: 14 additions & 6 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,27 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr;

Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(64))
if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
return nullptr;

std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 0)
return nullptr;

if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
if (inputDtype.isInteger(64)) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
} else {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<bool>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
}
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA();
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,7 +2154,6 @@
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseSeluModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic",
Expand Down Expand Up @@ -2669,8 +2668,6 @@
"CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",
Expand Down

0 comments on commit f1b16d8

Please sign in to comment.