Skip to content

Commit

Permalink
Fix builtin macros and zero init for memory operations
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 29, 2024
1 parent e37039d commit e1cf160
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 14 deletions.
1 change: 1 addition & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Checks: >
-readability-function-cognitive-complexity,
-readability-implicit-bool-conversion,
-cppcoreguidelines-avoid-magic-numbers,
-clang-analyzer-cplusplus.NewDeleteLeaks,
CheckOptions:
- key: readability-identifier-naming.ClassCase
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ namespace clad {

bool IsLiteral(const clang::Expr* E);

bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD);
bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);
} // namespace utils
} // namespace clad
Expand Down
11 changes: 11 additions & 0 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,5 +766,16 @@ static inline const DeclSpec& Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec(
,Node->isTransparent()
#endif

// Clang 9 above added isa_and_nonnull.
#if CLANG_VERSION_MAJOR < 9
template <typename X, typename Y> bool isa_and_nonnull(const Y* Val) {
return Val && isa<X>(Val);
}
#else
template <typename X, typename Y> bool isa_and_nonnull(const Y* Val) {
return llvm::isa_and_nonnull<X>(Val);
}
#endif

} // namespace clad_compat
#endif //CLAD_COMPATIBILITY
33 changes: 29 additions & 4 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Lookup.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -400,6 +401,12 @@ namespace clad {
auto& C = semaRef.getASTContext();
if (!TSI)
TSI = C.getTrivialTypeSourceInfo(qType);
if (clad_compat::isa_and_nonnull<ImplicitValueInitExpr>(initializer))
// If the initializer is an implicit value init expression, then
// we don't need to pass it explicitly to the CXXNewExpr. As, clang
// internally adds it when initializer is ParenListExpr and
// DirectInitRange is valid.
initializer = semaRef.ActOnParenListExpr(noLoc, noLoc, {}).get();

Check warning on line 409 in lib/Differentiator/CladUtils.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/CladUtils.cpp#L409

Added line #L409 was not covered by tests
auto newExpr =
semaRef
.BuildCXXNew(
Expand Down Expand Up @@ -642,18 +649,36 @@ namespace clad {
isa<GNUNullExpr>(E);
}

bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD) {
bool IsMemoryFunction(const clang::FunctionDecl* FD) {

#if CLANG_VERSION_MAJOR > 12
if (FD->getBuiltinID() == Builtin::BImalloc)
return true;
if (FD->getBuiltinID() == Builtin::BIcalloc)
if (FD->getBuiltinID() == Builtin::ID::BIcalloc)
return true;
if (FD->getBuiltinID() == Builtin::ID::BIrealloc)
return true;
if (FD->getBuiltinID() == Builtin::ID::BImemset)
return true;
#else
if (FD->getNameAsString() == "malloc")
return true;
if (FD->getBuiltinID() == Builtin::BIrealloc)
if (FD->getNameAsString() == "calloc")
return true;

Check warning on line 667 in lib/Differentiator/CladUtils.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/CladUtils.cpp#L667

Added line #L667 was not covered by tests
if (FD->getNameAsString() == "realloc")
return true;
if (FD->getNameAsString() == "memset")
return true;
#endif
return false;
}

bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) {
return FD->getBuiltinID() == Builtin::BIfree;
#if CLANG_VERSION_MAJOR > 12
return FD->getBuiltinID() == Builtin::ID::BIfree;
#else
return FD->getNameAsString() == "free";
#endif
}
} // namespace utils
} // namespace clad
29 changes: 23 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// For calls to C-style memory allocation functions, we do not need to
// differentiate the call. We just need to visit the arguments to the
// function.
if (utils::IsMemoryAllocationFunction(FD)) {
if (utils::IsMemoryFunction(FD)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
Expand Down Expand Up @@ -2649,8 +2649,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
// Check if the variable is pointer type and initialized by new expression
if (isPointerType && (VD->getInit() != nullptr) &&
isa<CXXNewExpr>(VD->getInit()))
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;

// VDDerivedInit now serves two purposes -- as the initial derivative value
Expand Down Expand Up @@ -2842,7 +2841,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalizingDifferentiateSingleStmt(direction::reverse);

addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);
// If the statement is a standalone call to a memory function, we want to
// add its derived statement in the same block as the original statement.
// For ex: memset(x, 0, 10) -> memset(_d_x, 0, 10)
Stmt* stmtDx = SDiff.getStmt_dx();
bool dxInForward = false;
if (auto* callExpr = dyn_cast_or_null<CallExpr>(stmtDx))
if (auto* FD = dyn_cast<FunctionDecl>(callExpr->getCalleeDecl()))
if (utils::IsMemoryFunction(FD))
dxInForward = true;

if (dxInForward)
addToCurrentBlock(stmtDx, direction::forward);
else
addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);
CompoundStmt* RCS = endBlock(direction::reverse);
std::reverse(RCS->body_begin(), RCS->body_end());
Stmt* ReverseResult = unwrapIfSingleStmt(RCS);
Expand Down Expand Up @@ -3824,9 +3836,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* clonedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), clonedArraySizeE,
initializerDiff.getExpr(), CNE->getAllocatedTypeSourceInfo());
Expr* diffInit = initializerDiff.getExpr_dx();
if (!diffInit) {
// we should initialize it implicitly using ParenListExpr.
diffInit = m_Sema.ActOnParenListExpr(noLoc, noLoc, {}).get();
}
Expr* derivedNewE = utils::BuildCXXNewExpr(
m_Sema, CNE->getAllocatedType(), derivedArraySizeE,
initializerDiff.getExpr_dx(), CNE->getAllocatedTypeSourceInfo());
m_Sema, CNE->getAllocatedType(), derivedArraySizeE, diffInit,
CNE->getAllocatedTypeSourceInfo());
return {clonedNewE, derivedNewE};
}

Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ double fn7(double i) {
// CHECK-NEXT: t->i = i;
// CHECK-NEXT: double _d_res = *_d_p + _d_t->i;
// CHECK-NEXT: double res = *p + t->i;
// CHECK-NEXT: unsigned long _t2 = sizeof(double);
// CHECK-NEXT: unsigned {{(int|long)}} _t2 = sizeof(double);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<void *, void *> _t3 = clad::custom_derivatives::realloc_pushforward(p, 2 * _t2, _d_p, 0 * _t2 + 2 * sizeof(double));
// CHECK-NEXT: _d_p = (double *)_t3.pushforward;
// CHECK-NEXT: p = (double *)_t3.value;
Expand Down
7 changes: 5 additions & 2 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ double newAndDeletePointer(double i, double j) {
// CHECK-NEXT: double *p = new double(i);
// CHECK-NEXT: _d_q = new double(* _d_j);
// CHECK-NEXT: double *q = new double(j);
// CHECK-NEXT: _d_r = new double [2];
// CHECK-NEXT: _d_r = new double [2](/*implicit*/(double{{[ ]?}}[2])0);
// CHECK-NEXT: double *r = new double [2];
// CHECK-NEXT: _t0 = r[0];
// CHECK-NEXT: r[0] = i + j;
Expand Down Expand Up @@ -418,7 +418,7 @@ double structPointer (double x) {
// CHECK: void structPointer_grad(double x, clad::array_ref<double> _d_x) {
// CHECK-NEXT: T *_d_t = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: _d_t = new T;
// CHECK-NEXT: _d_t = new T();
// CHECK-NEXT: T *t = new T({x, /*implicit*/(int)0});
// CHECK-NEXT: double res = t->x;
// CHECK-NEXT: goto _label0;
Expand All @@ -432,6 +432,7 @@ double structPointer (double x) {

double cStyleMemoryAlloc(double x, size_t n) {
T* t = (T*)malloc(n * sizeof(T));
memset(t, 0, n * sizeof(T));
t->x = x;
double* p = (double*)calloc(1, sizeof(double));
*p = x;
Expand All @@ -457,6 +458,8 @@ double cStyleMemoryAlloc(double x, size_t n) {
// CHECK-NEXT: double _t5;
// CHECK-NEXT: _d_t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: memset(_d_t, 0, n * sizeof(T));
// CHECK-NEXT: memset(t, 0, n * sizeof(T));
// CHECK-NEXT: _t0 = t->x;
// CHECK-NEXT: t->x = x;
// CHECK-NEXT: _d_p = (double *)calloc(1, sizeof(double));
Expand Down

0 comments on commit e1cf160

Please sign in to comment.