diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 1c31880011c5..9067b7e24665 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Type outElementType); -FailureOr getBroadcastResultShape(PatternRewriter &rewriter, - Operation *op, ArrayRef tensors, - size_t dimSizeIndexBits); +FailureOr>> +getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, + ArrayRef tensors, size_t dimSizeIndexBits); Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, TensorType outType, diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index dc8289b713b2..d764e9040252 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -220,16 +220,10 @@ namespace { FailureOr broadcastAndConcatIndices(Operation *op, ConversionPatternRewriter &rewriter, SmallVector indexTensors, - llvm::ArrayRef inputShape, size_t dimSizeIndexBits, int &maxIndexRank) { // Step 1: broadcast indices tensors - SmallVector indicesShape; - SmallVector expandShape; - SmallVector concatShape; - bool allIndexStaticShape = true; - Value bcastSizeTensor; // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { @@ -242,20 +236,15 @@ FailureOr broadcastAndConcatIndices(Operation *op, maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank()); } - if (!allIndexStaticShape) { - auto bcastSizeTensorInfo = hlo::getBroadcastResultShape( - rewriter, op, indexTensors, dimSizeIndexBits); - if (failed(bcastSizeTensorInfo)) { - return failure(); - } - bcastSizeTensor = *bcastSizeTensorInfo; - } - - for (int i = 0; i < maxIndexRank; i++) { - indicesShape.push_back(inputShape[i]); - expandShape.push_back(inputShape[i]); - concatShape.push_back(inputShape[i]); + auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors, + dimSizeIndexBits); + if (failed(bcastSizeInfo)) { + return failure(); } + Value bcastSizeTensor = (*bcastSizeInfo).first; + auto indicesShape = (*bcastSizeInfo).second; + SmallVector expandShape(indicesShape.begin(), indicesShape.end()); + SmallVector concatShape(indicesShape.begin(), indicesShape.end()); expandShape.push_back(1); concatShape.push_back(indexTensors.size()); @@ -890,9 +879,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto gatherIndicesInfo = - broadcastAndConcatIndices(op, rewriter, indexTensors, outShape, - options.dimSizeIndexBits, maxIndexRank); + auto gatherIndicesInfo = broadcastAndConcatIndices( + op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank); if (failed(gatherIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); @@ -949,6 +937,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = cast(getTypeConverter()->convertType(op.getType())); auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + auto inputRank = inputType.getRank(); auto valuesType = cast(values.getType()); int64_t valueRank = valuesType.getRank(); auto valuesShape = valuesType.getShape(); @@ -968,15 +958,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto scatterIndicesInfo = - broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape, - options.dimSizeIndexBits, maxIndexRank); + auto scatterIndicesInfo = broadcastAndConcatIndices( + op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank); if (failed(scatterIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); } auto scatterIndices = *scatterIndicesInfo; + // unsqueeze values to handle absent dimensions of size 1. + llvm::ArrayRef scatterIndicesShape = + (cast(scatterIndices.getType())).getShape(); + SmallVector expectedValuesShape( + scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank); + for (int64_t i = indexCnt; i < inputRank; i++) { + expectedValuesShape.push_back(inputShape[i]); + } + int64_t expectedValuesShapeIdx = expectedValuesShape.size() - 1; + int64_t valuesShapeIdx = valuesShape.size() - 1; + SmallVector unsqzDims; + while (expectedValuesShapeIdx >= 0 && valuesShapeIdx >= 0) { + if (valuesShape[valuesShapeIdx] == + expectedValuesShape[expectedValuesShapeIdx]) { + expectedValuesShapeIdx--; + valuesShapeIdx--; + } else if (expectedValuesShape[expectedValuesShapeIdx] == 1) { + unsqzDims.push_back(expectedValuesShapeIdx); + expectedValuesShapeIdx--; + } else { + return rewriter.notifyMatchFailure(op, + "invalid values argument provided"); + } + } + if (valuesShapeIdx >= 0) { + return rewriter.notifyMatchFailure(op, "invalid values argument provided"); + } + while (expectedValuesShapeIdx >= 0) { + unsqzDims.push_back(expectedValuesShapeIdx); + expectedValuesShapeIdx--; + } + + if (!unsqzDims.empty()) { + std::reverse(unsqzDims.begin(), unsqzDims.end()); + auto newValuesInfo = hlo::unsqueezeTensor(rewriter, op, values, unsqzDims); + if (failed(newValuesInfo)) { + return rewriter.notifyMatchFailure(op, + "invalid values argument provided"); + } + values = *newValuesInfo; + valuesType = cast(values.getType()); + valueRank = valuesType.getRank(); + valuesShape = valuesType.getShape(); + } + // create stablehlo::ScatterOp int64_t indexVecDim = maxIndexRank; SmallVector scatterDimOperandDimMap; @@ -1216,9 +1250,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, SmallVector indexTensors{Nidx, CIdx, idxY, idxX}; int maxIndexRank = -1; - auto gatherIndicesInfo = broadcastAndConcatIndices( - input.getDefiningOp(), rewriter, indexTensors, outType.getShape(), - dimSizeIndexBits, maxIndexRank); + auto gatherIndicesInfo = + broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors, + dimSizeIndexBits, maxIndexRank); auto gatherIndices = *gatherIndicesInfo; int64_t numIndicesDim = indexTensors.size(); int64_t indexVecDim = maxIndexRank; diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 8b2ec2ed53fe..7e263cecba27 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { return getDimIndexOfTensor(rewriter, op, value, dims); } -FailureOr getBroadcastResultShape(PatternRewriter &rewriter, - Operation *op, ArrayRef tensors, - size_t dimSizeIndexBits) { +FailureOr>> +getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, + ArrayRef tensors, size_t dimSizeIndexBits) { SmallVector> tensorSizes; int maxRank = 0; @@ -337,10 +337,11 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, } SmallVector bcastSizeTensors; + SmallVector bcastSizes; for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions. int dynamicDimCnt = 0; int staticDimCnt = 0; - int64_t staticDimSize; + int64_t dimSize; Value dimSizeTensor = rewriter.create( op->getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); @@ -357,6 +358,7 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, if (tensorSizes[i][inDim] == ShapedType::kDynamic || tensorSizes[i][inDim] == kUnknownSize) { dynamicDimCnt++; + dimSize = ShapedType::kDynamic; auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); if (failed(dimSizeTensorInfo)) { @@ -371,12 +373,12 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, return failure(); } // we already found static dim size not equal with this, fail. - if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) { + if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) { return failure(); } staticDimCnt++; - staticDimSize = tensorSizes[i][inDim]; + dimSize = tensorSizes[i][inDim]; auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); if (failed(dimSizeTensorInfo)) { @@ -389,12 +391,14 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, // if (dynamicDimCnt > 1) { // return failure(); // } - + bcastSizes.push_back(dimSize); bcastSizeTensors.push_back(dimSizeTensor); } std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); - return rewriter.create(op->getLoc(), bcastSizeTensors) - .getResult(); + return std::pair>( + rewriter.create(op->getLoc(), bcastSizeTensors) + .getResult(), + bcastSizes); } FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op,