Skip to content

Commit

Permalink
Improve derived variable init for const pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Jun 6, 2024
1 parent bfe6d65 commit b38d8cf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 28 deletions.
73 changes: 45 additions & 28 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;
// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;
Expand Down Expand Up @@ -2506,22 +2507,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType && VD->getInit()) {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (VD->getType()->getPointeeType().isConstQualified()) {
// first extract the pointee type
auto pointeeType = VD->getType()->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
else if (isPointerType) {
if (!isInitializedByNewExpr)
initDiff = Visit(VD->getInit());

// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VD->getType()->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VD->getType()->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
VDDerived =
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr, VD->getInitStyle());
if (initializeDerivedVar)
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VD->getInitStyle());
}

// If `VD` is a reference to a local variable, then it is already
Expand Down Expand Up @@ -2562,11 +2575,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assignToZero = GetCladZeroInit(declRef);
addToCurrentBlock(assignToZero, direction::reverse);
}
} else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
}
VarDecl* VDClone = nullptr;
Expr* derivedVDE = BuildDeclRef(VDDerived);
Expr* derivedVDE = nullptr;
if (VDDerived)
derivedVDE = BuildDeclRef(VDDerived);

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
Expand Down Expand Up @@ -2601,7 +2614,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
if (isPointerType) {
if (isPointerType && derivedVDE) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
Expand All @@ -2615,7 +2628,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
}
m_Variables.emplace(VDClone, derivedVDE);
if (derivedVDE)
m_Variables.emplace(VDClone, derivedVDE);

return DeclDiff<VarDecl>(VDClone, VDDerived);
}
Expand All @@ -2642,11 +2656,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (auto* FD = dyn_cast<FunctionDecl>(callExpr->getCalleeDecl()))
if (utils::IsMemoryFunction(FD))
dxInForward = true;

if (dxInForward)
addToCurrentBlock(stmtDx, direction::forward);
else
addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);
if (stmtDx) {
if (dxInForward)
addToCurrentBlock(stmtDx, direction::forward);
else
addToCurrentBlock(stmtDx, direction::reverse);
}
CompoundStmt* RCS = endBlock(direction::reverse);
std::reverse(RCS->body_begin(), RCS->body_end());
Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS);
Expand Down Expand Up @@ -2747,10 +2762,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

decls.push_back(VDDiff.getDecl());
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx()) {
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
}
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down
34 changes: 34 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,25 @@ double fn19(double x) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double weighted_sum(double* x, const double* w) {
return w[0] * x[0] + w[1] * x[1];
}

// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x);

double fn20(double* x, const double* w) {
const double* auxW = w + 1;
return weighted_sum(x, auxW);
}

// CHECK: void fn20_grad_0(double *x, const double *w, double *_d_x) {
// CHECK-NEXT: const double *auxW = w + 1;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: weighted_sum_pullback(x, auxW, 1, _d_x);
// CHECK-NEXT: }


template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -767,6 +786,12 @@ int main() {

INIT(fn19);
TEST1(fn19, 3); // CHECK-EXEC: {1.00}

auto fn20_grad_0 = clad::gradient(fn20, "x");
double x1[] = {3.0, 5.0}, w1[] = {-1.0, 2.0, 3.0};
double dx1[] = {0.0, 0.0};
fn20_grad_0.execute(x1, w1, dx1);
printf("{%.2f, %.2f}\n", dx1[0], dx1[1]); // CHECK-EXEC: {2.00, 3.00}
}

double sq_defined_later(double x) {
Expand Down Expand Up @@ -1010,3 +1035,12 @@ double sq_defined_later(double x) {
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: }

// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: _d_x[0] += w[0] * _d_y;
// CHECK-NEXT: _d_x[1] += w[1] * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit b38d8cf

Please sign in to comment.