diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp index 2de30bbdc6..5fb414a859 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -247,14 +247,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { }); // Finally reshape 'outputDataBuffer' to the shape of the output. - DimsExpr newOutputShape; - for (int64_t dim : outputShape) { - LiteralIndexExpr outputDim(dim); - newOutputShape.emplace_back(outputDim); - } - - Value reshapedOutput = - create.mem.reinterpretCast(outputDataBuffer, newOutputShape); + Value reshapedOutput = emitMemRefReinterpretCastOp( + rewriter, loc, data, shapeHelper.getOutputDims(), convertedType); LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n"); rewriter.replaceOp(op, reshapedOutput); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp index dd82884f5c..cd4ef79d6a 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp @@ -96,7 +96,7 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern { Value indexVal = createKrnl.loadIE(indices, indicesAccessFct); IndexExpr index = NonAffineIndexExpr(indexVal); outputAccessFct.emplace_back(index); - } else { + } else if (i < loopInd.size() - 1) { IndexExpr index = SymbolIndexExpr(loopInd[i]); outputAccessFct.emplace_back(index); }