Skip to content

Commit

Permalink
[Torch] enhance fold of aten.slice.Tensor (llvm#3557)
Browse files Browse the repository at this point in the history
so that it could support folding slice with any static shape.
  • Loading branch information
qingyunqu authored Jul 23, 2024
1 parent 7884642 commit 21ad890
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
47 changes: 29 additions & 18 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3625,42 +3625,53 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(outType.toBuiltinTensor(),
input.getSplatValue<Attribute>());

int count = 1;
int64_t count = 1;
for (auto dim : outType.getSizes())
count = count * dim;

if (count == 0)
return {};
return nullptr;

if (!dim)
return nullptr;
int64_t dimInt = dim.getValue().getSExtValue();
if (dimInt < 0)
dimInt += inType.getSizes().size();

bool unaryNonDim = true;
for (int i = 0, s = outType.getSizes().size(); i < s; ++i)
unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt;

// Fold the slice if the output tensor is relatively small, currently
// coded to 16:
if (input && start && step && dim && count < 16 && unaryNonDim &&
count < 16) {
int64_t inCount = input.getNumElements();
constexpr int64_t kMaxFold = 16;
if (input && start && step && dim && count <= kMaxFold) {
int64_t begin = start.getValue().getSExtValue();
int64_t limit = end.getValue().getSExtValue();
int64_t stride = step.getValue().getSExtValue();
if (stride < 1)
return {};
int64_t limit = end.getValue().getSExtValue();
begin = begin < 0 ? begin + inCount : begin;
limit = limit < 0 ? limit + inCount : limit;
limit = limit < 0 ? inType.getSizes()[dimInt] : limit;
return nullptr;
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
limit = std::min(limit, inType.getSizes()[dimInt]);

llvm::SmallVector<Attribute> values;
for (int i = begin; i < limit; i += stride)
values.push_back(input.getValues<Attribute>()[i]);
int64_t inputRank = inType.getSizes().size();
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
for (int64_t i = inputRank - 2; i >= 0; i--) {
inputStrides[i] = inputStrides[i + 1] * inType.getSizes()[i + 1];
}

llvm::SmallVector<Attribute> values;
values.reserve(count);
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
if (currDim >= inputRank)
return;
size_t _begin = (currDim == dimInt) ? begin : 0;
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
size_t _stride = (currDim == dimInt) ? stride : 1;
for (size_t i = _begin; i < _limit; i += _stride) {
if (currDim == inputRank - 1) {
values.push_back(input.getValues<Attribute>()[currOffset + i]);
}
self(self, currDim + 1, currOffset + inputStrides[currDim] * i);
}
};
recursiveIter(recursiveIter, 0, 0);
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
}

Expand Down
27 changes: 21 additions & 6 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2139,15 +2139,15 @@ func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> {

// -----

// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
// CHECK-LABEL: @torch.aten.slice.tensor$not_fold_slice
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> {
// CHECK: torch.aten.slice.Tensor
func.func @torch.aten.slice.tensor$not_fold_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%int0 = torch.constant.int 0
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
return %0 : !torch.vtensor<[4],f32>
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], f32>
return %0 : !torch.vtensor<[3],f32>
}

// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice
Expand Down Expand Up @@ -2209,7 +2209,10 @@ func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) {
}

// -----

// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK: %[[CST0:.+]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK: return %[[CST]], %[[CST0]]
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
%int0 = torch.constant.int 0
Expand All @@ -2224,6 +2227,18 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> !torch.vtensor<[4,1],si64> {
// CHECK{LITERAL}: %0 = torch.vtensor.literal(dense<[[28], [14], [7], [4]]> : tensor<4x1xsi64>) : !torch.vtensor<[4,1],si64>
// CHECK: return %0
func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> (!torch.vtensor<[4,1],si64>) {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.vtensor.literal(dense<[[28, 28], [14, 14], [7, 7], [4, 4]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
%1 = torch.aten.slice.Tensor %0, %int1, %int1, %int2, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64>
return %1 : !torch.vtensor<[4,1],si64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
Expand Down

0 comments on commit 21ad890

Please sign in to comment.