diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 422883914cd5..0b561744062e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3625,12 +3625,11 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { return DenseElementsAttr::get(outType.toBuiltinTensor(), input.getSplatValue()); - int count = 1; + int64_t count = 1; for (auto dim : outType.getSizes()) count = count * dim; - if (count == 0) - return {}; + return nullptr; if (!dim) return nullptr; @@ -3638,29 +3637,41 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { if (dimInt < 0) dimInt += inType.getSizes().size(); - bool unaryNonDim = true; - for (int i = 0, s = outType.getSizes().size(); i < s; ++i) - unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt; - // Fold the slice if the output tensor is relatively small, currently // coded to 16: - if (input && start && step && dim && count < 16 && unaryNonDim && - count < 16) { - int64_t inCount = input.getNumElements(); + constexpr int64_t kMaxFold = 16; + if (input && start && step && dim && count <= kMaxFold) { int64_t begin = start.getValue().getSExtValue(); + int64_t limit = end.getValue().getSExtValue(); int64_t stride = step.getValue().getSExtValue(); if (stride < 1) - return {}; - int64_t limit = end.getValue().getSExtValue(); - begin = begin < 0 ? begin + inCount : begin; - limit = limit < 0 ? limit + inCount : limit; - limit = limit < 0 ? inType.getSizes()[dimInt] : limit; + return nullptr; + begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin; + limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit; limit = std::min(limit, inType.getSizes()[dimInt]); - llvm::SmallVector values; - for (int i = begin; i < limit; i += stride) - values.push_back(input.getValues()[i]); + int64_t inputRank = inType.getSizes().size(); + llvm::SmallVector inputStrides(inputRank, 1); + for (int64_t i = inputRank - 2; i >= 0; i--) { + inputStrides[i] = inputStrides[i + 1] * inType.getSizes()[i + 1]; + } + llvm::SmallVector values; + values.reserve(count); + auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) { + if (currDim >= inputRank) + return; + size_t _begin = (currDim == dimInt) ? begin : 0; + size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim]; + size_t _stride = (currDim == dimInt) ? stride : 1; + for (size_t i = _begin; i < _limit; i += _stride) { + if (currDim == inputRank - 1) { + values.push_back(input.getValues()[currOffset + i]); + } + self(self, currDim + 1, currOffset + inputStrides[currDim] * i); + } + }; + recursiveIter(recursiveIter, 0, 0); return DenseElementsAttr::get(outType.toBuiltinTensor(), values); } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index aa943a5a1e5a..f0b8ff3e8662 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2139,15 +2139,15 @@ func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> { // ----- -// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice +// CHECK-LABEL: @torch.aten.slice.tensor$not_fold_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> -// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> -func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> { +// CHECK: torch.aten.slice.Tensor +func.func @torch.aten.slice.tensor$not_fold_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> { %int1 = torch.constant.int 1 %int-1 = torch.constant.int -1 %int0 = torch.constant.int 0 - %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32> - return %0 : !torch.vtensor<[4],f32> + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], f32> + return %0 : !torch.vtensor<[3],f32> } // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice @@ -2209,7 +2209,10 @@ func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) { } // ----- - +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +// CHECK: %[[CST0:.+]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +// CHECK: return %[[CST]], %[[CST0]] func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> %int0 = torch.constant.int 0 @@ -2224,6 +2227,18 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> !torch.vtensor<[4,1],si64> { +// CHECK{LITERAL}: %0 = torch.vtensor.literal(dense<[[28], [14], [7], [4]]> : tensor<4x1xsi64>) : !torch.vtensor<[4,1],si64> +// CHECK: return %0 +func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> (!torch.vtensor<[4,1],si64>) { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.vtensor.literal(dense<[[28, 28], [14, 14], [7, 7], [4, 4]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int1, %int1, %int2, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64> + return %1 : !torch.vtensor<[4,1],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {