Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] Cross-Thread Reduction #9360

Merged
merged 6 commits into from
Nov 14, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
#define TVM_TIR_SCHEDULE_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
Expand Down Expand Up @@ -331,16 +332,14 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,

/*!
* \brief Convert the `init` and `body` of the input block to BufferStores
* \tparam in_schedule Whether the function is called by schedule primitives
* \param self The schedule state
* \param block The block to be analyzed
* \return The BufferStores of the `init` and `body` of the input block
* \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
* buffer
*/
template <bool in_schedule>
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(const ScheduleState& self,
const Block& block);
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block);

/*!
* \brief Check whether the input array of IterVars only contains data-parallel and reduction block
Expand All @@ -363,16 +362,14 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block);
/*!
* \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative
* reducer, and extract the combiner lhs and combiner rhs
* \tparam in_schedule Whether the function is called by schedule primitives
* \param self The schedule state
* \param identity The reduction identity to be analyzed
* \param combiner The reduction combiner to be analyzed
* \return The corresponding CommReducer, the combiner lhs and the combiner rhs
* \throw ScheduleError If no corresponding commutative reducer can be matched
*/
template <bool in_schedule>
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner);
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner);

/******** Commutative Reducer ********/

Expand All @@ -381,7 +378,7 @@ std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
* \return The list of the registered reducer-getter functions
* \sa ReducerRegistry
*/
std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();

/*!
* \brief Given the input identity and the combiner BufferStore of a reduction, extract the
Expand Down
68 changes: 30 additions & 38 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
if (set == nullptr) {
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

Array<Var> vars_in_binding = UndefinedVars(iter_value);
for (const Var& var : vars_in_binding) {
set->insert(var.get());
Expand Down Expand Up @@ -1166,39 +1165,38 @@ class InitBodyNotSameBufferAccessError : public ScheduleError {
Block block_;
};

template <bool in_schedule>
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(const ScheduleState& self,
const Block& block) {
const char* error_str1 =
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block) {
static constexpr const char* error_str1 =
"ValueError: The `init` and `body` of the reduction block are required to be both "
"BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction "
"block that doesn't meet this requirement is ";
const char* error_str2 =
static constexpr const char* error_str2 =
"ValueError: The `init` and `body` of the reduction block are required to have the same "
"buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a "
"reduction block that doesn't meet this requirement is ";

const auto* init = block->init.as<BufferStoreNode>();
const auto* body = block->body.as<BufferStoreNode>();
if (!(init && body)) {
if (in_schedule) {
throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr);
if (self.defined()) {
throw InitBodyNotBufferStoreError(self.value()->mod, block, init != nullptr, body != nullptr);
} else {
LOG(FATAL) << error_str1 << block;
}
}
if (!init->buffer.same_as(body->buffer)) {
if (in_schedule) {
throw InitBodyNotSameBufferAccessError(self->mod, block);
if (self.defined()) {
throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
} else {
LOG(FATAL) << error_str2 << block;
}
}
int ndim = static_cast<int>(init->buffer->shape.size());
for (int i = 0; i < ndim; ++i) {
if (!ExprDeepEqual()(init->indices[i], body->indices[i])) {
if (in_schedule) {
throw InitBodyNotSameBufferAccessError(self->mod, block);
if (self.defined()) {
throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
} else {
LOG(FATAL) << error_str2 << block;
}
Expand Down Expand Up @@ -1231,26 +1229,30 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) {
for (const BufferRegion& write_region : block->writes) {
buffer_written.insert(write_region->buffer.get());
}
auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool {
return UsesVar(expr, [&](const VarNode* var) { //
return reduction_block_iters.count(var);
});
};
bool affected = false;
PreOrderVisit(block->body, [&](const ObjectRef& obj) {
if (affected) {
return false;
}
if (const auto* store = obj.as<BufferStoreNode>()) {
ICHECK(buffer_written.count(store->buffer.get()))
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature";
for (const PrimExpr& index : store->indices) {
if (UsesVar(index, [&reduction_block_iters](const VarNode* var) {
return reduction_block_iters.count(var);
})) {
affected = true;
return false;
}
const auto* store = obj.as<BufferStoreNode>();
if (!store) {
return true;
}
ICHECK(buffer_written.count(store->buffer.get()))
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature";
for (const PrimExpr& index : store->indices) {
if (f_uses_reduction_block_var(index)) {
affected = true;
return false;
}
return false;
}
return true;
return false;
});
return !affected;
}
Expand Down Expand Up @@ -1281,15 +1283,14 @@ class NoMatchedReducerError : public ScheduleError {
BufferStore combiner_;
};

template <bool in_schedule>
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) {
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner) {
CommReducer reducer{nullptr};
PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr};
bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs);
if (!matched) {
if (in_schedule) {
throw NoMatchedReducerError(self->mod, identity, combiner);
if (self.defined()) {
throw NoMatchedReducerError(self.value()->mod, identity, combiner);
} else {
LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the "
"reduction block. So rfactor and cross-thread reduction cannot be applied.";
Expand All @@ -1298,15 +1299,6 @@ std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs));
}

template std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock<true>(
const ScheduleState& self, const Block& block);
template std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock<false>(
const ScheduleState& self, const Block& block);
template std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs<true>(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner);
template std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs<false>(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner);

/******** Commutative Reducer ********/

bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner,
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
BufferStore update;
CommReducer reducer;
PrimExpr combiner_lhs, combiner_rhs;
std::tie(init, update) = GetBufferStoresFromReductionBlock<true>(self, block);
std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block);
std::tie(reducer, combiner_lhs, combiner_rhs) =
GetReducerAndCombinerLhsRhs<true>(self, init->value, update);
GetReducerAndCombinerLhsRhs(self, init->value, update);

// Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it
// is negative.
Expand Down
Loading