Skip to content

Commit

Permalink
[ONNX] Fix Onnx.Hardsigmoid lowering
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Apr 26, 2024
1 parent 7030eac commit ac1ca74
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
13 changes: 9 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getF64FloatAttr(beta));

// Expression: alpha * x + beta
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, tensorOperand, constBeta,
/*alpha=*/constAlpha);
Value alphaMulX = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, tensorOperand, constAlpha);
Value constOne = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));
Value alphaMulXPlusBeta = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, alphaMulX, constBeta,
/*alpha=*/constOne);

// Expression: min(1, alpha * x + beta)
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantOne);
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta);
binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta);

// Expression: max(0, min(1, alpha * x + beta))
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
Expand Down
2 changes: 0 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,8 +2183,6 @@
"ElementwiseSeluModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand Down
12 changes: 9 additions & 3 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,9 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !
func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32>
// CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3],f32>, !torch.float -> !torch.vtensor<[3],f32>
// CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
Expand All @@ -652,7 +654,9 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt
func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
Expand All @@ -676,7 +680,9 @@ func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224
// CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
Expand Down

0 comments on commit ac1ca74

Please sign in to comment.