Skip to content

Commit

Permalink
[mlir][linalg] Bugfix for InlineScalarOperands
Browse files Browse the repository at this point in the history
This PR fixes a bug where `scalarOperand` is a simple scalar and should be
used directly, rather than accessed via `tensor.extract`.
  • Loading branch information
CoTinker committed Oct 8, 2024
1 parent 6c331e5 commit 4050324
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
for (auto idx : indices)
indicesValues.emplace_back(
rewriter.create<arith::ConstantIndexOp>(loc, idx));
Value extractedValue = rewriter.create<tensor::ExtractOp>(
loc, opOperand->get(), indicesValues);
body->getArgument(idx).replaceAllUsesWith(extractedValue);
Value scalarValue = opOperand->get();
if (isa<RankedTensorType>(scalarValue.getType()))
scalarValue =
rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
body->getArgument(idx).replaceAllUsesWith(scalarValue);
body->eraseArgument(idx);
}

Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}

// -----

// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
#map2 = affine_map<(d0) -> (d0)>
#map3 = affine_map<(d0) -> ()>

// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
%0 = tensor.empty() : tensor<4xf32>
// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
// CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
%1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
iterator_types = ["parallel"]}
ins(%arg0, %scalar : tensor<4xf32>, f32)
outs(%0 : tensor<4xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
// CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
%2 = arith.divf %arg1, %arg2 : f32
linalg.yield %2 : f32
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}

0 comments on commit 4050324

Please sign in to comment.