diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index 1a60d1cc284e..6930906d54d4 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -744,15 +744,15 @@ static uint64_t getFlattenedIndex(ShapedType type, ArrayRef index) { static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, ShapedType rhsType, ValueRange rhsDynamicDims) { - if (lhsType.hasStaticShape() && rhsType.hasStaticShape() && - lhsType == rhsType) { + if (lhsType.hasStaticShape() && rhsType.hasStaticShape()) { // Static shape equivalence means we can fast-path the check. - return true; + return lhsType == rhsType; } if (lhsType.getRank() != rhsType.getRank()) { return false; } unsigned dynamicDimIndex = 0; + unsigned numNonmatchingSSADims = 0; for (unsigned i = 0; i < lhsType.getRank(); ++i) { if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) { // Static/dynamic dimension mismatch - definitely differ. @@ -760,8 +760,7 @@ static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, } else if (lhsType.isDynamicDim(i)) { unsigned j = dynamicDimIndex++; if (lhsDynamicDims[j] != rhsDynamicDims[j]) { - // Dynamic dimensions with different SSA values - probably differ. - return false; + numNonmatchingSSADims++; } } else { if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) { @@ -770,7 +769,7 @@ static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, } } } - return true; + return numNonmatchingSSADims <= 1; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index 959e398533c6..1559c3cbf3b2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -112,26 +112,38 @@ util.func public @reshapeNoOpDynamic(%arg0: tensor<4x?xf32>, %dim: index) -> ten // CHECK-LABEL: @reshapeDynamicDifferent util.func public @reshapeDynamicDifferent(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index) -> tensor<4x?xf32> { - // CHECK-NEXT: flow.tensor.reshape %arg0 + // CHECK-NEXT: util.return %arg0 : tensor<4x?xf32> %0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1} util.return %0 : tensor<4x?xf32> } // ----- -// CHECK-LABEL: @flattenReshapeChain +// CHECK-LABEL: @foldReshapeChain // CHECK-SAME: %[[ARG:.+]]: tensor<4x?xf32>, // CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index -util.func public @flattenReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> { - // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor<4x?xf32>{%[[DIM0]]} -> tensor<4x?xf32>{%[[DIM2]]} +util.func public @foldReshapeChain(%arg0: tensor<4x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<4x?xf32> { %0 = flow.tensor.reshape %arg0 : tensor<4x?xf32>{%dim0} -> tensor<4x?xf32>{%dim1} %1 = flow.tensor.reshape %0 : tensor<4x?xf32>{%dim1} -> tensor<4x?xf32>{%dim2} - // CHECK-NEXT: util.return %[[RET]] + // CHECK-NEXT: util.return %[[ARG]] util.return %1 : tensor<4x?xf32> } // ----- +// CHECK-LABEL: @flattenReshapeChain +// CHECK-SAME: %[[ARG:.+]]: tensor, +// CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index, %[[DIM3:.+]]: index, %[[DIM4:.+]]: index, %[[DIM5:.+]]: index +util.func public @flattenReshapeChain(%arg0: tensor, %dim0: index, %dim1: index, %dim2: index, %dim3 : index, %dim4 : index, %dim5 : index) -> tensor { + // CHECK-NEXT: %[[RET:.+]] = flow.tensor.reshape %[[ARG]] : tensor{%[[DIM0]], %[[DIM1]]} -> tensor{%[[DIM4]], %[[DIM5]]} + %0 = flow.tensor.reshape %arg0 : tensor{%dim0, %dim1} -> tensor{%dim2, %dim3} + %1 = flow.tensor.reshape %0 : tensor{%dim2, %dim3} -> tensor{%dim4, %dim5} + // CHECK-NEXT: util.return %[[RET]] + util.return %1 : tensor +} + +// ----- + // CHECK-LABEL: @flattenReshapeBitcastChain // CHECK-SAME: %[[ARG:.+]]: tensor<4x?xi16>, // CHECK-SAME: %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index