Skip to content

Commit

Permalink
Declare new variables for pullback derivatives of non-ref type args
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Mar 12, 2024
1 parent bd87275 commit cb3f1a4
Show file tree
Hide file tree
Showing 24 changed files with 402 additions and 125 deletions.
6 changes: 0 additions & 6 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,6 @@ namespace clad {

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);

/// Checks if the argument arg involves type conversion
/// (e.g. float to double). This is used to find out
/// a similar cast can be performed for the derived
/// arg (for corresponding pointers).
bool isTriviallyCastableArg(const clang::Expr* arg);
} // namespace utils
} // namespace clad

Expand Down
11 changes: 0 additions & 11 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,16 +680,5 @@ namespace clad {
return FD->getNameAsString() == "free";
#endif
}

bool isTriviallyCastableArg(const clang::Expr* arg) {
// `arg` is the of the type that the function expects.
const QualType expectedType = arg->getType().getCanonicalType();
// By removing all implicit casts, we can get the type of the original
// argument.
const QualType argType =
arg->IgnoreImplicit()->getType().getCanonicalType();
// Compare the unqualified types.
return expectedType.getTypePtr() == argType.getTypePtr();
}
} // namespace utils
} // namespace clad
25 changes: 10 additions & 15 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return call;
}

llvm::SmallVector<Expr*, 16> ArgResult{};
llvm::SmallVector<Stmt*, 16> PreCallStmts{};
// Save current index in the current block, to potentially put some
// statements there later.
Expand All @@ -1538,11 +1537,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
if (utils::isArrayOrPointerType(arg->getType()) ||
(arg->IgnoreImplicit()->isLValue() &&
utils::isTriviallyCastableArg(arg))) {
if (utils::IsReferenceOrPointerType(PVD->getType()) && !isa<MaterializeTemporaryExpr>(arg)) {
argDiff = Visit(arg);
ArgResult.push_back(argDiff.getExpr_dx());
CallArgDx.push_back(argDiff.getExpr_dx());
} else {
// Create temporary variables corresponding to derivative of each
// argument, so that they can be referred to when arguments is visited.
Expand All @@ -1553,11 +1550,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));
PreCallStmts.push_back(BuildDeclStmt(dArgDecl));
ArgResult.push_back(BuildDeclRef(dArgDecl));
CallArgDx.push_back(BuildDeclRef(dArgDecl));
// Visit using uninitialized reference.
argDiff = Visit(arg, BuildDeclRef(dArgDecl));
}
CallArgDx.push_back(argDiff.getExpr_dx());

// Save cloned arg in a "global" variable, so that it is accessible from
// the reverse pass.
Expand Down Expand Up @@ -1726,9 +1722,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (auto* argDerivative : CallArgDx) {
Expr* gradArgExpr = nullptr;
const Expr* arg = CE->getArg(idx);
if (utils::isArrayOrPointerType(arg->getType()) ||
(arg->IgnoreImplicit()->isLValue() &&
utils::isTriviallyCastableArg(arg))) {
const auto* PVD = FD->getParamDecl(idx);
if (utils::IsReferenceOrPointerType(PVD->getType()) && !isa<MaterializeTemporaryExpr>(arg)) {
if (argDerivative) {
if (utils::isArrayOrPointerType(argDerivative->getType()) ||
isCladArrayType(argDerivative->getType()) ||
Expand All @@ -1739,7 +1734,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative);
}
} else {
Expr* gradVarExpr = ArgResult[idx];
Expr* gradVarExpr = CallArgDx[idx];
gradArgExpr =
BuildOp(UO_AddrOf, gradVarExpr, m_Function->getLocation());
}
Expand Down Expand Up @@ -1881,20 +1876,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema, getCurrentScope(), OverloadedDerivedFn, "pushforward");
// If the derivative is called through _darg0 instead of _grad.
Expr* d = BuildOp(BO_Mul, dfdx(), OverloadedDerivedFn);
Expr* addGrad = BuildOp(BO_AddAssign, Clone(ArgResult[0]), d);
block.insert(it, addGrad);
Expr* addGrad = BuildOp(BO_AddAssign, Clone(CallArgDx[0]), d);
it = block.insert(it, addGrad);
it++;
} else {
// Insert the CallExpr to the derived function
block.insert(it, OverloadedDerivedFn);
it = block.insert(it, OverloadedDerivedFn);
it++;
}
// Insert PostCallStmts
it = block.insert(it, PostCallStmts.begin(), PostCallStmts.end());
}
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalizingVisitCallExpr(
CE, OverloadedDerivedFn, DerivedCallArgs, ArgResult, asGrad);
CE, OverloadedDerivedFn, DerivedCallArgs, CallArgDx, asGrad);

