Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflat…
…tenIntOp (llvm#3730) Fixes iree-org/iree#18562. During canonicalization pass on `AtenUnflattenIntOp`, if the second dim was statically equal to one, we would create an `AtenAddIntOp` to add one to the dimension obtained from `op.getDim()`. This, when passed into `Torch::unsqueezeTensor()`, would make it get interpreted as non-constant, which would lead to MLIR failing an assertion when `UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the output of the `UnsqueezeOp` would consist of only dynamic dims. This patch fixes this behavior, by extracting the integer value from the dim if it was constant, and then emitting a `ConstantIntOp` from (dim+1). This creates an output with static shape.
- Loading branch information