Skip to content

Commit

Permalink
Add shouldUseAtomicOps function
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 3, 2024
1 parent 5c0a26d commit 873ded3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 4 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,10 @@ namespace clad {
return StmtDiff{Clone(NS), Clone(NS)};
};

/// Helper function that checks whether the function to be derived
/// is meant to be executed only by the GPU
bool shouldUseCudaAtomicOps();

/// Add call to cuda::atomicAdd for the given LHS and RHS expressions.
///
/// \param[in] LHS The left-hand side expression.
Expand Down
14 changes: 8 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTapeResult{*this, PushExpr, PopExpr, TapeRef};
}

bool ReverseModeVisitor::shouldUseCudaAtomicOps() {
return m_DiffReq->hasAttr<clang::CUDAGlobalAttr>() ||
(m_DiffReq->hasAttr<clang::CUDADeviceAttr>() &&
!m_DiffReq->hasAttr<clang::CUDAHostAttr>());
}

clang::Expr* ReverseModeVisitor::BuildCallToCudaAtomicAdd(clang::Expr* LHS,
clang::Expr* RHS) {
DeclarationName atomicAddId = &m_Context.Idents.get("atomicAdd");
Expand Down Expand Up @@ -1518,9 +1524,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildArraySubscript(target, forwSweepDerivativeIndices);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>() ||
(m_DiffReq->hasAttr<clang::CUDADeviceAttr>() &&
!m_DiffReq->hasAttr<clang::CUDAHostAttr>())) {
if (shouldUseCudaAtomicOps()) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(result, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -2320,9 +2324,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>() ||
(m_DiffReq->hasAttr<clang::CUDADeviceAttr>() &&
!m_DiffReq->hasAttr<clang::CUDAHostAttr>())) {
if (shouldUseCudaAtomicOps()) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down

0 comments on commit 873ded3

Please sign in to comment.