diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e5f4fea4f46c..b6dbdc2c7b8c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +Value wrapNegativeIndices(Value index, int maxIndex, Operation *op, + ConversionPatternRewriter &rewriter) { + + auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto maxIndexValue = + tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + + auto indexType = dyn_cast(index.getType()); + + auto wrappedIndicesOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), indexType, maxIndexValue, index); + auto boolType = indexType.clone(rewriter.getIntegerType(1)); + auto isNegativeIndices = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, index); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), + indexType, isNegativeIndices, + wrappedIndicesOp, index); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, @@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()); + Operation *indicesTf; + // Support for multiple indexes if (indexTensors.size() > 1) { // t[i, i] @@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( index); } + index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, + rewriter); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indiceShapeOneDim; for (auto shape : indexShape) { @@ -4299,49 +4322,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesShapeConcat = indexesShape[0]; uint64_t lastDim = indexesRank[0]; indicesShapeConcat.push_back(indicesTfConcatTensors.size()); - auto indicesTf = tosa::CreateOpAndInfer( + indicesTf = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + } else { - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); + // Single index + auto index = indexTensors[0]; + auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), + index); } - rewriter.replaceOp(op, {result.value()}); - return success(); - } + index = + wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter); - // Support for multiple index - auto index = indexTensors[0]; - auto indexType = dyn_cast(index.getType()); - auto indexShape = indexType.getShape(); - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } - - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indicesShape; - for (auto shape : indexShape) { - indicesShape.push_back(shape); + // Expand last dim of index to tf indices [2,3] -> [2,3,1] + SmallVector indicesShape; + for (auto shape : indexShape) { + indicesShape.push_back(shape); + } + indicesShape.push_back(1); + indicesTf = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, + rewriter.getDenseI64ArrayAttr(indicesShape)); } - indicesShape.push_back(1); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); if (!indicesTf) { return rewriter.notifyMatchFailure(op, @@ -4349,7 +4362,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // do the tf gathernp algorithm with tf style indices as input. auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + indicesTf->getResult(0)); if (!result) { return rewriter.notifyMatchFailure( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e370a1d8b73d..82ca24443162 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1698,7 +1698,6 @@ "ArangeStartOutModule_basic", "ScatterSrcStaticModule_basic", # Runtime op verification: Out of bounds access - "IndexTensorNegativeIndexModule_basic", "ReduceAllDimEmpty_basic", } @@ -1706,7 +1705,6 @@ "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -2162,6 +2160,7 @@ "HardswishRandomModule_basic", "HardtanhBackward_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorNegativeIndexModule_basic", "IndexTensorStaticModule_basic", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", @@ -3635,7 +3634,6 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index e412bb390c35..ed6f909c4a1b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> return %0 : !torch.vtensor<[2,3,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> +// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64> + +func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> + }