Skip to content

Commit

Permalink
Code revision on analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 12, 2021
1 parent a560a9b commit 4ac6c65
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 53 deletions.
13 changes: 5 additions & 8 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
#define TVM_TIR_SCHEDULE_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
Expand Down Expand Up @@ -331,16 +332,14 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,

/*!
* \brief Convert the `init` and `body` of the input block to BufferStores
* \tparam in_schedule Whether the function is called by schedule primitives
* \param self The schedule state
* \param block The block to be analyzed
* \return The BufferStores of the `init` and `body` of the input block
* \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
* buffer
*/
template <bool in_schedule>
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(const ScheduleState& self,
const Block& block);
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block);

/*!
* \brief Check whether the input array of IterVars only contains data-parallel and reduction block
Expand All @@ -363,16 +362,14 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block);
/*!
* \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative
* reducer, and extract the combiner lhs and combiner rhs
* \tparam in_schedule Whether the function is called by schedule primitives
* \param self The schedule state
* \param identity The reduction identity to be analyzed
* \param combiner The reduction combiner to be analyzed
* \return The corresponding CommReducer, the combiner lhs and the combiner rhs
* \throw ScheduleError If no corresponding commutative reducer can be matched
*/
template <bool in_schedule>
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner);
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner);

/******** Commutative Reducer ********/

Expand All @@ -381,7 +378,7 @@ std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
* \return The list of the registered reducer-getter functions
* \sa ReducerRegistry
*/
std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();

/*!
* \brief Given the input identity and the combiner BufferStore of a reduction, extract the
Expand Down
70 changes: 30 additions & 40 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,8 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
set = reduce_vars;
} else {
has_block_vars_of_other_types = true;
}
if (set == nullptr) {
continue;
}

Array<Var> vars_in_binding = UndefinedVars(iter_value);
for (const Var& var : vars_in_binding) {
set->insert(var.get());
Expand Down Expand Up @@ -1166,39 +1163,38 @@ class InitBodyNotSameBufferAccessError : public ScheduleError {
Block block_;
};

template <bool in_schedule>
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(const ScheduleState& self,
const Block& block) {
const char* error_str1 =
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block) {
static constexpr const char* error_str1 =
"ValueError: The `init` and `body` of the reduction block are required to be both "
"BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction "
"block that doesn't meet this requirement is ";
const char* error_str2 =
static constexpr const char* error_str2 =
"ValueError: The `init` and `body` of the reduction block are required to have the same "
"buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a "
"reduction block that doesn't meet this requirement is ";

const auto* init = block->init.as<BufferStoreNode>();
const auto* body = block->body.as<BufferStoreNode>();
if (!(init && body)) {
if (in_schedule) {
throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr);
if (self.defined()) {
throw InitBodyNotBufferStoreError(self.value()->mod, block, init != nullptr, body != nullptr);
} else {
LOG(FATAL) << error_str1 << block;
}
}
if (!init->buffer.same_as(body->buffer)) {
if (in_schedule) {
throw InitBodyNotSameBufferAccessError(self->mod, block);
if (self.defined()) {
throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
} else {
LOG(FATAL) << error_str2 << block;
}
}
int ndim = static_cast<int>(init->buffer->shape.size());
for (int i = 0; i < ndim; ++i) {
if (!ExprDeepEqual()(init->indices[i], body->indices[i])) {
if (in_schedule) {
throw InitBodyNotSameBufferAccessError(self->mod, block);
if (self.defined()) {
throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
} else {
LOG(FATAL) << error_str2 << block;
}
Expand Down Expand Up @@ -1231,26 +1227,30 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) {
for (const BufferRegion& write_region : block->writes) {
buffer_written.insert(write_region->buffer.get());
}
auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool {
return UsesVar(expr, [&](const VarNode* var) { //
return reduction_block_iters.count(var);
});
};
bool affected = false;
PreOrderVisit(block->body, [&](const ObjectRef& obj) {
if (affected) {
return false;
}
if (const auto* store = obj.as<BufferStoreNode>()) {
ICHECK(buffer_written.count(store->buffer.get()))
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature";
for (const PrimExpr& index : store->indices) {
if (UsesVar(index, [&reduction_block_iters](const VarNode* var) {
return reduction_block_iters.count(var);
})) {
affected = true;
return false;
}
const auto* store = obj.as<BufferStoreNode>();
if (!store) {
return true;
}
ICHECK(buffer_written.count(store->buffer.get()))
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature";
for (const PrimExpr& index : store->indices) {
if (f_uses_reduction_block_var(index)) {
affected = true;
return false;
}
return false;
}
return true;
return false;
});
return !affected;
}
Expand Down Expand Up @@ -1281,15 +1281,14 @@ class NoMatchedReducerError : public ScheduleError {
BufferStore combiner_;
};

template <bool in_schedule>
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) {
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner) {
CommReducer reducer{nullptr};
PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr};
bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs);
if (!matched) {
if (in_schedule) {
throw NoMatchedReducerError(self->mod, identity, combiner);
if (self.defined()) {
throw NoMatchedReducerError(self.value()->mod, identity, combiner);
} else {
LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the "
"reduction block. So rfactor and cross-thread reduction cannot be applied.";
Expand All @@ -1298,15 +1297,6 @@ std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs));
}

template std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock<true>(
const ScheduleState& self, const Block& block);
template std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock<false>(
const ScheduleState& self, const Block& block);
template std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs<true>(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner);
template std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs<false>(
const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner);

/******** Commutative Reducer ********/

bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner,
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
BufferStore update;
CommReducer reducer;
PrimExpr combiner_lhs, combiner_rhs;
std::tie(init, update) = GetBufferStoresFromReductionBlock<true>(self, block);
std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block);
std::tie(reducer, combiner_lhs, combiner_rhs) =
GetReducerAndCombinerLhsRhs<true>(self, init->value, update);
GetReducerAndCombinerLhsRhs(self, init->value, update);

// Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it
// is negative.
Expand Down
5 changes: 2 additions & 3 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,13 @@ class CrossThreadReductionTransformer : public StmtMutator {
// both be BufferStores with the same buffer and indices.
BufferStore init;
BufferStore update;
std::tie(init, update) =
GetBufferStoresFromReductionBlock<false>(ScheduleState{nullptr}, GetRef<Block>(block));
std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef<Block>(block));

// Condition 5. Extract the commutative reducer, combiner lhs and combiner rhs from the
// reduction identity and the reduction combiner.
PrimExpr combiner_lhs;
std::tie(reducer_, combiner_lhs, combiner_rhs_) =
GetReducerAndCombinerLhsRhs<false>(ScheduleState{nullptr}, init->value, update);
GetReducerAndCombinerLhsRhs(NullOpt, init->value, update);

// Condition 6. The block should be the last block under the first reduction-related loop.
bool visit = false;
Expand Down

0 comments on commit 4ac6c65

Please sign in to comment.