diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index c437293e49c06..42e0e00995fe0 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ +#include #include #include @@ -331,16 +332,14 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, /*! * \brief Convert the `init` and `body` of the input block to BufferStores - * \tparam in_schedule Whether the function is called by schedule primitives * \param self The schedule state * \param block The block to be analyzed * \return The BufferStores of the `init` and `body` of the input block * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same * buffer */ -template -std::pair GetBufferStoresFromReductionBlock(const ScheduleState& self, - const Block& block); +std::pair GetBufferStoresFromReductionBlock( + const Optional& self, const Block& block); /*! * \brief Check whether the input array of IterVars only contains data-parallel and reduction block @@ -363,16 +362,14 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block); /*! * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative * reducer, and extract the combiner lhs and combiner rhs - * \tparam in_schedule Whether the function is called by schedule primitives * \param self The schedule state * \param identity The reduction identity to be analyzed * \param combiner The reduction combiner to be analyzed * \return The corresponding CommReducer, the combiner lhs and the combiner rhs * \throw ScheduleError If no corresponding commutative reducer can be matched */ -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); + const Optional& self, const PrimExpr& identity, const BufferStore& combiner); /******** Commutative Reducer ********/ @@ -381,7 +378,7 @@ std::tuple GetReducerAndCombinerLhsRhs( * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector> GetReducerGetters(); +std::vector> GetReducerGetters(); /*! * \brief Given the input identity and the combiner BufferStore of a reduction, extract the diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 672fc0f602f44..f869b5431cbc6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -519,11 +519,8 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, set = reduce_vars; } else { has_block_vars_of_other_types = true; - } - if (set == nullptr) { continue; } - Array vars_in_binding = UndefinedVars(iter_value); for (const Var& var : vars_in_binding) { set->insert(var.get()); @@ -1166,14 +1163,13 @@ class InitBodyNotSameBufferAccessError : public ScheduleError { Block block_; }; -template -std::pair GetBufferStoresFromReductionBlock(const ScheduleState& self, - const Block& block) { - const char* error_str1 = +std::pair GetBufferStoresFromReductionBlock( + const Optional& self, const Block& block) { + static constexpr const char* error_str1 = "ValueError: The `init` and `body` of the reduction block are required to be both " "BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction " "block that doesn't meet this requirement is "; - const char* error_str2 = + static constexpr const char* error_str2 = "ValueError: The `init` and `body` of the reduction block are required to have the same " "buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a " "reduction block that doesn't meet this requirement is "; @@ -1181,15 +1177,15 @@ std::pair GetBufferStoresFromReductionBlock(const Sche const auto* init = block->init.as(); const auto* body = block->body.as(); if (!(init && body)) { - if (in_schedule) { - throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); + if (self.defined()) { + throw InitBodyNotBufferStoreError(self.value()->mod, block, init != nullptr, body != nullptr); } else { LOG(FATAL) << error_str1 << block; } } if (!init->buffer.same_as(body->buffer)) { - if (in_schedule) { - throw InitBodyNotSameBufferAccessError(self->mod, block); + if (self.defined()) { + throw InitBodyNotSameBufferAccessError(self.value()->mod, block); } else { LOG(FATAL) << error_str2 << block; } @@ -1197,8 +1193,8 @@ std::pair GetBufferStoresFromReductionBlock(const Sche int ndim = static_cast(init->buffer->shape.size()); for (int i = 0; i < ndim; ++i) { if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { - if (in_schedule) { - throw InitBodyNotSameBufferAccessError(self->mod, block); + if (self.defined()) { + throw InitBodyNotSameBufferAccessError(self.value()->mod, block); } else { LOG(FATAL) << error_str2 << block; } @@ -1231,26 +1227,30 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) { for (const BufferRegion& write_region : block->writes) { buffer_written.insert(write_region->buffer.get()); } + auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&](const VarNode* var) { // + return reduction_block_iters.count(var); + }); + }; bool affected = false; PreOrderVisit(block->body, [&](const ObjectRef& obj) { if (affected) { return false; } - if (const auto* store = obj.as()) { - ICHECK(buffer_written.count(store->buffer.get())) - << "ValueError: The buffer \"" << store->buffer - << "\" is written in the block but is not in the block's signature"; - for (const PrimExpr& index : store->indices) { - if (UsesVar(index, [&reduction_block_iters](const VarNode* var) { - return reduction_block_iters.count(var); - })) { - affected = true; - return false; - } + const auto* store = obj.as(); + if (!store) { + return true; + } + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (f_uses_reduction_block_var(index)) { + affected = true; + return false; } - return false; } - return true; + return false; }); return !affected; } @@ -1281,15 +1281,14 @@ class NoMatchedReducerError : public ScheduleError { BufferStore combiner_; }; -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { + const Optional& self, const PrimExpr& identity, const BufferStore& combiner) { CommReducer reducer{nullptr}; PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); if (!matched) { - if (in_schedule) { - throw NoMatchedReducerError(self->mod, identity, combiner); + if (self.defined()) { + throw NoMatchedReducerError(self.value()->mod, identity, combiner); } else { LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " "reduction block. So rfactor and cross-thread reduction cannot be applied."; @@ -1298,15 +1297,6 @@ std::tuple GetReducerAndCombinerLhsRhs( return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); } -template std::pair GetBufferStoresFromReductionBlock( - const ScheduleState& self, const Block& block); -template std::pair GetBufferStoresFromReductionBlock( - const ScheduleState& self, const Block& block); -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); - /******** Commutative Reducer ********/ bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 18c4e5da13153..9c330765ef38b 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -1041,9 +1041,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax BufferStore update; CommReducer reducer; PrimExpr combiner_lhs, combiner_rhs; - std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); + std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); std::tie(reducer, combiner_lhs, combiner_rhs) = - GetReducerAndCombinerLhsRhs(self, init->value, update); + GetReducerAndCombinerLhsRhs(self, init->value, update); // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it // is negative. diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index d15a5f16a7833..b327cee4b6ada 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -283,14 +283,13 @@ class CrossThreadReductionTransformer : public StmtMutator { // both be BufferStores with the same buffer and indices. BufferStore init; BufferStore update; - std::tie(init, update) = - GetBufferStoresFromReductionBlock(ScheduleState{nullptr}, GetRef(block)); + std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); // Condition 5. Extract the commutative reducer, combiner lhs and combiner rhs from the // reduction identity and the reduction combiner. PrimExpr combiner_lhs; std::tie(reducer_, combiner_lhs, combiner_rhs_) = - GetReducerAndCombinerLhsRhs(ScheduleState{nullptr}, init->value, update); + GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); // Condition 6. The block should be the last block under the first reduction-related loop. bool visit = false;