Skip to content

Commit

Permalink
[stablehlo] fix: enhance torch's index-like op lowering to stablehloe…
Browse files Browse the repository at this point in the history
…'s gather/scatter
  • Loading branch information
Vremold committed Oct 30, 2024
1 parent 8b0bf2e commit c65996c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
Type outElementType);

FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
Operation *op, ArrayRef<Value> tensors,
size_t dimSizeIndexBits);
FailureOr<std::pair<Value, SmallVector<int64_t>>>
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
ArrayRef<Value> tensors, size_t dimSizeIndexBits);

Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType,
Expand Down
90 changes: 62 additions & 28 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,10 @@ namespace {
FailureOr<Value> broadcastAndConcatIndices(Operation *op,
ConversionPatternRewriter &rewriter,
SmallVector<Value> indexTensors,
llvm::ArrayRef<int64_t> inputShape,
size_t dimSizeIndexBits,
int &maxIndexRank) {
// Step 1: broadcast indices tensors
SmallVector<int64_t> indicesShape;
SmallVector<int64_t> expandShape;
SmallVector<int64_t> concatShape;

bool allIndexStaticShape = true;
Value bcastSizeTensor;

// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
Expand All @@ -242,20 +236,15 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
}

if (!allIndexStaticShape) {
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape(
rewriter, op, indexTensors, dimSizeIndexBits);
if (failed(bcastSizeTensorInfo)) {
return failure();
}
bcastSizeTensor = *bcastSizeTensorInfo;
}

for (int i = 0; i < maxIndexRank; i++) {
indicesShape.push_back(inputShape[i]);
expandShape.push_back(inputShape[i]);
concatShape.push_back(inputShape[i]);
auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors,
dimSizeIndexBits);
if (failed(bcastSizeInfo)) {
return failure();
}
Value bcastSizeTensor = (*bcastSizeInfo).first;
auto indicesShape = (*bcastSizeInfo).second;
SmallVector<int64_t> expandShape(indicesShape.begin(), indicesShape.end());
SmallVector<int64_t> concatShape(indicesShape.begin(), indicesShape.end());
expandShape.push_back(1);
concatShape.push_back(indexTensors.size());

Expand Down Expand Up @@ -890,9 +879,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
indicesTorchType);

int maxIndexRank = -1;
auto gatherIndicesInfo =
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape,
options.dimSizeIndexBits, maxIndexRank);
auto gatherIndicesInfo = broadcastAndConcatIndices(
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
if (failed(gatherIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices");
Expand Down Expand Up @@ -949,6 +937,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
auto inputRank = inputType.getRank();
auto valuesType = cast<RankedTensorType>(values.getType());
int64_t valueRank = valuesType.getRank();
auto valuesShape = valuesType.getShape();
Expand All @@ -968,15 +958,59 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
indicesTorchType);

int maxIndexRank = -1;
auto scatterIndicesInfo =
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape,
options.dimSizeIndexBits, maxIndexRank);
auto scatterIndicesInfo = broadcastAndConcatIndices(
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
if (failed(scatterIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices");
}
auto scatterIndices = *scatterIndicesInfo;

// unsqueeze values to handle absent dimensions of size 1.
llvm::ArrayRef<int64_t> scatterIndicesShape =
(cast<RankedTensorType>(scatterIndices.getType())).getShape();
SmallVector<int64_t> expectedValuesShape(
scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank);
for (int64_t i = indexCnt; i < inputRank; i++) {
expectedValuesShape.push_back(inputShape[i]);
}
int64_t expectedValuesShapeIdx = expectedValuesShape.size() - 1;
int64_t valuesShapeIdx = valuesShape.size() - 1;
SmallVector<int64_t> unsqzDims;
while (expectedValuesShapeIdx >= 0 && valuesShapeIdx >= 0) {
if (valuesShape[valuesShapeIdx] ==
expectedValuesShape[expectedValuesShapeIdx]) {
expectedValuesShapeIdx--;
valuesShapeIdx--;
} else if (expectedValuesShape[expectedValuesShapeIdx] == 1) {
unsqzDims.push_back(expectedValuesShapeIdx);
expectedValuesShapeIdx--;
} else {
return rewriter.notifyMatchFailure(op,
"invalid values argument provided");
}
}
if (valuesShapeIdx >= 0) {
return rewriter.notifyMatchFailure(op, "invalid values argument provided");
}
while (expectedValuesShapeIdx >= 0) {
unsqzDims.push_back(expectedValuesShapeIdx);
expectedValuesShapeIdx--;
}

if (!unsqzDims.empty()) {
std::reverse(unsqzDims.begin(), unsqzDims.end());
auto newValuesInfo = hlo::unsqueezeTensor(rewriter, op, values, unsqzDims);
if (failed(newValuesInfo)) {
return rewriter.notifyMatchFailure(op,
"invalid values argument provided");
}
values = *newValuesInfo;
valuesType = cast<RankedTensorType>(values.getType());
valueRank = valuesType.getRank();
valuesShape = valuesType.getShape();
}

// create stablehlo::ScatterOp
int64_t indexVecDim = maxIndexRank;
SmallVector<int64_t> scatterDimOperandDimMap;
Expand Down Expand Up @@ -1216,9 +1250,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};

int maxIndexRank = -1;
auto gatherIndicesInfo = broadcastAndConcatIndices(
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(),
dimSizeIndexBits, maxIndexRank);
auto gatherIndicesInfo =
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
dimSizeIndexBits, maxIndexRank);
auto gatherIndices = *gatherIndicesInfo;
int64_t numIndicesDim = indexTensors.size();
int64_t indexVecDim = maxIndexRank;
Expand Down
22 changes: 13 additions & 9 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
return getDimIndexOfTensor(rewriter, op, value, dims);
}

FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
Operation *op, ArrayRef<Value> tensors,
size_t dimSizeIndexBits) {
FailureOr<std::pair<Value, SmallVector<int64_t>>>
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
ArrayRef<Value> tensors, size_t dimSizeIndexBits) {
SmallVector<ArrayRef<int64_t>> tensorSizes;

int maxRank = 0;
Expand All @@ -337,10 +337,11 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
}

SmallVector<Value> bcastSizeTensors;
SmallVector<int64_t> bcastSizes;
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
int dynamicDimCnt = 0;
int staticDimCnt = 0;
int64_t staticDimSize;
int64_t dimSize;
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
Expand All @@ -357,6 +358,7 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
tensorSizes[i][inDim] == kUnknownSize) {
dynamicDimCnt++;
dimSize = ShapedType::kDynamic;
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) {
Expand All @@ -371,12 +373,12 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
return failure();
}
// we already found static dim size not equal with this, fail.
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) {
if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) {
return failure();
}

staticDimCnt++;
staticDimSize = tensorSizes[i][inDim];
dimSize = tensorSizes[i][inDim];
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) {
Expand All @@ -389,12 +391,14 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
// if (dynamicDimCnt > 1) {
// return failure();
// }

bcastSizes.push_back(dimSize);
bcastSizeTensors.push_back(dimSizeTensor);
}
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
.getResult();
return std::pair<Value, SmallVector<int64_t>>(
rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
.getResult(),
bcastSizes);
}

FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down

0 comments on commit c65996c

Please sign in to comment.