Skip to content

Commit

Permalink
[TorchToLinalg] Use linalg.broadcast instead of generic for conv …
Browse files Browse the repository at this point in the history
…bias (llvm#3661)

The current implementation uses a `linalg.generic` to broadcast the bias
tensor for the lowering of convolutions. This is suboptimal for later
pattern matching. This patch changes it to use the respective named op,
`linalg.broadcast`, instead.
  • Loading branch information
ubfx authored Aug 26, 2024
1 parent fa39d91 commit 638ef14
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,21 +1080,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");

auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
SmallVector<AffineMap> indexingMaps = {
// bias is used to initialize the channels - dimension 1 of output
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
rewriter.getMultiDimIdentityMap(resultRank)};
SmallVector<utils::IteratorType> iteratorTypes(
resultRank, utils::IteratorType::parallel);
SmallVector<int64_t, 4> addedDimensions;
// bias is used to initialize the channels - dimension 1 of
// output
for (int i = 0; i < resultRank; ++i)
if (i != 1)
addedDimensions.push_back(i);
outputTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
.create<linalg::BroadcastOp>(loc, bias, initTensor,
addedDimensions)
->getResult(0);
}

auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
Expand Down

0 comments on commit 638ef14

Please sign in to comment.