From 524ff99216e99805dd1fe003fb9c2947fb8a9770 Mon Sep 17 00:00:00 2001 From: ptrifunovic98 <156185835+ptrifunovic98@users.noreply.github.com> Date: Wed, 13 Mar 2024 20:17:22 +0100 Subject: [PATCH] Implement lowering of torch.aten.linalg_cross (#2986) Closes [nod-ai/SHARK-Turbine#497](https://github.com/nod-ai/SHARK-Turbine/issues/497) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ++++ lib/Dialect/Torch/IR/TorchOps.cpp | 90 ++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 76 ++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 112 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 24 ++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/matmul.py | 111 +++++++++++++++++ 9 files changed, 444 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3e92d40992b8..7c0d7a73e89c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11732,6 +11732,32 @@ def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ }]; } +def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgCrossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLinalgCrossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9ecf0e3e262e..30e1ff987aa9 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4278,6 +4278,96 @@ LogicalResult AtenPermuteOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenLinalgCrossOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenLinalgCrossOp::verify() { + + auto selfType = getSelf().getType().cast(); + auto otherType = getOther().getType().cast(); + + if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() || + !otherType.hasSizes()) { + return success(); + } + + Type selfDtype = selfType.getDtype(); + Type otherDtype = otherType.getDtype(); + + // the operation succeeds only if both inputs have the same dtype + if (selfDtype != otherDtype) { + return emitOpError("input tensors must have the same dtype, but got ") + << selfDtype << " and " << otherDtype; + } + + // Check if any of the input tensors has torch.bool dtype. + // The operation does not support this type. + // The docs state that only float, double, cfloat and cdouble dtypes are + // supported, but, when testing, it fails only for boolean dtype. Update to + // fit the docs if necessary. + // https://pytorch.org/docs/stable/generated/torch.linalg.cross.html + if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) { + return emitOpError("input tensors must not have bool dtype"); + } + + ArrayRef selfShape = selfType.getSizes(); + ArrayRef otherShape = otherType.getSizes(); + + int64_t selfRank = selfShape.size(); + int64_t otherRank = otherShape.size(); + + // check if both input tensors have the same number of dims + if (selfRank != otherRank) { + return emitOpError("input tensors must have the same number of dimensions, " + "but got ") + << selfRank << " and " << otherRank; + } + + // convert dim to an integer type + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) { + return success(); + } + + // check if dim is in the correct range + if (dim >= selfRank || dim < -selfRank) { + return emitOpError("dim expected to be in rank of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + } + + // compensate for possible negative dim value + if (dim < 0) { + dim += selfRank; + } + + // check if the size of the dimensions specified by 'dim' is equal to 3 + // (required by the operation) + if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) || + (otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) { + return emitOpError("inputs dimension ") + << dim << " must have length 3, but got " << selfShape[dim] + << " and " << otherShape[dim]; + } + + // Check if there is a disparity between dimension sizes. + // Dimensions at the same index must either have the same size, + // or one of them must be equal to 1. + int32_t i = 0; + for (auto [selfCurrent, otherCurrent] : + llvm::zip_equal(selfShape, otherShape)) { + if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) { + return emitOpError("the size of first tensor (") + << selfCurrent << ") must match the size of second tensor (" + << otherCurrent << ") at dimension " << i + << " or one of them must be 1"; + } + ++i; + } + + return success(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 19c84617a2a1..c55a2421f5be 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6793,6 +6793,57 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_cross\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"the size of first tensor ({}) must match the size of second tensor ({}) at dimension {}\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: inputs must have the same number of dimensions\"\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.format(%str_0, %10, %11, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %13 = torch.aten.add.str %str, %12 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %13, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10033,6 +10084,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_cross\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 09d5b90f0eeb..5335cbba9bb0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1823,6 +1823,117 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace +// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select, +// aten.add.Tensor and aten.mull.Tensor. See +// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70. +// def linalg_cross(self: Tensor, other: Tensor, dim: int = -1): +// broadcast_shape = compute_broadcast_shape(self, other) +// a = torch.broadcast_to(self, broadcast_shape) +// b = torch.broadcast_to(other, broadcast_shape) +// idx = torch.arange(3) +// return a.index_select(dim, (idx + 1) % 3) * +// b.index_select(dim, (idx + 2) % 3) - +// a.index_select(dim, (idx + 2) % 3) * +// b.index_select(dim, (idx + 1) % 3) +namespace { +class DecomposeAtenLinalgCrossOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgCrossOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + Type opType = op.getType(); + Value dim = op.getDim(); + + auto resType = self.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + Type dtype = resType.getDtype(); + if (dtype.isa()) { + return rewriter.notifyMatchFailure( + op, "lowering of aten.linalg_cross for complex inputs dtype is " + "currently unimplemented"); + } + + // calculate common shape for broadcast + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + computeBroadcastShape(rewriter, loc, self, other, broadcastShape, + broadcastShapeValue); + + Type broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), dtype); + + Value indexBroadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + broadcastShapeValue); + + // broadcast tensors to common shape + auto a = rewriter.create(loc, broadcastType, self, + indexBroadcastShapeTorchList); + auto b = rewriter.create(loc, broadcastType, other, + indexBroadcastShapeTorchList); + + // create constants + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constTwo = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value constThree = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + Value none = rewriter.create(loc); + + // idx = torch.arange(3) + auto outType = opType.dyn_cast(); + auto arangeType = outType.getWithSizesAndDtype( + llvm::ArrayRef(3), + IntegerType::get(op.getContext(), 64, IntegerType::Signed)); + auto idx = rewriter.create( + loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + // (idx + 1) and (idx + 2) + auto idxPlusOne = rewriter.create(loc, arangeType, idx, + constOne, constOne); + auto idxPlusTwo = rewriter.create(loc, arangeType, idx, + constTwo, constOne); + + // (idx + 1) % 3 and (idx + 2) % 3 + auto idxPlusOneRemainderThree = rewriter.create( + loc, arangeType, idxPlusOne, constThree); + auto idxPlusTwoRemainderThree = rewriter.create( + loc, arangeType, idxPlusTwo, constThree); + + // a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3) + auto idxSelectAPlusOne = rewriter.create( + loc, opType, a, dim, idxPlusOneRemainderThree); + auto idxSelectBPlusTwo = rewriter.create( + loc, opType, b, dim, idxPlusTwoRemainderThree); + auto firstMul = rewriter.create( + loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo); + + // a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + auto idxSelectAPlusTwo = rewriter.create( + loc, opType, a, dim, idxPlusTwoRemainderThree); + auto idxSelectBPlusOne = rewriter.create( + loc, opType, b, dim, idxPlusOneRemainderThree); + auto secondMul = rewriter.create( + loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne); + + // subtract the results of the two multiplications from above + rewriter.replaceOpWithNewOp(op, opType, firstMul, + secondMul, constOne); + + return success(); + } +}; +} // namespace + // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -7081,6 +7192,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f52c46789350..44a3986ac52f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -395,6 +395,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 885164e35cb3..e262f52e3018 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2107,6 +2107,9 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", + + # Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1" + "AtenLinalgCrossDynamic_basic" } ONNX_CRASHING_SET = { } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8ef43b0082b0..1f6e450da987 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -384,6 +384,17 @@ def aten〇clone〡shape(self: List[int], memory_format: Optional[int] = None) - def aten〇lift_fresh_copy〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +@check_shape_function([ + Invocation(TensorOfShape(1, 2, 3), TensorOfShape(4, 1, 3)), # two dimensions to broadcast, self[0] and other[1] + ErrorInvocation(TensorOfShape(3), TensorOfShape(2, 3)), # different number of dimensions + ErrorInvocation(TensorOfShape(2, 3), TensorOfShape(4, 3)) # non-broadcastable dimensions +]) +def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1) -> List[int]: + assert len(self) == len(other), "inputs must have the same number of dimensions" + for i in range(len(self)): + assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}" + return upstream_shape_functions.broadcast(self, other) + def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -2381,6 +2392,19 @@ def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_device="cpu", tensor_shapes=[(2,3), (2,3)], error_types={torch.bool}) + # same dtype + [ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32, device="cpu"), TensorOfShape(2, 3, dtype=torch.float16, device="cpu"))] #different dtypes +) +def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + assert self_dtype == other_dtype + assert self_dtype != torch.bool + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_two_tensor_op(dim=0, input_dtype=torch.float32) + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 055f7127c9f2..fa469e035064 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -687,6 +687,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 72a4097bc302..80f02a7b5dc8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -289,3 +289,114 @@ def forward(self, x, y): def AtenMmQuint8_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), tu.randint(4, 3, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class AtenLinalgCrossInt(torch.nn.Module): + + @export + @annotate_args([ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.int64, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossInt()) +def AtenLinalgCrossInt_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3), tu.randint(2, 3)) + +# ============================================================================== + +class AtenLinalgCrossFloat(torch.nn.Module): + + @export + @annotate_args([ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossFloat()) +def AtenLinalgCrossFloat_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.rand(2, 3)) + + +# ============================================================================== + +class AtenLinalgCrossBroadcast(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3], torch.float32, True), + ([5, 4, 3], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossBroadcast()) +def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3)) + +# ============================================================================== + +class AtenLinalgCrossCustomDim(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=2) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossCustomDim()) +def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + +# ============================================================================== + +class AtenLinalgCrossNegativeDim(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=-3) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossNegativeDim()) +def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + +# ============================================================================== + +class AtenLinalgCrossDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=1) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) +def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) \ No newline at end of file