diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8b237cf50b43..15cf9bb79f91 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1421,4 +1421,6 @@ "UniformStaticShapeModule_basic", "AtenEmbeddingBagStaticModule_basic", "EmptyStridedModule_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 673fab897be0..4ecad92c662b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2844,6 +2844,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ }]; } +def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -2938,6 +2985,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ }]; } +def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8a6366990b94..b47e13c8619e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -300,6 +300,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseAndScalar = dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseAndScalar.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseAndScalar.emitError( + "bitwise_and.Scalar does not support non-integer input dtype."); + return nullptr; + } + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + return b.create(loc, self, other); + } if (auto bitwiseOrTensor = dyn_cast(op)) { if (bitwiseOrTensor.getType() .cast() @@ -332,6 +345,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseRightShiftTensor = + dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseRightShiftTensor.emitError( + "Bitwise_Right_Shift op does not support non-integer input dtype."); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); @@ -571,7 +598,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); - } else if(dtype.isa()) { + } else if (dtype.isa()) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); @@ -1066,7 +1093,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = payloadArgs[0]; - Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); + Value threshold = + convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value predicate; @@ -1088,7 +1116,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); + Value threshold = + convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; @@ -1197,10 +1226,11 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, - AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, - AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, + AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, @@ -1699,7 +1729,8 @@ class ConvertAtenDetachOp : public OpConversionPattern { return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); return success(); } }; @@ -1735,16 +1766,17 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, - AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, - AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, - AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenRealOp, AtenImagOp>(); + AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, + AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, + AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, + AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, + AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2fffbd313927..e8f5aa568f59 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7410,10 +7410,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9201,6 +9209,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9217,6 +9234,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 7c507f53b70e..f74895d9dad6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -796,9 +796,15 @@ def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> Lis def aten〇bitwise_and〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_and〇Scalar〡shape(self: List[int], other: float) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_not〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2265,6 +2271,14 @@ def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_ dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇bitwise_and〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_two_tensor_op()) def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype @@ -2281,6 +2295,14 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_ dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + # Different width diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4c2b30d817db..56d18d3847d1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -301,8 +301,10 @@ def emit_with_mutating_variants(key, **kwargs): "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2df0a5513d4a..a0137f23e71b 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3515,3 +3515,107 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: TupleModule()) def TupleModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2), tu.rand(2, 2)) + + +# ============================================================================== + + +class ElementwiseBitwiseRightShiftInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt64Module()) +def ElementwiseBitwiseRightShiftInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64)) + + +class ElementwiseBitwiseRightShiftInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 4], torch.int32, True), + ([-1, 1], torch.int32, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt32Module()) +def ElementwiseBitwiseRightShiftInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32)) + + +class ElementwiseBitwiseRightShiftInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt8Module()) +def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8)) + + +# ============================================================================== + + +class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 15) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt64Module()) +def ElementwiseBitwiseAndScalarInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000)) + + +class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 100) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module()) +def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32))