Expr* call = nullptr;

Expand Down
18 changes: 14 additions & 4 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ float func2(float* a) {
//CHECK-NEXT: i--;
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_sum;
//CHECK-NEXT: helper_pullback(a[i], _r_d0, &_d_a[i]);
//CHECK-NEXT: float _r0 = 0;
//CHECK-NEXT: helper_pullback(a[i], _r_d0, &_r0);
//CHECK-NEXT: _d_a[i] += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -269,7 +271,9 @@ double func5(int k) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d1 = _d_sum;
//CHECK-NEXT: addArr_pullback(arr, n, _r_d1, _d_arr, &_d_n);
//CHECK-NEXT: int _r0 = 0;
//CHECK-NEXT: addArr_pullback(arr, n, _r_d1, _d_arr, &_r0);
//CHECK-NEXT: _d_n += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: for (; _t0; _t0--) {
Expand Down Expand Up @@ -440,7 +444,11 @@ double func8(double i, double *arr, int n) {
//CHECK-NEXT: res = _t1;
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: _d_res -= _r_d1;
//CHECK-NEXT: helper2_pullback(i, arr, n, _r_d1, &* _d_i, _d_arr, &* _d_n);
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: int _r1 = 0;
//CHECK-NEXT: helper2_pullback(i, arr, n, _r_d1, &_r0, _d_arr, &_r1);
//CHECK-NEXT: * _d_i += _r0;
//CHECK-NEXT: * _d_n += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: arr[0] = _t0;
Expand Down Expand Up @@ -501,7 +509,9 @@ double func9(double i, double j) {
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = clad::pop(_t1);
//CHECK-NEXT: arr[idx] = _r0;
//CHECK-NEXT: modify_pullback(_r0, i, &_d_arr[idx], &* _d_i);
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: modify_pullback(_r0, i, &_d_arr[idx], &_r1);
//CHECK-NEXT: * _d_i += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down
12 changes: 10 additions & 2 deletions test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ float func4(float x, float y) { return std::pow(x, y); }
//CHECK-NEXT: _ret_value0 = std::pow(x, y);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &* _d_x, &* _d_y);
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = 0;
//CHECK-NEXT: float _r1 = 0;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1);
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: * _d_y += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}});
//CHECK-NEXT: double _delta_y = 0;
Expand Down Expand Up @@ -221,7 +227,9 @@ float func5(float x, float y) {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: float _r_d0 = * _d_y;
//CHECK-NEXT: * _d_y -= _r_d0;
//CHECK-NEXT: * _d_x += _r_d0 * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward;
//CHECK-NEXT: float _r0 = 0;
//CHECK-NEXT: _r0 += _r_d0 * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward;
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: _delta_y += std::abs(_r_d0 * _EERepl_y1 * {{.+}});
//CHECK-NEXT: }
//CHECK-NEXT: double _delta_x = 0;
Expand Down
38 changes: 28 additions & 10 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ void f7_grad(float x, clad::array_ref<float> _d_x);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2., 1, &* _d_x, &_r0);
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2., 1, &_r0, &_r1);
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand All @@ -126,8 +128,10 @@ void f8_grad(float x, clad::array_ref<float> _d_x);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: int _r0 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, 1, &* _d_x, &_r0);
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: int _r1 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, 1, &_r0, &_r1);
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand All @@ -147,7 +151,13 @@ void f9_grad(float x, float y, clad::array_ref<float> _d_x, clad::array_ref<floa
// CHECK: void f9_grad(float x, float y, clad::array_ref<float> _d_x, clad::array_ref<float> _d_y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &* _d_x, &* _d_y);
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: float _r1 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1);
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: * _d_y += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

