Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Jun 20, 2024
1 parent 066177f commit a939959
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 22 deletions.
15 changes: 6 additions & 9 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(alpha));

Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(beta));
Expand All @@ -62,18 +61,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*alpha=*/constOne);

// Expression: min(1, alpha * x + beta)
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantOne);
Value oneTensor =
createRank0Tensor(rewriter, binder.getLoc(), resultType, constOne);
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta);

// Expression: max(0, min(1, alpha * x + beta))
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(),
resultType, constantZero);
Value constZero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0));
Value zeroTensor =
createRank0Tensor(rewriter, binder.getLoc(), resultType, constZero);
rewriter.replaceOpWithNewOp<Torch::AtenMaximumOp>(
binder.op, resultType, zeroTensor, minExpression);
return success();
Expand Down
22 changes: 9 additions & 13 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -894,20 +894,18 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt
// CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3],f32>, !torch.float -> !torch.vtensor<[3],f32>
// CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[F1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[INT_0:.*]] = torch.constant.int 0
// CHECK: %[[F0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[F0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32>

%0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
Expand All @@ -921,17 +919,16 @@ func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
// CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[F1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_0:.*]] = torch.constant.int 0
// CHECK: %[[F0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[F0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
Expand All @@ -947,17 +944,16 @@ func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torc
// CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_1:.*]] = torch.constant.int 1
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[F1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT_0:.*]] = torch.constant.int 0
// CHECK: %[[F0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none
// CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[F0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
Expand Down

0 comments on commit a939959

Please sign in to comment.