Skip to content

Commit

Permalink
dd aten left/right shift op conversion support
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy committed Apr 25, 2024
1 parent 7be22bb commit ce9827d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
32 changes: 32 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1978,6 +1978,36 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
return success();
}

// AtenBitwiseLeftShiftTensorOp
template <>
LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
AtenBitwiseLeftShiftTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();

auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
rewriter.replaceOpWithNewOp<stablehlo::ShiftLeftOp>(op, lhs, rhs);
return success();
}

// AtenBitwiseRightShiftTensorOp
template <>
LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
AtenBitwiseRightShiftTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();

auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
rewriter.replaceOpWithNewOp<stablehlo::ShiftRightArithmeticOp>(op, lhs, rhs);
return success();
}

void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
Expand Down Expand Up @@ -2137,6 +2167,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
Expand Down
31 changes: 31 additions & 0 deletions test/Conversion/TorchToStablehlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,34 @@ func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vte
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
return %0 : !torch.vtensor<[32, 64],f64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,1],si32> -> tensor<3x1xi32>
// CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_1:.*]], dims = [0, 1] : (tensor<3x1xi32>) -> tensor<3x4xi32>
// CHECK: %[[VAL_3:.*]] = stablehlo.shift_left %[[VAL_0:.*]], %[[VAL_2:.*]] : tensor<3x4xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3:.*]] : tensor<3x4xi32> -> !torch.vtensor<[3,4],si32>
// CHECK: return %[[VAL_4:.*]] : !torch.vtensor<[3,4],si32>
func.func @torch.aten.bitwise_left_shift.Tensor(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
%0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,1],si32> -> !torch.vtensor<[3,4],si32>
return %0 : !torch.vtensor<[3,4],si32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si64>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64>
// CHECK: %[[VAL_2:.*]] = stablehlo.shift_right_arithmetic %[[VAL_0:.*]], %[[VAL_1:.*]] : tensor<3x4xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2:.*]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[VAL_3:.*]] : !torch.vtensor<[3,4],si64>
func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si64>, %arg1: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> {
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
return %0 : !torch.vtensor<[3,4],si64>
}

0 comments on commit ce9827d

Please sign in to comment.