double f10(float x, int y) {
Expand All @@ -166,7 +176,13 @@ void f10_grad(float x, int y, clad::array_ref<float> _d_x, clad::array_ref<int>
// CHECK: void f10_grad(float x, int y, clad::array_ref<float> _d_x, clad::array_ref<int> _d_y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &* _d_x, &* _d_y);
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: int _r1 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1);
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: * _d_y += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

double f11(double x, double y) {
Expand All @@ -184,11 +200,13 @@ double f11(double x, double y) {
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback((1. - x), 2, 1, &_r0, &_r1);
// CHECK-NEXT: * _d_x += -_r0;
// CHECK-NEXT: double _r2 = 0;
// CHECK-NEXT: int _r4 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(y - std::pow(x, 2), 2, 100. * 1, &_r2, &_r4);
// CHECK-NEXT: int _r5 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(y - std::pow(x, 2), 2, 100. * 1, &_r2, &_r5);
// CHECK-NEXT: * _d_y += _r2;
// CHECK-NEXT: int _r3 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, -_r2, &* _d_x, &_r3);
// CHECK-NEXT: double _r3 = 0;
// CHECK-NEXT: int _r4 = 0;
// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, -_r2, &_r3, &_r4);
// CHECK-NEXT: * _d_x += _r3;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down
10 changes: 9 additions & 1 deletion test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,15 @@ double f19(double a, double b) {
//CHECK: void f19_grad(double a, double b, clad::array_ref<double> _d_a, clad::array_ref<double> _d_b) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &* _d_a, &* _d_b, &* _d_b);
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: double _r2 = 0;
//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2);
//CHECK-NEXT: * _d_a += _r0;
//CHECK-NEXT: * _d_b += _r1;
//CHECK-NEXT: * _d_b += _r2;
//CHECK-NEXT: }
//CHECK-NEXT: }

double f20(double x, double y) {
Expand Down
26 changes: 18 additions & 8 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ double fn1(float i) {
// CHECK-NEXT: _d_res += _d_a * i;
// CHECK-NEXT: * _d_i += res * _d_a;
// CHECK-NEXT: }
// CHECK-NEXT: constantFn_pullback(i, _d_res, &* _d_i);
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: constantFn_pullback(i, _d_res, &_r0);
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

double modify1(double& i, double& j) {
Expand Down Expand Up @@ -274,15 +278,17 @@ double fn4(double* arr, int n) {
// CHECK-NEXT: _d_arr[i] += _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = clad::pop(_t2);
// CHECK-NEXT: arr[i] = _r0;
// CHECK-NEXT: twice_pullback(_r0, &_d_arr[i]);
// CHECK-NEXT: double _r1 = clad::pop(_t2);
// CHECK-NEXT: arr[i] = _r1;
// CHECK-NEXT: twice_pullback(_r1, &_d_arr[i]);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: res = _t0;
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: sum_pullback(arr, n, _r_d0, _d_arr, &* _d_n);
// CHECK-NEXT: int _r0 = 0;
// CHECK-NEXT: sum_pullback(arr, n, _r_d0, _d_arr, &_r0);
// CHECK-NEXT: * _d_n += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down Expand Up @@ -457,8 +463,10 @@ double fn8(double x, double y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: char _r0 = 0;
// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &* _d_x, &_r0, "");
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: char _r1 = 0;
// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_r0, &_r1, "");
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: * _d_y += _t2 * 1 * _t0 * _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }
Expand Down Expand Up @@ -682,7 +690,9 @@ double fn14(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: emptyFn_pullback(_t0, y, &* _d_x, &* _d_y);
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: emptyFn_pullback(_t0, y, &* _d_x, &_r0);
// CHECK-NEXT: * _d_y += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down
30 changes: 26 additions & 4 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ int main() {
// CHECK-NEXT: _t0 = E;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &_d_E, &* _d_i, &* _d_j);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &_d_E, &_r0, &_r1);
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a function calling operator() on a functor
Expand All @@ -259,7 +265,13 @@ int main() {
// CHECK-NEXT: _t0 = fn;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &(* _d_fn), &* _d_i, &* _d_j);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &(* _d_fn), &_r0, &_r1);
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// testing differentiating a function taking functor as an argument
Expand All @@ -274,7 +286,13 @@ int main() {
// CHECK-NEXT: _t0 = fn;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _t0.operator_call_pullback(i, j, _d_y, &(* _d_fn), &* _d_i, &* _d_j);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.operator_call_pullback(i, j, _d_y, &(* _d_fn), &_r0, &_r1);
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void FunctorAsArgWrapper_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
Expand All @@ -284,7 +302,11 @@ int main() {
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: Experiment _r0 = {};
// CHECK-NEXT: FunctorAsArg_pullback(E, i, j, 1, &_r0, &* _d_i, &* _d_j);
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: double _r2 = 0;
// CHECK-NEXT: FunctorAsArg_pullback(E, i, j, 1, &_r0, &_r1, &_r2);
// CHECK-NEXT: * _d_i += _r1;
// CHECK-NEXT: * _d_j += _r2;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down
Loading

0 comments on commit cb3f1a4

Please sign in to comment.