Skip to content

Commit

Permalink
[sparse] propagate sparsity properly when decompose torch operations. (
Browse files Browse the repository at this point in the history
  • Loading branch information
Peiming Liu authored May 15, 2024
1 parent ba32b9c commit ccb772c
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 13 deletions.
25 changes: 20 additions & 5 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class BaseTensorType : public Type {
/// convenient API.
Type getOptionalDtype() const;

/// Get the raw optional sparse tensor encoding.
Attribute getOptionalSparsity() const;

/// Return true if this type has a list of sizes.
bool hasSizes() const { return getOptionalSizes().has_value(); }

Expand Down Expand Up @@ -93,6 +96,10 @@ class BaseTensorType : public Type {
Type getWithSizesAndDtype(std::optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) const;

Type getWithSizesAndDtypeAndSparsity(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
Attribute optionalSparsity) const;

/// Return a type with the same shape and dtype as this one, but with
/// value semantics.
ValueTensorType getWithValueSemantics() const;
Expand Down Expand Up @@ -129,23 +136,31 @@ namespace Torch {

inline std::optional<ArrayRef<int64_t>>
BaseTensorType::getOptionalSizes() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalSizes();
if (auto tensor = dyn_cast<ValueTensorType>())
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalSizes();
llvm_unreachable("not a BaseTensorType!");
}

inline Type BaseTensorType::getOptionalDtype() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalDtype();
if (auto tensor = dyn_cast<ValueTensorType>())
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalDtype();
llvm_unreachable("not a BaseTensorType!");
}

inline Attribute BaseTensorType::getOptionalSparsity() const {
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalSparsity();
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalSparsity();
llvm_unreachable("not a BaseTensorType!");
}

inline bool BaseTensorType::classof(Type type) {
return type.isa<NonValueTensorType, ValueTensorType>();
return mlir::isa<NonValueTensorType, ValueTensorType>(type);
}

} // namespace Torch
Expand Down
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir {
namespace torch {
namespace Torch {

// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with
// a new dense level inserted at `dim`.
FailureOr<Attribute> getSparsityWithDenseLTAtDim(Attribute attr, Value dim);

} // namespace Torch
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1880,9 +1880,11 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}

SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides);
loc, sliceType, input, offsets, resultShape, strides);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ Type BaseTensorType::getWithSizesAndDtype(
llvm_unreachable("not a BaseTensorType!");
}

Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
Attribute optionalSparsity) const {
if (mlir::isa<NonValueTensorType>(*this))
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype,
optionalSparsity);
if (mlir::isa<ValueTensorType>(*this))
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype,
optionalSparsity);
llvm_unreachable("not a BaseTensorType!");
}

ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
return tensor.getWithValueSemantics();
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
}
}

Type resultType = tensorType.getWithSizesAndDtype(
Type resultType = tensorType.getWithSizesAndDtypeAndSparsity(
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(sizes),
tensorType.getOptionalDtype());
tensorType.getOptionalDtype(), tensorType.getOptionalSparsity());
return resultType;
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(TorchMLIRTorchUtils
Utils.cpp
SparsityUtils.cpp
TorchUpstream.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
55 changes: 55 additions & 0 deletions lib/Dialect/Torch/Utils/SparsityUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>

using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

FailureOr<Attribute> Torch::getSparsityWithDenseLTAtDim(Attribute attr,
Value dim) {
if (!attr)
return Attribute();

auto enc = cast<SparseTensorEncodingAttr>(attr);
int64_t dimInt = 0;
int64_t rank = enc.getDimRank() + 1;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, rank);
if (!isValidDim(dimInt, rank)) {
return failure();
}
if (!enc.isIdentity()) {
// TODO: support block sparsity and permutation (CSC).
return failure();
}
auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true);
SmallVector<LevelType> lvlTps = llvm::to_vector(enc.getLvlTypes());
lvlTps.insert(lvlTps.begin() + dimInt, denseLT);
auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext());
return SparseTensorEncodingAttr::get(
enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(),
enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal());
}
// Do not know how to handle dynamic dimension.
return failure();
}
10 changes: 8 additions & 2 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -318,6 +319,11 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size");
}
FailureOr<Attribute> enc =
getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim);
if (failed(enc)) {
return failure();
}

SmallVector<int64_t> unsqueezedShape;
ArrayRef<int64_t> inputShape = inputType.getSizes();
Expand All @@ -334,8 +340,8 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
} else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
}
Type unsqueezedType = inputType.getWithSizesAndDtype(
unsqueezedShape, inputType.getOptionalDtype());
Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity(
unsqueezedShape, inputType.getOptionalDtype(), enc.value());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ def invoke(*args):
"builtin.module("
+ ",".join(
[
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Apply some optimizations. It would be great if MLIR had more useful
# optimizations that worked out of the box here.
# Note: When measured, this doesn't seem to actually help that much
Expand All @@ -157,6 +155,10 @@ def invoke(*args):
"sparse-storage-specifier-to-llvm",
# Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)",
# Generalize pad and concat after sparse compiler, as they are handled
# differently when the operations involve sparse operand.
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Bufferize.
"func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)",
Expand Down
10 changes: 10 additions & 0 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def sparse_export(
# elif opname == "_to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog


Expand Down
2 changes: 2 additions & 0 deletions utils/bazel/torch-mlir-overlay/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ gentbl_cc_library(
cc_library(
name = "TorchMLIRTorchDialectUtils",
srcs = [
"lib/Dialect/Torch/Utils/SparsityUtils.cpp",
"lib/Dialect/Torch/Utils/TorchUpstream.cpp",
"lib/Dialect/Torch/Utils/Utils.cpp",
],
hdrs = [
"include/torch-mlir/Dialect/Torch/IR/TorchOps.h",
"include/torch-mlir/Dialect/Torch/IR/TorchTraits.h",
"include/torch-mlir/Dialect/Torch/IR/TorchTypes.h",
"include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h",
"include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h",
"include/torch-mlir/Dialect/Torch/Utils/Utils.h",
],
Expand Down

0 comments on commit ccb772c

Please sign in to comment.