Skip to content

Commit

Permalink
[TensorIR] Cross-Thread Reduction (apache#9360)
Browse files Browse the repository at this point in the history
* [TensorIR] Cross-Thread Reduction

* Code revision on analysis and misc

* Refactor TransformReductionBlock

* Refactor code organization

* Address comment

* Use `std::make_tuple`

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
2 people authored and yangulei committed Jan 11, 2022
1 parent a3aa129 commit 48d77d3
Show file tree
Hide file tree
Showing 8 changed files with 1,662 additions and 177 deletions.
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ TVM_DLL Pass PointerValueTypeRewrite();
*/
TVM_DLL Pass HoistIfThenElse();

/*!
* \brief Lower cross-thread reduction from thread
* bindings to intrinsic function calls.
* \return The pass.
*/
TVM_DLL Pass LowerCrossThreadReduction();

/*!
* \brief Lower block init stmt into IfThenElse stmts
* \return The pass.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,18 @@ def HoistIfThenElse(variant: Optional[str] = None):
return _ffi_api.HoistIfThenElse() # type: ignore


def LowerCrossThreadReduction():
"""Lower cross-thread reduction from thread bindings to
intrinsic function calls.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerCrossThreadReduction() # type: ignore


def LowerInitBlock():
"""Lower block init stmt into IfThenElse statements.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
Expand Down
50 changes: 49 additions & 1 deletion src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
#define TVM_TIR_SCHEDULE_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -323,14 +328,57 @@ struct ProducerConsumerSplit {
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/******** Reduction Block Related ********/

/*!
* \brief Convert the `init` and `body` of the input block to BufferStores
* \param self The schedule state
* \param block The block to be analyzed
* \return The BufferStores of the `init` and `body` of the input block
* \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
* buffer
*/
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block);

/*!
* \brief Check whether the input array of IterVars only contains data-parallel and reduction block
* iters
* \param iters The input array of IterVars to be checked
* \return A boolean indicating whether the input array of IterVars only contains data-parallel and
* reduction block iters
*/
bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters);

/*!
* \brief Check whether the block's reduction block iters are not used to index the block's output
* buffers
* \param block The block to be checked
* \return A boolean indicating whether the block's reduction block iters are not used to index the
* block's output buffer
*/
bool ReductionIterNotIndexOutputBuffer(const Block& block);

/*!
* \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative
* reducer, and extract the combiner lhs and combiner rhs
* \param self The schedule state
* \param identity The reduction identity to be analyzed
* \param combiner The reduction combiner to be analyzed
* \return The corresponding CommReducer, the combiner lhs and the combiner rhs
* \throw ScheduleError If no corresponding commutative reducer can be matched
*/
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner);

/******** Commutative Reducer ********/

/*!
* \brief Get the list of the registered reducer-getter functions
* \return The list of the registered reducer-getter functions
* \sa ReducerRegistry
*/
std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();

/*!
* \brief Given the input identity and the combiner BufferStore of a reduction, extract the
Expand Down
Loading

0 comments on commit 48d77d3

Please sign in to comment.