diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index 1a60d1cc284ea..4928533b75fe5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -744,15 +744,15 @@ static uint64_t getFlattenedIndex(ShapedType type, ArrayRef index) { static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, ShapedType rhsType, ValueRange rhsDynamicDims) { - if (lhsType.hasStaticShape() && rhsType.hasStaticShape() && - lhsType == rhsType) { + if (lhsType.hasStaticShape() && rhsType.hasStaticShape()) { // Static shape equivalence means we can fast-path the check. - return true; + return lhsType == rhsType; } if (lhsType.getRank() != rhsType.getRank()) { return false; } unsigned dynamicDimIndex = 0; + unsigned numNonmatchingSSADims = 0; for (unsigned i = 0; i < lhsType.getRank(); ++i) { if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) { // Static/dynamic dimension mismatch - definitely differ. @@ -760,8 +760,7 @@ static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, } else if (lhsType.isDynamicDim(i)) { unsigned j = dynamicDimIndex++; if (lhsDynamicDims[j] != rhsDynamicDims[j]) { - // Dynamic dimensions with different SSA values - probably differ. - return false; + numNonmatchingSSADims++; } } else { if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) { @@ -770,7 +769,7 @@ static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, } } } - return true; + return numNonmatchingSSADims < 2; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index 959e398533c69..516d0d974a44a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -112,7 +112,7 @@ util.func public @reshapeNoOpDynamic(%arg0: tensor<4x?xf32>, %dim: index) -> ten // CHECK-LABEL: @reshapeDynamicDifferent util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index) -> tensor<4x?xf32> { - // CHECK-NEXT: flow.tensor.reshape %arg0 + // CHECK-NEXT: util.return %arg0 : tensor<4x?xf32> %0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1} util.return %0 : tensor<4x?xf32> } @@ -123,10 +123,9 @@ util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index, // CHECK-SAME: %[[ARG:.+]]: tensor<4x?xf32>, // CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index util.func public @flattenReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> { - // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<4x?xf32>{%[[DIM0]]} -> tensor<4x?xf32>{%[[DIM2]]} %0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1} %1 = flow.tensor.reshape %0 : tensor<4x?xf32>{%dim1} -> tensor<4x?xf32>{%dim2} - // CHECK-NEXT: util.return %[[RET]] + // CHECK-NEXT: util.return %[[ARG]] util.return %1 : tensor<4x?xf32> }