Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move reduction lowering from DistributeOp to TeamsOp and use teams reduction clauses to generate info. #159

Merged
merged 5 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
return;

bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);

ASTContext &C = CGM.getContext();
Expand Down Expand Up @@ -1756,7 +1755,7 @@ void CGOpenMPRuntimeGPU::emitReduction(

CGF.Builder.restoreIP(OMPBuilder.createReductionsGPU(
OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
CGF.getTarget().getGridValue(), C.getLangOpts().OpenMPCUDAReductionBufNum,
RTLoc));
return;
Expand Down
35 changes: 20 additions & 15 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ genReductionVars(mlir::Operation *op, lower::AbstractConverter &converter,

mlir::Block *entryBlock = firOpBuilder.createBlock(
&op->getRegion(0), {}, reductionTypes, blockArgLocs);

// Bind the reduction arguments to their block arguments.
for (auto [arg, prv] :
llvm::zip_equal(reductionArgs, entryBlock->getArguments())) {
Expand Down Expand Up @@ -1659,14 +1658,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
loc, llvm::omp::Directive::OMPD_taskwait);
}

static void
genTeamsClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, bool evalOutsideTarget,
mlir::omp::TeamsOperands &clauseOps,
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps) {
static void genTeamsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, bool evalOutsideTarget,
mlir::omp::TeamsOperands &clauseOps,
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
Expand All @@ -1684,8 +1684,7 @@ genTeamsClauses(lower::AbstractConverter &converter,
cp.processNumTeams(stmtCtx, numTeamsClauseOps);
cp.processThreadLimit(stmtCtx, threadLimitClauseOps);
}

// cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
}

static void genWsloopClauses(
Expand Down Expand Up @@ -1874,7 +1873,6 @@ static mlir::omp::ParallelOp genParallelOp(
llvm::ArrayRef<mlir::Type> reductionTypes, DataSharingProcessor *dsp,
bool isComposite = false, mlir::omp::TargetOp parentTarget = nullptr) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

auto reductionCallback = [&](mlir::Operation *op) {
genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
return llvm::SmallVector<const semantics::Symbol *>(reductionSyms);
Expand Down Expand Up @@ -2360,14 +2358,22 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::omp::TeamsOperands clauseOps;
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
evalOutsideTarget, clauseOps, numTeamsClauseOps,
threadLimitClauseOps);
threadLimitClauseOps, reductionTypes, reductionSyms);

auto reductionCallback = [&](mlir::Operation *op) {
genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
return llvm::SmallVector<const semantics::Symbol *>(reductionSyms);
};

auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
.setClauses(&item->clauses),
.setClauses(&item->clauses)
.setGenRegionEntryCb(reductionCallback),
queue, item, clauseOps);

if (numTeamsClauseOps.numTeamsUpper) {
Expand Down Expand Up @@ -2436,7 +2442,6 @@ static void genStandaloneDo(lower::AbstractConverter &converter,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;

mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
Expand Down
15 changes: 15 additions & 0 deletions flang/test/Lower/OpenMP/reduction-target-spmd.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s
! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s

! CHECK: omp.teams
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}} -> %{{.*}} : !fir.ref<i32>)
subroutine myfun()
integer :: i, j
i = 0
j = 0
!$omp target teams distribute parallel do reduction(+:i)
do j = 1,5
i = i + j
end do
!$omp end target teams distribute parallel do
end subroutine myfun
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s
! XFAIL: *

