Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for bitwise_right_shit and bitwise_and.Scal…
Browse files Browse the repository at this point in the history
…ar op

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Oct 2, 2023
1 parent c434736 commit 9293326
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 18 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,4 +1421,6 @@
"UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
}
94 changes: 94 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2844,6 +2844,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
}];
}

def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -2938,6 +2985,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
}];
}

def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
68 changes: 50 additions & 18 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
Type dtype = converter->convertType(bitwiseAndScalar.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
bitwiseAndScalar.emitError(
"bitwise_and.Scalar does not support non-integer input dtype.");
return nullptr;
}
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::AndIOp>(loc, self, other);
}
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
if (bitwiseOrTensor.getType()
.cast<ValueTensorType>()
Expand Down Expand Up @@ -332,6 +345,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::XOrIOp>(loc, lhs, rhs);
}
if (auto bitwiseRightShiftTensor =
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
bitwiseRightShiftTensor.emitError(
"Bitwise_Right_Shift op does not support non-integer input dtype.");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
}
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Expand Down Expand Up @@ -571,7 +598,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else if(dtype.isa<mlir::ComplexType>()) {
} else if (dtype.isa<mlir::ComplexType>()) {
return b.create<complex::MulOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
Expand Down Expand Up @@ -1066,7 +1093,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();

Value self = payloadArgs[0];
Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value threshold =
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);

Value predicate;
Expand All @@ -1088,7 +1116,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(

Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value threshold =
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));

Value predicate;
Expand Down Expand Up @@ -1197,10 +1226,11 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp,
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
Expand Down Expand Up @@ -1699,7 +1729,8 @@ class ConvertAtenDetachOp : public OpConversionPattern<AtenDetachOp> {
return failure();

Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, adaptor.getSelf());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
Expand Down Expand Up @@ -1735,16 +1766,17 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp>();
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp,
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp,
AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
25 changes: 25 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7410,10 +7410,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9201,6 +9209,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand All @@ -9217,6 +9234,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -796,9 +796,15 @@ def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> Lis
def aten〇bitwise_and〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇bitwise_and〇Scalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -2265,6 +2271,14 @@ def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def aten〇bitwise_and〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_two_tensor_op())
def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
Expand All @@ -2281,6 +2295,14 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_two_tensor_op())
def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,10 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
Expand Down
Loading

0 comments on commit 9293326

Please sign in to comment.