Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve derived variable init for const pointers #919

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;
// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;
Expand Down Expand Up @@ -2506,22 +2507,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType && VD->getInit()) {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (VD->getType()->getPointeeType().isConstQualified()) {
// first extract the pointee type
auto pointeeType = VD->getType()->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
else if (isPointerType) {
if (!isInitializedByNewExpr)
initDiff = Visit(VD->getInit());

// If the pointer is const and derived expression is not available, then
// we should not create a derived variable for it. This will be useful
// for reducing number of differentiation variables in pullbacks.
bool constPointer = VD->getType()->getPointeeType().isConstQualified();
if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx())
initializeDerivedVar = false;
else {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// If it's a pointer to a constant type, then remove the constness.
if (constPointer) {
// first extract the pointee type
auto pointeeType = VD->getType()->getPointeeType();
// then remove the constness
pointeeType.removeLocalConst();
// then create a new pointer type with the new pointee type
VDDerivedType = m_Context.getPointerType(pointeeType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
VDDerivedInit = getZeroInit(VDDerivedType);
}
VDDerived =
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr, VD->getInitStyle());
if (initializeDerivedVar)
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VD->getInitStyle());
}

// If `VD` is a reference to a local variable, then it is already
Expand Down Expand Up @@ -2562,11 +2575,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assignToZero = GetCladZeroInit(declRef);
addToCurrentBlock(assignToZero, direction::reverse);
}
} else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
}
VarDecl* VDClone = nullptr;
Expr* derivedVDE = BuildDeclRef(VDDerived);
Expr* derivedVDE = nullptr;
if (VDDerived)
derivedVDE = BuildDeclRef(VDDerived);

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
Expand Down Expand Up @@ -2601,7 +2614,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
if (isPointerType) {
if (isPointerType && derivedVDE) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
Expand All @@ -2615,7 +2628,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
}
m_Variables.emplace(VDClone, derivedVDE);
if (derivedVDE)
m_Variables.emplace(VDClone, derivedVDE);

return DeclDiff<VarDecl>(VDClone, VDDerived);
}
Expand All @@ -2642,11 +2656,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (auto* FD = dyn_cast<FunctionDecl>(callExpr->getCalleeDecl()))
if (utils::IsMemoryFunction(FD))
dxInForward = true;

if (dxInForward)
addToCurrentBlock(stmtDx, direction::forward);
else
addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);
if (stmtDx) {
if (dxInForward)
addToCurrentBlock(stmtDx, direction::forward);
else
addToCurrentBlock(stmtDx, direction::reverse);
}
CompoundStmt* RCS = endBlock(direction::reverse);
std::reverse(RCS->body_begin(), RCS->body_end());
Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS);
Expand Down Expand Up @@ -2747,10 +2762,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

decls.push_back(VDDiff.getDecl());
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx()) {
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
}
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down
34 changes: 34 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,25 @@ double fn19(double x) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double weighted_sum(double* x, const double* w) {
return w[0] * x[0] + w[1] * x[1];
}

// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x);

double fn20(double* x, const double* w) {
const double* auxW = w + 1;
return weighted_sum(x, auxW);
}

// CHECK: void fn20_grad_0(double *x, const double *w, double *_d_x) {
// CHECK-NEXT: const double *auxW = w + 1;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: weighted_sum_pullback(x, auxW, 1, _d_x);
// CHECK-NEXT: }


template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -767,6 +786,12 @@ int main() {

INIT(fn19);
TEST1(fn19, 3); // CHECK-EXEC: {1.00}

auto fn20_grad_0 = clad::gradient(fn20, "x");
double x1[] = {3.0, 5.0}, w1[] = {-1.0, 2.0, 3.0};
double dx1[] = {0.0, 0.0};
fn20_grad_0.execute(x1, w1, dx1);
printf("{%.2f, %.2f}\n", dx1[0], dx1[1]); // CHECK-EXEC: {2.00, 3.00}
}

double sq_defined_later(double x) {
Expand Down Expand Up @@ -1010,3 +1035,12 @@ double sq_defined_later(double x) {
// CHECK-NEXT: _label0:
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: }

// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: _d_x[0] += w[0] * _d_y;
// CHECK-NEXT: _d_x[1] += w[1] * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }
Loading