diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index c8d1c5051f28..163ed6300878 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -53,6 +53,9 @@ class BaseTensorType : public Type { /// convenient API. Type getOptionalDtype() const; + /// Get the raw optional sparse tensor encoding. + Attribute getOptionalSparsity() const; + /// Return true if this type has a list of sizes. bool hasSizes() const { return getOptionalSizes().has_value(); } @@ -93,6 +96,10 @@ class BaseTensorType : public Type { Type getWithSizesAndDtype(std::optional> optionalSizes, Type optionalDtype) const; + Type getWithSizesAndDtypeAndSparsity( + std::optional> optionalSizes, Type optionalDtype, + Attribute optionalSparsity) const; + /// Return a type with the same shape and dtype as this one, but with /// value semantics. ValueTensorType getWithValueSemantics() const; @@ -129,23 +136,31 @@ namespace Torch { inline std::optional> BaseTensorType::getOptionalSizes() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalSizes(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalSizes(); llvm_unreachable("not a BaseTensorType!"); } inline Type BaseTensorType::getOptionalDtype() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalDtype(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalDtype(); llvm_unreachable("not a BaseTensorType!"); } +inline Attribute BaseTensorType::getOptionalSparsity() const { + if (auto tensor = mlir::dyn_cast(*this)) + return tensor.getOptionalSparsity(); + if (auto tensor = mlir::dyn_cast(*this)) + return tensor.getOptionalSparsity(); + llvm_unreachable("not a BaseTensorType!"); +} + inline bool BaseTensorType::classof(Type type) { - return type.isa(); + return mlir::isa(type); } } // namespace Torch diff --git a/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h b/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h new file mode 100644 index 000000000000..e29054790e5c --- /dev/null +++ b/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H +#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace torch { +namespace Torch { + +// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with +// a new dense level inserted at `dim`. +FailureOr getSparsityWithDenseLTAtDim(Attribute attr, Value dim); + +} // namespace Torch +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d8dd75a9a233..a5b07b947af6 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1880,9 +1880,11 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } - + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); + auto sliceType = RankedTensorType::get( + dynShape, resultType.getElementType(), resultType.getEncoding()); Value result = rewriter.create( - loc, input, offsets, resultShape, strides); + loc, sliceType, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index c162166cdd13..d1906d6989af 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -235,6 +235,18 @@ Type BaseTensorType::getWithSizesAndDtype( llvm_unreachable("not a BaseTensorType!"); } +Type BaseTensorType::getWithSizesAndDtypeAndSparsity( + std::optional> optionalSizes, Type optionalDtype, + Attribute optionalSparsity) const { + if (mlir::isa(*this)) + return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype, + optionalSparsity); + if (mlir::isa(*this)) + return ValueTensorType::get(getContext(), optionalSizes, optionalDtype, + optionalSparsity); + llvm_unreachable("not a BaseTensorType!"); +} + ValueTensorType BaseTensorType::getWithValueSemantics() const { if (auto tensor = dyn_cast()) return tensor.getWithValueSemantics(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9fad15e132ff..54b852dcf06d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -71,10 +71,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, } } - Type resultType = tensorType.getWithSizesAndDtype( + Type resultType = tensorType.getWithSizesAndDtypeAndSparsity( !tensorType.hasSizes() ? std::optional>() : llvm::ArrayRef(sizes), - tensorType.getOptionalDtype()); + tensorType.getOptionalDtype(), tensorType.getOptionalSparsity()); return resultType; } diff --git a/lib/Dialect/Torch/Utils/CMakeLists.txt b/lib/Dialect/Torch/Utils/CMakeLists.txt index 91088078891d..45b3e1b987aa 100644 --- a/lib/Dialect/Torch/Utils/CMakeLists.txt +++ b/lib/Dialect/Torch/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(TorchMLIRTorchUtils Utils.cpp + SparsityUtils.cpp TorchUpstream.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/Torch/Utils/SparsityUtils.cpp b/lib/Dialect/Torch/Utils/SparsityUtils.cpp new file mode 100644 index 000000000000..b2f1ef2d5280 --- /dev/null +++ b/lib/Dialect/Torch/Utils/SparsityUtils.cpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +FailureOr Torch::getSparsityWithDenseLTAtDim(Attribute attr, + Value dim) { + if (!attr) + return Attribute(); + + auto enc = cast(attr); + int64_t dimInt = 0; + int64_t rank = enc.getDimRank() + 1; + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, rank); + if (!isValidDim(dimInt, rank)) { + return failure(); + } + if (!enc.isIdentity()) { + // TODO: support block sparsity and permutation (CSC). + return failure(); + } + auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true); + SmallVector lvlTps = llvm::to_vector(enc.getLvlTypes()); + lvlTps.insert(lvlTps.begin() + dimInt, denseLT); + auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext()); + return SparseTensorEncodingAttr::get( + enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(), + enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal()); + } + // Do not know how to handle dynamic dimension. + return failure(); +} diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d634556c98a1..ed035b3030dd 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" using namespace mlir; using namespace mlir::torch; @@ -318,6 +319,11 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(op, "input tensor must have size"); } + FailureOr enc = + getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim); + if (failed(enc)) { + return failure(); + } SmallVector unsqueezedShape; ArrayRef inputShape = inputType.getSizes(); @@ -334,8 +340,8 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, } else { unsqueezedShape.resize(unsqueezedRank, kUnknownSize); } - Type unsqueezedType = inputType.getWithSizesAndDtype( - unsqueezedShape, inputType.getOptionalDtype()); + Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity( + unsqueezedShape, inputType.getOptionalDtype(), enc.value()); Value unsqueezed = rewriter.create( op->getLoc(), unsqueezedType, input, dim); return unsqueezed; diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 8935a2a060fd..0179dd369893 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -138,8 +138,6 @@ def invoke(*args): "builtin.module(" + ",".join( [ - "func.func(refback-generalize-tensor-pad)", - "func.func(refback-generalize-tensor-concat)", # Apply some optimizations. It would be great if MLIR had more useful # optimizations that worked out of the box here. # Note: When measured, this doesn't seem to actually help that much @@ -157,6 +155,10 @@ def invoke(*args): "sparse-storage-specifier-to-llvm", # Buffer deallocation pass does not know how to handle realloc. "func.func(expand-realloc)", + # Generalize pad and concat after sparse compiler, as they are handled + # differently when the operations involve sparse operand. + "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", # Bufferize. "func.func(scf-bufferize)", "func.func(tm-tensor-bufferize)", diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index bfe404c92f1a..87d2e3d96d0e 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -134,6 +134,16 @@ def sparse_export( # elif opname == "_to_dense": # # hack (assumes we never really want the to_dense for now) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) + elif opname == "select" and node.args[0].meta.get("sparsity", None): + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 + ) + elif opname == "stack" and node.args[0][0].meta.get("sparsity", None): + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64 + ) return prog diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index e62780ff9634..2118660a9b8b 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -90,6 +90,7 @@ gentbl_cc_library( cc_library( name = "TorchMLIRTorchDialectUtils", srcs = [ + "lib/Dialect/Torch/Utils/SparsityUtils.cpp", "lib/Dialect/Torch/Utils/TorchUpstream.cpp", "lib/Dialect/Torch/Utils/Utils.cpp", ], @@ -97,6 +98,7 @@ cc_library( "include/torch-mlir/Dialect/Torch/IR/TorchOps.h", "include/torch-mlir/Dialect/Torch/IR/TorchTraits.h", "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", + "include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h", "include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h", "include/torch-mlir/Dialect/Torch/Utils/Utils.h", ],