Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Bugfix for InlineScalarOperands #111534

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Oct 8, 2024

This PR fixes a bug where scalarOperand is a simple scalar and should be used directly, rather than accessed via tensor.extract. Fixes #111243.

This PR fixes a bug where `scalarOperand` is a simple scalar and should be
used directly, rather than accessed via `tensor.extract`.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a bug where scalarOperand is a simple scalar and should be used directly, rather than accessed via tensor.extract. Fixes #111243.


Full diff: https://github.com/llvm/llvm-project/pull/111534.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp (+5-3)
  • (modified) mlir/test/Dialect/Linalg/inline-scalar-operands.mlir (+24)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 6db51f4b84d112..a8b46905733b8c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -78,9 +78,11 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
       for (auto idx : indices)
         indicesValues.emplace_back(
             rewriter.create<arith::ConstantIndexOp>(loc, idx));
-      Value extractedValue = rewriter.create<tensor::ExtractOp>(
-          loc, opOperand->get(), indicesValues);
-      body->getArgument(idx).replaceAllUsesWith(extractedValue);
+      Value scalarValue = opOperand->get();
+      if (isa<RankedTensorType>(scalarValue.getType()))
+        scalarValue =
+            rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
+      body->getArgument(idx).replaceAllUsesWith(scalarValue);
       body->eraseArgument(idx);
     }
 
diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
index 93d5b8779c7461..8384b307d2dfbd 100644
--- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
     } -> tensor<4xf32>
   return %1 : tensor<4xf32>
 }
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
+func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
+    %0 = tensor.empty() : tensor<4xf32>
+    // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+    // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+    %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+                         iterator_types = ["parallel"]}
+                         ins(%arg0, %scalar : tensor<4xf32>, f32)
+                         outs(%0 : tensor<4xf32>) {
+    // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      // CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
+      %2 = arith.divf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Crash when using --linalg-inline-scalar-operands
2 participants