diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp index 5fb414a859..32629284af 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -247,8 +247,18 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { }); // Finally reshape 'outputDataBuffer' to the shape of the output. - Value reshapedOutput = emitMemRefReinterpretCastOp( - rewriter, loc, data, shapeHelper.getOutputDims(), convertedType); + DimsExpr newOutputShape; + for (int64_t dim : outputShape) { + if ( dim > 0) { + LiteralIndexExpr outputDim(dim); + newOutputShape.emplace_back(outputDim); + } else { + newOutputShape.emplace_back(QuestionmarkIndexExpr(/*isFloat*/ false)); + } + } + + Value reshapedOutput = + create.mem.reinterpretCast(outputDataBuffer, newOutputShape); LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n"); rewriter.replaceOp(op, reshapedOutput);