diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 56519593d..ffc90be12 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -174,11 +174,12 @@ namespace clad { /// otherwise returns false. bool HasAnyReferenceOrPointerArgument(const clang::FunctionDecl* FD); - /// Returns true if `T` is a reference, pointer or array type. + /// Returns true if `arg` is an argument passed by reference or is of + /// pointer/array type. /// - /// \note Please note that this function returns true for array types as - /// well. - bool IsReferenceOrPointerType(clang::QualType T); + /// \note Please note that this function returns false for temporary + /// expressions. + bool IsReferenceOrPointerArg(const clang::Expr* arg); /// Returns true if `T1` and `T2` have same cononical type; otherwise /// returns false. diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index d7d599ce5..6144ae23a 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -301,8 +301,12 @@ namespace clad { return false; } - bool IsReferenceOrPointerType(QualType T) { - return T->isReferenceType() || isArrayOrPointerType(T); + bool IsReferenceOrPointerArg(const Expr* arg) { + // The argument is passed by reference if it's passed as an L-value. + // However, if arg is a MaterializeTemporaryExpr, then arg is a + // temporary variable passed as a const reference. + bool isRefType = arg->isLValue() && !isa(arg); + return isRefType || isArrayOrPointerType(arg->getType()); } bool SameCanonicalType(clang::QualType T1, clang::QualType T2) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 166b90f5f..1962971e1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1537,8 +1537,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerType(PVD->getType()) && - !isa(arg)) { + if (utils::IsReferenceOrPointerArg(arg)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -1723,9 +1722,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; const Expr* arg = CE->getArg(idx); - const auto* PVD = FD->getParamDecl(idx); - if (utils::IsReferenceOrPointerType(PVD->getType()) && - !isa(arg)) { + if (utils::IsReferenceOrPointerArg(arg)) { if (argDerivative) { if (utils::isArrayOrPointerType(argDerivative->getType()) || isCladArrayType(argDerivative->getType()) ||