Skip to content

Commit

Permalink
[Stablehlo] Support AtenPowScalarOp, AtenTanOp, AtenAsinhOp, AtenAcos…
Browse files Browse the repository at this point in the history
…hOp, AtenAtanhOp, Atan2Op (#3233)
  • Loading branch information
Xinyu Yang authored Apr 26, 2024
1 parent 634a796 commit ac85338
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 7 deletions.
43 changes: 43 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,43 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
return success();
}

// AtenPowScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
AtenPowScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent();
auto rhsType = dyn_cast<TensorType>(rhs.getType());

if (!rhsType)
return op.emitError("only Tensor types supported in StableHLO");

auto outType = cast<TensorType>(
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
op.getType()));

Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}

if (!lhsType) {
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);

rewriter.replaceOp(op, result);
return success();
}

// PrimNumToTensorScalarOp
template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
Expand Down Expand Up @@ -2050,11 +2087,15 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanOp, chlo::TanOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinhOp, chlo::AsinhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcoshOp, chlo::AcoshOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanhOp, chlo::AtanhOp);
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
Expand Down Expand Up @@ -2137,6 +2178,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenPowScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
Expand Down Expand Up @@ -2181,5 +2223,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenAtan2Op, chlo::BroadcastAtan2Op);
#undef INSERT_BINARY_BROADCAST_PATTERN
}
20 changes: 13 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,16 +622,10 @@
"DiagonalModule_with_offset",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAsinhIntModule_basic",
"ElementwiseAsinhModule_basic",
"ElementwiseAtan2FloatIntModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic",
Expand All @@ -643,7 +637,6 @@
"ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -992,6 +985,15 @@
"DropoutEvalIntModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAsinhIntModule_basic",
"ElementwiseAsinhModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtan2TensorFloatStaticModule_basic",
"ElementwiseAtan2TensorIntStaticModule_basic",
"ElementwiseAtan2FloatIntStaticModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenIsinfOpModule_basic",
Expand Down Expand Up @@ -1057,6 +1059,7 @@
"ElementwiseNegModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePreluStaticModule_basic",
Expand All @@ -1069,6 +1072,8 @@
"ElementwiseSigmoidModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic",
Expand Down Expand Up @@ -2180,6 +2185,7 @@
"AvgPool2dDivisorOverrideModule_basic",
"BroadcastDynamicDimModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtan2TensorIntStaticModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseLog10IntModule_basic",
Expand Down
70 changes: 70 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,28 @@ def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2TensorFloatStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([4, 5, 6], torch.float32, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatStaticModule())
def ElementwiseAtan2TensorFloatStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6))


# ==============================================================================

class ElementwiseAtan2TensorIntModule(torch.nn.Module):

def __init__(self):
Expand All @@ -1816,6 +1838,30 @@ def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2TensorIntStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.int32, True),
([4, 5, 6], torch.int64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntStaticModule())
def ElementwiseAtan2TensorIntStaticModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, 5, 6, low=1, high=10).type(torch.int32), tu.randint(4, 5, 6, low=1, high=10))


# ==============================================================================


class ElementwiseAtan2FloatIntModule(torch.nn.Module):

def __init__(self):
Expand All @@ -1840,6 +1886,30 @@ def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2FloatIntStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.int32, True),
([4, 5, 6], torch.float64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntStaticModule())
def ElementwiseAtan2FloatIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, 6, low=1, high=10).to(torch.int32),
tu.rand(4, 5, 6).double())


# ==============================================================================


class ElementwiseLogModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit ac85338

Please sign in to comment.