From a560a9be3d08e5f09130218dae3622ba31be0606 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 4 Oct 2021 23:21:58 +0800 Subject: [PATCH 1/6] [TensorIR] Cross-Thread Reduction --- include/tvm/tir/transform.h | 7 + python/tvm/tir/transform/transform.py | 12 + src/driver/driver_api.cc | 1 + src/tir/schedule/analysis.h | 51 ++ src/tir/schedule/analysis/analysis.cc | 255 ++++++- src/tir/schedule/primitive/reduction.cc | 140 +--- .../lower_cross_thread_reduction.cc | 590 ++++++++++++++ ..._transform_lower_cross_thread_reduction.py | 722 ++++++++++++++++++ 8 files changed, 1602 insertions(+), 176 deletions(-) create mode 100644 src/tir/transforms/lower_cross_thread_reduction.cc create mode 100644 tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e6b0af9773d9..7922e978c381 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -357,6 +357,13 @@ TVM_DLL Pass PointerValueTypeRewrite(); */ TVM_DLL Pass HoistIfThenElse(); +/*! + * \brief Lower cross-thread reduction from thread + * bindings to intrinsic function calls. + * \return The pass. + */ +TVM_DLL Pass LowerCrossThreadReduction(); + /*! * \brief Lower block init stmt into IfThenElse stmts * \return The pass. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 722810e9aa5b..86f798caceba 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -577,6 +577,18 @@ def HoistIfThenElse(variant: Optional[str] = None): return _ffi_api.HoistIfThenElse() # type: ignore +def LowerCrossThreadReduction(): + """Lower cross-thread reduction from thread bindings to + intrinsic function calls. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerCrossThreadReduction() # type: ignore + + def LowerInitBlock(): """Lower block init stmt into IfThenElse statements. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ad1f51ba6d71..f49409c2baee 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -234,6 +234,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 5a2f46c910b4..c437293e49c0 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -21,10 +21,14 @@ #include +#include #include #include +#include #include +#include "../../runtime/thread_storage_scope.h" + namespace tvm { namespace tir { @@ -323,6 +327,53 @@ struct ProducerConsumerSplit { */ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write); +/******** Reduction Block Related ********/ + +/*! + * \brief Convert the `init` and `body` of the input block to BufferStores + * \tparam in_schedule Whether the function is called by schedule primitives + * \param self The schedule state + * \param block The block to be analyzed + * \return The BufferStores of the `init` and `body` of the input block + * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same + * buffer + */ +template +std::pair GetBufferStoresFromReductionBlock(const ScheduleState& self, + const Block& block); + +/*! + * \brief Check whether the input array of IterVars only contains data-parallel and reduction block + * iters + * \param iters The input array of IterVars to be checked + * \return A boolean indicating whether the input array of IterVars only contains data-parallel and + * reduction block iters + */ +bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters); + +/*! + * \brief Check whether the block's reduction block iters are not used to index the block's output + * buffers + * \param block The block to be checked + * \return A boolean indicating whether the block's reduction block iters are not used to index the + * block's output buffer + */ +bool ReductionIterNotIndexOutputBuffer(const Block& block); + +/*! + * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative + * reducer, and extract the combiner lhs and combiner rhs + * \tparam in_schedule Whether the function is called by schedule primitives + * \param self The schedule state + * \param identity The reduction identity to be analyzed + * \param combiner The reduction combiner to be analyzed + * \return The corresponding CommReducer, the combiner lhs and the combiner rhs + * \throw ScheduleError If no corresponding commutative reducer can be matched + */ +template +std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); + /******** Commutative Reducer ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e3a535e9b3d4..672fc0f602f4 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -153,15 +153,15 @@ Definition of a scope that is a stage pipeline: /*! * \brief Check the dominant property of a block: * the block is the only writer of its output, dominating the reader of its output buffers - * \param self The schedule state + * \param scope The block-scope of the block to be checked * \param block_sref The block whose dominant property is to be checked * \return A boolean indicating if the block is a dominant block */ -bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { +bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { // Check whether the input block is the only writer of its outputs const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = - self->buffer_writers; + scope->buffer_writers; for (const BufferRegion& write_region : block->writes) { ICHECK(buffer_writers.count(write_region->buffer)) << "InternalError: buffer \"" << write_region->buffer->name @@ -279,14 +279,8 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc } // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, // we collect all the reduction block vars. - std::unordered_set reduction_block_vars; - reduction_block_vars.reserve(block->iter_vars.size()); - for (const IterVar& iter_var : block->iter_vars) { - if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - return 3; - } else if (iter_var->iter_type == kCommReduce) { - reduction_block_vars.insert(iter_var->var.get()); - } + if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) { + return 3; } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. @@ -294,33 +288,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. - std::unordered_set buffer_written; - buffer_written.reserve(block->writes.size()); - for (const BufferRegion& write_region : block->writes) { - buffer_written.insert(write_region->buffer.get()); - } - bool affected = false; - PreOrderVisit(block->body, [&](const ObjectRef& obj) { - if (affected) { - return false; - } - if (const auto* store = obj.as()) { - ICHECK(buffer_written.count(store->buffer.get())) - << "ValueError: The buffer \"" << store->buffer - << "\" is written in the block but is not in the block's signature"; - for (const PrimExpr& index : store->indices) { - if (UsesVar(index, [&reduction_block_vars](const VarNode* var) { - return reduction_block_vars.count(var); - })) { - affected = true; - return false; - } - } - return false; - } - return true; - }); - return !affected ? 0 : 5; + return ReductionIterNotIndexOutputBuffer(GetRef(block)) ? 0 : 5; } bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -552,6 +520,9 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, } else { has_block_vars_of_other_types = true; } + if (set == nullptr) { + continue; + } Array vars_in_binding = UndefinedVars(iter_value); for (const Var& var : vars_in_binding) { @@ -1128,6 +1099,214 @@ class PatternMatcher : public ExprVisitor { std::unordered_map filled_map_; }; +/******** Reduction Block Related ********/ + +class InitBodyNotBufferStoreError : public ScheduleError { + public: + explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, + bool body_is_bufferstore) + : mod_(std::move(mod)), + block_(std::move(block)), + init_is_bufferstore_(init_is_bufferstore), + body_is_bufferstore_(body_is_bufferstore) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of reduction block are required to be both " + "BufferStore so that rfactor or cross-thread reduction can be applied"; + } + + String DetailRenderTemplate() const final { + if (!init_is_bufferstore_ && !body_is_bufferstore_) { + return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor or " + "cross-thread reduction can be applied"; + } else if (!init_is_bufferstore_) { + return "The `init` of block {0} is required to be BufferStore so that rfactor or cross-thread" + " reduction can be applied"; + } else { + ICHECK(!body_is_bufferstore_); + return "The `body` of block {0} is required to be BufferStore so that rfactor or cross-thread" + " reduction can be applied"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + bool init_is_bufferstore_; + bool body_is_bufferstore_; +}; + +class InitBodyNotSameBufferAccessError : public ScheduleError { + public: + explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of the reduction block are required to have the " + "same buffer access pattern"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + const auto* init = block_->init.as(); + const auto* update = block_->body.as(); + os << "The `init` and `body` of the block {0} is required to have the same buffer access " + "pattern. However, in block {0} the `init` writes to " + << init->buffer->name << init->indices << ", and the `body` writes to " + << update->buffer->name << update->indices; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + +template +std::pair GetBufferStoresFromReductionBlock(const ScheduleState& self, + const Block& block) { + const char* error_str1 = + "ValueError: The `init` and `body` of the reduction block are required to be both " + "BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction " + "block that doesn't meet this requirement is "; + const char* error_str2 = + "ValueError: The `init` and `body` of the reduction block are required to have the same " + "buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a " + "reduction block that doesn't meet this requirement is "; + + const auto* init = block->init.as(); + const auto* body = block->body.as(); + if (!(init && body)) { + if (in_schedule) { + throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); + } else { + LOG(FATAL) << error_str1 << block; + } + } + if (!init->buffer.same_as(body->buffer)) { + if (in_schedule) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } else { + LOG(FATAL) << error_str2 << block; + } + } + int ndim = static_cast(init->buffer->shape.size()); + for (int i = 0; i < ndim; ++i) { + if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { + if (in_schedule) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } else { + LOG(FATAL) << error_str2 << block; + } + } + } + return std::make_pair(GetRef(init), GetRef(body)); +} + +bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters) { + for (const IterVar& iter_var : iters) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + return false; + } + } + return true; +} + +bool ReductionIterNotIndexOutputBuffer(const Block& block) { + // Step 1. Collect the reduction block iters. + std::unordered_set reduction_block_iters; + reduction_block_iters.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type == kCommReduce) { + reduction_block_iters.insert(iter_var->var.get()); + } + } + // Step 2. Check if the reduction block iters are used to index the output buffer. + std::unordered_set buffer_written; + buffer_written.reserve(block->writes.size()); + for (const BufferRegion& write_region : block->writes) { + buffer_written.insert(write_region->buffer.get()); + } + bool affected = false; + PreOrderVisit(block->body, [&](const ObjectRef& obj) { + if (affected) { + return false; + } + if (const auto* store = obj.as()) { + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (UsesVar(index, [&reduction_block_iters](const VarNode* var) { + return reduction_block_iters.count(var); + })) { + affected = true; + return false; + } + } + return false; + } + return true; + }); + return !affected; +} + +class NoMatchedReducerError : public ScheduleError { + public: + explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) + : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} + + String FastErrorString() const final { + return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " + "block. So rfactor and cross-thread reduction cannot be applied."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ + << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + "default reducers or registering new reducers."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + PrimExpr identity_; + BufferStore combiner_; +}; + +template +std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { + CommReducer reducer{nullptr}; + PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; + bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); + if (!matched) { + if (in_schedule) { + throw NoMatchedReducerError(self->mod, identity, combiner); + } else { + LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " + "reduction block. So rfactor and cross-thread reduction cannot be applied."; + } + } + return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); +} + +template std::pair GetBufferStoresFromReductionBlock( + const ScheduleState& self, const Block& block); +template std::pair GetBufferStoresFromReductionBlock( + const ScheduleState& self, const Block& block); +template std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); +template std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); + /******** Commutative Reducer ********/ bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 0f851684de2a..18c4e5da1315 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -370,69 +370,6 @@ class NotSerialLoopKindError : public ScheduleError { For loop_; }; -class InitBodyNotBufferStoreError : public ScheduleError { - public: - explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, - bool body_is_bufferstore) - : mod_(std::move(mod)), - block_(std::move(block)), - init_is_bufferstore_(init_is_bufferstore), - body_is_bufferstore_(body_is_bufferstore) {} - - String FastErrorString() const final { - return "ScheduleError: The `init` and `body` of reduction block are required to be both " - "BufferStore"; - } - - String DetailRenderTemplate() const final { - if (!init_is_bufferstore_ && !body_is_bufferstore_) { - return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor " - "can be applied"; - } else if (!init_is_bufferstore_) { - return "The `init` of block {0} is required to be BufferStore so that rfactor can be applied"; - } else { - ICHECK(!body_is_bufferstore_); - return "The `body` of block {0} is required to be BufferStore so that rfactor can be applied"; - } - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - - IRModule mod_; - Block block_; - bool init_is_bufferstore_; - bool body_is_bufferstore_; -}; - -class InitBodyNotSameBufferAccessError : public ScheduleError { - public: - explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) - : mod_(std::move(mod)), block_(std::move(block)) {} - - String FastErrorString() const final { - return "ScheduleError: The `init` and `body` of the reduction block are required to have the " - "same buffer access pattern"; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - const auto* init = block_->init.as(); - const auto* update = block_->body.as(); - os << "The `init` and `body` of the block {0} is required to have the same buffer access " - "pattern. However, in block {0} the `init` writes to " - << init->buffer->name << init->indices << ", and the `body` writes to " - << update->buffer->name << update->indices; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - - IRModule mod_; - Block block_; -}; - class FactorAxisOutOfRangeError : public ScheduleError { public: explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) @@ -473,32 +410,6 @@ class FactorAxisOutOfRangeError : public ScheduleError { int factor_axis_; }; -class NoMatchedReducerError : public ScheduleError { - public: - explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) - : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} - - String FastErrorString() const final { - return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " - "block. So rfactor cannot be applied."; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ - << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " - "default reducers or registering new reducers."; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } - - IRModule mod_; - PrimExpr identity_; - BufferStore combiner_; -}; - class LoopPropertyError : public ScheduleError { public: enum ErrorType { @@ -591,53 +502,6 @@ class LoopPropertyError : public ScheduleError { ErrorType error_type_; }; -/*! - * \brief Convert the `init` and `body` of the input block to BufferStores - * \param self The schedule state - * \param block The block to be analyzed - * \return The BufferStores of the `init` and `body` of the input block - * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same - * buffer - */ -std::pair GetBufferStoreNodes(const ScheduleState& self, - const Block& block) { - const auto* init = block->init.as(); - const auto* body = block->body.as(); - if (!(init && body)) { - throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); - } - if (!init->buffer.same_as(body->buffer)) { - throw InitBodyNotSameBufferAccessError(self->mod, block); - } - int ndim = static_cast(init->buffer->shape.size()); - for (int i = 0; i < ndim; ++i) { - if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { - throw InitBodyNotSameBufferAccessError(self->mod, block); - } - } - return std::make_pair(GetRef(init), GetRef(body)); -} - -/*! - * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative - * reducer, and extract the combiner lhs and combiner rhs - * \param self The schedule state - * \param identity The reduction identity to be analyzed - * \param combiner The reduction combiner to be analyzed - * \return The corresponding CommReducer, the combiner lhs and the combiner rhs - * \throw ScheduleError If no corresponding commutative reducer can be matched - */ -std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { - CommReducer reducer{nullptr}; - PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; - bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); - if (!matched) { - throw NoMatchedReducerError(self->mod, identity, combiner); - } - return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); -} - /*! * \brief For each loop in the given array of loop, associate its loop var with the loop itself * using a mapping @@ -1177,9 +1041,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax BufferStore update; CommReducer reducer; PrimExpr combiner_lhs, combiner_rhs; - std::tie(init, update) = GetBufferStoreNodes(self, block); + std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); std::tie(reducer, combiner_lhs, combiner_rhs) = - GetReducerAndCombinerLhsRhs(self, init->value, update); + GetReducerAndCombinerLhsRhs(self, init->value, update); // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it // is negative. diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc new file mode 100644 index 000000000000..d15a5f16a783 --- /dev/null +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -0,0 +1,590 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_cross_thread_reduction.cc + */ +#include +#include +#include +#include + +#include "../schedule/analysis.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check the dominant property of a block: + * the block is the only writer of its output, dominating the reader of its output buffers + * \param scope_block The scope block of the block to be checked + * \param block The block whose dominant property is to be checked + * \return A boolean indicating if the block is a dominant block + */ +bool IsDominantBlock(const Block& scope_block, const Block& block) { + // Step 1. Count the number of writers for each buffer written by the scope block. + std::unordered_map buffer_writer_cnt; + PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) { + if (const auto* block = obj.as()) { + for (const BufferRegion& buffer_region : block->writes) { + ++buffer_writer_cnt[buffer_region->buffer.get()]; + } + return false; + } + return true; + }); + // Step 2. Check whether `block` is the only writer of its outputs. + for (const BufferRegion& buffer_region : block->writes) { + ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get())); + if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) { + return false; + } + } + return true; +} + +/*! + * \brief Check whether the input block is a reduction block. + * \param block_realize The block to be checked + * \param loop_range_map The mapping from the loop variables outside the input block to their ranges + * \param scope_block The scope block of the input block + * \param analyzer The analyzer + * \return A boolean indicating whether the input block is a reduction block. + * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is + * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the + * check again. + */ +bool IsReductionBlock(const BlockRealize& block_realize, const Map& loop_range_map, + const Block& scope_block, arith::Analyzer* analyzer) { + const auto* block = block_realize->block.as(); + // Cond 1. The block has the `init` statement. + if (!block->init.defined()) { + return false; + } + // Cond 2. All the block bindings are quasi-affine expressions. + if (!IsAffineBinding(block_realize, loop_range_map, analyzer)) { + return false; + } + // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, + // we collect all the reduction block vars. + if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) { + return false; + } + // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its + // output buffers. + if (!IsDominantBlock(scope_block, GetRef(block))) { + return false; + } + // Cond 5. The reduction block vars are not used to index the output buffers. + return ReductionIterNotIndexOutputBuffer(GetRef(block)); +} + +/*! + * \brief Create an intermediate buffer with specified name and data type + * \param name The specified name + * \param dtype The specified data type + * \return The created buffer + */ +Buffer CreateReductionBuffer(String name, const DataType& dtype) { + Var var(name, PointerType(PrimType(dtype), "local")); + return Buffer(var, dtype, {1}, {1}, PrimExpr(), std::move(name), 0, 0, kDefault); +} + +/*! + * \brief Remove the BufferRegions whose buffer is the input buffer + * \param buffer_regions The array of BufferRegions to be + * \param buffer_to_remove The specified buffer + * \return The mutated array of BufferRegions, no longer containing BufferRegion of the input buffer + */ +Array RemoveBufferFromBufferRegions(const Array& buffer_regions, + const Buffer& buffer_to_remove) { + Array res; + res.reserve(buffer_regions.size()); + for (const BufferRegion& buffer_region : buffer_regions) { + if (!buffer_region->buffer.same_as(buffer_to_remove)) { + res.push_back(buffer_region); + } + } + return res; +} + +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions + */ +class BufferAccessReplacer : public StmtExprMutator { + public: + explicit BufferAccessReplacer(Buffer src_buffer, Buffer tgt_buffer) + : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} + + private: + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) + : GetRef(load); + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + if (store->buffer.same_as(src_buffer_)) { + PrimExpr value = StmtExprMutator::VisitExpr(store->value); + return BufferStore(tgt_buffer_, value, {0}); + } else { + return StmtMutator::VisitStmt_(store); + } + } + + Buffer src_buffer_; + Buffer tgt_buffer_; +}; + +/*! + * \brief Substitute a given source block with a given target block, or remove the source block + * branch from the AST if the target block is undefined + */ +class ReductionBlockReplacer : public StmtMutator { + public: + explicit ReductionBlockReplacer(const BlockRealizeNode* src_block, BlockRealize tgt_block) + : src_block_(src_block), tgt_block_(std::move(tgt_block)) {} + + private: + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + return block_realize == src_block_ ? tgt_block_ : GetRef(block_realize); + } + + Stmt VisitStmt_(const ForNode* loop) final { + For res = Downcast(StmtMutator::VisitStmt_(loop)); + return !res.defined() ? Stmt{nullptr} : (res->thread_binding.defined() ? res->body : res); + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array results; + results.reserve(seq->size()); + for (const Stmt& stmt : seq->seq) { + Stmt res = StmtMutator::VisitStmt(stmt); + if (res.defined()) { + results.push_back(res); + } + } + return results.empty() ? Stmt{nullptr} : SeqStmt(results); + } + + const BlockRealizeNode* src_block_; + BlockRealize tgt_block_; +}; + +/*! + * \brief Detect cross-thread reduction pattern and then transform + */ +class CrossThreadReductionTransformer : public StmtMutator { + private: + // Check if the input block needs cross-thread reduction. + bool NeedCrossThreadReduction(const BlockRealizeNode* block_realize) { + // Step 0. If the block is the root block, just return. + if (block_stack_.empty()) { + return false; + } + + // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. + if (!IsReductionBlock(GetRef(block_realize), loop_range_map_, + GetRef(block_stack_.back()), &analyzer_)) { + return false; + } + + // Step 2. Collect all the vars that appear in the bindings of reduction block iters. + std::unordered_set reduction_vars; + GetVarsTouchedByBlockIters(GetRef(block_realize), nullptr, &reduction_vars); + + // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. + // We call these loops "reduction-related". + // Step 4. See whether at least one reduction-related loop is bound to thread axis in GPU - if + // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to + // thread axis, cross-thread reduction is not needed for the input block. + bool need = false; + reduction_loops_.clear(); + for (const ForNode* loop : loop_stack_) { + if (reduction_vars.count(loop->loop_var.get())) { + // Step 3. Collect the loop. + reduction_loops_.push_back(loop); + // Step 4. See whether the loop is bound to some thread axis. + if (loop->thread_binding.defined()) { + need = true; + } + } + } + + return need; + } + + // Given that the input block needs cross-thread reduction, check if cross-thread reduction can + // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread + // reduction). + void CheckCanApplyCrossThreadReduction(const BlockNode* block) { + const String& block_name = block->name_hint; + + // Condition 1. The block being applied cross-thread reduction should write to single buffer. + int n_write_buffer = static_cast(block->writes.size()); + CHECK_EQ(n_write_buffer, 1) << "ValueError: Cross-thread reduction requires the block to only " + "write to single buffer. However, the block " + << block_name << " writes to " << n_write_buffer << " buffer(s)."; + + // Condition 2. All the reduction-related loops should be the deepest among all statements + // outside the block (ignoring SeqStmt here). + int n_deepest_reduction_loops = 0; + for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) { + if ((*rit)->IsInstance()) { + // Skip SeqStmt. + continue; + } + if (std::find(reduction_loops_.begin(), reduction_loops_.end(), + reinterpret_cast(*rit)) == reduction_loops_.end()) { + break; + } + ++n_deepest_reduction_loops; + } + CHECK_EQ(n_deepest_reduction_loops, reduction_loops_.size()) + << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " + "deepest among all statements outside the desired block. However, block " + << block_name + << " needs cross-thread reduction, while the reduction-related loops outside of it are not " + "the deepest statements, which violates the condition."; + + // Condition 3. All the reduction-related loops that are bound to thread axes should only be + // bound to `threadIdx.x/y/z`. + n_bound_reduction_loops_ = 0; + for (const ForNode* reduction_loop : reduction_loops_) { + if (reduction_loop->thread_binding.defined()) { + ++n_bound_reduction_loops_; + const String& thread_tag = reduction_loop->thread_binding.value()->thread_tag; + CHECK(thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || + thread_tag == "threadIdx.z") + << "ValueError: Cross-thread reduction requires all the reduction-related loops that " + "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop " + << reduction_loop->loop_var->name_hint << " is bound to " << thread_tag + << ", which violates the condition."; + } + } + + // Condition 4. Get the `init` identity and the `update` combiner of the reduction. They should + // both be BufferStores with the same buffer and indices. + BufferStore init; + BufferStore update; + std::tie(init, update) = + GetBufferStoresFromReductionBlock(ScheduleState{nullptr}, GetRef(block)); + + // Condition 5. Extract the commutative reducer, combiner lhs and combiner rhs from the + // reduction identity and the reduction combiner. + PrimExpr combiner_lhs; + std::tie(reducer_, combiner_lhs, combiner_rhs_) = + GetReducerAndCombinerLhsRhs(ScheduleState{nullptr}, init->value, update); + + // Condition 6. The block should be the last block under the first reduction-related loop. + bool visit = false; + PreOrderVisit(GetRef(reduction_loops_[0]), [block, &visit](const ObjectRef& obj) { + if (const auto* realize = obj.as()) { + CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " + "block isn't the last block under its first reduction-related loop"; + if (realize->block.get() == block) { + visit = true; + } + return false; + } + return true; + }); + } + + void TransformReductionBlock(const BlockRealizeNode* block_realize, bool with_normal_reduction) { + const BlockNode* block = block_realize->block.get(); + Buffer result_buffer = block->writes[0]->buffer; + + BufferRegion ct_reduction_buf_region(cross_thread_reduction_buf_, {Range::FromMinExtent(0, 1)}); + BufferRegion normal_reduction_buf_region{nullptr}; + if (with_normal_reduction) { + normal_reduction_buf_region = + BufferRegion(normal_reduction_buf_, {Range::FromMinExtent(0, 1)}); + } + + Array seq_stmt; + seq_stmt.reserve(4); + + if (with_normal_reduction) { + // Step 1. Create the BufferStore which initializes `normal_reduction_buf_`. + seq_stmt.push_back(BufferStore(/*buffer=*/normal_reduction_buf_, + /*value=*/block->init.value().as()->value, + /*indices=*/{0})); + + // Step 2. Create the block and loops which do the normal reduction. + // - 2.1. Create the block. + ObjectPtr p_new_block = make_object(*block); + { + p_new_block->reads = RemoveBufferFromBufferRegions(block->reads, result_buffer); + p_new_block->reads.push_back(normal_reduction_buf_region); + p_new_block->writes = {normal_reduction_buf_region}; + p_new_block->name_hint = block->name_hint + "_normal_reduction"; + p_new_block->body = BufferAccessReplacer(result_buffer, normal_reduction_buf_)(block->body); + p_new_block->init = NullOpt; + } + // - 2.2. Create the block-realize. + ObjectPtr p_new_block_realize = + make_object(*block_realize); + p_new_block_realize->block = Block(p_new_block); + // - 2.3. Replace the original reduction block with the normal reduction block. + Stmt replace_result = ReductionBlockReplacer( + block_realize, BlockRealize(p_new_block_realize))(GetRef(reduction_loops_[0])); + ICHECK(replace_result.defined()); + seq_stmt.push_back(replace_result); + } else { + // Remove the original reduction block. + Stmt replace_result = ReductionBlockReplacer( + block_realize, BlockRealize{nullptr})(GetRef(reduction_loops_[0])); + if (replace_result.defined()) { + seq_stmt.push_back(replace_result); + } + } + + // Step 3. Create the statement which calls the intrinsic and does the cross-thread reduction. + // - 3.1. Create the intrinsic parameters. + PrimExpr reduction_value = + with_normal_reduction ? BufferLoad(normal_reduction_buf_, {0}) : combiner_rhs_; + Array parameters{make_const(DataType::UInt(32), static_cast(1)), + std::move(reduction_value), const_true(), + cross_thread_reduction_buf_->data}; + parameters.reserve(reduction_loops_.size() + 4); + for (const ForNode* reduction_loop : reduction_loops_) { + if (reduction_loop->thread_binding.defined()) { + parameters.push_back(reduction_loop->loop_var); + } + } + // - 3.2. Create the intrinsic call and the block body. + AttrStmt ct_reduction_body(/*node=*/reducer_, + /*attr_key=*/tir::attr::reduce_scope, + /*value=*/make_zero(DataType::Handle()), + /*body=*/ + Evaluate(Call(/*dtype=*/DataType::Handle(), + /*op=*/tir::builtin::tvm_thread_allreduce(), + /*args=*/std::move(parameters)))); + // - 3.3. Create the block and the block-realize. + { + Array iters; + Array bindings; + Array reads{nullptr}; + if (with_normal_reduction) { + reads = {std::move(normal_reduction_buf_region)}; + } else { + iters = block->iter_vars; + bindings = block_realize->iter_values; + reads = {RemoveBufferFromBufferRegions(block->reads, result_buffer)}; + } + Block ct_reduction_block(/*iter_vars=*/std::move(iters), + /*reads=*/std::move(reads), + /*writes=*/{ct_reduction_buf_region}, + /*name_hint=*/block->name_hint + "_cross_thread_reduction", + /*body=*/std::move(ct_reduction_body)); + seq_stmt.push_back(BlockRealize(/*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/std::move(ct_reduction_block))); + } + + // Step 4. Create the block which writes the cross-thread reduction result back to the original + // result buffer. + // - 4.1. Create the block iters and their corresponding iter bindings. + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n_iter = static_cast(block->iter_vars.size()); + Array write_back_block_iters; + Array write_back_block_bindings; + std::unordered_map write_back_block_var_map; + write_back_block_iters.reserve(n_iter); + write_back_block_bindings.reserve(n_iter); + write_back_block_var_map.reserve(n_iter); + for (int i = 0; i < n_iter; ++i) { + IterVar iter = block->iter_vars[i]; + PrimExpr binding = block_realize->iter_values[i]; + if (iter->iter_type == kCommReduce) { + continue; + } + ObjectPtr p_new_iter = make_object(*iter.get()); + p_new_iter->var = Var(make_object(*iter->var.get())); + IterVar new_iter(p_new_iter); + write_back_block_iters.push_back(new_iter); + write_back_block_bindings.push_back(binding); + write_back_block_var_map[iter->var.get()] = std::move(new_iter); + } + // - 4.2. Create the body of the write-back block. + const auto* old_reduction_body = block->body.as(); + BufferStore write_back_body(/*buffer=*/std::move(result_buffer), + /*value=*/BufferLoad(cross_thread_reduction_buf_, {0}), + /*indices=*/old_reduction_body->indices); + // - 4.3. Create the block and block-realize. + Block write_back_block(/*iter_vars=*/std::move(write_back_block_iters), + /*reads=*/{std::move(ct_reduction_buf_region)}, + /*writes=*/block->writes, + /*name_hint=*/block->name_hint + "_write_back", + /*body=*/std::move(write_back_body)); + write_back_block = + Downcast(Substitute(Stmt(write_back_block), write_back_block_var_map)); + seq_stmt.push_back(BlockRealize(/*iter_values=*/std::move(write_back_block_bindings), + /*predicate=*/const_true(), + /*block=*/std::move(write_back_block))); + + // Step 5. Wrap all the above four statements with the reduction loops were bound. + Stmt new_stmt = SeqStmt::Flatten(seq_stmt); + for (auto rit = reduction_loops_.rbegin(); rit != reduction_loops_.rend(); ++rit) { + if ((*rit)->thread_binding.defined()) { + ObjectPtr p_new_loop = make_object(*(*rit)); + p_new_loop->body = std::move(new_stmt); + new_stmt = For(p_new_loop); + } + } + + // Step 6. Replace the first reduction-related loop the new statement. + loop2new_stmt_[reduction_loops_[0]] = std::move(new_stmt); + } + + Stmt VisitStmt(const Stmt& stmt) final { + statement_stack_.push_back(stmt.get()); + Stmt result = StmtMutator::VisitStmt(stmt); + statement_stack_.pop_back(); + return result; + } + + Stmt VisitStmt_(const ForNode* loop) final { + loop_stack_.push_back(loop); + loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + Stmt result = StmtMutator::VisitStmt_(loop); + loop_stack_.pop_back(); + loop_range_map_.erase(loop->loop_var); + + // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`. + auto it = loop2new_stmt_.find(loop); + if (it != loop2new_stmt_.end()) { + return it->second; + } else { + return result; + } + } + + Stmt VisitStmt_(const BlockNode* block) final { + Map old_loop_range_map; + + block_stack_.push_back(block); + std::swap(old_loop_range_map, loop_range_map_); + Block new_block = Downcast(StmtMutator::VisitStmt_(block)); + block_stack_.pop_back(); + std::swap(old_loop_range_map, loop_range_map_); + + // Insert the new allocated buffers into the block's `alloc_buffers` field. + auto it = block2new_buffers_.find(block); + if (it != block2new_buffers_.end()) { + BlockNode* p_new_block = new_block.CopyOnWrite(); + for (const Buffer& new_buffer : it->second) { + if (new_buffer.defined()) { + p_new_block->alloc_buffers.push_back(new_buffer); + } + } + } + return new_block; + } + + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + const BlockNode* block = block_realize->block.get(); + + // Step 1. Check whether cross-thread reduction is needed. If no, skip this block. + if (!NeedCrossThreadReduction(block_realize)) { + return StmtMutator::VisitStmt_(block_realize); + } + ++reduction_id_; + + // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on + // which condition the block violates. + CheckCanApplyCrossThreadReduction(block); + + // Step 3. When not all the reduction-related loops are bound to thread axes, normal reduction + // is needed in this cross-thread reduction. + bool need_normal_reduction = + n_bound_reduction_loops_ < static_cast(reduction_loops_.size()); + + // Step 4. Create intermediate buffers, storing them in `cross_thread_reduction_buf_` and + // `normal_reduction_buf_`. Let the scope block allocate these new buffers. + std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; + DataType dtype = block->writes[0]->buffer->dtype; + cross_thread_reduction_buf_ = + CreateReductionBuffer("reduce_temp" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(cross_thread_reduction_buf_); + if (need_normal_reduction) { + normal_reduction_buf_ = + CreateReductionBuffer("normal_reduce_temp" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(normal_reduction_buf_); + } + + // Step 5. Transform. + TransformReductionBlock(block_realize, need_normal_reduction); + + // Step 6. Return an empty statement, because the transformation result will be inserted when + // returning to the first reduction-related loop. + return Stmt{nullptr}; + } + + private: + int reduction_id_ = -1; + + std::vector statement_stack_; + std::vector loop_stack_; + std::vector block_stack_; + Map loop_range_map_; + + int n_bound_reduction_loops_ = 0; + std::vector reduction_loops_; + + CommReducer reducer_; + PrimExpr combiner_rhs_; + + Buffer cross_thread_reduction_buf_; + Buffer normal_reduction_buf_; + + std::unordered_map> block2new_buffers_; + std::unordered_map loop2new_stmt_; + + arith::Analyzer analyzer_; +}; + +PrimFunc LowerCrossThreadReduction(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = CrossThreadReductionTransformer()(f->body); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass LowerCrossThreadReduction() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LowerCrossThreadReduction(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") + .set_body_typed(LowerCrossThreadReduction); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py new file mode 100644 index 000000000000..353386e2f27a --- /dev/null +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -0,0 +1,722 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys + +import pytest +import tvm +from tvm import te +from tvm.script import tir as T + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(ValueError): + tvm.tir.transform.LowerCrossThreadReduction()(mod) + + +@T.prim_func +def loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_loop_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + normal_reduce_temp0[0] = T.float32(0) + for ko in T.serial(0, 4): + with T.block("B_normal_reduction"): + vi = T.axis.S(128, i) + vk = T.axis.R(128, ko * 32 + ki) + T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ki, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.S(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def no_normal_reduction(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B_cross_thread_reduction"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([A[vi, vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def two_bound_loops(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for ko in T.thread_binding(0, 4, thread="threadIdx.x"): + for ki in T.thread_binding(0, 32, thread="threadIdx.y"): + with T.block("B"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(128, ko * 32 + ki) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ko in T.thread_binding(0, 4, thread="threadIdx.x"): + for ki in T.thread_binding(0, 32, thread="threadIdx.y"): + with T.block("B_cross_thread_reduction"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(128, ko * 32 + ki) + T.reads([A[vi, vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle" + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + B = T.match_buffer(b, [16], dtype="float32") + B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") + for i in T.thread_binding(0, 16, thread="blockIdx.x"): + for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): + for k0i0, k1 in T.grid(4, 16): + with T.block("B_rf"): + vk0 = T.axis.spatial(16, k0o * 4 + k0i0) + vi, vk1 = T.axis.remap("SR", [i, k1]) + T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) + T.writes([B_rf_local[vk0, vi]]) + with T.init(): + B_rf_local[vk0, vi] = T.float32(0) + B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] + for k0i1 in T.serial(0, 4): + with T.block("B"): + vk0 = T.axis.reduce(16, k0o * 4 + k0i1) + vi = T.axis.spatial(16, i) + T.reads([B[vi], B_rf_local[vk0, vi]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + B_rf_local[vk0, vi] + + +@T.prim_func +def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + B = T.match_buffer(b, [16], dtype="float32") + B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.thread_binding(0, 16, thread="blockIdx.x"): + for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): + normal_reduce_temp0[0] = T.float32(0) + for k0i0, k1 in T.grid(4, 16): + with T.block("B_rf"): + vk0 = T.axis.spatial(16, k0o * 4 + k0i0) + vi, vk1 = T.axis.remap("SR", [i, k1]) + T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) + T.writes([B_rf_local[vk0, vi]]) + with T.init(): + B_rf_local[vk0, vi] = T.float32(0) + B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] + for k0i1 in T.serial(0, 4): + with T.block("B_normal_reduction"): + vk0 = T.axis.reduce(16, k0o * 4 + k0i1) + vi = T.axis.spatial(16, i) + T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + k0o, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(16, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def with_block_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 120], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i, ko in T.grid(128, 4): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(120, ko * 32 + ki) + T.where(ko * 32 + ki < 120) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 120], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for ki in T.thread_binding(0, 32, thread="threadIdx.x"): + normal_reduce_temp0[0] = T.float32(0) + for ko in T.serial(0, 4): + with T.block("B_normal_reduction"): + vi = T.axis.spatial(128, i) + vk = T.axis.reduce(120, ko * 32 + ki) + T.where(ko * 32 + ki < 120) + T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] + with T.block("B_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ki, + dtype="handle", + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def reducer_max(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.min_value("float32") + B[vi] = T.max(B[vi], A[vi, vk]) + + +@T.prim_func +def lowered_reducer_max(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B_cross_thread_reduction"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([A[vi, vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle" + ) + ) + with T.block("B_write_back"): + vi = T.axis.spatial(128, i) + T.reads([reduce_temp0[0]]) + T.writes([B[vi]]) + B[vi] = reduce_temp0[0] + + +@T.prim_func +def zero_rank_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128], dtype="float32") + B = T.match_buffer(b, [], dtype="float32") + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vk = T.axis.reduce(128, k) + T.reads([B[()], A[vk]]) + T.writes([B[()]]) + with T.init(): + B[()] = T.float32(0) + B[()] = B[()] + A[vk] + + +@T.prim_func +def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128], dtype="float32") + B = T.match_buffer(b, [], dtype="float32") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B_cross_thread_reduction"): + vk = T.axis.reduce(128, k) + T.reads([A[vk]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle" + ) + ) + with T.block("B_write_back"): + T.reads([reduce_temp0[0]]) + T.writes([B[()]]) + B[()] = reduce_temp0[0] + + +@T.prim_func +def multiple_bufferstore(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + C = T.alloc_buffer([], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([A[vi, vk], B[vi], C[()]]) + T.writes([B[vi], C[()]]) + with T.init(): + B[vi] = T.float32(0) + C[()] = A[vi, vk] + B[vi] = B[vi] + C[()] + + +@T.prim_func +def reduction_loop_not_deepest(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + for i in T.serial(0, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def reduction_loop_bound_to_blockidx(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="blockIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def different_access_indices(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i, j in T.grid(128, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads([B[vi, vj], A[vi, vj, vk]]) + T.writes( + [ + B[ + T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), + T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), + ] + ] + ) + with T.init(): + B[vj, vi] = T.float32(0) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + +@T.prim_func +def invalid_reducer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128], dtype="float32") + for i in T.serial(0, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads([B[vi], A[vi, vk]]) + T.writes([B[vi]]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] - A[vi, vk] + + +@T.prim_func +def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: + A = T.match_buffer(var_A, [256, 256], dtype="float32") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): + for ax0_0 in T.serial(0, 8): + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) + T.writes([T_softmax_maxelem_shared[i0_1]]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.min_value("float32") + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for ax0_0 in T.serial(0, 8): + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_2 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads( + [ + T_softmax_expsum_shared[i0_2], + A[i0_2, k], + T_softmax_maxelem_shared[i0_2], + ] + ) + T.writes([T_softmax_expsum_shared[i0_2]]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i1_0 in T.serial(0, 8): + for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_3 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) + T.reads( + [ + A[i0_3, i1], + T_softmax_maxelem_shared[i0_3], + T_softmax_expsum_shared[i0_3], + ] + ) + T.writes([T_softmax_norm[i0_3, i1]]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_3, i1] = ( + T.exp( + A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], + dtype="float32", + ) + / T_softmax_expsum_shared[i0_3] + ) + + +@T.prim_func +def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: + A = T.match_buffer(var_A, [256, 256], dtype="float32") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + normal_reduce_temp0[0] = T.min_value("float32") + for ax0_0 in T.serial(0, 8): + with T.block("T_softmax_maxelem_normal_reduction"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads([A[i0_1, k], normal_reduce_temp0[0]]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.reads([normal_reduce_temp0[0]]) + T.writes([reduce_temp0[0]]) + T.attr( + T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp0[0], + True, + reduce_temp0.data, + ax0_1, + dtype="handle", + ) + ) + with T.block("T_softmax_maxelem_write_back"): + i0_2 = T.axis.spatial(256, i0) + T.reads([reduce_temp0[0]]) + T.writes([T_softmax_maxelem_shared[i0_2]]) + T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + normal_reduce_temp1[0] = T.float32(0) + for ax0_0 in T.serial(0, 8): + with T.block("T_softmax_expsum_normal_reduction"): + i0_3 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) + T.reads( + [ + A[i0_3, k], + T_softmax_maxelem_shared[i0_3], + normal_reduce_temp1[0], + ] + ) + T.writes([normal_reduce_temp1[0]]) + normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( + A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" + ) + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.reads([normal_reduce_temp1[0]]) + T.writes([reduce_temp1[0]]) + T.attr( + T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + normal_reduce_temp1[0], + True, + reduce_temp1.data, + ax0_1, + dtype="handle", + ) + ) + with T.block("T_softmax_expsum_write_back"): + i0_4 = T.axis.spatial(256, i0) + T.reads([reduce_temp1[0]]) + T.writes([T_softmax_expsum_shared[i0_4]]) + T_softmax_expsum_shared[i0_4] = reduce_temp1[0] + for i1_0 in T.serial(0, 8): + for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_5 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) + T.reads( + [ + A[i0_5, i1], + T_softmax_maxelem_shared[i0_5], + T_softmax_expsum_shared[i0_5], + ] + ) + T.writes([T_softmax_norm[i0_5, i1]]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_5, i1] = ( + T.exp( + A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], + dtype="float32", + ) + / T_softmax_expsum_shared[i0_5] + ) + + +def test_loop_split(): + _check(loop_split, lowered_loop_split) + + +def test_no_normal_reduction(): + _check(no_normal_reduction, lowered_no_normal_reduction) + + +def test_two_bound_loops(): + _check(two_bound_loops, lowered_two_bound_loops) + + +def test_multiple_blocks_under_reduction_loop(): + _check(multiple_blocks_under_reduction_loop, lowered_multiple_blocks_under_reduction_loop) + + +def test_with_block_predicate(): + _check(with_block_predicate, lowered_with_block_predicate) + + +def test_reducer_max(): + _check(reducer_max, lowered_reducer_max) + + +def test_zero_rank_buffer(): + _check(zero_rank_buffer, lowered_zero_rank_buffer) + + +def test_multiple_bufferstore(): + _check_fail(multiple_bufferstore) + + +def test_reduction_block_not_deepest(): + _check_fail(reduction_loop_not_deepest) + + +def test_reduction_loop_bound_to_blockidx(): + _check_fail(reduction_loop_bound_to_blockidx) + + +def test_different_access_indices(): + _check_fail(different_access_indices) + + +def test_invalid_reducer(): + _check_fail(invalid_reducer) + + +def test_softmax(): + _check(softmax, lowered_softmax) + + +def test_lower_te(): + a = te.placeholder((32, 2, 2)) + k1 = te.reduce_axis((0, 2), "k1") + k2 = te.reduce_axis((0, 2), "k2") + b = te.compute((32,), lambda i: te.sum(a[i, k1, k2], axis=[k1, k2])) + s = te.create_schedule(b.op) + s[b].bind(k1, te.thread_axis("threadIdx.x")) + s[b].bind(k2, te.thread_axis("threadIdx.y")) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [a, b]) + mod = tvm.tir.transform.LowerCrossThreadReduction()(orig_mod) + tvm.ir.assert_structural_equal( + mod, orig_mod + ) # LowerCrossThreadReduction should do nothing on TE + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 6a162d72eccd0697f1a31f4a0b854940dd9ef702 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 12 Nov 2021 12:29:06 -0800 Subject: [PATCH 2/6] Code revision on analysis and misc --- src/tir/schedule/analysis.h | 13 ++-- src/tir/schedule/analysis/analysis.cc | 68 ++++++++----------- src/tir/schedule/primitive/reduction.cc | 4 +- .../lower_cross_thread_reduction.cc | 5 +- 4 files changed, 39 insertions(+), 51 deletions(-) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index c437293e49c0..42e0e00995fe 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ +#include #include #include @@ -331,16 +332,14 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, /*! * \brief Convert the `init` and `body` of the input block to BufferStores - * \tparam in_schedule Whether the function is called by schedule primitives * \param self The schedule state * \param block The block to be analyzed * \return The BufferStores of the `init` and `body` of the input block * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same * buffer */ -template -std::pair GetBufferStoresFromReductionBlock(const ScheduleState& self, - const Block& block); +std::pair GetBufferStoresFromReductionBlock( + const Optional& self, const Block& block); /*! * \brief Check whether the input array of IterVars only contains data-parallel and reduction block @@ -363,16 +362,14 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block); /*! * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative * reducer, and extract the combiner lhs and combiner rhs - * \tparam in_schedule Whether the function is called by schedule primitives * \param self The schedule state * \param identity The reduction identity to be analyzed * \param combiner The reduction combiner to be analyzed * \return The corresponding CommReducer, the combiner lhs and the combiner rhs * \throw ScheduleError If no corresponding commutative reducer can be matched */ -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); + const Optional& self, const PrimExpr& identity, const BufferStore& combiner); /******** Commutative Reducer ********/ @@ -381,7 +378,7 @@ std::tuple GetReducerAndCombinerLhsRhs( * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector> GetReducerGetters(); +std::vector> GetReducerGetters(); /*! * \brief Given the input identity and the combiner BufferStore of a reduction, extract the diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 672fc0f602f4..7e16bc92e4ce 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -523,7 +523,6 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, if (set == nullptr) { continue; } - Array vars_in_binding = UndefinedVars(iter_value); for (const Var& var : vars_in_binding) { set->insert(var.get()); @@ -1166,14 +1165,13 @@ class InitBodyNotSameBufferAccessError : public ScheduleError { Block block_; }; -template -std::pair GetBufferStoresFromReductionBlock(const ScheduleState& self, - const Block& block) { - const char* error_str1 = +std::pair GetBufferStoresFromReductionBlock( + const Optional& self, const Block& block) { + static constexpr const char* error_str1 = "ValueError: The `init` and `body` of the reduction block are required to be both " "BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction " "block that doesn't meet this requirement is "; - const char* error_str2 = + static constexpr const char* error_str2 = "ValueError: The `init` and `body` of the reduction block are required to have the same " "buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a " "reduction block that doesn't meet this requirement is "; @@ -1181,15 +1179,15 @@ std::pair GetBufferStoresFromReductionBlock(const Sche const auto* init = block->init.as(); const auto* body = block->body.as(); if (!(init && body)) { - if (in_schedule) { - throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); + if (self.defined()) { + throw InitBodyNotBufferStoreError(self.value()->mod, block, init != nullptr, body != nullptr); } else { LOG(FATAL) << error_str1 << block; } } if (!init->buffer.same_as(body->buffer)) { - if (in_schedule) { - throw InitBodyNotSameBufferAccessError(self->mod, block); + if (self.defined()) { + throw InitBodyNotSameBufferAccessError(self.value()->mod, block); } else { LOG(FATAL) << error_str2 << block; } @@ -1197,8 +1195,8 @@ std::pair GetBufferStoresFromReductionBlock(const Sche int ndim = static_cast(init->buffer->shape.size()); for (int i = 0; i < ndim; ++i) { if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { - if (in_schedule) { - throw InitBodyNotSameBufferAccessError(self->mod, block); + if (self.defined()) { + throw InitBodyNotSameBufferAccessError(self.value()->mod, block); } else { LOG(FATAL) << error_str2 << block; } @@ -1231,26 +1229,30 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) { for (const BufferRegion& write_region : block->writes) { buffer_written.insert(write_region->buffer.get()); } + auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&](const VarNode* var) { // + return reduction_block_iters.count(var); + }); + }; bool affected = false; PreOrderVisit(block->body, [&](const ObjectRef& obj) { if (affected) { return false; } - if (const auto* store = obj.as()) { - ICHECK(buffer_written.count(store->buffer.get())) - << "ValueError: The buffer \"" << store->buffer - << "\" is written in the block but is not in the block's signature"; - for (const PrimExpr& index : store->indices) { - if (UsesVar(index, [&reduction_block_iters](const VarNode* var) { - return reduction_block_iters.count(var); - })) { - affected = true; - return false; - } + const auto* store = obj.as(); + if (!store) { + return true; + } + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (f_uses_reduction_block_var(index)) { + affected = true; + return false; } - return false; } - return true; + return false; }); return !affected; } @@ -1281,15 +1283,14 @@ class NoMatchedReducerError : public ScheduleError { BufferStore combiner_; }; -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { + const Optional& self, const PrimExpr& identity, const BufferStore& combiner) { CommReducer reducer{nullptr}; PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); if (!matched) { - if (in_schedule) { - throw NoMatchedReducerError(self->mod, identity, combiner); + if (self.defined()) { + throw NoMatchedReducerError(self.value()->mod, identity, combiner); } else { LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " "reduction block. So rfactor and cross-thread reduction cannot be applied."; @@ -1298,15 +1299,6 @@ std::tuple GetReducerAndCombinerLhsRhs( return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); } -template std::pair GetBufferStoresFromReductionBlock( - const ScheduleState& self, const Block& block); -template std::pair GetBufferStoresFromReductionBlock( - const ScheduleState& self, const Block& block); -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); -template std::tuple GetReducerAndCombinerLhsRhs( - const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner); - /******** Commutative Reducer ********/ bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 18c4e5da1315..9c330765ef38 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -1041,9 +1041,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax BufferStore update; CommReducer reducer; PrimExpr combiner_lhs, combiner_rhs; - std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); + std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); std::tie(reducer, combiner_lhs, combiner_rhs) = - GetReducerAndCombinerLhsRhs(self, init->value, update); + GetReducerAndCombinerLhsRhs(self, init->value, update); // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it // is negative. diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index d15a5f16a783..b327cee4b6ad 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -283,14 +283,13 @@ class CrossThreadReductionTransformer : public StmtMutator { // both be BufferStores with the same buffer and indices. BufferStore init; BufferStore update; - std::tie(init, update) = - GetBufferStoresFromReductionBlock(ScheduleState{nullptr}, GetRef(block)); + std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); // Condition 5. Extract the commutative reducer, combiner lhs and combiner rhs from the // reduction identity and the reduction combiner. PrimExpr combiner_lhs; std::tie(reducer_, combiner_lhs, combiner_rhs_) = - GetReducerAndCombinerLhsRhs(ScheduleState{nullptr}, init->value, update); + GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); // Condition 6. The block should be the last block under the first reduction-related loop. bool visit = false; From 60a91fa5d00e182c38cd9be3f90de0a898dab8c9 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Nov 2021 18:29:02 -0800 Subject: [PATCH 3/6] Refactor TransformReductionBlock --- .../lower_cross_thread_reduction.cc | 416 ++++++++++-------- ..._transform_lower_cross_thread_reduction.py | 25 +- 2 files changed, 248 insertions(+), 193 deletions(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index b327cee4b6ad..b62be207e6dd 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -26,11 +26,25 @@ #include #include "../schedule/analysis.h" -#include "ir_utils.h" +#include "./ir_utils.h" namespace tvm { namespace tir { +/*! + * \brief Checks if a loop is bound to threadIdx.x/y/z + * \brief loop The loop to be checked + * \return True if the loop is bound to threadIdx.x/y/z + */ +bool IsBoundToThreadIdx(const ForNode* loop) { + if (!loop->thread_binding.defined()) { + return false; + } + runtime::ThreadScope scope = + runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); + return scope.rank == 1 && scope.dim_index >= 0; +} + /*! * \brief Check the dominant property of a block: * the block is the only writer of its output, dominating the reader of its output buffers @@ -156,35 +170,52 @@ class BufferAccessReplacer : public StmtExprMutator { * \brief Substitute a given source block with a given target block, or remove the source block * branch from the AST if the target block is undefined */ -class ReductionBlockReplacer : public StmtMutator { +class InThreadReducerMaker : private StmtMutator { public: - explicit ReductionBlockReplacer(const BlockRealizeNode* src_block, BlockRealize tgt_block) - : src_block_(src_block), tgt_block_(std::move(tgt_block)) {} + static Optional Make(const BlockRealizeNode* src_realize, + Optional tgt_realize, Stmt stmt) { + return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt)); + } private: - Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { - return block_realize == src_block_ ? tgt_block_ : GetRef(block_realize); + explicit InThreadReducerMaker(const BlockRealizeNode* src_realize, + Optional tgt_realize) + : src_realize_(src_realize), tgt_realize_(tgt_realize) {} + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize == src_realize_) { + return tgt_realize_.defined() // + ? tgt_realize_.value() + : Stmt{nullptr}; + } + return GetRef(realize); } Stmt VisitStmt_(const ForNode* loop) final { - For res = Downcast(StmtMutator::VisitStmt_(loop)); - return !res.defined() ? Stmt{nullptr} : (res->thread_binding.defined() ? res->body : res); + if (Optional opt_res = Downcast>(StmtMutator::VisitStmt_(loop))) { + For res = opt_res.value(); + if (res->thread_binding.defined()) { + return res->body; + } else { + return res; + } + } else { + return Stmt{nullptr}; + } } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array results; - results.reserve(seq->size()); + Array stmts; + stmts.reserve(seq->size()); for (const Stmt& stmt : seq->seq) { - Stmt res = StmtMutator::VisitStmt(stmt); - if (res.defined()) { - results.push_back(res); + if (Optional opt_res = VisitStmt(stmt)) { + stmts.push_back(opt_res.value()); } } - return results.empty() ? Stmt{nullptr} : SeqStmt(results); + return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts); } - const BlockRealizeNode* src_block_; - BlockRealize tgt_block_; + const BlockRealizeNode* src_realize_; + Optional tgt_realize_; }; /*! @@ -234,32 +265,30 @@ class CrossThreadReductionTransformer : public StmtMutator { // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread // reduction). void CheckCanApplyCrossThreadReduction(const BlockNode* block) { - const String& block_name = block->name_hint; - // Condition 1. The block being applied cross-thread reduction should write to single buffer. - int n_write_buffer = static_cast(block->writes.size()); - CHECK_EQ(n_write_buffer, 1) << "ValueError: Cross-thread reduction requires the block to only " - "write to single buffer. However, the block " - << block_name << " writes to " << n_write_buffer << " buffer(s)."; + CHECK_EQ(block->writes.size(), 1) + << "ValueError: Cross-thread reduction requires the block to only " + "write to single buffer. However, the block " + << block->name_hint << " writes to " << block->writes.size() << " buffer(s)."; // Condition 2. All the reduction-related loops should be the deepest among all statements // outside the block (ignoring SeqStmt here). int n_deepest_reduction_loops = 0; for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) { - if ((*rit)->IsInstance()) { - // Skip SeqStmt. - continue; - } - if (std::find(reduction_loops_.begin(), reduction_loops_.end(), - reinterpret_cast(*rit)) == reduction_loops_.end()) { - break; + const StmtNode* stmt = *rit; + if (stmt->IsInstance()) { + const ForNode* loop = static_cast(stmt); + if (std::find(reduction_loops_.begin(), reduction_loops_.end(), loop) == + reduction_loops_.end()) { + break; + } + ++n_deepest_reduction_loops; } - ++n_deepest_reduction_loops; } CHECK_EQ(n_deepest_reduction_loops, reduction_loops_.size()) << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " "deepest among all statements outside the desired block. However, block " - << block_name + << block->name_hint << " needs cross-thread reduction, while the reduction-related loops outside of it are not " "the deepest statements, which violates the condition."; @@ -269,29 +298,25 @@ class CrossThreadReductionTransformer : public StmtMutator { for (const ForNode* reduction_loop : reduction_loops_) { if (reduction_loop->thread_binding.defined()) { ++n_bound_reduction_loops_; - const String& thread_tag = reduction_loop->thread_binding.value()->thread_tag; - CHECK(thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || - thread_tag == "threadIdx.z") + CHECK(IsBoundToThreadIdx(reduction_loop)) << "ValueError: Cross-thread reduction requires all the reduction-related loops that " "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop " - << reduction_loop->loop_var->name_hint << " is bound to " << thread_tag - << ", which violates the condition."; + << reduction_loop->loop_var->name_hint << " violates the condition."; } } // Condition 4. Get the `init` identity and the `update` combiner of the reduction. They should - // both be BufferStores with the same buffer and indices. - BufferStore init; - BufferStore update; + // both be BufferStores with the same buffer and indices; + // Extract the commutative reducer, combiner lhs and combiner rhs from the reduction identity + // and the reduction combiner. + BufferStore init{nullptr}; + BufferStore update{nullptr}; + PrimExpr combiner_lhs{nullptr}; std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); - - // Condition 5. Extract the commutative reducer, combiner lhs and combiner rhs from the - // reduction identity and the reduction combiner. - PrimExpr combiner_lhs; - std::tie(reducer_, combiner_lhs, combiner_rhs_) = + std::tie(this->reducer_, combiner_lhs, this->combiner_rhs_) = GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); - // Condition 6. The block should be the last block under the first reduction-related loop. + // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; PreOrderVisit(GetRef(reduction_loops_[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { @@ -306,150 +331,166 @@ class CrossThreadReductionTransformer : public StmtMutator { }); } - void TransformReductionBlock(const BlockRealizeNode* block_realize, bool with_normal_reduction) { - const BlockNode* block = block_realize->block.get(); - Buffer result_buffer = block->writes[0]->buffer; + void TransformReductionBlock(const BlockRealizeNode* realize, bool with_in_thread_reduction) { + const BlockNode* block = realize->block.get(); + Buffer wb_buffer = block->writes[0]->buffer; + Array wb_region = block->writes[0]->region; - BufferRegion ct_reduction_buf_region(cross_thread_reduction_buf_, {Range::FromMinExtent(0, 1)}); - BufferRegion normal_reduction_buf_region{nullptr}; - if (with_normal_reduction) { - normal_reduction_buf_region = - BufferRegion(normal_reduction_buf_, {Range::FromMinExtent(0, 1)}); + BufferRegion ct_buffer_region(ct_buffer_, {Range::FromMinExtent(0, 1)}); + Optional it_buffer_region = NullOpt; + if (with_in_thread_reduction) { + it_buffer_region = BufferRegion(it_buffer_, {Range::FromMinExtent(0, 1)}); } - - Array seq_stmt; - seq_stmt.reserve(4); - - if (with_normal_reduction) { - // Step 1. Create the BufferStore which initializes `normal_reduction_buf_`. - seq_stmt.push_back(BufferStore(/*buffer=*/normal_reduction_buf_, - /*value=*/block->init.value().as()->value, - /*indices=*/{0})); - - // Step 2. Create the block and loops which do the normal reduction. - // - 2.1. Create the block. - ObjectPtr p_new_block = make_object(*block); - { - p_new_block->reads = RemoveBufferFromBufferRegions(block->reads, result_buffer); - p_new_block->reads.push_back(normal_reduction_buf_region); - p_new_block->writes = {normal_reduction_buf_region}; - p_new_block->name_hint = block->name_hint + "_normal_reduction"; - p_new_block->body = BufferAccessReplacer(result_buffer, normal_reduction_buf_)(block->body); - p_new_block->init = NullOpt; - } - // - 2.2. Create the block-realize. - ObjectPtr p_new_block_realize = - make_object(*block_realize); - p_new_block_realize->block = Block(p_new_block); - // - 2.3. Replace the original reduction block with the normal reduction block. - Stmt replace_result = ReductionBlockReplacer( - block_realize, BlockRealize(p_new_block_realize))(GetRef(reduction_loops_[0])); - ICHECK(replace_result.defined()); - seq_stmt.push_back(replace_result); - } else { - // Remove the original reduction block. - Stmt replace_result = ReductionBlockReplacer( - block_realize, BlockRealize{nullptr})(GetRef(reduction_loops_[0])); - if (replace_result.defined()) { - seq_stmt.push_back(replace_result); - } + // In total, the block is transformed into at most 4 statements + // - Stmt 1: initialize the buffer for in-thread reduction + // - Stmt 2: do in-thread reduction + // - Stmt 3: do cross-thread reduction + // - Stmt 4: write cross-thread reduction result to the original buffer + Array stmts; + stmts.reserve(4); + // Stmt 1: initialize the buffer for in-thread reduction + if (with_in_thread_reduction) { + BufferStore init = Downcast(block->init); + stmts.push_back(BlockRealize( + /*iter_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{it_buffer_region.value()}, + /*name_hint=*/block->name_hint + "_in_thread_init", + /*body=*/ + BufferStore(/*buffer=*/it_buffer_, + /*value=*/init->value, + /*indices=*/{Integer(0)})))); } - - // Step 3. Create the statement which calls the intrinsic and does the cross-thread reduction. - // - 3.1. Create the intrinsic parameters. - PrimExpr reduction_value = - with_normal_reduction ? BufferLoad(normal_reduction_buf_, {0}) : combiner_rhs_; - Array parameters{make_const(DataType::UInt(32), static_cast(1)), - std::move(reduction_value), const_true(), - cross_thread_reduction_buf_->data}; - parameters.reserve(reduction_loops_.size() + 4); - for (const ForNode* reduction_loop : reduction_loops_) { - if (reduction_loop->thread_binding.defined()) { - parameters.push_back(reduction_loop->loop_var); + // Stmt 2: do in-thread reduction + { + Optional new_realize = NullOpt; + // If need to generate in-thread reduction, + // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize + // otherwise, directly remove given BlockRealize + if (with_in_thread_reduction) { + ObjectPtr new_block = make_object(*block); + new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer); + new_block->reads.push_back(it_buffer_region.value()); + new_block->writes = {it_buffer_region.value()}; + new_block->name_hint = new_block->name_hint + "_in_thread"; + new_block->body = BufferAccessReplacer(wb_buffer, it_buffer_)(std::move(new_block->body)); + new_block->init = NullOpt; + ObjectPtr n = make_object(*realize); + n->block = Block(new_block); + new_realize = BlockRealize(n); + } + For loop = GetRef(reduction_loops_[0]); + if (Optional stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { + stmts.push_back(stmt.value()); } } - // - 3.2. Create the intrinsic call and the block body. - AttrStmt ct_reduction_body(/*node=*/reducer_, - /*attr_key=*/tir::attr::reduce_scope, - /*value=*/make_zero(DataType::Handle()), - /*body=*/ - Evaluate(Call(/*dtype=*/DataType::Handle(), - /*op=*/tir::builtin::tvm_thread_allreduce(), - /*args=*/std::move(parameters)))); - // - 3.3. Create the block and the block-realize. + // Stmt 3: do cross-thread reduction { - Array iters; - Array bindings; + // Step 3.1. Create the parameters to the intrinsic + Array parameters; + parameters.reserve(reduction_loops_.size() + 4); + // 1-st argument: size + parameters.push_back(make_const(DataType::UInt(32), 1)); + // 2-nd argument: source + if (with_in_thread_reduction) { + parameters.push_back(BufferLoad(it_buffer_, {Integer(0)})); + } else { + parameters.push_back(combiner_rhs_); + } + // 3-rd argument: predicate + parameters.push_back(const_true()); + // 4-th argument: destination + parameters.push_back(ct_buffer_->data); + // next arguments: all the reduction threads + for (const ForNode* reduction_loop : reduction_loops_) { + if (reduction_loop->thread_binding.defined()) { + parameters.push_back(reduction_loop->loop_var); + } + } + // Step 3.2. Create the block and the block-realize. + Array iter_vars{nullptr}; + Array bindings{nullptr}; Array reads{nullptr}; - if (with_normal_reduction) { - reads = {std::move(normal_reduction_buf_region)}; + if (with_in_thread_reduction) { + iter_vars = Array{}; + bindings = Array{}; + reads = {it_buffer_region.value()}; } else { - iters = block->iter_vars; - bindings = block_realize->iter_values; - reads = {RemoveBufferFromBufferRegions(block->reads, result_buffer)}; + iter_vars = block->iter_vars; + bindings = realize->iter_values; + reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)}; } - Block ct_reduction_block(/*iter_vars=*/std::move(iters), - /*reads=*/std::move(reads), - /*writes=*/{ct_reduction_buf_region}, - /*name_hint=*/block->name_hint + "_cross_thread_reduction", - /*body=*/std::move(ct_reduction_body)); - seq_stmt.push_back(BlockRealize(/*iter_values=*/std::move(bindings), - /*predicate=*/const_true(), - /*block=*/std::move(ct_reduction_block))); + stmts.push_back(BlockRealize( + /*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/std::move(iter_vars), + /*reads=*/std::move(reads), + /*writes=*/{ct_buffer_region}, + /*name_hint=*/block->name_hint + "_cross_thread", + /*body=*/ + AttrStmt(/*node=*/reducer_, + /*attr_key=*/tir::attr::reduce_scope, + /*value=*/make_zero(DataType::Handle()), + /*body=*/ + Evaluate(Call(/*dtype=*/DataType::Handle(), + /*op=*/tir::builtin::tvm_thread_allreduce(), + /*args=*/std::move(parameters))))))); } - - // Step 4. Create the block which writes the cross-thread reduction result back to the original - // result buffer. - // - 4.1. Create the block iters and their corresponding iter bindings. - ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); - int n_iter = static_cast(block->iter_vars.size()); - Array write_back_block_iters; - Array write_back_block_bindings; - std::unordered_map write_back_block_var_map; - write_back_block_iters.reserve(n_iter); - write_back_block_bindings.reserve(n_iter); - write_back_block_var_map.reserve(n_iter); - for (int i = 0; i < n_iter; ++i) { - IterVar iter = block->iter_vars[i]; - PrimExpr binding = block_realize->iter_values[i]; - if (iter->iter_type == kCommReduce) { - continue; + // Stmt 4: write cross-thread reduction result to the original buffer + { + ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); + int n_iter = static_cast(block->iter_vars.size()); + Array iter_vars; + Array bindings; + Map var_map; + iter_vars.reserve(n_iter); + bindings.reserve(n_iter); + for (int i = 0; i < n_iter; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + if (iter_var->iter_type != kCommReduce) { + IterVar new_iter_var{nullptr}; + { + ObjectPtr n = make_object(*iter_var.get()); + ObjectPtr v = make_object(*iter_var->var.get()); + n->var = Var(v); + new_iter_var = IterVar(n); + } + iter_vars.push_back(new_iter_var); + bindings.push_back(binding); + var_map.Set(iter_var->var, new_iter_var->var); + } } - ObjectPtr p_new_iter = make_object(*iter.get()); - p_new_iter->var = Var(make_object(*iter->var.get())); - IterVar new_iter(p_new_iter); - write_back_block_iters.push_back(new_iter); - write_back_block_bindings.push_back(binding); - write_back_block_var_map[iter->var.get()] = std::move(new_iter); + BufferStore update = Downcast(block->body); + update = Downcast(Substitute(std::move(update), var_map)); + stmts.push_back(BlockRealize( + /*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/ + Block( + /*iter_vars=*/std::move(iter_vars), + /*reads=*/{std::move(ct_buffer_region)}, + /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))}, + /*name_hint=*/block->name_hint + "_write_back", + /*body=*/ + BufferStore(/*buffer=*/wb_buffer, + /*value=*/BufferLoad(ct_buffer_, {0}), + /*indices=*/update->indices)))); } - // - 4.2. Create the body of the write-back block. - const auto* old_reduction_body = block->body.as(); - BufferStore write_back_body(/*buffer=*/std::move(result_buffer), - /*value=*/BufferLoad(cross_thread_reduction_buf_, {0}), - /*indices=*/old_reduction_body->indices); - // - 4.3. Create the block and block-realize. - Block write_back_block(/*iter_vars=*/std::move(write_back_block_iters), - /*reads=*/{std::move(ct_reduction_buf_region)}, - /*writes=*/block->writes, - /*name_hint=*/block->name_hint + "_write_back", - /*body=*/std::move(write_back_body)); - write_back_block = - Downcast(Substitute(Stmt(write_back_block), write_back_block_var_map)); - seq_stmt.push_back(BlockRealize(/*iter_values=*/std::move(write_back_block_bindings), - /*predicate=*/const_true(), - /*block=*/std::move(write_back_block))); - - // Step 5. Wrap all the above four statements with the reduction loops were bound. - Stmt new_stmt = SeqStmt::Flatten(seq_stmt); + // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx + Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); for (auto rit = reduction_loops_.rbegin(); rit != reduction_loops_.rend(); ++rit) { - if ((*rit)->thread_binding.defined()) { - ObjectPtr p_new_loop = make_object(*(*rit)); - p_new_loop->body = std::move(new_stmt); - new_stmt = For(p_new_loop); + const ForNode* loop = *rit; + if (loop->thread_binding.defined()) { + ObjectPtr n = make_object(*loop); + n->body = std::move(new_stmt); + new_stmt = For(n); } } - - // Step 6. Replace the first reduction-related loop the new statement. loop2new_stmt_[reduction_loops_[0]] = std::move(new_stmt); } @@ -511,26 +552,25 @@ class CrossThreadReductionTransformer : public StmtMutator { // which condition the block violates. CheckCanApplyCrossThreadReduction(block); - // Step 3. When not all the reduction-related loops are bound to thread axes, normal reduction - // is needed in this cross-thread reduction. - bool need_normal_reduction = + // Step 3. When not all the reduction-related loops are bound to thread axes, in-thread + // reduction is needed in this cross-thread reduction. + bool need_in_thread_reduction = n_bound_reduction_loops_ < static_cast(reduction_loops_.size()); - // Step 4. Create intermediate buffers, storing them in `cross_thread_reduction_buf_` and - // `normal_reduction_buf_`. Let the scope block allocate these new buffers. + // Step 4. Create intermediate buffers, storing them in `ct_buffer_` and + // `it_buffer_`. Let the scope block allocate these new buffers. std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; DataType dtype = block->writes[0]->buffer->dtype; - cross_thread_reduction_buf_ = - CreateReductionBuffer("reduce_temp" + std::to_string(reduction_id_), dtype); - new_buffers.push_back(cross_thread_reduction_buf_); - if (need_normal_reduction) { - normal_reduction_buf_ = + ct_buffer_ = CreateReductionBuffer("reduce_temp" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(ct_buffer_); + if (need_in_thread_reduction) { + it_buffer_ = CreateReductionBuffer("normal_reduce_temp" + std::to_string(reduction_id_), dtype); - new_buffers.push_back(normal_reduction_buf_); + new_buffers.push_back(it_buffer_); } // Step 5. Transform. - TransformReductionBlock(block_realize, need_normal_reduction); + TransformReductionBlock(block_realize, need_in_thread_reduction); // Step 6. Return an empty statement, because the transformation result will be inserted when // returning to the first reduction-related loop. @@ -551,8 +591,8 @@ class CrossThreadReductionTransformer : public StmtMutator { CommReducer reducer_; PrimExpr combiner_rhs_; - Buffer cross_thread_reduction_buf_; - Buffer normal_reduction_buf_; + Buffer ct_buffer_; + Buffer it_buffer_; std::unordered_map> block2new_buffers_; std::unordered_map loop2new_stmt_; diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 353386e2f27a..4fa3ab0c550c 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -58,7 +58,10 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - normal_reduce_temp0[0] = T.float32(0) + with T.block("B_in_thread_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.S(128, i) @@ -217,7 +220,10 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): - normal_reduce_temp0[0] = T.float32(0) + with T.block("B_in_thread_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) @@ -284,7 +290,10 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): - normal_reduce_temp0[0] = T.float32(0) + with T.block("B_in_thread_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.spatial(128, i) @@ -557,7 +566,10 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - normal_reduce_temp0[0] = T.min_value("float32") + with T.block("T_softmax_maxelem_normal_reduction_init"): + T.reads([]) + T.writes([normal_reduce_temp0[0]]) + normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) @@ -589,7 +601,10 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): - normal_reduce_temp1[0] = T.float32(0) + with T.block("T_softmax_expsum_normal_reduction_init"): + T.reads([]) + T.writes([normal_reduce_temp1[0]]) + normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): with T.block("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) From 339b1d0df39deeeeb67d312d97e0f44ce286de93 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 14 Nov 2021 00:13:24 -0800 Subject: [PATCH 4/6] Refactor code organization --- .../lower_cross_thread_reduction.cc | 463 +++++++++--------- 1 file changed, 239 insertions(+), 224 deletions(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index b62be207e6dd..fcb5a3f41a46 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -76,7 +76,7 @@ bool IsDominantBlock(const Block& scope_block, const Block& block) { /*! * \brief Check whether the input block is a reduction block. - * \param block_realize The block to be checked + * \param realize The block to be checked * \param loop_range_map The mapping from the loop variables outside the input block to their ranges * \param scope_block The scope block of the input block * \param analyzer The analyzer @@ -85,15 +85,15 @@ bool IsDominantBlock(const Block& scope_block, const Block& block) { * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the * check again. */ -bool IsReductionBlock(const BlockRealize& block_realize, const Map& loop_range_map, +bool IsReductionBlock(const BlockRealize& realize, const Map& loop_range_map, const Block& scope_block, arith::Analyzer* analyzer) { - const auto* block = block_realize->block.as(); + const auto* block = realize->block.as(); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { return false; } // Cond 2. All the block bindings are quasi-affine expressions. - if (!IsAffineBinding(block_realize, loop_range_map, analyzer)) { + if (!IsAffineBinding(realize, loop_range_map, analyzer)) { return false; } // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, @@ -116,9 +116,16 @@ bool IsReductionBlock(const BlockRealize& block_realize, const Map& * \param dtype The specified data type * \return The created buffer */ -Buffer CreateReductionBuffer(String name, const DataType& dtype) { - Var var(name, PointerType(PrimType(dtype), "local")); - return Buffer(var, dtype, {1}, {1}, PrimExpr(), std::move(name), 0, 0, kDefault); +Buffer MakeScratchpad(String name, const DataType& dtype) { + return Buffer(/*ptr=*/Var(name, PointerType(PrimType(dtype), "local")), + /*dtype=*/dtype, + /*shape=*/{Integer(1)}, + /*strides=*/{Integer(1)}, + /*elem_offset=*/PrimExpr{nullptr}, + /*name=*/std::move(name), + /*data_alignment=*/0, + /*offset_factor=*/0, + /*buffer_type=*/kDefault); } /*! @@ -142,12 +149,16 @@ Array RemoveBufferFromBufferRegions(const Array& buf /*! * \brief Substitute a given source buffer with a given target buffer in statements or expressions */ -class BufferAccessReplacer : public StmtExprMutator { +class BufferReplacer : private StmtExprMutator { public: - explicit BufferAccessReplacer(Buffer src_buffer, Buffer tgt_buffer) - : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} + static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { + return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + } private: + explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) + : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} + PrimExpr VisitExpr_(const BufferLoadNode* load) final { return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) : GetRef(load); @@ -218,27 +229,203 @@ class InThreadReducerMaker : private StmtMutator { Optional tgt_realize_; }; +/*! + * \brief Create the lowered allreduce block transformed from the input reduction block + * \param reduction_block The input reduction block + * \param it_buffer The buffer to store in-thread reduction results + * \param ct_buffer The buffer to store cross-thread reduction results + * \param reducer The reduction function + * \param combiner_rhs The RHS of the combiner + * \param reduction_loops The reduction loops + */ +Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional& it_buffer, + const Buffer& ct_buffer, const CommReducer& reducer, + const PrimExpr& combiner_rhs, + const std::vector& reduction_loops) { + const BlockNode* block = realize->block.get(); + Buffer wb_buffer = block->writes[0]->buffer; + Array wb_region = block->writes[0]->region; + + BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)}); + Optional it_buffer_region = NullOpt; + if (it_buffer.defined()) { + it_buffer_region = BufferRegion(it_buffer.value(), {Range::FromMinExtent(0, 1)}); + } + // In total, the block is transformed into at most 4 statements + // - Stmt 1: initialize the buffer for in-thread reduction + // - Stmt 2: do in-thread reduction + // - Stmt 3: do cross-thread reduction + // - Stmt 4: write cross-thread reduction result to the original buffer + Array stmts; + stmts.reserve(4); + // Stmt 1: initialize the buffer for in-thread reduction + if (it_buffer.defined()) { + BufferStore init = Downcast(block->init); + stmts.push_back(BlockRealize( + /*iter_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{it_buffer_region.value()}, + /*name_hint=*/block->name_hint + "_in_thread_init", + /*body=*/ + BufferStore(/*buffer=*/it_buffer.value(), + /*value=*/init->value, + /*indices=*/{Integer(0)})))); + } + // Stmt 2: do in-thread reduction + { + Optional new_realize = NullOpt; + // If need to generate in-thread reduction, + // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize + // otherwise, directly remove given BlockRealize + if (it_buffer.defined()) { + ObjectPtr new_block = make_object(*block); + new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer); + new_block->reads.push_back(it_buffer_region.value()); + new_block->writes = {it_buffer_region.value()}; + new_block->name_hint = new_block->name_hint + "_in_thread"; + new_block->body = + BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + new_block->init = NullOpt; + ObjectPtr n = make_object(*realize); + n->block = Block(new_block); + new_realize = BlockRealize(n); + } + For loop = GetRef(reduction_loops[0]); + if (Optional stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { + stmts.push_back(stmt.value()); + } + } + // Stmt 3: do cross-thread reduction + { + // Step 3.1. Create the parameters to the intrinsic + Array parameters; + parameters.reserve(reduction_loops.size() + 4); + // 1-st argument: size + parameters.push_back(make_const(DataType::UInt(32), 1)); + // 2-nd argument: source + if (it_buffer.defined()) { + parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)})); + } else { + parameters.push_back(combiner_rhs); + } + // 3-rd argument: predicate + parameters.push_back(const_true()); + // 4-th argument: destination + parameters.push_back(ct_buffer->data); + // next arguments: all the reduction threads + for (const ForNode* reduction_loop : reduction_loops) { + if (reduction_loop->thread_binding.defined()) { + parameters.push_back(reduction_loop->loop_var); + } + } + // Step 3.2. Create the block and the block-realize. + Array iter_vars{nullptr}; + Array bindings{nullptr}; + Array reads{nullptr}; + if (it_buffer.defined()) { + iter_vars = Array{}; + bindings = Array{}; + reads = {it_buffer_region.value()}; + } else { + iter_vars = block->iter_vars; + bindings = realize->iter_values; + reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)}; + } + stmts.push_back(BlockRealize( + /*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/std::move(iter_vars), + /*reads=*/std::move(reads), + /*writes=*/{ct_buffer_region}, + /*name_hint=*/block->name_hint + "_cross_thread", + /*body=*/ + AttrStmt(/*node=*/reducer, + /*attr_key=*/tir::attr::reduce_scope, + /*value=*/make_zero(DataType::Handle()), + /*body=*/ + Evaluate(Call(/*dtype=*/DataType::Handle(), + /*op=*/tir::builtin::tvm_thread_allreduce(), + /*args=*/std::move(parameters))))))); + } + // Stmt 4: write cross-thread reduction result to the original buffer + { + ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); + int n_iter = static_cast(block->iter_vars.size()); + Array iter_vars; + Array bindings; + Map var_map; + iter_vars.reserve(n_iter); + bindings.reserve(n_iter); + for (int i = 0; i < n_iter; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + if (iter_var->iter_type != kCommReduce) { + IterVar new_iter_var{nullptr}; + { + ObjectPtr n = make_object(*iter_var.get()); + ObjectPtr v = make_object(*iter_var->var.get()); + n->var = Var(v); + new_iter_var = IterVar(n); + } + iter_vars.push_back(new_iter_var); + bindings.push_back(binding); + var_map.Set(iter_var->var, new_iter_var->var); + } + } + BufferStore update = Downcast(block->body); + update = Downcast(Substitute(std::move(update), var_map)); + stmts.push_back(BlockRealize( + /*iter_values=*/std::move(bindings), + /*predicate=*/const_true(), + /*block=*/ + Block( + /*iter_vars=*/std::move(iter_vars), + /*reads=*/{std::move(ct_buffer_region)}, + /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))}, + /*name_hint=*/block->name_hint + "_write_back", + /*body=*/ + BufferStore(/*buffer=*/wb_buffer, + /*value=*/BufferLoad(ct_buffer, {Integer(0)}), + /*indices=*/update->indices)))); + } + // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx + Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); + for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) { + const ForNode* loop = *rit; + if (loop->thread_binding.defined()) { + ObjectPtr n = make_object(*loop); + n->body = std::move(new_stmt); + new_stmt = For(n); + } + } + return new_stmt; +} + /*! * \brief Detect cross-thread reduction pattern and then transform */ class CrossThreadReductionTransformer : public StmtMutator { private: // Check if the input block needs cross-thread reduction. - bool NeedCrossThreadReduction(const BlockRealizeNode* block_realize) { + std::vector NeedCrossThreadReduction(const BlockRealizeNode* realize) { // Step 0. If the block is the root block, just return. if (block_stack_.empty()) { - return false; + return {}; } // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. - if (!IsReductionBlock(GetRef(block_realize), loop_range_map_, + if (!IsReductionBlock(GetRef(realize), loop_range_map_, GetRef(block_stack_.back()), &analyzer_)) { - return false; + return {}; } // Step 2. Collect all the vars that appear in the bindings of reduction block iters. std::unordered_set reduction_vars; - GetVarsTouchedByBlockIters(GetRef(block_realize), nullptr, &reduction_vars); + GetVarsTouchedByBlockIters(GetRef(realize), nullptr, &reduction_vars); // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. // We call these loops "reduction-related". @@ -246,25 +433,25 @@ class CrossThreadReductionTransformer : public StmtMutator { // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to // thread axis, cross-thread reduction is not needed for the input block. bool need = false; - reduction_loops_.clear(); + std::vector reduction_loops; for (const ForNode* loop : loop_stack_) { if (reduction_vars.count(loop->loop_var.get())) { // Step 3. Collect the loop. - reduction_loops_.push_back(loop); + reduction_loops.push_back(loop); // Step 4. See whether the loop is bound to some thread axis. if (loop->thread_binding.defined()) { need = true; } } } - - return need; + return need ? reduction_loops : std::vector{}; } // Given that the input block needs cross-thread reduction, check if cross-thread reduction can // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread // reduction). - void CheckCanApplyCrossThreadReduction(const BlockNode* block) { + std::tuple CheckCanApplyCrossThreadReduction( + const BlockNode* block, const std::vector& reduction_loops) const { // Condition 1. The block being applied cross-thread reduction should write to single buffer. CHECK_EQ(block->writes.size(), 1) << "ValueError: Cross-thread reduction requires the block to only " @@ -278,14 +465,14 @@ class CrossThreadReductionTransformer : public StmtMutator { const StmtNode* stmt = *rit; if (stmt->IsInstance()) { const ForNode* loop = static_cast(stmt); - if (std::find(reduction_loops_.begin(), reduction_loops_.end(), loop) == - reduction_loops_.end()) { + if (std::find(reduction_loops.begin(), reduction_loops.end(), loop) == + reduction_loops.end()) { break; } ++n_deepest_reduction_loops; } } - CHECK_EQ(n_deepest_reduction_loops, reduction_loops_.size()) + CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size()) << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " "deepest among all statements outside the desired block. However, block " << block->name_hint @@ -294,10 +481,10 @@ class CrossThreadReductionTransformer : public StmtMutator { // Condition 3. All the reduction-related loops that are bound to thread axes should only be // bound to `threadIdx.x/y/z`. - n_bound_reduction_loops_ = 0; - for (const ForNode* reduction_loop : reduction_loops_) { + int n_bound_reduction_loops = 0; + for (const ForNode* reduction_loop : reduction_loops) { if (reduction_loop->thread_binding.defined()) { - ++n_bound_reduction_loops_; + ++n_bound_reduction_loops; CHECK(IsBoundToThreadIdx(reduction_loop)) << "ValueError: Cross-thread reduction requires all the reduction-related loops that " "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop " @@ -311,14 +498,16 @@ class CrossThreadReductionTransformer : public StmtMutator { // and the reduction combiner. BufferStore init{nullptr}; BufferStore update{nullptr}; + CommReducer reducer{nullptr}; PrimExpr combiner_lhs{nullptr}; + PrimExpr combiner_rhs{nullptr}; std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); - std::tie(this->reducer_, combiner_lhs, this->combiner_rhs_) = + std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; - PreOrderVisit(GetRef(reduction_loops_[0]), [block, &visit](const ObjectRef& obj) { + PreOrderVisit(GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " "block isn't the last block under its first reduction-related loop"; @@ -329,169 +518,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } return true; }); - } - - void TransformReductionBlock(const BlockRealizeNode* realize, bool with_in_thread_reduction) { - const BlockNode* block = realize->block.get(); - Buffer wb_buffer = block->writes[0]->buffer; - Array wb_region = block->writes[0]->region; - - BufferRegion ct_buffer_region(ct_buffer_, {Range::FromMinExtent(0, 1)}); - Optional it_buffer_region = NullOpt; - if (with_in_thread_reduction) { - it_buffer_region = BufferRegion(it_buffer_, {Range::FromMinExtent(0, 1)}); - } - // In total, the block is transformed into at most 4 statements - // - Stmt 1: initialize the buffer for in-thread reduction - // - Stmt 2: do in-thread reduction - // - Stmt 3: do cross-thread reduction - // - Stmt 4: write cross-thread reduction result to the original buffer - Array stmts; - stmts.reserve(4); - // Stmt 1: initialize the buffer for in-thread reduction - if (with_in_thread_reduction) { - BufferStore init = Downcast(block->init); - stmts.push_back(BlockRealize( - /*iter_values=*/{}, - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/{}, - /*reads=*/{}, - /*writes=*/{it_buffer_region.value()}, - /*name_hint=*/block->name_hint + "_in_thread_init", - /*body=*/ - BufferStore(/*buffer=*/it_buffer_, - /*value=*/init->value, - /*indices=*/{Integer(0)})))); - } - // Stmt 2: do in-thread reduction - { - Optional new_realize = NullOpt; - // If need to generate in-thread reduction, - // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize - // otherwise, directly remove given BlockRealize - if (with_in_thread_reduction) { - ObjectPtr new_block = make_object(*block); - new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer); - new_block->reads.push_back(it_buffer_region.value()); - new_block->writes = {it_buffer_region.value()}; - new_block->name_hint = new_block->name_hint + "_in_thread"; - new_block->body = BufferAccessReplacer(wb_buffer, it_buffer_)(std::move(new_block->body)); - new_block->init = NullOpt; - ObjectPtr n = make_object(*realize); - n->block = Block(new_block); - new_realize = BlockRealize(n); - } - For loop = GetRef(reduction_loops_[0]); - if (Optional stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { - stmts.push_back(stmt.value()); - } - } - // Stmt 3: do cross-thread reduction - { - // Step 3.1. Create the parameters to the intrinsic - Array parameters; - parameters.reserve(reduction_loops_.size() + 4); - // 1-st argument: size - parameters.push_back(make_const(DataType::UInt(32), 1)); - // 2-nd argument: source - if (with_in_thread_reduction) { - parameters.push_back(BufferLoad(it_buffer_, {Integer(0)})); - } else { - parameters.push_back(combiner_rhs_); - } - // 3-rd argument: predicate - parameters.push_back(const_true()); - // 4-th argument: destination - parameters.push_back(ct_buffer_->data); - // next arguments: all the reduction threads - for (const ForNode* reduction_loop : reduction_loops_) { - if (reduction_loop->thread_binding.defined()) { - parameters.push_back(reduction_loop->loop_var); - } - } - // Step 3.2. Create the block and the block-realize. - Array iter_vars{nullptr}; - Array bindings{nullptr}; - Array reads{nullptr}; - if (with_in_thread_reduction) { - iter_vars = Array{}; - bindings = Array{}; - reads = {it_buffer_region.value()}; - } else { - iter_vars = block->iter_vars; - bindings = realize->iter_values; - reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)}; - } - stmts.push_back(BlockRealize( - /*iter_values=*/std::move(bindings), - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/std::move(iter_vars), - /*reads=*/std::move(reads), - /*writes=*/{ct_buffer_region}, - /*name_hint=*/block->name_hint + "_cross_thread", - /*body=*/ - AttrStmt(/*node=*/reducer_, - /*attr_key=*/tir::attr::reduce_scope, - /*value=*/make_zero(DataType::Handle()), - /*body=*/ - Evaluate(Call(/*dtype=*/DataType::Handle(), - /*op=*/tir::builtin::tvm_thread_allreduce(), - /*args=*/std::move(parameters))))))); - } - // Stmt 4: write cross-thread reduction result to the original buffer - { - ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); - int n_iter = static_cast(block->iter_vars.size()); - Array iter_vars; - Array bindings; - Map var_map; - iter_vars.reserve(n_iter); - bindings.reserve(n_iter); - for (int i = 0; i < n_iter; ++i) { - const IterVar& iter_var = block->iter_vars[i]; - const PrimExpr& binding = realize->iter_values[i]; - if (iter_var->iter_type != kCommReduce) { - IterVar new_iter_var{nullptr}; - { - ObjectPtr n = make_object(*iter_var.get()); - ObjectPtr v = make_object(*iter_var->var.get()); - n->var = Var(v); - new_iter_var = IterVar(n); - } - iter_vars.push_back(new_iter_var); - bindings.push_back(binding); - var_map.Set(iter_var->var, new_iter_var->var); - } - } - BufferStore update = Downcast(block->body); - update = Downcast(Substitute(std::move(update), var_map)); - stmts.push_back(BlockRealize( - /*iter_values=*/std::move(bindings), - /*predicate=*/const_true(), - /*block=*/ - Block( - /*iter_vars=*/std::move(iter_vars), - /*reads=*/{std::move(ct_buffer_region)}, - /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))}, - /*name_hint=*/block->name_hint + "_write_back", - /*body=*/ - BufferStore(/*buffer=*/wb_buffer, - /*value=*/BufferLoad(ct_buffer_, {0}), - /*indices=*/update->indices)))); - } - // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx - Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); - for (auto rit = reduction_loops_.rbegin(); rit != reduction_loops_.rend(); ++rit) { - const ForNode* loop = *rit; - if (loop->thread_binding.defined()) { - ObjectPtr n = make_object(*loop); - n->body = std::move(new_stmt); - new_stmt = For(n); - } - } - loop2new_stmt_[reduction_loops_[0]] = std::move(new_stmt); + return {n_bound_reduction_loops, reducer, combiner_rhs}; } Stmt VisitStmt(const Stmt& stmt) final { @@ -539,39 +566,39 @@ class CrossThreadReductionTransformer : public StmtMutator { return new_block; } - Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { - const BlockNode* block = block_realize->block.get(); - + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + const BlockNode* block = realize->block.get(); // Step 1. Check whether cross-thread reduction is needed. If no, skip this block. - if (!NeedCrossThreadReduction(block_realize)) { - return StmtMutator::VisitStmt_(block_realize); + std::vector reduction_loops = NeedCrossThreadReduction(realize); + if (reduction_loops.empty()) { + return StmtMutator::VisitStmt_(realize); } ++reduction_id_; - // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on // which condition the block violates. - CheckCanApplyCrossThreadReduction(block); - + int n_bound_reduction_loops = 0; + CommReducer reducer{nullptr}; + PrimExpr combiner_rhs{nullptr}; + std::tie(n_bound_reduction_loops, reducer, combiner_rhs) = + CheckCanApplyCrossThreadReduction(block, reduction_loops); // Step 3. When not all the reduction-related loops are bound to thread axes, in-thread // reduction is needed in this cross-thread reduction. bool need_in_thread_reduction = - n_bound_reduction_loops_ < static_cast(reduction_loops_.size()); - - // Step 4. Create intermediate buffers, storing them in `ct_buffer_` and - // `it_buffer_`. Let the scope block allocate these new buffers. + n_bound_reduction_loops < static_cast(reduction_loops.size()); + // Step 4. Create intermediate buffers, storing them in `ct_buffer` and + // `it_buffer`. Let the scope block allocate these new buffers. std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; DataType dtype = block->writes[0]->buffer->dtype; - ct_buffer_ = CreateReductionBuffer("reduce_temp" + std::to_string(reduction_id_), dtype); - new_buffers.push_back(ct_buffer_); + Buffer ct_buffer = MakeScratchpad("cross_thread_" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(ct_buffer); + Optional it_buffer = NullOpt; if (need_in_thread_reduction) { - it_buffer_ = - CreateReductionBuffer("normal_reduce_temp" + std::to_string(reduction_id_), dtype); - new_buffers.push_back(it_buffer_); + it_buffer = MakeScratchpad("in_thread_" + std::to_string(reduction_id_), dtype); + new_buffers.push_back(it_buffer.value()); } - // Step 5. Transform. - TransformReductionBlock(block_realize, need_in_thread_reduction); - + loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock( + realize, it_buffer, ct_buffer, reducer, combiner_rhs, reduction_loops); // Step 6. Return an empty statement, because the transformation result will be inserted when // returning to the first reduction-related loop. return Stmt{nullptr}; @@ -579,24 +606,12 @@ class CrossThreadReductionTransformer : public StmtMutator { private: int reduction_id_ = -1; - std::vector statement_stack_; std::vector loop_stack_; std::vector block_stack_; - Map loop_range_map_; - - int n_bound_reduction_loops_ = 0; - std::vector reduction_loops_; - - CommReducer reducer_; - PrimExpr combiner_rhs_; - - Buffer ct_buffer_; - Buffer it_buffer_; - std::unordered_map> block2new_buffers_; std::unordered_map loop2new_stmt_; - + Map loop_range_map_; arith::Analyzer analyzer_; }; From a69a4162bb70616ab4e20ad58b91b03847efd56f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 14 Nov 2021 17:45:16 +0800 Subject: [PATCH 5/6] Address comment --- .../transforms/lower_cross_thread_reduction.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index fcb5a3f41a46..6ba7c03d8b68 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -463,14 +463,15 @@ class CrossThreadReductionTransformer : public StmtMutator { int n_deepest_reduction_loops = 0; for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) { const StmtNode* stmt = *rit; - if (stmt->IsInstance()) { - const ForNode* loop = static_cast(stmt); - if (std::find(reduction_loops.begin(), reduction_loops.end(), loop) == - reduction_loops.end()) { - break; - } - ++n_deepest_reduction_loops; + if ((*rit)->IsInstance()) { + // Skip SeqStmt. + continue; + } + if (std::find(reduction_loops.begin(), reduction_loops.end(), + reinterpret_cast(stmt)) == reduction_loops.end()) { + break; } + ++n_deepest_reduction_loops; } CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size()) << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " From e07a20ddf6d7c128a1b04ca82b769b8b1b4780ce Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 14 Nov 2021 18:51:33 +0800 Subject: [PATCH 6/6] Use `std::make_tuple` --- src/tir/transforms/lower_cross_thread_reduction.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 6ba7c03d8b68..630c00f8c1f1 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -519,7 +519,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } return true; }); - return {n_bound_reduction_loops, reducer, combiner_rhs}; + return std::make_tuple(n_bound_reduction_loops, reducer, combiner_rhs); } Stmt VisitStmt(const Stmt& stmt) final {