-
Notifications
You must be signed in to change notification settings - Fork 11.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][memref] Fix normalization issue in memref.load #107771
Conversation
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: None (DarshanRamakant) ChangesThis change will fix the normalization issue with Full diff: https://github.com/llvm/llvm-project/pull/107771.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..70bfb322932346 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
#include <optional>
#define DEBUG_TYPE "affine-utils"
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// is set.
return failure();
}
- op->setOperand(memRefOperandPos, newMemRef);
+
+ // Check if it is a memref.load
+ auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+ bool isReductionLike =
+ indexRemap.getNumResults() < indexRemap.getNumInputs();
+ if (!memrefLoad || !isReductionLike) {
+ op->setOperand(memRefOperandPos, newMemRef);
+ return success();
+ }
+
+ unsigned oldMapNumInputs = oldMemRefRank;
+ SmallVector<Value, 4> oldMapOperands(
+ op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+ SmallVector<Value, 4> oldMemRefOperands;
+ oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+ SmallVector<Value, 4> remapOperands;
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+ symbolOperands.size());
+ remapOperands.append(extraOperands.begin(), extraOperands.end());
+ remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+ remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+ SmallVector<Value, 4> remapOutputs;
+ remapOutputs.reserve(oldMemRefRank);
+ SmallVector<Value, 4> affineApplyOps;
+
+ if (indexRemap &&
+ indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+ // Remapped indices.
+ for (auto resultExpr : indexRemap.getResults()) {
+ auto singleResMap = AffineMap::get(
+ indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ remapOperands);
+ remapOutputs.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ // No remapping specified.
+ remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+ }
+
+ SmallVector<Value, 4> newMapOperands;
+ newMapOperands.reserve(newMemRefRank);
+
+ // Prepend 'extraIndices' in 'newMapOperands'.
+ for (Value extraIndex : extraIndices) {
+ assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+ "invalid memory op index");
+ newMapOperands.push_back(extraIndex);
+ }
+
+ // Append 'remapOutputs' to 'newMapOperands'.
+ newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+ // Create new fully composed AffineMap for new op to be created.
+ assert(newMapOperands.size() == newMemRefRank);
+
+ OperationState state(op->getLoc(), op->getName());
+ // Construct the new operation using this memref.
+ state.operands.reserve(newMapOperands.size() + extraIndices.size());
+ state.operands.push_back(newMemRef);
+
+ // Insert the new memref map operands.
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+ state.types.reserve(op->getNumResults());
+ for (auto result : op->getResults())
+ state.types.push_back(result.getType());
+
+ // Add attribute for 'newMap', other Attributes do not change.
+ // auto newMapAttr = AffineMapAttr::get(newMap);
+ for (auto namedAttr : op->getAttrs()) {
+ state.attributes.push_back(namedAttr);
+ }
+
+ // Create the new operation.
+ auto *repOp = builder.create(state);
+ op->replaceAllUsesWith(repOp);
+ op->erase();
+
return success();
}
// Perform index rewrites for the dereferencing op and then replace the op
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..ca485f9fddbc8d 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
%1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
return %1 : tensor<16x512xf32>
}
+
+#map0 = affine_map<(i,k) -> (2*(i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8*(k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4*(j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4*i+j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () {
+ %0 = memref.alloc() : memref<4x8xf32,#map0>
+ %1 = memref.alloc() : memref<8x4xf32,#map1>
+ %2 = memref.alloc() : memref<4x4xf32,#map2>
+ // CHECK-NOT: memref<4x8xf32>
+ // CHECK-NOT: memref<8x4xf32>
+ // CHECK-NOT: memref<4x4xf32>
+ %cst = arith.constant 3.0 : f32
+ %cst0 = arith.constant 0 : index
+ affine.for %i = 0 to 4 {
+ affine.for %j = 0 to 8 {
+ affine.for %k = 0 to 8 {
+ // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+ // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
+ %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+ %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+ %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+ %3 = arith.mulf %a, %b : f32
+ %4 = arith.addf %3, %c : f32
+ affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+ }
+ }
+ }
+ return
+}
\ No newline at end of file
|
@llvm/pr-subscribers-mlir-memref Author: None (DarshanRamakant) ChangesThis change will fix the normalization issue with Full diff: https://github.com/llvm/llvm-project/pull/107771.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..70bfb322932346 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
#include <optional>
#define DEBUG_TYPE "affine-utils"
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// is set.
return failure();
}
- op->setOperand(memRefOperandPos, newMemRef);
+
+ // Check if it is a memref.load
+ auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+ bool isReductionLike =
+ indexRemap.getNumResults() < indexRemap.getNumInputs();
+ if (!memrefLoad || !isReductionLike) {
+ op->setOperand(memRefOperandPos, newMemRef);
+ return success();
+ }
+
+ unsigned oldMapNumInputs = oldMemRefRank;
+ SmallVector<Value, 4> oldMapOperands(
+ op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+ SmallVector<Value, 4> oldMemRefOperands;
+ oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+ SmallVector<Value, 4> remapOperands;
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+ symbolOperands.size());
+ remapOperands.append(extraOperands.begin(), extraOperands.end());
+ remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+ remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+ SmallVector<Value, 4> remapOutputs;
+ remapOutputs.reserve(oldMemRefRank);
+ SmallVector<Value, 4> affineApplyOps;
+
+ if (indexRemap &&
+ indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+ // Remapped indices.
+ for (auto resultExpr : indexRemap.getResults()) {
+ auto singleResMap = AffineMap::get(
+ indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ remapOperands);
+ remapOutputs.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ // No remapping specified.
+ remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+ }
+
+ SmallVector<Value, 4> newMapOperands;
+ newMapOperands.reserve(newMemRefRank);
+
+ // Prepend 'extraIndices' in 'newMapOperands'.
+ for (Value extraIndex : extraIndices) {
+ assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+ "invalid memory op index");
+ newMapOperands.push_back(extraIndex);
+ }
+
+ // Append 'remapOutputs' to 'newMapOperands'.
+ newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+ // Create new fully composed AffineMap for new op to be created.
+ assert(newMapOperands.size() == newMemRefRank);
+
+ OperationState state(op->getLoc(), op->getName());
+ // Construct the new operation using this memref.
+ state.operands.reserve(newMapOperands.size() + extraIndices.size());
+ state.operands.push_back(newMemRef);
+
+ // Insert the new memref map operands.
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+ state.types.reserve(op->getNumResults());
+ for (auto result : op->getResults())
+ state.types.push_back(result.getType());
+
+ // Add attribute for 'newMap', other Attributes do not change.
+ // auto newMapAttr = AffineMapAttr::get(newMap);
+ for (auto namedAttr : op->getAttrs()) {
+ state.attributes.push_back(namedAttr);
+ }
+
+ // Create the new operation.
+ auto *repOp = builder.create(state);
+ op->replaceAllUsesWith(repOp);
+ op->erase();
+
return success();
}
// Perform index rewrites for the dereferencing op and then replace the op
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..ca485f9fddbc8d 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
%1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
return %1 : tensor<16x512xf32>
}
+
+#map0 = affine_map<(i,k) -> (2*(i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8*(k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4*(j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4*i+j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () {
+ %0 = memref.alloc() : memref<4x8xf32,#map0>
+ %1 = memref.alloc() : memref<8x4xf32,#map1>
+ %2 = memref.alloc() : memref<4x4xf32,#map2>
+ // CHECK-NOT: memref<4x8xf32>
+ // CHECK-NOT: memref<8x4xf32>
+ // CHECK-NOT: memref<4x4xf32>
+ %cst = arith.constant 3.0 : f32
+ %cst0 = arith.constant 0 : index
+ affine.for %i = 0 to 4 {
+ affine.for %j = 0 to 8 {
+ affine.for %k = 0 to 8 {
+ // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+ // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
+ %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+ %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+ %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+ %3 = arith.mulf %a, %b : f32
+ %4 = arith.addf %3, %c : f32
+ affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+ }
+ }
+ }
+ return
+}
\ No newline at end of file
|
9a52249
to
f8611e5
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
f8611e5
to
5bfbe62
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have left one comment to improve the test coverage. If that point is improved, I think it looks good to me overall.
@@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, | |||
op->erase(); | |||
} | |||
|
|||
// Private helper function to transform memref.load with reduced rank. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
affine.for %j = 0 to 8 { | ||
affine.for %k = 0 to 8 { | ||
// CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}}) | ||
// CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can know the normalization does not fail with this test. But it would be better to check the calculation of the affine_map is kept precisely as well as we do with affine.load
if possible.
How about checking affine.apply
is calling the same logic as follows to ensure the consistency?
%0 = affine.load %alloc[%arg1 * 2 + %arg3 + (%arg3 floordiv 2) * 6] : memref<32xf32>
%1 = affine.load %alloc_0[%arg3 + %arg2 * 2 + (%arg3 floordiv 2) * 6] : memref<32xf32>
%2 = affine.load %alloc_1[%arg1 * 4 + %arg2] : memref<16xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output IR for the test after the transform will be like below
`
#map = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
#map1 = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
#map2 = affine_map<(d0, d1) -> (d0 * 4 + d1)>
module {
func.func @memref_load_with_reduction_map(%arg0: memref<16xf32>) {
%alloc = memref.alloc() : memref<32xf32>
%alloc_0 = memref.alloc() : memref<32xf32>
%alloc_1 = memref.alloc() : memref<16xf32>
%cst = arith.constant 3.000000e+00 : f32
%c0 = arith.constant 0 : index
affine.for %arg1 = 0 to 4 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 8 {
%0 = affine.apply #map(%arg1, %arg3)
%1 = memref.load %alloc[%0] : memref<32xf32>
%2 = affine.apply #map1(%arg3, %arg2)
%3 = memref.load %alloc_0[%2] : memref<32xf32>
%4 = affine.apply #map2(%arg1, %arg2)
%5 = memref.load %alloc_1[%4] : memref<16xf32>
%6 = arith.mulf %1, %3 : f32
%7 = arith.addf %6, %5 : f32
affine.store %7, %arg0[%arg1 * 4 + %arg2] : memref<16xf32>
}
}
}
return
}`
The affine.apply() is using the original maps. Do you want to me put those map names while matching ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the test more stricter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late response. What I meant was to match the affine map logic itself. So it would be great if we can match the map names.
Is it possible to match with the code something like this? It checks the affine map correspondense by its name.
// CHECK-DAG: #[[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
...
// CHECK: %[[INDEX0:.*]] = affine.apply #[[$MAP1]]{{.*}}(%{{.*}}, %{{.*}})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have already tried to match like this with the below code :
#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))> #map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))> #map2 = affine_map<(i,j) -> (4 * i + j)> // CHECK-LABEL: func @memref_load_with_reduction_map // CHECK-DAG : #[[MAP1:map[0-9]+]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
But it fails to match. I guess this is because of the "CHECK-LABEL" match previously, which may start scan after this match.
I also tried to move this match at the beginning of the file, but agin, after the "CHECK-LABEL" command, the previous matched variable definitions are getting destroyed.
Only way I can do this is by creating a new test for this case. Should I create a new test file ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is because of the "CHECK-LABEL" match previously, which may start scan after this match.
Ah, I see. That makes sense. Thank you for trying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DarshanRamakant
I experimented to make the test pass somehow.
I found we could put the affine map definition on top of the test file and use them in the function after the CHECK-LABEL
. That's because the affine map is defined at the top of the module, I guess.
// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
...
// CHECK-LABEL: func @memref_load_with_reduction_map
func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () {
%0 = memref.alloc() : memref<4x8xf32,#map0>
%1 = memref.alloc() : memref<8x4xf32,#map1>
%2 = memref.alloc() : memref<4x4xf32,#map2>
// CHECK-NOT: memref<4x8xf32>
// CHECK-NOT: memref<8x4xf32>
// CHECK-NOT: memref<4x4xf32>
%cst = arith.constant 3.0 : f32
%cst0 = arith.constant 0 : index
affine.for %i = 0 to 4 {
affine.for %j = 0 to 8 {
affine.for %k = 0 to 8 {
// CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
// CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
%a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
// CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
// CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
%b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
// CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
// CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
%c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
%3 = arith.mulf %a, %b : f32
%4 = arith.addf %3, %c : f32
affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
}
}
}
return
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the change I tried to check the affine map definition in the spec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking time and giving me this fix ! I could also verify that this passes. I have modified my patch with the change now. Thanks again.
a03d990
to
d1e7b10
Compare
This change will fix the normalization issue with memref.load when the associated affine map is reducing the dimension. This PR fixes llvm#82675
fbe266c
to
48e081b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for updating the patch accordingly. It looks good to me now!
Thanks @Lewuathe . Can you please merge this PR ? |
@DarshanRamakant Merged. Thank you for creating the patch to fix this issue! |
This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.
This PR fixes #82675