Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][memref] Fix normalization issue in memref.load #107771

Merged
merged 4 commits into from
Oct 9, 2024

Conversation

DarshanRamakant
Copy link
Contributor

@DarshanRamakant DarshanRamakant commented Sep 8, 2024

This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.
This PR fixes #82675

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 8, 2024

@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: None (DarshanRamakant)

Changes

This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.


Full diff: https://github.com/llvm/llvm-project/pull/107771.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+83-1)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+30)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..70bfb322932346 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 #define DEBUG_TYPE "affine-utils"
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
       // is set.
       return failure();
     }
-    op->setOperand(memRefOperandPos, newMemRef);
+
+    // Check if it is a memref.load
+    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+    bool isReductionLike =
+        indexRemap.getNumResults() < indexRemap.getNumInputs();
+    if (!memrefLoad || !isReductionLike) {
+      op->setOperand(memRefOperandPos, newMemRef);
+      return success();
+    }
+
+    unsigned oldMapNumInputs = oldMemRefRank;
+    SmallVector<Value, 4> oldMapOperands(
+        op->operand_begin() + memRefOperandPos + 1,
+        op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+    SmallVector<Value, 4> oldMemRefOperands;
+    oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+    SmallVector<Value, 4> remapOperands;
+    remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+                          symbolOperands.size());
+    remapOperands.append(extraOperands.begin(), extraOperands.end());
+    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+    remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+    SmallVector<Value, 4> remapOutputs;
+    remapOutputs.reserve(oldMemRefRank);
+    SmallVector<Value, 4> affineApplyOps;
+
+    if (indexRemap &&
+        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+      // Remapped indices.
+      for (auto resultExpr : indexRemap.getResults()) {
+        auto singleResMap = AffineMap::get(
+            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+        auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                  remapOperands);
+        remapOutputs.push_back(afOp);
+        affineApplyOps.push_back(afOp);
+      }
+    } else {
+      // No remapping specified.
+      remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+    }
+
+    SmallVector<Value, 4> newMapOperands;
+    newMapOperands.reserve(newMemRefRank);
+
+    // Prepend 'extraIndices' in 'newMapOperands'.
+    for (Value extraIndex : extraIndices) {
+      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+             "invalid memory op index");
+      newMapOperands.push_back(extraIndex);
+    }
+
+    // Append 'remapOutputs' to 'newMapOperands'.
+    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+    // Create new fully composed AffineMap for new op to be created.
+    assert(newMapOperands.size() == newMemRefRank);
+
+    OperationState state(op->getLoc(), op->getName());
+    // Construct the new operation using this memref.
+    state.operands.reserve(newMapOperands.size() + extraIndices.size());
+    state.operands.push_back(newMemRef);
+
+    // Insert the new memref map operands.
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+    state.types.reserve(op->getNumResults());
+    for (auto result : op->getResults())
+      state.types.push_back(result.getType());
+
+    // Add attribute for 'newMap', other Attributes do not change.
+    // auto newMapAttr = AffineMapAttr::get(newMap);
+    for (auto namedAttr : op->getAttrs()) {
+      state.attributes.push_back(namedAttr);
+    }
+
+    // Create the new operation.
+    auto *repOp = builder.create(state);
+    op->replaceAllUsesWith(repOp);
+    op->erase();
+
     return success();
   }
   // Perform index rewrites for the dereferencing op and then replace the op
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..ca485f9fddbc8d 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
   %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
   return %1 : tensor<16x512xf32>
 }
+
+#map0 = affine_map<(i,k) -> (2*(i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8*(k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4*(j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4*i+j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
+  %0 = memref.alloc() : memref<4x8xf32,#map0>
+  %1 = memref.alloc() : memref<8x4xf32,#map1>
+  %2 = memref.alloc() : memref<4x4xf32,#map2>
+  // CHECK-NOT:  memref<4x8xf32>
+  // CHECK-NOT:  memref<8x4xf32>
+  // CHECK-NOT:  memref<4x4xf32>
+  %cst = arith.constant 3.0 : f32
+  %cst0 = arith.constant 0 : index
+  affine.for %i = 0 to 4 {
+    affine.for %j = 0 to 8 {
+      affine.for %k = 0 to 8 {
+        // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
+        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+        %3 = arith.mulf %a, %b : f32
+        %4 = arith.addf %3, %c : f32
+        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+      }
+    }
+  }
+  return
+}
\ No newline at end of file

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 8, 2024

@llvm/pr-subscribers-mlir-memref

Author: None (DarshanRamakant)

Changes

This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.


Full diff: https://github.com/llvm/llvm-project/pull/107771.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+83-1)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+30)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..70bfb322932346 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 #define DEBUG_TYPE "affine-utils"
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
       // is set.
       return failure();
     }
