diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1858b1a6d7ca..2334d0180865 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1978,6 +1978,36 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenBitwiseLeftShiftTensorOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBitwiseLeftShiftTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + Value rhs = adaptor.getOther(); + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); + rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); +} + +// AtenBitwiseRightShiftTensorOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBitwiseRightShiftTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + Value rhs = adaptor.getOther(); + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); + rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -2137,6 +2167,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); INSERT_ATENOP_PATTERN(AtenFmodTensorOp); + INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp); + INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 5f096205ea8c..92888616a67b 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -315,3 +315,34 @@ func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vte %0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64> return %0 : !torch.vtensor<[32, 64],f64> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,1],si32> -> tensor<3x1xi32> +// CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_1:.*]], dims = [0, 1] : (tensor<3x1xi32>) -> tensor<3x4xi32> +// CHECK: %[[VAL_3:.*]] = stablehlo.shift_left %[[VAL_0:.*]], %[[VAL_2:.*]] : tensor<3x4xi32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3:.*]] : tensor<3x4xi32> -> !torch.vtensor<[3,4],si32> +// CHECK: return %[[VAL_4:.*]] : !torch.vtensor<[3,4],si32> +func.func @torch.aten.bitwise_left_shift.Tensor(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> { + %0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,1],si32> -> !torch.vtensor<[3,4],si32> + return %0 : !torch.vtensor<[3,4],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si64>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64> +// CHECK: %[[VAL_2:.*]] = stablehlo.shift_right_arithmetic %[[VAL_0:.*]], %[[VAL_1:.*]] : tensor<3x4xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2:.*]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> +// CHECK: return %[[VAL_3:.*]] : !torch.vtensor<[3,4],si64> +func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si64>, %arg1: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> { + %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64> + return %0 : !torch.vtensor<[3,4],si64> +}