Skip to content

Commit

Permalink
Fold flow reshape with mismatching dyn dims
Browse files Browse the repository at this point in the history
`flow.tensor.reshape` is foldable when there is a single mismatching
dynamic dimension (unequal SSA values) because they must be the same.
  • Loading branch information
IanWood1 committed Oct 3, 2024
1 parent d341128 commit 51c7a8b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
11 changes: 5 additions & 6 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,24 +744,23 @@ static uint64_t getFlattenedIndex(ShapedType type, ArrayRef<uint64_t> index) {

static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
ShapedType rhsType, ValueRange rhsDynamicDims) {
if (lhsType.hasStaticShape() && rhsType.hasStaticShape() &&
lhsType == rhsType) {
if (lhsType.hasStaticShape() && rhsType.hasStaticShape()) {
// Static shape equivalence means we can fast-path the check.
return true;
return lhsType == rhsType;
}
if (lhsType.getRank() != rhsType.getRank()) {
return false;
}
unsigned dynamicDimIndex = 0;
unsigned numNonmatchingSSADims = 0;
for (unsigned i = 0; i < lhsType.getRank(); ++i) {
if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) {
// Static/dynamic dimension mismatch - definitely differ.
return false;
} else if (lhsType.isDynamicDim(i)) {
unsigned j = dynamicDimIndex++;
if (lhsDynamicDims[j] != rhsDynamicDims[j]) {
// Dynamic dimensions with different SSA values - probably differ.
return false;
numNonmatchingSSADims++;
}
} else {
if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) {
Expand All @@ -770,7 +769,7 @@ static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
}
}
}
return true;
return numNonmatchingSSADims < 2;
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ util.func public @reshapeNoOpDynamic(%arg0: tensor<4x?xf32>, %dim: index) -> ten

// CHECK-LABEL: @reshapeDynamicDifferent
util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index) -> tensor<4x?xf32> {
// CHECK-NEXT: flow.tensor.reshape %arg0
// CHECK-NEXT: util.return %arg0 : tensor<4x?xf32>
%0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1}
util.return %0 : tensor<4x?xf32>
}
Expand All @@ -123,10 +123,9 @@ util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index,
// CHECK-SAME: %[[ARG:.+]]: tensor<4x?xf32>,
// CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index
util.func public @flattenReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> {
// CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<4x?xf32>{%[[DIM0]]} -> tensor<4x?xf32>{%[[DIM2]]}
%0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1}
%1 = flow.tensor.reshape %0 : tensor<4x?xf32>{%dim1} -> tensor<4x?xf32>{%dim2}
// CHECK-NEXT: util.return %[[RET]]
// CHECK-NEXT: util.return %[[ARG]]
util.return %1 : tensor<4x?xf32>
}

Expand Down

0 comments on commit 51c7a8b

Please sign in to comment.