-    op->setOperand(memRefOperandPos, newMemRef);
+
+    // Check if it is a memref.load
+    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+    bool isReductionLike =
+        indexRemap.getNumResults() < indexRemap.getNumInputs();
+    if (!memrefLoad || !isReductionLike) {
+      op->setOperand(memRefOperandPos, newMemRef);
+      return success();
+    }
+
+    unsigned oldMapNumInputs = oldMemRefRank;
+    SmallVector<Value, 4> oldMapOperands(
+        op->operand_begin() + memRefOperandPos + 1,
+        op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+    SmallVector<Value, 4> oldMemRefOperands;
+    oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+    SmallVector<Value, 4> remapOperands;
+    remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+                          symbolOperands.size());
+    remapOperands.append(extraOperands.begin(), extraOperands.end());
+    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+    remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+    SmallVector<Value, 4> remapOutputs;
+    remapOutputs.reserve(oldMemRefRank);
+    SmallVector<Value, 4> affineApplyOps;
+
+    if (indexRemap &&
+        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+      // Remapped indices.
+      for (auto resultExpr : indexRemap.getResults()) {
+        auto singleResMap = AffineMap::get(
+            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+        auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                  remapOperands);
+        remapOutputs.push_back(afOp);
+        affineApplyOps.push_back(afOp);
+      }
+    } else {
+      // No remapping specified.
+      remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+    }
+
+    SmallVector<Value, 4> newMapOperands;
+    newMapOperands.reserve(newMemRefRank);
+
+    // Prepend 'extraIndices' in 'newMapOperands'.
+    for (Value extraIndex : extraIndices) {
+      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+             "invalid memory op index");
+      newMapOperands.push_back(extraIndex);
+    }
+
+    // Append 'remapOutputs' to 'newMapOperands'.
+    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+    // Create new fully composed AffineMap for new op to be created.
+    assert(newMapOperands.size() == newMemRefRank);
+
+    OperationState state(op->getLoc(), op->getName());
+    // Construct the new operation using this memref.
+    state.operands.reserve(newMapOperands.size() + extraIndices.size());
+    state.operands.push_back(newMemRef);
+
+    // Insert the new memref map operands.
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+    state.types.reserve(op->getNumResults());
+    for (auto result : op->getResults())
+      state.types.push_back(result.getType());
+
+    // Add attribute for 'newMap', other Attributes do not change.
+    // auto newMapAttr = AffineMapAttr::get(newMap);
+    for (auto namedAttr : op->getAttrs()) {
+      state.attributes.push_back(namedAttr);
+    }
+
+    // Create the new operation.
+    auto *repOp = builder.create(state);
+    op->replaceAllUsesWith(repOp);
+    op->erase();
+
     return success();
   }
   // Perform index rewrites for the dereferencing op and then replace the op
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..ca485f9fddbc8d 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
   %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
   return %1 : tensor<16x512xf32>
 }
+
+#map0 = affine_map<(i,k) -> (2*(i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8*(k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4*(j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4*i+j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
+  %0 = memref.alloc() : memref<4x8xf32,#map0>
+  %1 = memref.alloc() : memref<8x4xf32,#map1>
+  %2 = memref.alloc() : memref<4x4xf32,#map2>
+  // CHECK-NOT:  memref<4x8xf32>
+  // CHECK-NOT:  memref<8x4xf32>
+  // CHECK-NOT:  memref<4x4xf32>
+  %cst = arith.constant 3.0 : f32
+  %cst0 = arith.constant 0 : index
+  affine.for %i = 0 to 4 {
+    affine.for %j = 0 to 8 {
+      affine.for %k = 0 to 8 {
+        // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
+        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+        %3 = arith.mulf %a, %b : f32
+        %4 = arith.addf %3, %c : f32
+        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+      }
+    }
+  }
+  return
+}
\ No newline at end of file

mlir/lib/Dialect/Affine/Utils/Utils.cpp Show resolved Hide resolved
mlir/lib/Dialect/Affine/Utils/Utils.cpp Outdated Show resolved Hide resolved
mlir/test/Dialect/MemRef/normalize-memrefs.mlir Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Sep 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have left one comment to improve the test coverage. If that point is improved, I think it looks good to me overall.

@@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
op->erase();
}

// Private helper function to transform memref.load with reduced rank.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

