Skip to content

Commit

Permalink
[Im2col] Generate matmuls with expanded H and W dims
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Oct 9, 2024
1 parent 598a60e commit 22cb8aa
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::LinalgExt {
Expand Down Expand Up @@ -125,22 +124,22 @@ class ConvertConv2DNhwcHwcf final

auto loc = convOp.getLoc();

SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
SmallVector<int64_t> colTensorShape = {n, oh, ow, fh * fw * ic};

SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
auto reshapedOutputType =
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());

Value colTensor = rewriter.create<tensor::EmptyOp>(
loc, colTensorShape, inputType.getElementType());
SmallVector<int64_t> strides(convOp.getStrides().getValues<int64_t>());
SmallVector<int64_t> dilations(convOp.getDilations().getValues<int64_t>());
SmallVector<OpFoldResult> kernelSize = {rewriter.getIndexAttr(fh),
rewriter.getIndexAttr(fw)};
SmallVector<OpFoldResult> mOffset = {rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> kBasis = {rewriter.getIndexAttr(1)};
OpFoldResult zero = rewriter.getIndexAttr(0);
OpFoldResult one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> mOffset = {zero, zero};
SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(ow), one};
SmallVector<OpFoldResult> kOffset = {zero};
SmallVector<OpFoldResult> kBasis = {one};
SmallVector<int64_t> batchPos = {0};
SmallVector<int64_t> mPos = {1, 2};
SmallVector<int64_t> kPos = {3};
Expand All @@ -158,22 +157,21 @@ class ConvertConv2DNhwcHwcf final
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedFilterType, filter, filterReassocIndices);

Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedOutputType, output, outputReassocIndices);

AffineExpr bDim, mDim, nDim, kDim;
bindDims(getContext(), bDim, mDim, nDim, kDim);
auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, getContext());
auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, getContext());
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, getContext());
AffineExpr bDim, m0Dim, m1Dim, nDim, kDim;
bindDims(getContext(), bDim, m0Dim, m1Dim, nDim, kDim);
auto lhsMap =
AffineMap::get(5, 0, {bDim, m0Dim, m1Dim, kDim}, getContext());
auto rhsMap = AffineMap::get(5, 0, {kDim, nDim}, getContext());
auto resultMap =
AffineMap::get(5, 0, {bDim, m0Dim, m1Dim, nDim}, getContext());
auto parallel = utils::IteratorType::parallel;
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
parallel, reduction};
SmallVector<utils::IteratorType> genericIterators = {
parallel, parallel, parallel, parallel, reduction};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, reshapedOutputType,
loc, outputType,
/*inputs=*/ValueRange{img2ColTensor, reshapedFilter},
/*outputs=*/ValueRange{reshapedOutput},
/*outputs=*/ValueRange{output},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
[](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value lhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[0],
Expand All @@ -188,10 +186,7 @@ class ConvertConv2DNhwcHwcf final
});
Value result = genericOp.getResults().front();

auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, result, outputReassocIndices);

rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
rewriter.replaceOp(convOp, ArrayRef<Value>{result});

return success();
}
Expand Down Expand Up @@ -254,18 +249,20 @@ class ConvertConv2DNchwFchw final

auto loc = convOp.getLoc();

SmallVector<int64_t> colTensorShape = {n, oh * ow, ic * fh * fw};
SmallVector<int64_t> colTensorShape = {n, oh, ow, fh * fw * ic};

Value colTensor = rewriter.create<tensor::EmptyOp>(
loc, colTensorShape, inputType.getElementType());
SmallVector<int64_t> strides(convOp.getStrides().getValues<int64_t>());
SmallVector<int64_t> dilations(convOp.getDilations().getValues<int64_t>());
SmallVector<OpFoldResult> kernelSize = {rewriter.getIndexAttr(fh),
rewriter.getIndexAttr(fw)};
SmallVector<OpFoldResult> mOffset = {rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> kBasis = {rewriter.getIndexAttr(1)};
OpFoldResult zero = rewriter.getIndexAttr(0);
OpFoldResult one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> mOffset = {zero, zero};
SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(ow), one};
SmallVector<OpFoldResult> kOffset = {zero};
SmallVector<OpFoldResult> kBasis = {one};
SmallVector<int64_t> batchPos = {0};
SmallVector<int64_t> mPos = {2, 3};
SmallVector<int64_t> kPos = {1};
Expand All @@ -282,26 +279,21 @@ class ConvertConv2DNchwFchw final
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedFilterType, filter, filterReassocIndices);

SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
RankedTensorType reshapedOutputType =
RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());

Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedOutputType, output, outputReassocIndices);

AffineExpr bDim, mDim, nDim, kDim;
bindDims(getContext(), bDim, mDim, nDim, kDim);
auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, getContext());
auto rhsMap = AffineMap::get(4, 0, {bDim, nDim, kDim}, getContext());
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, getContext());
AffineExpr bDim, mDim, nDim0, nDim1, kDim;
bindDims(getContext(), bDim, mDim, nDim0, nDim1, kDim);
auto lhsMap = AffineMap::get(5, 0, {mDim, kDim}, getContext());
auto rhsMap =
AffineMap::get(5, 0, {bDim, nDim0, nDim1, kDim}, getContext());
auto resultMap =
AffineMap::get(5, 0, {bDim, mDim, nDim0, nDim1}, getContext());
auto parallel = utils::IteratorType::parallel;
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
parallel, reduction};
SmallVector<utils::IteratorType> genericIterators = {
parallel, parallel, parallel, parallel, reduction};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, reshapedOutputType,
loc, outputType,
/*inputs=*/ValueRange{reshapedFilter, img2ColTensor},
/*outputs=*/ValueRange{reshapedOutput},
/*outputs=*/ValueRange{output},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
[](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value lhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[0],
Expand All @@ -316,10 +308,7 @@ class ConvertConv2DNchwFchw final
});
Value result = genericOp.getResults().front();

auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
loc, outputType, result, outputReassocIndices);

rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
rewriter.replaceOp(convOp, ArrayRef<Value>{result});

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,30 @@ util.func public @conv_2d_nhwc_hwcf(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
util.return %0 : tensor<1x14x14x16xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
// CHECK: util.func public @conv_2d_nhwc_hwcf(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x4x16xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x14x14x16xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x36xf32>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: m_offset = [0, 0] * [14, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x14x14x36xf32>) -> tensor<1x14x14x36xf32>
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
// CHECK-DAG: %[[COLLAPSED_0:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
// CHECK-SAME: outs(%[[COLLAPSED_0]] : tensor<1x196x16xf32>) {
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x36xf32>, tensor<36x16xf32>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<1x14x14x16xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: } -> tensor<1x196x16xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: util.return %[[EXPANDED]] : tensor<1x14x14x16xf32>
// CHECK: } -> tensor<1x14x14x16xf32>
// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32>

// -----

Expand All @@ -43,32 +41,30 @@ util.func public @conv_2d_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<
outs(%arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
util.return %0 : tensor<1x16x14x14xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
// CHECK: util.func public @conv_2d_nchw_fchw(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x16x16xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<16x4x3x3xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x16x14x14xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x36xf32>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: m_offset = [0, 0] * [14, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [2, 3] k_pos = [1]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x4x16x16xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x14x14x36xf32>) -> tensor<1x14x14x36xf32>
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
// CHECK-DAG: %[[COLLAPSED_0:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1], [2, 3]] : tensor<1x16x14x14xf32> into tensor<1x16x196xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : tensor<16x36xf32>, tensor<1x196x36xf32>)
// CHECK-SAME: outs(%[[COLLAPSED_0]] : tensor<1x16x196xf32>) {
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : tensor<16x36xf32>, tensor<1x14x14x36xf32>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<1x16x14x14xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: } -> tensor<1x16x196xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [1, 16, 14, 14] : tensor<1x16x196xf32> into tensor<1x16x14x14xf32>
// CHECK: util.return %[[EXPANDED]] : tensor<1x16x14x14xf32>
// CHECK: } -> tensor<1x16x14x14xf32>
// CHECK: util.return %[[MATMUL]] : tensor<1x16x14x14xf32>

// -----

Expand All @@ -79,34 +75,32 @@ util.func public @conv_mixed_types(%arg0: tensor<1x16x16x4xf16>, %arg1: tensor<3
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
util.return %0 : tensor<1x14x14x16xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
// CHECK: util.func public @conv_mixed_types(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x4x16xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x14x14x16xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x36xf16>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: m_offset = [0, 0] * [14, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf16>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x196x36xf16>) -> tensor<1x196x36xf16>
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x14x14x36xf16>) -> tensor<1x14x14x36xf16>
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf16> into tensor<36x16xf16>
// CHECK-DAG: %[[COLLAPSED_0:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x196x36xf16>, tensor<36x16xf16>)
// CHECK-SAME: outs(%[[COLLAPSED_0]] : tensor<1x196x16xf32>) {
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x36xf16>, tensor<36x16xf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<1x14x14x16xf32>) {
// CHECK: arith.extf
// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: } -> tensor<1x196x16xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: util.return %[[EXPANDED]] : tensor<1x14x14x16xf32>
// CHECK: } -> tensor<1x14x14x16xf32>
// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32>

// -----

Expand All @@ -117,31 +111,29 @@ util.func public @conv_strided(%arg0: tensor<1x16x16x4xf16>, %arg1: tensor<3x3x4
outs(%arg2: tensor<1x7x7x16xf32>) -> tensor<1x7x7x16xf32>
util.return %0 : tensor<1x7x7x16xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
// CHECK: util.func public @conv_strided(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x4x16xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x7x7x16xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x49x36xf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x7x7x36xf16>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: m_offset = [0, 0] * [7, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf16>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x49x36xf16>) -> tensor<1x49x36xf16>
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x7x7x36xf16>) -> tensor<1x7x7x36xf16>
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf16> into tensor<36x16xf16>
// CHECK-DAG: %[[COLLAPSED_0:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2], [3]] : tensor<1x7x7x16xf32> into tensor<1x49x16xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x49x36xf16>, tensor<36x16xf16>)
// CHECK-SAME: outs(%[[COLLAPSED_0]] : tensor<1x49x16xf32>) {
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x7x7x36xf16>, tensor<36x16xf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<1x7x7x16xf32>) {
// CHECK: arith.extf
// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: } -> tensor<1x49x16xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [1, 7, 7, 16] : tensor<1x49x16xf32> into tensor<1x7x7x16xf32>
// CHECK: util.return %[[EXPANDED]] : tensor<1x7x7x16xf32>
// CHECK: } -> tensor<1x7x7x16xf32>
// CHECK: util.return %[[MATMUL]] : tensor<1x7x7x16xf32>

0 comments on commit 22cb8aa

Please sign in to comment.