Skip to content

Commit

Permalink
[MLIR][TORCH] Add e2e support for aten.pow.Scalar op
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Aug 31, 2023
1 parent aa15f0d commit 5c43daa
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 21 deletions.
53 changes: 33 additions & 20 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
divTensorMode.emitError("invalid rounding mode");
return nullptr;
}

if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype);
Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
return b.create<math::PowFOp>(loc, selfPromoted, expPromoted);
}

if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType()
.cast<ValueTensorType>()
Expand Down Expand Up @@ -1162,20 +1174,20 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -1697,17 +1709,18 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6522,6 +6522,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.pow.Scalar\"(%arg0: !torch.float, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9747,6 +9751,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
1 change: 0 additions & 1 deletion python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ def compile(model: torch.nn.Module,
if output_type == OutputType.RAW:
return mb.module

# mb.module.dump()
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
" extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int
def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇pow〇Scalar〡shape(self: float, exponent: List[int]) -> List[int]:
return upstream_shape_functions.unary(exponent)

def aten〇pow〇Tensor_Scalar〡shape(self: List[int], exponent: float) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -2704,6 +2707,12 @@ def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)

def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
exponent_rank, exponent_dtype = exponent_rank_dtype
ranks: List[Optional[int]] = [None, exponent_rank]
dtypes = [get_dtype_of_scalar(self), exponent_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0))
def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int:
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,28 @@ def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwisePowScalarModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4], torch.float32, True),
])
def forward(self, exp):
return torch.pow(2.0, exp)


@register_test_case(module_factory=lambda: ElementwisePowScalarModule())
def ElementwisePowScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))


# ==============================================================================


class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 5c43daa

Please sign in to comment.