Skip to content

Commit

Permalink
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflat…
Browse files Browse the repository at this point in the history
…tenIntOp (llvm#3730)

Fixes iree-org/iree#18562.

During canonicalization pass on `AtenUnflattenIntOp`, if the second dim
was statically equal to one, we would create an `AtenAddIntOp` to add
one to the dimension obtained from `op.getDim()`. This, when passed into
`Torch::unsqueezeTensor()`, would make it get interpreted as
non-constant, which would lead to MLIR failing an assertion when
`UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the
output of the `UnsqueezeOp` would consist of only dynamic dims.

This patch fixes this behavior, by extracting the integer value from the
dim if it was constant, and then emitting a `ConstantIntOp` from
(dim+1). This creates an output with static shape.
  • Loading branch information
vinayakdsci authored Sep 24, 2024
1 parent e4f2bdf commit 6773288
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,9 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
if (dim0 != 1 && dim1 != 1)
return failure();
Value unflattenDim = op.getDim();
int64_t dimAsInt;
bool dimWasConstant =
matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt));
Value self = op.getSelf();
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
// the runtime asserts below are introduced to catch malformed unflatten ops
Expand Down Expand Up @@ -2217,9 +2220,22 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
}
if (dim1 == 1) {
// unsqueeze at dim + 1
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
Value dimPlusOne =
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
Value dimPlusOne;
if (!dimWasConstant) {
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
dimPlusOne =
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
} else {
// If dim was constant, creating an AtenAddIntOp will make
// Torch::unsqueezeTensor() interpret it as still not being a constant,
// and the resultant shape would consist of only dynamic dims. To fix
// this, emit a ConstantIntOp for (dim + 1) to avoid an assertion
// failure, when AtenUnsqueezeOp is in a later pass converted to
// ExpandShapeOp, which is bound to fail shape inference in MLIR if
// output dims are dynamic.
dimPlusOne = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1));
}
FailureOr<Value> maybeUnsqueeze =
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
if (failed(maybeUnsqueeze))
Expand Down

0 comments on commit 6773288

Please sign in to comment.