forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[sparse] propagate sparsity properly when decompose torch operations. (…
- Loading branch information
Peiming Liu
authored
May 15, 2024
1 parent
ba32b9c
commit ccb772c
Showing
11 changed files
with
146 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H | ||
#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H | ||
|
||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/Value.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
namespace mlir { | ||
namespace torch { | ||
namespace Torch { | ||
|
||
// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with | ||
// a new dense level inserted at `dim`. | ||
FailureOr<Attribute> getSparsityWithDenseLTAtDim(Attribute attr, Value dim); | ||
|
||
} // namespace Torch | ||
} // namespace torch | ||
} // namespace mlir | ||
|
||
#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" | ||
#include "mlir/Dialect/SparseTensor/IR/Enums.h" | ||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include <cstdint> | ||
|
||
using namespace mlir; | ||
using namespace mlir::sparse_tensor; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
FailureOr<Attribute> Torch::getSparsityWithDenseLTAtDim(Attribute attr, | ||
Value dim) { | ||
if (!attr) | ||
return Attribute(); | ||
|
||
auto enc = cast<SparseTensorEncodingAttr>(attr); | ||
int64_t dimInt = 0; | ||
int64_t rank = enc.getDimRank() + 1; | ||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { | ||
dimInt = toPositiveDim(dimInt, rank); | ||
if (!isValidDim(dimInt, rank)) { | ||
return failure(); | ||
} | ||
if (!enc.isIdentity()) { | ||
// TODO: support block sparsity and permutation (CSC). | ||
return failure(); | ||
} | ||
auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true); | ||
SmallVector<LevelType> lvlTps = llvm::to_vector(enc.getLvlTypes()); | ||
lvlTps.insert(lvlTps.begin() + dimInt, denseLT); | ||
auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext()); | ||
return SparseTensorEncodingAttr::get( | ||
enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(), | ||
enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal()); | ||
} | ||
// Do not know how to handle dynamic dimension. | ||
return failure(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters