Skip to content

Commit

Permalink
Fix for hardsigmoid
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Apr 24, 2024
1 parent 3097a1a commit 705355a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
13 changes: 9 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getF64FloatAttr(beta));

// Expression: alpha * x + beta
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, tensorOperand, constBeta,
/*alpha=*/constAlpha);
Value alphaMulX = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, tensorOperand, constAlpha);
Value constOne = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));
Value alphaMulXPlusBeta = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, alphaMulX, constBeta,
/*alpha=*/constOne);

// Expression: min(1, alpha * x + beta)
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantOne);
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta);
binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta);

// Expression: max(0, min(1, alpha * x + beta))
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
Expand Down
4 changes: 2 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,8 +2156,8 @@
"ElementwiseLog2IntModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
# "HardsigmoidModule_basic",
# "HardsigmoidRandomModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand Down

0 comments on commit 705355a

Please sign in to comment.