Skip to content

Commit

Permalink
Fixing GatherND/ScatteND with dynamic-shape indices issues.
Browse files Browse the repository at this point in the history
Signed-off-by: Yasushi Negishi <[email protected]>
  • Loading branch information
negiyas committed Sep 6, 2023
1 parent 49233b0 commit db80e45
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
});

// Finally reshape 'outputDataBuffer' to the shape of the output.
DimsExpr newOutputShape;
for (int64_t dim : outputShape) {
LiteralIndexExpr outputDim(dim);
newOutputShape.emplace_back(outputDim);
}

Value reshapedOutput =
create.mem.reinterpretCast(outputDataBuffer, newOutputShape);
Value reshapedOutput = emitMemRefReinterpretCastOp(
rewriter, loc, data, shapeHelper.getOutputDims(), convertedType);
LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n");

rewriter.replaceOp(op, reshapedOutput);
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern<ONNXScatterNDOp> {
Value indexVal = createKrnl.loadIE(indices, indicesAccessFct);
IndexExpr index = NonAffineIndexExpr(indexVal);
outputAccessFct.emplace_back(index);
} else {
} else if (i < loopInd.size() - 1) {
IndexExpr index = SymbolIndexExpr(loopInd[i]);
outputAccessFct.emplace_back(index);
}
Expand Down

0 comments on commit db80e45

Please sign in to comment.