Skip to content

Commit

Permalink
[torch] Add OnnxToTorch lowering for onnx.ReduceL2 (#3175)
Browse files Browse the repository at this point in the history
Adds OnnxToTorch lowering for the ReduceL2 op.
  • Loading branch information
vinayakdsci authored Apr 23, 2024
1 parent 3c252cd commit cff2f08
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 6 deletions.
55 changes: 55 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*storeValue=*/operand, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp(
"ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(operand, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0))
return failure();

// A ReduceL2 op is equivalent to the following sequence of operations:
// Mul(x, x) -> ReduceSum -> CastF32 -> Sqrt -> CastLike(resultType)
Value squareOfOperand = rewriter.create<Torch::AtenMulTensorOp>(
binder.getLoc(), operand.getType(), operand, operand);

auto reducedSum =
reducedSumImpl(binder, rewriter, squareOfOperand, resultType,
operand, keepDims, noop_with_empty_axes, true);
if (failed(reducedSum))
return rewriter.notifyMatchFailure(
binder.op,
"Failed to perform sum operation on square of operand");

Value castDType = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 6));

Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value constFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);

// Perform an AtenToDtype op on the squared sum of the operand, stored
// now in operand itself.
auto size = operand.getType()
.dyn_cast<Torch::ValueTensorType>()
.getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
size, rewriter.getF32Type());
Value operandCast = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), f32ResultType, operand, castDType,
/*non_blocking=*/constFalse, /*copy=*/constFalse,
/*memory_format=*/noneVal);

Value operandSqrt = rewriter.create<Torch::AtenSqrtOp>(
binder.getLoc(), f32ResultType, operandCast);

Value resultDtype = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), resultType.getDtype());
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, operandSqrt, resultDtype,
/*non_blocking=*/constFalse, /*copy=*/constFalse,
/*memory_format=*/noneVal);
return success();
});
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
6 changes: 0 additions & 6 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,12 +2615,6 @@
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",

# Failure - onnx_lowering: onnx.ReduceL2
"LinalgNormKeepDimModule_basic",
"LinalgNormModule_basic",
"NormalizeModule_basic",
"ReduceL2NormModule_basic",

# Failure - onnx_lowering: onnx.ReduceProd
"ReduceProdDimIntFloatModule_basic",

Expand Down
99 changes: 99 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,105 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f

// -----

// CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example
func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE_0:.+]] = torch.constant.bool true
// CHECK: %[[NONE_0:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK: %[[NONE_1:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_l2_do_not_keepdims_example_expanded
func.func @test_reduce_l2_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE_0:.+]] = torch.constant.bool false
// CHECK: %[[NONE_0:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK: %[[NONE_1:.+]] = torch.constant.none
// CHECK: %[[FALSE_1:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
// CHECK: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_l2_keep_dims_example
func.func @test_reduce_l2_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE_0:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK: %[[NONE_1:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>

%0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_l2_keep_dims_int_input_example
func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT0_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE_0:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK: %[[NONE_1:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[SUM]], %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[CAST]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[SQRT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>

%0 = torch.operator "onnx.ReduceL2"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
Expand Down

0 comments on commit cff2f08

Please sign in to comment.