diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e6b0af9773d9..7922e978c381 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -357,6 +357,13 @@ TVM_DLL Pass PointerValueTypeRewrite(); */ TVM_DLL Pass HoistIfThenElse(); +/*! + * \brief Lower cross-thread reduction from thread + * bindings to intrinsic function calls. + * \return The pass. + */ +TVM_DLL Pass LowerCrossThreadReduction(); + /*! * \brief Lower block init stmt into IfThenElse stmts * \return The pass. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 722810e9aa5b..86f798caceba 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -577,6 +577,18 @@ def HoistIfThenElse(variant: Optional[str] = None): return _ffi_api.HoistIfThenElse() # type: ignore +def LowerCrossThreadReduction(): + """Lower cross-thread reduction from thread bindings to + intrinsic function calls. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerCrossThreadReduction() # type: ignore + + def LowerInitBlock(): """Lower block init stmt into IfThenElse statements. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ad1f51ba6d71..f49409c2baee 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -234,6 +234,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 5a2f46c910b4..42e0e00995fe 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -19,12 +19,17 @@ #ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ +#include #include +#include #include #include +#include #include +#include "../../runtime/thread_storage_scope.h" + namespace tvm { namespace tir { @@ -323,6 +328,49 @@ struct ProducerConsumerSplit { */ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write); +/******** Reduction Block Related ********/ + +/*! + * \brief Convert the `init` and `body` of the input block to BufferStores + * \param self The schedule state + * \param block The block to be analyzed + * \return The BufferStores of the `init` and `body` of the input block + * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same + * buffer + */ +std::pair GetBufferStoresFromReductionBlock( + const Optional& self, const Block& block); + +/*! + * \brief Check whether the input array of IterVars only contains data-parallel and reduction block + * iters + * \param iters The input array of IterVars to be checked + * \return A boolean indicating whether the input array of IterVars only contains data-parallel and + * reduction block iters + */ +bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters); + +/*! + * \brief Check whether the block's reduction block iters are not used to index the block's output + * buffers + * \param block The block to be checked + * \return A boolean indicating whether the block's reduction block iters are not used to index the + * block's output buffer + */ +bool ReductionIterNotIndexOutputBuffer(const Block& block); + +/*! + * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative + * reducer, and extract the combiner lhs and combiner rhs + * \param self The schedule state + * \param identity The reduction identity to be analyzed + * \param combiner The reduction combiner to be analyzed + * \return The corresponding CommReducer, the combiner lhs and the combiner rhs + * \throw ScheduleError If no corresponding commutative reducer can be matched + */ +std::tuple GetReducerAndCombinerLhsRhs( + const Optional& self, const PrimExpr& identity, const BufferStore& combiner); + /******** Commutative Reducer ********/ /*! @@ -330,7 +378,7 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector> GetReducerGetters(); +std::vector> GetReducerGetters(); /*! * \brief Given the input identity and the combiner BufferStore of a reduction, extract the diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e3a535e9b3d4..7e16bc92e4ce 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -153,15 +153,15 @@ Definition of a scope that is a stage pipeline: /*! * \brief Check the dominant property of a block: * the block is the only writer of its output, dominating the reader of its output buffers - * \param self The schedule state + * \param scope The block-scope of the block to be checked * \param block_sref The block whose dominant property is to be checked * \return A boolean indicating if the block is a dominant block */ -bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { +bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { // Check whether the input block is the only writer of its outputs const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = - self->buffer_writers; + scope->buffer_writers; for (const BufferRegion& write_region : block->writes) { ICHECK(buffer_writers.count(write_region->buffer)) << "InternalError: buffer \"" << write_region->buffer->name @@ -279,14 +279,8 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc } // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, // we collect all the reduction block vars. - std::unordered_set reduction_block_vars; - reduction_block_vars.reserve(block->iter_vars.size()); - for (const IterVar& iter_var : block->iter_vars) { - if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - return 3; - } else if (iter_var->iter_type == kCommReduce) { - reduction_block_vars.insert(iter_var->var.get()); - } + if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) { + return 3; } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. @@ -294,33 +288,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. - std::unordered_set buffer_written; - buffer_written.reserve(block->writes.size()); - for (const BufferRegion& write_region : block->writes) { - buffer_written.insert(write_region->buffer.get()); - } - bool affected = false; - PreOrderVisit(block->body, [&](const ObjectRef& obj) { - if (affected) { - return false; - } - if (const auto* store = obj.as()) { - ICHECK(buffer_written.count(store->buffer.get())) - << "ValueError: The buffer \"" << store->buffer - << "\" is written in the block but is not in the block's signature"; - for (const PrimExpr& index : store->indices) { - if (UsesVar(index, [&reduction_block_vars](const VarNode* var) { - return reduction_block_vars.count(var); - })) { - affected = true; - return false; - } - } - return false; - } - return true; - }); - return !affected ? 0 : 5; + return ReductionIterNotIndexOutputBuffer(GetRef(block)) ? 0 : 5; } bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -552,7 +520,9 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, } else { has_block_vars_of_other_types = true; } - + if (set == nullptr) { + continue; + } Array vars_in_binding = UndefinedVars(iter_value); for (const Var& var : vars_in_binding) { set->insert(var.get()); @@ -1128,6 +1098,207 @@ class PatternMatcher : public ExprVisitor { std::unordered_map filled_map_; }; +/******** Reduction Block Related ********/ + +class InitBodyNotBufferStoreError : public ScheduleError { + public: + explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, + bool body_is_bufferstore) + : mod_(std::move(mod)), + block_(std::move(block)), + init_is_bufferstore_(init_is_bufferstore), + body_is_bufferstore_(body_is_bufferstore) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of reduction block are required to be both " + "BufferStore so that rfactor or cross-thread reduction can be applied"; + } + + String DetailRenderTemplate() const final { + if (!init_is_bufferstore_ && !body_is_bufferstore_) { + return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor or " + "cross-thread reduction can be applied"; + } else if (!init_is_bufferstore_) { + return "The `init` of block {0} is required to be BufferStore so that rfactor or cross-thread" + " reduction can be applied"; + } else { + ICHECK(!body_is_bufferstore_); + return "The `body` of block {0} is required to be BufferStore so that rfactor or cross-thread" + " reduction can be applied"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + bool init_is_bufferstore_; + bool body_is_bufferstore_; +}; + +class InitBodyNotSameBufferAccessError : public ScheduleError { + public: + explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of the reduction block are required to have the " + "same buffer access pattern"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + const auto* init = block_->init.as(); + const auto* update = block_->body.as(); + os << "The `init` and `body` of the block {0} is required to have the same buffer access " + "pattern. However, in block {0} the `init` writes to " + << init->buffer->name << init->indices << ", and the `body` writes to " + << update->buffer->name << update->indices; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + +std::pair GetBufferStoresFromReductionBlock( + const Optional& self, const Block& block) { + static constexpr const char* error_str1 = + "ValueError: The `init` and `body` of the reduction block are required to be both " + "BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction " + "block that doesn't meet this requirement is "; + static constexpr const char* error_str2 = + "ValueError: The `init` and `body` of the reduction block are required to have the same " + "buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a " + "reduction block that doesn't meet this requirement is "; + + const auto* init = block->init.as(); + const auto* body = block->body.as(); + if (!(init && body)) { + if (self.defined()) { + throw InitBodyNotBufferStoreError(self.value()->mod, block, init != nullptr, body != nullptr); + } else { + LOG(FATAL) << error_str1 << block; + } + } + if (!init->buffer.same_as(body->buffer)) { + if (self.defined()) { + throw InitBodyNotSameBufferAccessError(self.value()->mod, block); + } else { + LOG(FATAL) << error_str2 << block; + } + } + int ndim = static_cast(init->buffer->shape.size()); + for (int i = 0; i < ndim; ++i) { + if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { + if (self.defined()) { + throw InitBodyNotSameBufferAccessError(self.value()->mod, block); + } else { + LOG(FATAL) << error_str2 << block; + } + } + } + return std::make_pair(GetRef(init), GetRef(body)); +} + +bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters) { + for (const IterVar& iter_var : iters) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + return false; + } + } + return true; +} + +bool ReductionIterNotIndexOutputBuffer(const Block& block) { + // Step 1. Collect the reduction block iters. + std::unordered_set reduction_block_iters; + reduction_block_iters.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type == kCommReduce) { + reduction_block_iters.insert(iter_var->var.get()); + } + } + // Step 2. Check if the reduction block iters are used to index the output buffer. + std::unordered_set buffer_written; + buffer_written.reserve(block->writes.size()); + for (const BufferRegion& write_region : block->writes) { + buffer_written.insert(write_region->buffer.get()); + } + auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&](const VarNode* var) { // + return reduction_block_iters.count(var); + }); + }; + bool affected = false; + PreOrderVisit(block->body, [&](const ObjectRef& obj) { + if (affected) { + return false; + } + const auto* store = obj.as(); + if (!store) { + return true; + } + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (f_uses_reduction_block_var(index)) { + affected = true; + return false; + } + } + return false; + }); + return !affected; +} + +class NoMatchedReducerError : public ScheduleError { + public: + explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) + : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} + + String FastErrorString() const final { + return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " + "block. So rfactor and cross-thread reduction cannot be applied."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ + << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + "default reducers or registering new reducers."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + PrimExpr identity_; + BufferStore combiner_; +}; + +std::tuple GetReducerAndCombinerLhsRhs( + const Optional& self, const PrimExpr& identity, const BufferStore& combiner) { + CommReducer reducer{nullptr}; + PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; + bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); + if (!matched) { + if (self.defined()) { + throw NoMatchedReducerError(self.value()->mod, identity, combiner); + } else { + LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " + "reduction block. So rfactor and cross-thread reduction cannot be applied."; + } + } + return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); +} + /******** Commutative Reducer ********/ bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 0f851684de2a..9c330765ef38 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -370,69 +370,6 @@ class NotSerialLoopKindError : public ScheduleError { For loop_; }; -class InitBodyNotBufferStoreError : public ScheduleError { - public: - explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, - bool body_is_bufferstore) - : mod_(std::move(mod)), - block_(std::move(block)), - init_is_bufferstore_(init_is_bufferstore), - body_is_bufferstore_(body_is_bufferstore) {} - - String FastErrorString() const final { - return "ScheduleError: The `init` and `body` of reduction block are required to be both " - "BufferStore"; - } - - String DetailRenderTemplate() const final { - if (!init_is_bufferstore_ && !body_is_bufferstore_) { - return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor " - "can be applied"; - } else if (!init_is_bufferstore_) { - return "The `init` of block {0} is required to be BufferStore so that rfactor can be applied"; - } else { - ICHECK(!body_is_bufferstore_); - return "The `body` of block {0} is required to be BufferStore so that rfactor can be applied"; - } - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - - IRModule mod_; - Block block_; - bool init_is_bufferstore_; - bool body_is_bufferstore_; -}; - -class InitBodyNotSameBufferAccessError : public ScheduleError { - public: - explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) - : mod_(std::move(mod)), block_(std::move(block)) {} - - String FastErrorString() const final { - return "ScheduleError: The `init` and `body` of the reduction block are required to have the " - "same buffer access pattern"; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - const auto* init = block_->init.as(); - const auto* update = block_->body.as(); - os << "The `init` and `body` of the block {0} is required to have the same buffer access " - "pattern. However, in block {0} the `init` writes to " - << init->buffer->name << init->indices << ", and the `body` writes to " - << update->buffer->name << update->indices; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - - IRModule mod_; - Block block_; -}; - class FactorAxisOutOfRangeError : public ScheduleError { public: explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) @@ -473,32 +410,6 @@ class FactorAxisOutOfRangeError : public ScheduleError { int factor_axis_; }; -class NoMatchedReducerError : public ScheduleError { - public: - explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) - : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} - - String FastErrorString() const final { - return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " - "block. So rfactor cannot be applied."; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ - << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " - "default reducers or registering new reducers."; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } - - IRModule mod_; - PrimExpr identity_; - BufferStore combiner_; -}; - class LoopPropertyError : public ScheduleError { public: enum ErrorType { @@ -591,53 +502,6 @@ class LoopPropertyError : public ScheduleError { ErrorType error_type_; }; -/*! - * \brief Convert the `init` and `body` of the input block to BufferStores - * \param self The schedule state - * \param block The block to be analyzed - * \return The BufferStores of the `init` and `body` of the input block - * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same - * buffer - */ -std::pair GetBufferStoreNodes(const ScheduleState& self, - const Block& block) { - const auto* init = block->init.as(); - const auto* body = block->body.as(); - if (!(init && body)) { - throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); - } - if (!init->buffer.same_as(body->buffer)) { - throw InitBodyNotSameBufferAccessError(self->mod, block); - } - int ndim = static_cast(init->buffer->shape.size()); - for (int i = 0; i < ndim; ++i) { - if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { - throw InitBodyNotSameBufferAccessError(self->mod, block); - } - } - return std::make_pair(GetRef(init), GetRef(body)); -} - -/*! - * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative - * reducer, and extract the combiner lhs and combiner rhs - * \param self The schedule state - * \param identity The reduction identity to be analyzed - * \param combiner The reduction combiner to be analyzed - * \return The corresponding CommReducer, the combiner lhs and the combiner rhs - * \throw ScheduleError If no corresponding commutative reducer can be matched - */ -std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { - CommReducer reducer{nullptr}; - PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; - bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); - if (!matched) { - throw NoMatchedReducerError(self->mod, identity, combiner); - } - return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); -} - /*! * \brief For each loop in the given array of loop, associate its loop var with the loop itself * using a mapping @@ -1177,7 +1041,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax BufferStore update; CommReducer reducer; PrimExpr combiner_lhs, combiner_rhs; - std::tie(init, update) = GetBufferStoreNodes(self, block); + std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(self, init->value, update); diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc new file mode 100644 index 000000000000..630c00f8c1f1 --- /dev/null +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -0,0 +1,645 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_cross_thread_reduction.cc + */ +#include +#include +#include +#include + +#include "../schedule/analysis.h" +#include "./ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Checks if a loop is bound to threadIdx.x/y/z + * \brief loop The loop to be checked + * \return True if the loop is bound to threadIdx.x/y/z + */ +bool IsBoundToThreadIdx(const ForNode* loop) { + if (!loop->thread_binding.defined()) { + return false; + } + runtime::ThreadScope scope = + runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); + return scope.rank == 1 && scope.dim_index >= 0; +} + +/*! + * \brief Check the dominant property of a block: + * the block is the only writer of its output, dominating the reader of its output buffers + * \param scope_block The scope block of the block to be checked + * \param block The block whose dominant property is to be checked + * \return A boolean indicating if the block is a dominant block + */ +bool IsDominantBlock(const Block& scope_block, const Block& block) { + // Step 1. Count the number of writers for each buffer written by the scope block. + std::unordered_map buffer_writer_cnt; + PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) { + if (const auto* block = obj.as()) { + for (const BufferRegion& buffer_region : block->writes) { + ++buffer_writer_cnt[buffer_region->buffer.get()]; + } + return false; + } + return true; + }); + // Step 2. Check whether `block` is the only writer of its outputs. + for (const BufferRegion& buffer_region : block->writes) { + ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get())); + if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) { + return false; + } + } + return true; +} + +/*! + * \brief Check whether the input block is a reduction block. + * \param realize The block to be checked + * \param loop_range_map The mapping from the loop variables outside the input block to their ranges + * \param scope_block The scope block of the input block + * \param analyzer The analyzer + * \return A boolean indicating whether the input block is a reduction block. + * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is + * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the + * check again. + */ +bool IsReductionBlock(const BlockRealize& realize, const Map& loop_range_map, + const Block& scope_block, arith::Analyzer* analyzer) { + const auto* block = realize->block.as(); + // Cond 1. The block has the `init` statement. + if (!block->init.defined()) { + return false; + } + // Cond 2. All the block bindings are quasi-affine expressions. + if (!IsAffineBinding(realize, loop_range_map, analyzer)) { + return false; + } + // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, + // we collect all the reduction block vars. + if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) { + return false; + } + // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its + // output buffers. + if (!IsDominantBlock(scope_block, GetRef(block))) { + return false; + } + // Cond 5. The reduction block vars are not used to index the output buffers. + return ReductionIterNotIndexOutputBuffer(GetRef(block)); +} + +/*! + * \brief Create an intermediate buffer with specified name and data type + * \param name The specified name + * \param dtype The specified data type + * \return The created buffer + */ +Buffer MakeScratchpad(String name, const DataType& dtype) { + return Buffer(/*ptr=*/Var(name, PointerType(PrimType(dtype), "local")), + /*dtype=*/dtype, + /*shape=*/{Integer(1)}, + /*strides=*/{Integer(1)}, + /*elem_offset=*/PrimExpr{nullptr}, + /*name=*/std::move(name), + /*data_alignment=*/0, + /*offset_factor=*/0, + /*buffer_type=*/kDefault); +} + +/*! + * \brief Remove the BufferRegions whose buffer is the input buffer + * \param buffer_regions The array of BufferRegions to be + * \param buffer_to_remove The specified buffer + * \return The mutated array of BufferRegions, no longer containing BufferRegion of the input buffer + */ +Array RemoveBufferFromBufferRegions(const Array& buffer_regions, + const Buffer& buffer_to_remove) { + Array res; + res.reserve(buffer_regions.size()); + for (const BufferRegion& buffer_region : buffer_regions) { + if (!buffer_region->buffer.same_as(buffer_to_remove)) { + res.push_back(buffer_region); + } + } + return res; +} + +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions + */ +class BufferReplacer : private StmtExprMutator { + public: + static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { + return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + } + + private: + explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) + : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) + : GetRef(load); + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + if (store->buffer.same_as(src_buffer_)) { + PrimExpr value = StmtExprMutator::VisitExpr(store->value); + return BufferStore(tgt_buffer_, value, {0}); + } else { + return StmtMutator::VisitStmt_(store); + } + } + + Buffer src_buffer_; + Buffer tgt_buffer_; +}; + +/*! + * \brief Substitute a given source block with a given target block, or remove the source block + * branch from the AST if the target block is undefined + */ +class InThreadReducerMaker : private StmtMutator { + public: + static Optional Make(const BlockRealizeNode* src_realize, + Optional tgt_realize, Stmt stmt) { + return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt)); + } + + private: + explicit InThreadReducerMaker(const BlockRealizeNode* src_realize, + Optional tgt_realize) + : src_realize_(src_realize), tgt_realize_(tgt_realize) {} + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize == src_realize_) { + return tgt_realize_.defined() // + ? tgt_realize_.value() + : Stmt{nullptr}; + } + return GetRef(realize); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (Optional opt_res = Downcast>(StmtMutator::VisitStmt_(loop))) { + For res = opt_res.value(); + if (res->thread_binding.defined()) { + return res->body; + } else { + return res; + } + } else { + return Stmt{nullptr}; + } + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array stmts; + stmts.reserve(seq->size()); + for (const Stmt& stmt : seq->seq) { + if (Optional opt_res = VisitStmt(stmt)) { + stmts.push_back(opt_res.value()); + } + } + return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts); + } + + const BlockRealizeNode* src_realize_; + Optional tgt_realize_; +}; + +/*! + * \brief Create the lowered allreduce block transformed from the input reduction block + * \param reduction_block The input reduction block + * \param it_buffer The buffer to store in-thread reduction results + * \param ct_buffer The buffer to store cross-thread reduction results + * \param reducer The reduction function + * \param combiner_rhs The RHS of the combiner + * \param reduction_loops The reduction loops + */ +Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional& it_buffer, + const Buffer& ct_buffer, const CommReducer& reducer, + const PrimExpr& combiner_rhs, + const std::vector& reduction_loops) { + const BlockNode* block = realize->block.get(); + Buffer wb_buffer = block->writes[0]->buffer; + Array wb_region = block->writes[0]->region; + + BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)}); + Optional it_buffer_region = NullOpt; + if (it_buffer.defined()) { + it_buffer_region = BufferRegion(it_buffer.value(), {Range::FromMinExtent(0, 1)}); + } + // In total, the block is transformed into at most 4 statements + // - Stmt 1: initialize the buffer for in-thread reduction + // - Stmt 2: do in-thread reduction + // - Stmt 3: do cross-thread reduction + // - Stmt 4: write cross-thread reduction result to the original buffer + Array stmts; + stmts.reserve(4); + // Stmt 1: initialize the buffer for in-thread reduction + if (it_buffer.defined()) { + BufferStore init = Downcast(block->init); + stmts.push_back(BlockRealize( + /*iter_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{it_buffer_region.value()}, + /*name_hint=*/block->name_hint + "_in_thread_init", + /*body=*/ + BufferStore(/*buffer=*/it_buffer.value(), + /*value=*/init->value, + /*indices=*/{Integer(0)})))); + } + // Stmt 2: do in-thread reduction + { + Optional new_realize = NullOpt; + // If need to generate in-thread reduction, + // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize + // otherwise, directly remove given BlockRealize + if (it_buffer.defined()) { + ObjectPtr new_block = make_object(*block); + new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer); + new_block->reads.push_back(it_buffer_region.value()); + new_block->writes = {it_buffer_region.value()}; + new_block->name_hint = new_block->name_hint + "_in_thread"; + new_block->body = + BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + new_block->init = NullOpt; + ObjectPtr n = make_object(*realize); + n->block = Block(new_block); + new_realize = BlockRealize(n); + } + For loop = GetRef(reduction_loops[0]); + if (Optional stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { + stmts.push_back(stmt.value()); + } + } + // Stmt 3: do cross-thread reduction + { + // Step 3.1. Create the parameters to the intrinsic + Array parameters; + parameters.reserve(reduction_loops.size() + 4); + // 1-st argument: size + parameters.push_back(make_const(DataType::UInt(32), 1)); + // 2-nd argument: source + if (it_buffer.defined()) { + parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)})); + } else { + parameters.push_back(combiner_rhs); + } + // 3-rd argument: predicate + parameters.push_back(const_true()); + // 4-th argument: destination + parameters.push_back(ct_buffer->data); + // next arguments: all the reduction threads + for (const ForNode* reduction_loop : reduction_loops) { + if (reduction_loop->thread_binding.defined()) { + parameters.push_back(reduction_loop->loop_var); + } + } + // Step 3.2. Create the block and the block-realize. + Array iter_vars{nullptr}; + Array bindings{nullptr}; + Array reads{nullptr}; + if (it_buffer.defined()) { + iter_vars = Array{}; + bindings = Array{}; + reads = {it_buffer_region.value()}; + } else { + iter_vars = block->iter_vars; + bindings = realize->iter_values; + reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)}; + } + stmts.push_back(BlockRealize( + /*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/std::move(iter_vars), + /*reads=*/std::move(reads), + /*writes=*/{ct_buffer_region}, + /*name_hint=*/block->name_hint + "_cross_thread", + /*body=*/ + AttrStmt(/*node=*/reducer, + /*attr_key=*/tir::attr::reduce_scope, + /*value=*/make_zero(DataType::Handle()), + /*body=*/ + Evaluate(Call(/*dtype=*/DataType::Handle(), + /*op=*/tir::builtin::tvm_thread_allreduce(), + /*args=*/std::move(parameters))))))); + } + // Stmt 4: write cross-thread reduction result to the original buffer + { + ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); + int n_iter = static_cast(block->iter_vars.size()); + Array iter_vars; + Array bindings; + Map var_map; + iter_vars.reserve(n_iter); + bindings.reserve(n_iter); + for (int i = 0; i < n_iter; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + if (iter_var->iter_type != kCommReduce) { + IterVar new_iter_var{nullptr}; + { + ObjectPtr n = make_object(*iter_var.get()); + ObjectPtr v = make_object(*iter_var->var.get()); + n->var = Var(v); + new_iter_var = IterVar(n); + } + iter_vars.push_back(new_iter_var); + bindings.push_back(binding); + var_map.Set(iter_var->var, new_iter_var->var); + } + } + BufferStore update = Downcast(block->body); + update = Downcast(Substitute(std::move(update), var_map)); + stmts.push_back(BlockRealize( + /*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/ + Block( + /*iter_vars=*/std::move(iter_vars), + /*reads=*/{std::move(ct_buffer_region)}, + /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))}, + /*name_hint=*/block->name_hint + "_write_back", + /*body=*/ + BufferStore(/*buffer=*/wb_buffer, + /*value=*/BufferLoad(ct_buffer, {Integer(0)}), + /*indices=*/update->indices)))); + } + // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx + Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); + for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) { + const ForNode* loop = *rit; + if (loop->thread_binding.defined()) { + ObjectPtr n = make_object(*loop); + n->body = std::move(new_stmt); + new_stmt = For(n); + } + } + return new_stmt; +} + +/*! + * \brief Detect cross-thread reduction pattern and then transform + */ +class CrossThreadReductionTransformer : public StmtMutator { + private: + // Check if the input block needs cross-thread reduction. + std::vector NeedCrossThreadReduction(const BlockRealizeNode* realize) { + // Step 0. If the block is the root block, just return. + if (block_stack_.empty()) { + return {}; + } + + // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. + if (!IsReductionBlock(GetRef(realize), loop_range_map_, + GetRef(block_stack_.back()), &analyzer_)) { + return {}; + } + + // Step 2. Collect all the vars that appear in the bindings of reduction block iters. + std::unordered_set reduction_vars; + GetVarsTouchedByBlockIters(GetRef(realize), nullptr, &reduction_vars); + + // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. + // We call these loops "reduction-related". + // Step 4. See whether at least one reduction-related loop is bound to thread axis in GPU - if + // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to + // thread axis, cross-thread reduction is not needed for the input block. + bool need = false; + std::vector reduction_loops; + for (const ForNode* loop : loop_stack_) { + if (reduction_vars.count(loop->loop_var.get())) { + // Step 3. Collect the loop. + reduction_loops.push_back(loop); + // Step 4. See whether the loop is bound to some thread axis. + if (loop->thread_binding.defined()) { + need = true; + } + } + } + return need ? reduction_loops : std::vector{}; + } + + // Given that the input block needs cross-thread reduction, check if cross-thread reduction can + // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread + // reduction). + std::tuple CheckCanApplyCrossThreadReduction( + const BlockNode* block, const std::vector& reduction_loops) const { + // Condition 1. The block being applied cross-thread reduction should write to single buffer. + CHECK_EQ(block->writes.size(), 1) + << "ValueError: Cross-thread reduction requires the block to only " + "write to single buffer. However, the block " + << block->name_hint << " writes to " << block->writes.size() << " buffer(s)."; + + // Condition 2. All the reduction-related loops should be the deepest among all statements + // outside the block (ignoring SeqStmt here). + int n_deepest_reduction_loops = 0; + for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) { + const StmtNode* stmt = *rit; + if ((*rit)->IsInstance()) { + // Skip SeqStmt. + continue; + } + if (std::find(reduction_loops.begin(), reduction_loops.end(), + reinterpret_cast(stmt)) == reduction_loops.end()) { + break; + } + ++n_deepest_reduction_loops; + } + CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size()) + << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " + "deepest among all statements outside the desired block. However, block " + << block->name_hint + << " needs cross-thread reduction, while the reduction-related loops outside of it are not " + "the deepest statements, which violates the condition."; + + // Condition 3. All the reduction-related loops that are bound to thread axes should only be + // bound to `threadIdx.x/y/z`. + int n_bound_reduction_loops = 0; + for (const ForNode* reduction_loop : reduction_loops) { + if (reduction_loop->thread_binding.defined()) { + ++n_bound_reduction_loops; + CHECK(IsBoundToThreadIdx(reduction_loop)) + << "ValueError: Cross-thread reduction requires all the reduction-related loops that " + "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop " + << reduction_loop->loop_var->name_hint << " violates the condition."; + } + } + + // Condition 4. Get the `init` identity and the `update` combiner of the reduction. They should + // both be BufferStores with the same buffer and indices; + // Extract the commutative reducer, combiner lhs and combiner rhs from the reduction identity + // and the reduction combiner. + BufferStore init{nullptr}; + BufferStore update{nullptr}; + CommReducer reducer{nullptr}; + PrimExpr combiner_lhs{nullptr}; + PrimExpr combiner_rhs{nullptr}; + std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); + std::tie(reducer, combiner_lhs, combiner_rhs) = + GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); + + // Condition 5. The block should be the last block under the first reduction-related loop. + bool visit = false; + PreOrderVisit(GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { + if (const auto* realize = obj.as()) { + CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " + "block isn't the last block under its first reduction-related loop"; + if (realize->block.get() == block) { + visit = true; + } + return false; + } + return true; + }); + return std::make_tuple(n_bound_reduction_loops, reducer, combiner_rhs); + } + + Stmt VisitStmt(const Stmt& stmt) final { + statement_stack_.push_back(stmt.get()); + Stmt result = StmtMutator::VisitStmt(stmt); + statement_stack_.pop_back(); + return result; + } + + Stmt VisitStmt_(const ForNode* loop) final { + loop_stack_.push_back(loop); + loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + Stmt result = StmtMutator::VisitStmt_(loop); + loop_stack_.pop_back(); + loop_range_map_.erase(loop->loop_var); + + // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`. + auto it = loop2new_stmt_.find(loop); + if (it != loop2new_stmt_.end()) { + return it->second; + } else { + return result; + } + } + + Stmt VisitStmt_(const BlockNode* block) final { + Map old_loop_range_map; + + block_stack_.push_back(block); + std::swap(old_loop_range_map, loop_range_map_); + Block new_block = Downcast(StmtMutator::VisitStmt_(block)); + block_stack_.pop_back(); + std::swap(old_loop_range_map, loop_range_map_); + + // Insert the new allocated buffers into the block's `alloc_buffers` field. + auto it = block2new_buffers_.find(block); + if (it != block2new_buffers_.end()) { + BlockNode* p_new_block = new_block.CopyOnWrite(); + for (const Buffer& new_buffer : it->second) { + if (new_buffer.defined()) { + p_new_block->alloc_buffers.push_back(new_buffer); + } + } + } + return new_block; + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + const BlockNode* block = realize->block.get(); + // Step 1. Check whether cross-thread reduction is needed. If no, skip this block. + std::vector reduction_loops = NeedCrossThreadReduction(realize); + if (reduction_loops.empty()) { + return StmtMutator::VisitStmt_(realize); + } + ++reduction_id_; + // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on + // which condition the block violates. + int n_bound_reduction_loops = 0; + CommReducer reducer{nullptr}; + PrimExpr combiner_rhs{nullptr}; + std::tie(n_bound_reduction_loops, reducer, combiner_rhs) = + CheckCanApplyCrossThreadReduction(block, reduction_loops); + // Step 3. When not all the reduction-related loops are bound to thread axes, in-thread + // reduction is needed in this cross-thread reduction. + bool need_in_thread_reduction = + n_bound_reduction_loops < static_cast(reduction_loops.size()); + // Step 4. Create intermediate buffers, storing them in `ct_buffer` and + // `it_buffer`. Let the scope block allocate these new buffers. + std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; + DataType dtype = block->writes[0]->buffer->dtype; + Buffer ct_buffer = MakeScratchpad("cross_thread_" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(ct_buffer); + Optional it_buffer = NullOpt; + if (need_in_thread_reduction) { + it_buffer = MakeScratchpad("in_thread_" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(it_buffer.value()); + } + // Step 5. Transform. + loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock( + realize, it_buffer, ct_buffer, reducer, combiner_rhs, reduction_loops); + // Step 6. Return an empty statement, because the transformation result will be inserted when + // returning to the first reduction-related loop. + return Stmt{nullptr}; + } + + private: + int reduction_id_ = -1; + std::vector statement_stack_; + std::vector loop_stack_; + std::vector block_stack_; + std::unordered_map> block2new_buffers_; + std::unordered_map loop2new_stmt_; + Map loop_range_map_; + arith::Analyzer analyzer_; +}; + +PrimFunc LowerCrossThreadReduction(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = CrossThreadReductionTransformer()(f->body); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass LowerCrossThreadReduction() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LowerCrossThreadReduction(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") + .set_body_typed(LowerCrossThreadReduction); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py new file mode 100644 index 000000000000..4fa3ab0c550c --- /dev/null +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -0,0 +1,737 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys + +import pytest +import tvm +from tvm import te +from tvm.script import tir as T + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(ValueError): + tvm.tir.transform.LowerCrossThreadReduction()(mod) + + +@T.prim_func +def loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B_in_thread_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.float32(0) + for ko in T.serial(0, 4): + with T.block("B_normal_reduction"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ki, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.S(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def no_normal_reduction(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B_cross_thread_reduction"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([A[vi, vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def two_bound_loops(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for ko in T.thread_binding(0, 4, thread="threadIdx.x"): + for ki in T.thread_binding(0, 32, thread="threadIdx.y"): + with T.block("B"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ko in T.thread_binding(0, 4, thread="threadIdx.x"): + for ki in T.thread_binding(0, 32, thread="threadIdx.y"): + with T.block("B_cross_thread_reduction"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(128, ko * 32 + ki) + T.reads([A[vi, vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle" + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + B = T.match_buffer(b, [16], dtype="float32") + B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") + for i in T.thread_binding(0, 16, thread="blockIdx.x"): + for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): + for k0i0, k1 in T.grid(4, 16): + with T.block("B_rf"): + vk0 = T.axis.spatial(16, k0o * 4 + k0i0) + vi, vk1 = T.axis.remap("SR", [i, k1]) + T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) + T.writes([B_rf_local[vk0, vi]]) + with T.init(): + B_rf_local[vk0, vi] = T.float32(0) + B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] + for k0i1 in T.serial(0, 4): + with T.block("B"): + vk0 = T.axis.reduce(16, k0o * 4 + k0i1) + vi = T.axis.spatial(16, i) + T.reads([B[vi], B_rf_local[vk0, vi]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + B_rf_local[vk0, vi] + + +@T.prim_func +def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + B = T.match_buffer(b, [16], dtype="float32") + B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.thread_binding(0, 16, thread="blockIdx.x"): + for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): + with T.block("B_in_thread_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.float32(0) + for k0i0, k1 in T.grid(4, 16): + with T.block("B_rf"): + vk0 = T.axis.spatial(16, k0o * 4 + k0i0) + vi, vk1 = T.axis.remap("SR", [i, k1]) + T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) + T.writes([B_rf_local[vk0, vi]]) + with T.init(): + B_rf_local[vk0, vi] = T.float32(0) + B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] + for k0i1 in T.serial(0, 4): + with T.block("B_normal_reduction"): + vk0 = T.axis.reduce(16, k0o * 4 + k0i1) + vi = T.axis.spatial(16, i) + T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + k0o, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(16, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def with_block_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 120], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(120, ko * 32 + ki) + T.where(ko * 32 + ki < 120) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 120], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B_in_thread_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.float32(0) + for ko in T.serial(0, 4): + with T.block("B_normal_reduction"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(120, ko * 32 + ki) + T.where(ko * 32 + ki < 120) + T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ki, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def reducer_max(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.min_value("float32") + B[vi] = T.max(B[vi], A[vi, vk]) + + +@T.prim_func +def lowered_reducer_max(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B_cross_thread_reduction"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([A[vi, vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def zero_rank_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128], dtype="float32") + B = T.match_buffer(b, [], dtype="float32") + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vk = T.axis.reduce(128, k) + T.reads([B[()], A[vk]]) + T.writes([B[()]]) + with T.init(): + B[()] = T.float32(0) + B[()] = B[()] + A[vk] + + +@T.prim_func +def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128], dtype="float32") + B = T.match_buffer(b, [], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B_cross_thread_reduction"): + vk = T.axis.reduce(128, k) + T.reads([A[vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle" + ) + ) + with T.block("B_write_back"): + T.reads([reduce_temp0[0]]) + T.writes([B[()]]) + B[()] = reduce_temp0[0] + + +@T.prim_func +def multiple_bufferstore(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + C = T.alloc_buffer([], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([A[vi, vk], B[vi], C[()]]) + T.writes([B[vi], C[()]]) + with T.init(): + B[vi] = T.float32(0) + C[()] = A[vi, vk] + B[vi] = B[vi] + C[()] + + +@T.prim_func +def reduction_loop_not_deepest(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + for i in T.serial(0, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def reduction_loop_bound_to_blockidx(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="blockIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def different_access_indices(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i, j in T.grid(128, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads([B[vi, vj], A[vi, vj, vk]]) + T.writes( + [ + B[ + T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), + T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), + ] + ] + ) + with T.init(): + B[vj, vi] = T.float32(0) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + +@T.prim_func +def invalid_reducer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] - A[vi, vk] + + +@T.prim_func +def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: + A = T.match_buffer(var_A, [256, 256], dtype="float32") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): + for ax0_0 in T.serial(0, 8): + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) + T.writes([T_softmax_maxelem_shared[i0_1]]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.min_value("float32") + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for ax0_0 in T.serial(0, 8): + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_2 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads( + [ + T_softmax_expsum_shared[i0_2], + A[i0_2, k], + T_softmax_maxelem_shared[i0_2], + ] + ) + T.writes([T_softmax_expsum_shared[i0_2]]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i1_0 in T.serial(0, 8): + for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_3 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) + T.reads( + [ + A[i0_3, i1], + T_softmax_maxelem_shared[i0_3], + T_softmax_expsum_shared[i0_3], + ] + ) + T.writes([T_softmax_norm[i0_3, i1]]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_3, i1] = ( + T.exp( + A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], + dtype="float32", + ) + / T_softmax_expsum_shared[i0_3] + ) + + +@T.prim_func +def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: + A = T.match_buffer(var_A, [256, 256], dtype="float32") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_normal_reduction_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.min_value("float32") + for ax0_0 in T.serial(0, 8): + with T.block("T_softmax_maxelem_normal_reduction"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads([A[i0_1, k], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ax0_1, + dtype="handle", + ) + ) + with T.block("T_softmax_maxelem_write_back"): + i0_2 = T.axis.spatial(256, i0) + T.reads([reduce_temp0[0]]) + T.writes([T_softmax_maxelem_shared[i0_2]]) + T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum_normal_reduction_init"): + T.reads([]) + T.writes([normal_reduce_temp1[0]]) + normal_reduce_temp1[0] = T.float32(0) + for ax0_0 in T.serial(0, 8): + with T.block("T_softmax_expsum_normal_reduction"): + i0_3 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads( + [ + A[i0_3, k], + T_softmax_maxelem_shared[i0_3], + normal_reduce_temp1[0], + ] + ) + T.writes([normal_reduce_temp1[0]]) + normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( + A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" + ) + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.reads([normal_reduce_temp1[0]]) + T.writes([reduce_temp1[0]]) + T.attr( + T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp1[0], + True, + reduce_temp1.data, + ax0_1, + dtype="handle", + ) + ) + with T.block("T_softmax_expsum_write_back"): + i0_4 = T.axis.spatial(256, i0) + T.reads([reduce_temp1[0]]) + T.writes([T_softmax_expsum_shared[i0_4]]) + T_softmax_expsum_shared[i0_4] = reduce_temp1[0] + for i1_0 in T.serial(0, 8): + for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_5 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) + T.reads( + [ + A[i0_5, i1], + T_softmax_maxelem_shared[i0_5], + T_softmax_expsum_shared[i0_5], + ] + ) + T.writes([T_softmax_norm[i0_5, i1]]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_5, i1] = ( + T.exp( + A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], + dtype="float32", + ) + / T_softmax_expsum_shared[i0_5] + ) + + +def test_loop_split(): + _check(loop_split, lowered_loop_split) + + +def test_no_normal_reduction(): + _check(no_normal_reduction, lowered_no_normal_reduction) + + +def test_two_bound_loops(): + _check(two_bound_loops, lowered_two_bound_loops) + + +def test_multiple_blocks_under_reduction_loop(): + _check(multiple_blocks_under_reduction_loop, lowered_multiple_blocks_under_reduction_loop) + + +def test_with_block_predicate(): + _check(with_block_predicate, lowered_with_block_predicate) + + +def test_reducer_max(): + _check(reducer_max, lowered_reducer_max) + + +def test_zero_rank_buffer(): + _check(zero_rank_buffer, lowered_zero_rank_buffer) + + +def test_multiple_bufferstore(): + _check_fail(multiple_bufferstore) + + +def test_reduction_block_not_deepest(): + _check_fail(reduction_loop_not_deepest) + + +def test_reduction_loop_bound_to_blockidx(): + _check_fail(reduction_loop_bound_to_blockidx) + + +def test_different_access_indices(): + _check_fail(different_access_indices) + + +def test_invalid_reducer(): + _check_fail(invalid_reducer) + + +def test_softmax(): + _check(softmax, lowered_softmax) + + +def test_lower_te(): + a = te.placeholder((32, 2, 2)) + k1 = te.reduce_axis((0, 2), "k1") + k2 = te.reduce_axis((0, 2), "k2") + b = te.compute((32,), lambda i: te.sum(a[i, k1, k2], axis=[k1, k2])) + s = te.create_schedule(b.op) + s[b].bind(k1, te.thread_axis("threadIdx.x")) + s[b].bind(k2, te.thread_axis("threadIdx.y")) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [a, b]) + mod = tvm.tir.transform.LowerCrossThreadReduction()(orig_mod) + tvm.ir.assert_structural_equal( + mod, orig_mod + ) # LowerCrossThreadReduction should do nothing on TE + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))