Skip to content

Commit

Permalink
Merge pull request #160 from frasercrmck/create-loop-api
Browse files Browse the repository at this point in the history
[compiler] Move IVs into the CreateLoopOpts struct
  • Loading branch information
frasercrmck authored Oct 16, 2023
2 parents 4cfaa18 + fff995f commit eed1caf
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 85 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Upgrade guidance:
`compiler::utils::LowerToMuxBuiltinsPass`.
* The `compiler::utils::HandleBarriersPass` has been renamed to the
`compiler::utils::WorkItemLoopsPass`.
* The `compiler::utils::createLoop` API has moved its list of `IVs` parameter
into its `compiler::utils::CreateLoopOpts` parameter. It can now also set the
IV names via a second `CreateLoopOpts` field.

## Version 3.0.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ PreservedAnalyses RefSiWGLoopPass::run(Module &M, ModuleAnalysisManager &AM) {
// looping through num groups in the outermost dimension
auto *const exitBlock = compiler::utils::createLoop(
loopPreheaderIR.GetInsertBlock(), nullptr, zero, numGroups[outer_dim],
{}, create_loop_opts,
create_loop_opts,
[&](BasicBlock *blockz, Value *z, ArrayRef<Value *>,
MutableArrayRef<Value *>) -> BasicBlock * {
IRBuilder<> ir(blockz);
Expand All @@ -136,8 +136,7 @@ PreservedAnalyses RefSiWGLoopPass::run(Module &M, ModuleAnalysisManager &AM) {
{i32_0, ir.getInt32(outer_dim)}));
// looping through num groups in the middle dimension
return compiler::utils::createLoop(
blockz, nullptr, zero, numGroups[middle_dim], {},
create_loop_opts,
blockz, nullptr, zero, numGroups[middle_dim], create_loop_opts,
[&](BasicBlock *blocky, Value *y, ArrayRef<Value *>,
MutableArrayRef<Value *>) -> BasicBlock * {
IRBuilder<> ir(blocky);
Expand All @@ -147,7 +146,7 @@ PreservedAnalyses RefSiWGLoopPass::run(Module &M, ModuleAnalysisManager &AM) {

// looping through num groups in the x dimension
return compiler::utils::createLoop(
blocky, nullptr, zero, numGroups[inner_dim], {},
blocky, nullptr, zero, numGroups[inner_dim],
create_loop_opts,
[&](BasicBlock *blockx, Value *x, ArrayRef<Value *>,
MutableArrayRef<Value *>) -> BasicBlock * {
Expand Down
6 changes: 3 additions & 3 deletions modules/compiler/targets/host/source/AddEntryHook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ PreservedAnalyses AddEntryHookPass::run(Module &M, ModuleAnalysisManager &AM) {

// looping through num groups in the outermost dimension
auto exitBlock = compiler::utils::createLoop(
loopIR.GetInsertBlock(), nullptr, zero, numGroups[outer_dim], {}, opts,
loopIR.GetInsertBlock(), nullptr, zero, numGroups[outer_dim], opts,
[&](BasicBlock *blockz, Value *z, ArrayRef<Value *>,
MutableArrayRef<Value *>) -> BasicBlock * {
IRBuilder<> ir(blockz);
Expand All @@ -187,7 +187,7 @@ PreservedAnalyses AddEntryHookPass::run(Module &M, ModuleAnalysisManager &AM) {
{i32_0, ir.getInt32(outer_dim)}));
// looping through num groups in the middle dimension
return compiler::utils::createLoop(
blockz, nullptr, zero, numGroups[middle_dim], {}, opts,
blockz, nullptr, zero, numGroups[middle_dim], opts,
[&](BasicBlock *blocky, Value *y, ArrayRef<Value *>,
MutableArrayRef<Value *>) -> BasicBlock * {
IRBuilder<> ir(blocky);
Expand All @@ -197,7 +197,7 @@ PreservedAnalyses AddEntryHookPass::run(Module &M, ModuleAnalysisManager &AM) {

// looping through num groups in the x dimension
return compiler::utils::createLoop(
blocky, nullptr, sliceStart, clampedSliceEnd, {}, opts,
blocky, nullptr, sliceStart, clampedSliceEnd, opts,
[&](BasicBlock *blockx, Value *x, ArrayRef<Value *>,
MutableArrayRef<Value *>) -> BasicBlock * {
IRBuilder<> ir(blockx);
Expand Down
36 changes: 24 additions & 12 deletions modules/compiler/utils/include/compiler/utils/pass_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ struct CreateLoopOpts {
/// @brief headerName Optional name for the loop header block. Defaults to:
/// "loopIR".
llvm::StringRef headerName = "loopIR";
/// @brief An optional list of incoming IV values.
///
/// Each of these is used as the incoming value to a PHI created by
/// createLoop. These PHIs are provided to the 'body' function of createLoop,
/// which should in turn set the 'next' version of the IV.
std::vector<llvm::Value *> IVs;
/// @brief An optional list of IV names, to be set on the PHIs provided by
/// 'IVs' field/parameter.
///
/// If set, the names are assumed to correlate 1:1 with those IVs. The list
/// may be shorter than the list of IVs, in which case the trailing IVs are
/// not named.
std::vector<std::string> loopIVNames;
};

/// @brief Create a loop around a body, creating an implicit induction variable
Expand All @@ -200,23 +213,22 @@ struct CreateLoopOpts {
/// @param exit Loop exit block. The new loop will jump to this once it exits.
/// @param indexStart The start index
/// @param indexEnd The end index (we compare for <)
/// @param ivs A list of extra induction variables to create.
/// @param opts Set of options configuring the generation of this loop.
/// @param body Body of code to insert into loop. The parameters of this
/// function are as follows: the loop body BasicBlock; the Value corresponding
/// to the IV beginning at `indexStart` and incremented each iteration by
/// `indexInc` while less than `indexEnd`; the list of IVs for this iteration
/// of the loop (may or may not be PHIs, depending on the loop bounds); the
/// list of IVs for the next iteration of the loop (the function is required to
/// fill these in). Both these sets of IVs will be arrays of equal length to
/// the original list of IVs, in the same order. The function returns the loop
/// latch/exiting block: this block will be given the branch that decides
/// between continuing the loop and exiting from it.
/// @param body Body of code to insert into loop.
///
/// The parameters of this function are as follows: the loop body BasicBlock;
/// the Value corresponding to the IV beginning at `indexStart` and incremented
/// each iteration by `indexInc` while less than `indexEnd`; the list of IVs
/// for this iteration of the loop (may or may not be PHIs, depending on the
/// loop bounds); the list of IVs for the next iteration of the loop (the
/// function is required to fill these in). Both these sets of IVs will be
/// arrays of equal length to the original list of IVs, in the same order. The
/// function returns the loop latch/exiting block: this block will be given the
/// branch that decides between continuing the loop and exiting from it.
///
/// @return llvm::BasicBlock* The exit block
llvm::BasicBlock *createLoop(llvm::BasicBlock *entry, llvm::BasicBlock *exit,
llvm::Value *indexStart, llvm::Value *indexEnd,
llvm::ArrayRef<llvm::Value *> ivs,
const CreateLoopOpts &opts, CreateLoopBodyFn body);

/// @brief Get the last argument of a function.
Expand Down
20 changes: 12 additions & 8 deletions modules/compiler/utils/source/mux_builtin_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,12 +591,13 @@ static BasicBlock *copy1D(Module &M, BasicBlock &ParentBB, Value *DstPtr,
assert(DstPtr->getType()->isPointerTy() &&
"Mux DMA builtins are always byte-accessed");

Value *DmaIVs[] = {SrcPtr, DstPtr};
compiler::utils::CreateLoopOpts opts;
opts.IVs = {SrcPtr, DstPtr};
opts.loopIVNames = {"dma.src", "dma.dst"};

// This is a simple loop copy a byte at a time from SrcPtr to DstPtr.
BasicBlock *ExitBB = compiler::utils::createLoop(
&ParentBB, nullptr, ConstantInt::get(getSizeType(M), 0), NumBytes, DmaIVs,
compiler::utils::CreateLoopOpts{},
&ParentBB, nullptr, ConstantInt::get(getSizeType(M), 0), NumBytes, opts,
[&](BasicBlock *BB, Value *X, ArrayRef<Value *> IVsCurr,
MutableArrayRef<Value *> IVsNext) {
IRBuilder<> B(BB);
Expand Down Expand Up @@ -625,12 +626,13 @@ static BasicBlock *copy2D(Module &M, BasicBlock &ParentBB, Value *DstPtr,
assert(DstPtr->getType()->isPointerTy() &&
"Mux DMA builtins are always byte-accessed");

Value *DmaIVs[] = {SrcPtr, DstPtr};
compiler::utils::CreateLoopOpts opts;
opts.IVs = {SrcPtr, DstPtr};
opts.loopIVNames = {"dma.src", "dma.dst"};

// This is a loop over the range of lines, calling a 1D copy on each line
BasicBlock *ExitBB = compiler::utils::createLoop(
&ParentBB, nullptr, ConstantInt::get(getSizeType(M), 0), NumLines, DmaIVs,
compiler::utils::CreateLoopOpts{},
&ParentBB, nullptr, ConstantInt::get(getSizeType(M), 0), NumLines, opts,
[&](BasicBlock *block, Value *, ArrayRef<Value *> IVsCurr,
MutableArrayRef<Value *> IVsNext) {
IRBuilder<> loopIr(block);
Expand Down Expand Up @@ -735,12 +737,14 @@ Function *BIMuxInfoConcept::defineDMA3D(Function &F) {
assert(ArgDstPtr->getType()->isPointerTy() &&
"Mux DMA builtins are always byte-accessed");

Value *DmaIVs[] = {ArgSrcPtr, ArgDstPtr};
compiler::utils::CreateLoopOpts opts;
opts.IVs = {ArgSrcPtr, ArgDstPtr};
opts.loopIVNames = {"dma.src", "dma.dst"};

// Create a loop around 1D DMA memcpy, adding stride, local width each time.
BasicBlock *LoopExitBB = compiler::utils::createLoop(
LoopEntryBB, nullptr, ConstantInt::get(getSizeType(M), 0), ArgNumPlanes,
DmaIVs, compiler::utils::CreateLoopOpts{},
opts,
[&](BasicBlock *BB, Value *, ArrayRef<Value *> IVsCurr,
MutableArrayRef<Value *> IVsNext) {
IRBuilder<> loopIr(BB);
Expand Down
9 changes: 6 additions & 3 deletions modules/compiler/utils/source/pass_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ bool addParamToAllFunctions(llvm::Module &module,

llvm::BasicBlock *createLoop(llvm::BasicBlock *entry, llvm::BasicBlock *exit,
llvm::Value *indexStart, llvm::Value *indexEnd,
llvm::ArrayRef<llvm::Value *> ivs,
const CreateLoopOpts &opts,
CreateLoopBodyFn body) {
// If the index increment is null, we default to 1 as our index.
Expand All @@ -483,8 +482,8 @@ llvm::BasicBlock *createLoop(llvm::BasicBlock *entry, llvm::BasicBlock *exit,

llvm::LLVMContext &ctx = entry->getContext();

llvm::SmallVector<llvm::Value *, 4> currIVs(ivs.begin(), ivs.end());
llvm::SmallVector<llvm::Value *, 4> nextIVs(ivs.size());
llvm::SmallVector<llvm::Value *, 4> currIVs(opts.IVs.begin(), opts.IVs.end());
llvm::SmallVector<llvm::Value *, 4> nextIVs(opts.IVs.size());

// Check if indexStart, indexEnd, and indexInc are constants.
if (llvm::isa<llvm::ConstantInt>(indexStart) &&
Expand Down Expand Up @@ -564,6 +563,10 @@ llvm::BasicBlock *createLoop(llvm::BasicBlock *entry, llvm::BasicBlock *exit,
auto *const phi = loopIR.CreatePHI(currIVs[i]->getType(), 2);
llvm::cast<llvm::PHINode>(phi)->addIncoming(currIVs[i],
entryIR.GetInsertBlock());
// Set IV names if they've been given to us.
if (i < opts.loopIVNames.size()) {
phi->setName(opts.loopIVNames[i]);
}
currIVs[i] = phi;
}

Expand Down
Loading

0 comments on commit eed1caf

Please sign in to comment.