affine.for %j = 0 to 8 {
affine.for %k = 0 to 8 {
// CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
// CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can know the normalization does not fail with this test. But it would be better to check the calculation of the affine_map is kept precisely as well as we do with affine.load if possible.

How about checking affine.apply is calling the same logic as follows to ensure the consistency?

          %0 = affine.load %alloc[%arg1 * 2 + %arg3 + (%arg3 floordiv 2) * 6] : memref<32xf32>
          %1 = affine.load %alloc_0[%arg3 + %arg2 * 2 + (%arg3 floordiv 2) * 6] : memref<32xf32>
          %2 = affine.load %alloc_1[%arg1 * 4 + %arg2] : memref<16xf32>

https://godbolt.org/z/bPhYGrhKa

Copy link
Contributor Author

@DarshanRamakant DarshanRamakant Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output IR for the test after the transform will be like below
`
#map = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>

#map1 = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>

#map2 = affine_map<(d0, d1) -> (d0 * 4 + d1)>

module {
func.func @memref_load_with_reduction_map(%arg0: memref<16xf32>) {
%alloc = memref.alloc() : memref<32xf32>
%alloc_0 = memref.alloc() : memref<32xf32>
%alloc_1 = memref.alloc() : memref<16xf32>
%cst = arith.constant 3.000000e+00 : f32
%c0 = arith.constant 0 : index
affine.for %arg1 = 0 to 4 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 8 {
%0 = affine.apply #map(%arg1, %arg3)
%1 = memref.load %alloc[%0] : memref<32xf32>
%2 = affine.apply #map1(%arg3, %arg2)
%3 = memref.load %alloc_0[%2] : memref<32xf32>
%4 = affine.apply #map2(%arg1, %arg2)
%5 = memref.load %alloc_1[%4] : memref<16xf32>
%6 = arith.mulf %1, %3 : f32
%7 = arith.addf %6, %5 : f32
affine.store %7, %arg0[%arg1 * 4 + %arg2] : memref<16xf32>
}
}
}
return
}`

The affine.apply() is using the original maps. Do you want to me put those map names while matching ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the test more stricter

Copy link
Member

@Lewuathe Lewuathe Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response. What I meant was to match the affine map logic itself. So it would be great if we can match the map names.

Is it possible to match with the code something like this? It checks the affine map correspondense by its name.

// CHECK-DAG: #[[$MAP1:#map[0-9]+]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>

...

// CHECK: %[[INDEX0:.*]] = affine.apply #[[$MAP1]]{{.*}}(%{{.*}}, %{{.*}})

Copy link
Contributor Author

@DarshanRamakant DarshanRamakant Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already tried to match like this with the below code :
#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))> #map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))> #map2 = affine_map<(i,j) -> (4 * i + j)> // CHECK-LABEL: func @memref_load_with_reduction_map // CHECK-DAG : #[[MAP1:map[0-9]+]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
But it fails to match. I guess this is because of the "CHECK-LABEL" match previously, which may start scan after this match.
I also tried to move this match at the beginning of the file, but agin, after the "CHECK-LABEL" command, the previous matched variable definitions are getting destroyed.
Only way I can do this is by creating a new test for this case. Should I create a new test file ?

Copy link
Member

@Lewuathe Lewuathe Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is because of the "CHECK-LABEL" match previously, which may start scan after this match.

Ah, I see. That makes sense. Thank you for trying.

Copy link
Member

@Lewuathe Lewuathe Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarshanRamakant
I experimented to make the test pass somehow.

I found we could put the affine map definition on top of the test file and use them in the function after the CHECK-LABEL. That's because the affine map is defined at the top of the module, I guess.

// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>

...

// CHECK-LABEL: func @memref_load_with_reduction_map
func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
  %0 = memref.alloc() : memref<4x8xf32,#map0>
  %1 = memref.alloc() : memref<8x4xf32,#map1>
  %2 = memref.alloc() : memref<4x4xf32,#map2>
  // CHECK-NOT:  memref<4x8xf32>
  // CHECK-NOT:  memref<8x4xf32>
  // CHECK-NOT:  memref<4x4xf32>
  %cst = arith.constant 3.0 : f32
  %cst0 = arith.constant 0 : index
  affine.for %i = 0 to 4 {
    affine.for %j = 0 to 8 {
      affine.for %k = 0 to 8 {
        // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
        // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
        // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
        // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
        // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
        // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
        %3 = arith.mulf %a, %b : f32
        %4 = arith.addf %3, %c : f32
        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
      }
    }
  }
  return
}

Copy link
Member

@Lewuathe Lewuathe Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the change I tried to check the affine map definition in the spec.

Lewuathe@aa77879

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking time and giving me this fix ! I could also verify that this passes. I have modified my patch with the change now. Thanks again.

@DarshanRamakant DarshanRamakant force-pushed the dev/NormalizeMemref branch 3 times, most recently from a03d990 to d1e7b10 Compare September 22, 2024 13:03
This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.
This PR fixes llvm#82675
Copy link
Member

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for updating the patch accordingly. It looks good to me now!

@DarshanRamakant
Copy link
Contributor Author

Thank you for updating the patch accordingly. It looks good to me now!

Thanks @Lewuathe . Can you please merge this PR ?

@Lewuathe Lewuathe merged commit aabddc9 into llvm:main Oct 9, 2024
8 checks passed
@Lewuathe
Copy link
Member

Lewuathe commented Oct 9, 2024

@DarshanRamakant Merged. Thank you for creating the patch to fix this issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR][memref] The Normalizing Memref on removing extra input arguments onmemref.load
3 participants