Skip to content

Commit

Permalink
Fix backend test errors for onnx.GatherND lowering.
Browse files Browse the repository at this point in the history
Signed-off-by: Yasushi Negishi <[email protected]>
  • Loading branch information
negiyas committed Sep 14, 2023
1 parent cc4b350 commit 3644104
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,18 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
});

// Finally reshape 'outputDataBuffer' to the shape of the output.
Value reshapedOutput = emitMemRefReinterpretCastOp(
rewriter, loc, data, shapeHelper.getOutputDims(), convertedType);
DimsExpr newOutputShape;
for (int64_t dim : outputShape) {
if ( dim > 0) {
LiteralIndexExpr outputDim(dim);
newOutputShape.emplace_back(outputDim);
} else {
newOutputShape.emplace_back(QuestionmarkIndexExpr(/*isFloat*/ false));
}
}

Value reshapedOutput =
create.mem.reinterpretCast(outputDataBuffer, newOutputShape);
LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n");

rewriter.replaceOp(op, reshapedOutput);
Expand Down

0 comments on commit 3644104

Please sign in to comment.