! CHECK: omp.teams
! CHECK-SAME: reduction
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/sections-array-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ subroutine sectionsReduction(x)
! CHECK: omp.parallel {
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 -> %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):
! CHECK: omp.section {
! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenMP/sections-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ subroutine sectionsReduction(x,y)
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: omp.parallel {
! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref<f32>, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref<f32>) {
! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[ARG_0:.*]] : !fir.ref<f32>, @add_reduction_f32 %[[VAL_4]]#0 -> %[[ARG_1:.*]] : !fir.ref<f32>) {
! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<f32>, %[[VAL_6:.*]]: !fir.ref<f32>):
! CHECK: omp.section {
! CHECK: ^bb0(%[[VAL_7:.*]]: !fir.ref<f32>, %[[VAL_8:.*]]: !fir.ref<f32>):
Expand Down Expand Up @@ -71,7 +71,7 @@ subroutine sectionsReduction(x,y)
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.parallel {
! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref<f32>, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref<f32>) {
! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[ARG_2:.*]] : !fir.ref<f32>, @add_reduction_f32 %[[VAL_4]]#0 -> %[[ARG_3:.*]] : !fir.ref<f32>) {
! CHECK: ^bb0(%[[VAL_23:.*]]: !fir.ref<f32>, %[[VAL_24:.*]]: !fir.ref<f32>):
! CHECK: omp.section {
! CHECK: ^bb0(%[[VAL_25:.*]]: !fir.ref<f32>, %[[VAL_26:.*]]: !fir.ref<f32>):
Expand Down
6 changes: 1 addition & 5 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1844,8 +1844,6 @@ class OpenMPIRBuilder {
/// nowait.
/// \param IsTeamsReduction Optional flag set if it is a teams
/// reduction.
/// \param HasDistribute Optional flag set if it is a
/// distribute reduction.
/// \param GridValue Optional GPU grid value.
/// \param ReductionBufNum Optional OpenMPCUDAReductionBufNumValue to be
/// used for teams reduction.
Expand All @@ -1854,7 +1852,6 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
bool IsNoWait = false, bool IsTeamsReduction = false,
bool HasDistribute = false,
ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
Value *SrcLocInfo = nullptr);
Expand Down Expand Up @@ -1926,8 +1923,7 @@ class OpenMPIRBuilder {
InsertPointTy AllocaIP,
ArrayRef<ReductionInfo> ReductionInfos,
ArrayRef<bool> IsByRef, bool IsNoWait = false,
bool IsTeamsReduction = false,
bool HasDistribute = false);
bool IsTeamsReduction = false);

///}

Expand Down
22 changes: 10 additions & 12 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3412,9 +3412,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
unsigned ReductionBufNum, Value *SrcLocInfo) {
bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
Value *SrcLocInfo) {
if (!updateToLocation(Loc))
return InsertPointTy();
Builder.restoreIP(CodeGenIP);
Expand Down Expand Up @@ -3590,13 +3590,11 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
ReductionFunc;
});
} else {
if (!HasDistribute || IsTeamsReduction) {
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
Value *Reduced;
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
Builder.CreateStore(Reduced, LHS, false);
}
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Remove HasDistribute function argument, since it's no longer necessary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a good point.

Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
Value *Reduced;
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
Builder.CreateStore(Reduced, LHS, false);
}
}
emitBlock(ExitBB, CurFunc);
Expand Down Expand Up @@ -3685,11 +3683,11 @@ static void populateReductionFunction(
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
const LocationDescription &Loc, InsertPointTy AllocaIP,
ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute) {
bool IsNoWait, bool IsTeamsReduction) {
assert(ReductionInfos.size() == IsByRef.size());
if (Config.isGPU())
return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
IsNoWait, IsTeamsReduction, HasDistribute);
IsNoWait, IsTeamsReduction);

checkReductionInfos(ReductionInfos, /*IsGPU*/ false);

Expand Down
68 changes: 25 additions & 43 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,20 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
//===----------------------------------------------------------------------===//

static ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser, Region &region,
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols,
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
bool parseParens = true) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();

OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren
: OpAsmParser::Delimiter::None;

if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
parser.parseCommaSeparatedList(delimiter, [&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
Expand Down Expand Up @@ -536,15 +540,15 @@ static ParseResult parseParallelRegion(
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;

if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
if (failed(parseClauseWithRegionArgs(parser, reductionVars,
reductionTypes, reductionByref,
reductionSyms, regionPrivateArgs)))
return failure();
}

if (succeeded(parser.parseOptionalKeyword("private"))) {
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
if (failed(parseClauseWithRegionArgs(parser, privateVars,
privateTypes, privateByref,
privateSyms, regionPrivateArgs)))
return failure();
Expand Down Expand Up @@ -597,48 +601,26 @@ static ParseResult parseReductionVarList(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(reductionVars.emplace_back()) ||
parser.parseColonType(reductionTypes.emplace_back()))
return failure();
isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSyms = ArrayAttr::get(parser.getContext(), reductions);
return success();
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
return parseClauseWithRegionArgs(parser, reductionVars, reductionTypes,
reductionByref, reductionSyms,
regionPrivateArgs, /*parseParens=*/false);
}

/// Print Reduction clause
static void
printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars, TypeRange reductionTypes,
std::optional<DenseBoolArrayAttr> reductionByref,
std::optional<ArrayAttr> reductionSyms) {
auto getByRef = [&](unsigned i) -> const char * {
if (!reductionByref || !*reductionByref)
return "";
assert(reductionByref->empty() || i < reductionByref->size());
if (!reductionByref->empty() && (*reductionByref)[i])
return "byref ";
return "";
};

for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i]
<< " : " << reductionVars[i].getType();
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
if (reductionSyms) {
auto *argsBegin = op->getRegion(0).front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, llvm::StringRef(),
reductionVars, reductionTypes, reductionByref,
reductionSyms);
}
}

/// Verifies Reduction Clause
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
Expand Down Expand Up @@ -1824,7 +1806,7 @@ parseWsloop(OpAsmParser &parser, Region &region,
// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
if (failed(parseClauseWithRegionArgs(parser, reductionOperands,
reductionTypes, reductionByRef,
reductionSymbols, privates)))
return failure();
Expand Down
Loading