diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index fa7c09ad6..802db880a 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -30,16 +30,12 @@ class BaseForwardModeVisitor ///\brief Produces the first derivative of a given function. /// - ///\param[in] FD - the function that will be differentiated. - /// ///\returns The differentiated and potentially created enclosing /// context. /// - DerivativeAndOverload Derive(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload Derive(); - DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload DerivePushforward(); /// Returns the return type for the pushforward function of the function /// `m_DiffReq->Function`. diff --git a/include/clad/Differentiator/HessianModeVisitor.h b/include/clad/Differentiator/HessianModeVisitor.h index b9b768d30..65dca1309 100644 --- a/include/clad/Differentiator/HessianModeVisitor.h +++ b/include/clad/Differentiator/HessianModeVisitor.h @@ -41,8 +41,6 @@ namespace clad { ///\brief Produces the hessian second derivative columns of a given /// function. /// - ///\param[in] FD - the function that will be differentiated. - /// ///\returns A function containing second derivatives (columns) of a hessian /// matrix and potentially created enclosing context. /// @@ -50,8 +48,7 @@ namespace clad { /// ReverseModeVisitor to generate second derivatives that correspond to /// columns of the Hessian. uses Merge to return a FunctionDecl /// containing CallExprs to the generated second derivatives. - DerivativeAndOverload Derive(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload Derive(); }; } // end namespace clad diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h index 546b8b006..b1b3018cf 100644 --- a/include/clad/Differentiator/ReverseModeForwPassVisitor.h +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -22,8 +22,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor { public: ReverseModeForwPassVisitor(DerivativeBuilder& builder, const DiffRequest& request); - DerivativeAndOverload Derive(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload Derive(); StmtDiff ProcessSingleStmt(const clang::Stmt* S); StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override; diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index ad9981bb1..dc4fb74bf 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -355,8 +355,6 @@ namespace clad { ///\brief Produces the gradient of a given function. /// - ///\param[in] FD - the function that will be differentiated. - /// ///\returns The gradient of the function, potentially created enclosing /// context and if generated, its overload. /// @@ -373,10 +371,8 @@ namespace clad { /// Improved naming scheme is required. Hence, we append the indices to of /// the requested parameters to 'f_grad', i.e. in the previous example "x, /// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'. - DerivativeAndOverload Derive(const clang::FunctionDecl* FD, - const DiffRequest& request); - DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload Derive(); + DerivativeAndOverload DerivePullback(); StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 9635e19f8..0d0bbcb47 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -59,19 +59,15 @@ bool IsRealNonReferenceType(QualType T) { return T.getNonReferenceType()->isRealType(); } -DerivativeAndOverload -BaseForwardModeVisitor::Derive(const FunctionDecl* FD, - const DiffRequest& request) { - assert(m_DiffReq == request && "Can't pass two different requests!"); - m_Functor = request.Functor; +DerivativeAndOverload BaseForwardModeVisitor::Derive() { + m_Functor = m_DiffReq.Functor; + const FunctionDecl* FD = m_DiffReq.Function; assert(m_DiffReq.Mode == DiffMode::forward); assert(!m_DerivativeInFlight && "Doesn't support recursive diff. Use DiffPlan."); m_DerivativeInFlight = true; - DiffInputVarsInfo DVI = request.DVI; - - DVI = request.DVI; + DiffInputVarsInfo DVI = m_DiffReq.DVI; // FIXME: Shouldn't we give error here that no arg is specified? if (DVI.empty()) @@ -84,7 +80,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) && (diffVarInfo.paramIndexInterval.size() != 1))) { diag(DiagnosticsEngine::Error, - request.Args ? request.Args->getEndLoc() : noLoc, + m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc, "Forward mode differentiation w.r.t. several parameters at once is " "not " "supported, call 'clad::differentiate' for each parameter " @@ -129,7 +125,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, isField = true; } if (!IsRealNonReferenceType(T)) { - diag(DiagnosticsEngine::Error, request.Args->getEndLoc(), + diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(), "Attempted differentiation w.r.t. %0 '%1' which is not " "of real type.", {(isField ? "member" : "parameter"), diffVarInfo.source}); @@ -157,12 +153,12 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, argInfo += "_" + field; std::string s; - if (request.CurrentDerivativeOrder > 1) - s = std::to_string(request.CurrentDerivativeOrder); + if (m_DiffReq.CurrentDerivativeOrder > 1) + s = std::to_string(m_DiffReq.CurrentDerivativeOrder); // Check if the function is already declared as a custom derivative. - std::string gradientName = - request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix; + std::string gradientName = m_DiffReq.BaseFunctionName + "_d" + s + "arg" + + argInfo + derivativeSuffix; // FIXME: We should not use const_cast to get the decl context here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); @@ -221,7 +217,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, derivedFD->setParams(paramsRef); derivedFD->setBody(nullptr); - if (!request.DeclarationOnly) { + if (!m_DiffReq.DeclarationOnly) { // Function body scope beginScope(Scope::FnScope | Scope::DeclScope); m_DerivativeFnScope = getCurrentScope(); @@ -365,9 +361,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, // Size >= current derivative order means that there exists a declaration // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + if (m_DiffReq.DerivedFDPrototypes.size() >= + m_DiffReq.CurrentDerivativeOrder) m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); + m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); @@ -401,13 +398,9 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { addToCurrentBlock(S); } -DerivativeAndOverload -BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, - const DiffRequest& request) { - // FIXME: We must not reset the diff request here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(m_DiffReq) = request; - m_Functor = request.Functor; +DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { + const FunctionDecl* FD = m_DiffReq.Function; + m_Functor = m_DiffReq.Functor; assert(m_DiffReq.Mode == GetPushForwardMode()); assert(!m_DerivativeInFlight && "Doesn't support recursive diff. Use DiffPlan."); @@ -517,7 +510,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, m_Derivative->setParams(params); m_Derivative->setBody(nullptr); - if (!request.DeclarationOnly) { + if (!m_DiffReq.DeclarationOnly) { beginScope(Scope::FnScope | Scope::DeclScope); m_DerivativeFnScope = getCurrentScope(); beginBlock(); @@ -532,9 +525,10 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, // Size >= current derivative order means that there exists a declaration // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + if (m_DiffReq.DerivedFDPrototypes.size() >= + m_DiffReq.CurrentDerivativeOrder) m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); + m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index a569a90ae..e4ba3f99a 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -399,44 +399,44 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { DerivativeAndOverload result{}; if (request.Mode == DiffMode::forward) { BaseForwardModeVisitor V(*this, request); - result = V.Derive(FD, request); + result = V.Derive(); } else if (request.Mode == DiffMode::experimental_pushforward) { PushForwardModeVisitor V(*this, request); - result = V.DerivePushforward(FD, request); + result = V.DerivePushforward(); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this, request); result = V.DeriveVectorMode(FD, request); } else if (request.Mode == DiffMode::experimental_vector_pushforward) { VectorPushForwardModeVisitor V(*this, request); - result = V.DerivePushforward(FD, request); + result = V.DerivePushforward(); } else if (request.Mode == DiffMode::reverse) { ReverseModeVisitor V(*this, request); - result = V.Derive(FD, request); + result = V.Derive(); } else if (request.Mode == DiffMode::experimental_pullback) { ReverseModeVisitor V(*this, request); if (!m_ErrorEstHandler.empty()) { InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); V.AddExternalSource(*m_ErrorEstHandler.back()); } - result = V.DerivePullback(FD, request); + result = V.DerivePullback(); if (!m_ErrorEstHandler.empty()) CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { ReverseModeForwPassVisitor V(*this, request); - result = V.Derive(FD, request); + result = V.Derive(); } else if (request.Mode == DiffMode::hessian || request.Mode == DiffMode::hessian_diagonal) { HessianModeVisitor H(*this, request); - result = H.Derive(FD, request); + result = H.Derive(); } else if (request.Mode == DiffMode::jacobian) { ReverseModeVisitor R(*this, request); - result = R.Derive(FD, request); + result = R.Derive(); } else if (request.Mode == DiffMode::error_estimation) { ReverseModeVisitor R(*this, request); InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); R.AddExternalSource(*m_ErrorEstHandler.back()); // Finally begin estimation. - result = R.Derive(FD, request); + result = R.Derive(); // Once we are done, we want to clear the model for any further // calls to estimate_error. CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 55161e6e8..af713b4d6 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -93,160 +93,152 @@ static FunctionDecl* DeriveUsingForwardModeTwice( return secondDerivative; } - DerivativeAndOverload - HessianModeVisitor::Derive(const clang::FunctionDecl* FD, - const DiffRequest& request) { - DiffParams args{}; - IndexIntervalTable indexIntervalTable{}; - DiffInputVarsInfo DVI; - if (request.Args) { - DVI = request.DVI; - for (auto dParam : DVI) { - args.push_back(dParam.param); - indexIntervalTable.push_back(dParam.paramIndexInterval); - } +DerivativeAndOverload HessianModeVisitor::Derive() { + const FunctionDecl* FD = m_DiffReq.Function; + DiffParams args{}; + IndexIntervalTable indexIntervalTable{}; + if (m_DiffReq.Args) + for (auto dParam : m_DiffReq.DVI) { + args.push_back(dParam.param); + indexIntervalTable.push_back(dParam.paramIndexInterval); } - else - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - - std::vector secondDerivativeFuncs; - llvm::SmallVector IndependentArgsSize{}; - size_t TotalIndependentArgsSize = 0; - - // request.Function is original function passed in from clad::hessian - assert(m_DiffReq == request); - - std::string hessianFuncName = request.BaseFunctionName + "_hessian"; - if (request.Mode == DiffMode::hessian_diagonal) - hessianFuncName += "_diagonal"; - // To be consistent with older tests, nothing is appended to 'f_hessian' if - // we differentiate w.r.t. all the parameters at once. - if (args.size() != FD->getNumParams() || - !std::equal(m_DiffReq->param_begin(), m_DiffReq->param_end(), - args.begin())) { - for (auto arg : args) { - auto it = - std::find(m_DiffReq->param_begin(), m_DiffReq->param_end(), arg); - auto idx = std::distance(m_DiffReq->param_begin(), it); - hessianFuncName += ('_' + std::to_string(idx)); - } + else + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + + std::vector secondDerivativeFuncs; + llvm::SmallVector IndependentArgsSize{}; + size_t TotalIndependentArgsSize = 0; + + std::string hessianFuncName = m_DiffReq.BaseFunctionName + "_hessian"; + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) + hessianFuncName += "_diagonal"; + // To be consistent with older tests, nothing is appended to 'f_hessian' if + // we differentiate w.r.t. all the parameters at once. + if (args.size() != FD->getNumParams() || + !std::equal(m_DiffReq->param_begin(), m_DiffReq->param_end(), + args.begin())) { + for (auto arg : args) { + auto it = + std::find(m_DiffReq->param_begin(), m_DiffReq->param_end(), arg); + auto idx = std::distance(m_DiffReq->param_begin(), it); + hessianFuncName += ('_' + std::to_string(idx)); } + } - llvm::SmallVector paramTypes(m_DiffReq->getNumParams() + 1); - std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(), - std::begin(paramTypes), - [](const ParmVarDecl* PVD) { return PVD->getType(); }); - paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType()); - - const auto* originalFnProtoType = - cast(m_DiffReq->getType()); - QualType hessianFunctionType = m_Context.getFunctionType( - m_Context.VoidTy, - llvm::ArrayRef(paramTypes.data(), paramTypes.size()), - // Cast to function pointer. - originalFnProtoType->getExtProtoInfo()); - - // Check if the function is already declared as a custom derivative. - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto* DC = const_cast(m_DiffReq->getDeclContext()); - if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - hessianFuncName, DC, hessianFunctionType)) - return DerivativeAndOverload{customDerivative, nullptr}; - - // Ascertains the independent arguments and differentiates the function - // in forward and reverse mode by calling ProcessDiffRequest twice each - // iteration, storing each generated second derivative function - // (corresponds to columns of Hessian matrix) in a vector for private method - // merge. - for (auto PVD : FD->parameters()) { - auto it = std::find(std::begin(args), std::end(args), PVD); - if (it != args.end()) { - // Using the properties of a vector to find the index of the requested - // arg - auto argIndex = it - args.begin(); - if (isArrayOrPointerType(PVD->getType())) { - if (indexIntervalTable.size() == 0 || - indexIntervalTable[argIndex].size() == 0) { - std::string suggestedArgsStr{}; - if (auto SL = dyn_cast( - request.Args->IgnoreParenImpCasts())) { - llvm::StringRef str = SL->getString().trim(); - llvm::StringRef name{}; - do { - std::tie(name, str) = str.split(','); - if (name.trim().str() == PVD->getNameAsString()) { - suggestedArgsStr += (suggestedArgsStr.empty() ? "" : ", ") + - PVD->getNameAsString() + - "[0:getNameAsString() + ">]"; - } else { - suggestedArgsStr += (suggestedArgsStr.empty() ? "" : ", ") + - name.trim().str(); - } - } while (!str.empty()); - } else { - suggestedArgsStr = - PVD->getNameAsString() + "[0:]"; - } - std::string helperMsg("clad::hessian(" + FD->getNameAsString() + - ", \"" + suggestedArgsStr + "\")"); - diag(DiagnosticsEngine::Error, - request.Args ? request.Args->getEndLoc() : noLoc, - "Hessian mode differentiation w.r.t. array or pointer " - "parameters needs explicit declaration of the indices of the " - "array using the args parameter; did you mean '%0'", - {helperMsg}); - return {}; + llvm::SmallVector paramTypes(m_DiffReq->getNumParams() + 1); + std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(), + std::begin(paramTypes), + [](const ParmVarDecl* PVD) { return PVD->getType(); }); + paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType()); + + const auto* originalFnProtoType = + cast(m_DiffReq->getType()); + QualType hessianFunctionType = m_Context.getFunctionType( + m_Context.VoidTy, + llvm::ArrayRef(paramTypes.data(), paramTypes.size()), + // Cast to function pointer. + originalFnProtoType->getExtProtoInfo()); + + // Check if the function is already declared as a custom derivative. + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* DC = const_cast(m_DiffReq->getDeclContext()); + if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( + hessianFuncName, DC, hessianFunctionType)) + return DerivativeAndOverload{customDerivative, nullptr}; + + // Ascertains the independent arguments and differentiates the function + // in forward and reverse mode by calling ProcessDiffRequest twice each + // iteration, storing each generated second derivative function + // (corresponds to columns of Hessian matrix) in a vector for private method + // merge. + for (auto PVD : FD->parameters()) { + auto it = std::find(std::begin(args), std::end(args), PVD); + if (it != args.end()) { + // Using the properties of a vector to find the index of the requested + // arg + auto argIndex = it - args.begin(); + if (isArrayOrPointerType(PVD->getType())) { + if (indexIntervalTable.size() == 0 || + indexIntervalTable[argIndex].size() == 0) { + std::string suggestedArgsStr{}; + if (auto SL = dyn_cast( + m_DiffReq.Args->IgnoreParenImpCasts())) { + llvm::StringRef str = SL->getString().trim(); + llvm::StringRef name{}; + do { + std::tie(name, str) = str.split(','); + if (name.trim().str() == PVD->getNameAsString()) { + suggestedArgsStr += (suggestedArgsStr.empty() ? "" : ", ") + + PVD->getNameAsString() + + "[0:getNameAsString() + ">]"; + } else { + suggestedArgsStr += + (suggestedArgsStr.empty() ? "" : ", ") + name.trim().str(); + } + } while (!str.empty()); + } else { + suggestedArgsStr = PVD->getNameAsString() + "[0:]"; } + std::string helperMsg("clad::hessian(" + FD->getNameAsString() + + ", \"" + suggestedArgsStr + "\")"); + diag(DiagnosticsEngine::Error, + m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc, + "Hessian mode differentiation w.r.t. array or pointer " + "parameters needs explicit declaration of the indices of the " + "array using the args parameter; did you mean '%0'", + {helperMsg}); + return {}; + } - IndependentArgsSize.push_back(indexIntervalTable[argIndex].size()); - TotalIndependentArgsSize += indexIntervalTable[argIndex].size(); - - // Derive the function w.r.t. to each requested index of the current - // array in forward mode and then in reverse mode w.r.t to all - // requested args - for (auto i = indexIntervalTable[argIndex].Start; - i < indexIntervalTable[argIndex].Finish; i++) { - auto independentArgString = - PVD->getNameAsString() + "[" + std::to_string(i) + "]"; - auto ForwardModeIASL = - CreateStringLiteral(m_Context, independentArgString); - FunctionDecl* DFD = nullptr; - if (request.Mode == DiffMode::hessian_diagonal) - DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, - request, ForwardModeIASL, - m_Builder.m_DFC); - else - DFD = DeriveUsingForwardAndReverseMode( - m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args, m_Builder.m_DFC); - secondDerivativeFuncs.push_back(DFD); - } - } else { - IndependentArgsSize.push_back(1); - TotalIndependentArgsSize++; - // Derive the function w.r.t. to the current arg in forward mode and - // then in reverse mode w.r.t to all requested args + IndependentArgsSize.push_back(indexIntervalTable[argIndex].size()); + TotalIndependentArgsSize += indexIntervalTable[argIndex].size(); + + // Derive the function w.r.t. to each requested index of the current + // array in forward mode and then in reverse mode w.r.t to all + // requested args + for (auto i = indexIntervalTable[argIndex].Start; + i < indexIntervalTable[argIndex].Finish; i++) { + auto independentArgString = + PVD->getNameAsString() + "[" + std::to_string(i) + "]"; auto ForwardModeIASL = - CreateStringLiteral(m_Context, PVD->getNameAsString()); + CreateStringLiteral(m_Context, independentArgString); FunctionDecl* DFD = nullptr; - if (request.Mode == DiffMode::hessian_diagonal) + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, - request, ForwardModeIASL, + m_DiffReq, ForwardModeIASL, m_Builder.m_DFC); else DFD = DeriveUsingForwardAndReverseMode( - m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args, m_Builder.m_DFC); + m_Sema, m_CladPlugin, m_Builder, m_DiffReq, ForwardModeIASL, + m_DiffReq.Args, m_Builder.m_DFC); secondDerivativeFuncs.push_back(DFD); } + } else { + IndependentArgsSize.push_back(1); + TotalIndependentArgsSize++; + // Derive the function w.r.t. to the current arg in forward mode and + // then in reverse mode w.r.t to all requested args + auto ForwardModeIASL = + CreateStringLiteral(m_Context, PVD->getNameAsString()); + FunctionDecl* DFD = nullptr; + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) + DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, + m_DiffReq, ForwardModeIASL, + m_Builder.m_DFC); + else + DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, m_DiffReq, ForwardModeIASL, + m_DiffReq.Args, m_Builder.m_DFC); + secondDerivativeFuncs.push_back(DFD); } } - return Merge(secondDerivativeFuncs, IndependentArgsSize, - TotalIndependentArgsSize, hessianFuncName, DC, - hessianFunctionType, paramTypes); } + return Merge(secondDerivativeFuncs, IndependentArgsSize, + TotalIndependentArgsSize, hessianFuncName, DC, + hessianFunctionType, paramTypes); +} // Combines all generated second derivative functions into a // single hessian function by creating CallExprs to each individual diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 949e44a8b..d05884095 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -16,10 +16,8 @@ ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( DerivativeBuilder& builder, const DiffRequest& request) : ReverseModeVisitor(builder, request) {} -DerivativeAndOverload -ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, - const DiffRequest& request) { - assert(m_DiffReq == request); +DerivativeAndOverload ReverseModeForwPassVisitor::Derive() { + const FunctionDecl* FD = m_DiffReq.Function; assert(m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass); @@ -64,7 +62,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, m_Derivative->setParams(params); m_Derivative->setBody(nullptr); - if (!request.DeclarationOnly) { + if (!m_DiffReq.DeclarationOnly) { beginScope(Scope::FnScope | Scope::DeclScope); m_DerivativeFnScope = getCurrentScope(); @@ -87,9 +85,10 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, // Size >= current derivative order means that there exists a declaration // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + if (m_DiffReq.DerivedFDPrototypes.size() >= + m_DiffReq.CurrentDerivativeOrder) m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); + m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4d600caf7..320b030b3 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -263,12 +263,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return gradientOverloadFD; } - DerivativeAndOverload - ReverseModeVisitor::Derive(const FunctionDecl* FD, - const DiffRequest& request) { + DerivativeAndOverload ReverseModeVisitor::Derive() { + const FunctionDecl* FD = m_DiffReq.Function; if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); - assert(m_DiffReq == request); // FIXME: reverse mode plugins may have request mode other than // `DiffMode::reverse`, but they still need the `DiffMode::reverse` mode @@ -284,31 +282,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; - DiffInputVarsInfo DVI; - if (request.Args) { - DVI = request.DVI; - for (const auto& dParam : DVI) + if (m_DiffReq.Args) + for (const auto& dParam : m_DiffReq.DVI) args.push_back(dParam.param); - } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); if (args.empty()) return {}; if (m_ExternalSource) - m_ExternalSource->ActAfterParsingDiffArgs(request, args); + m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); // Save the type of the output parameter(s) that is add by clad to the // derived function - if (request.Mode == DiffMode::jacobian) { + if (m_DiffReq.Mode == DiffMode::jacobian) { unsigned lastArgN = m_DiffReq->getNumParams() - 1; outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString(); } - auto derivativeBaseName = request.BaseFunctionName; + auto derivativeBaseName = m_DiffReq.BaseFunctionName; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if // we differentiate w.r.t. all the parameters at once. - if (request.Mode == DiffMode::jacobian) { + if (m_DiffReq.Mode == DiffMode::jacobian) { // If Jacobian is asked, the last parameter is the result parameter // and should be ignored if (args.size() != FD->getNumParams()-1){ @@ -346,9 +341,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool shouldCreateOverload = false; // FIXME: Gradient overload doesn't know how to handle additional parameters // added by the plugins yet. - if (request.Mode != DiffMode::jacobian && numExtraParam == 0) + if (m_DiffReq.Mode != DiffMode::jacobian && numExtraParam == 0) shouldCreateOverload = true; - if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty()) + if (!m_DiffReq.DeclarationOnly && !m_DiffReq.DerivedFDPrototypes.empty()) // If the overload is already created, we don't need to create it again. shouldCreateOverload = false; @@ -411,8 +406,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, gradientFD->setParams(paramsRef); gradientFD->setBody(nullptr); - if (!request.DeclarationOnly) { - if (request.Mode == DiffMode::jacobian) { + if (!m_DiffReq.DeclarationOnly) { + if (m_DiffReq.Mode == DiffMode::jacobian) { // Reference to the output parameter. m_Result = BuildDeclRef(params.back()); numParams = args.size(); @@ -451,7 +446,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DerivativeFnScope = getCurrentScope(); beginBlock(); if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); + m_ExternalSource->ActOnStartOfDerivedFnBody(m_DiffReq); Stmt* gradientBody = nullptr; @@ -466,9 +461,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Size >= current derivative order means that there exists a declaration // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + if (m_DiffReq.DerivedFDPrototypes.size() >= + m_DiffReq.CurrentDerivativeOrder) m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); + m_DiffReq + .DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); @@ -483,22 +480,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return DerivativeAndOverload{result.first, gradientOverloadFD}; } - DerivativeAndOverload - ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request) { + DerivativeAndOverload ReverseModeVisitor::DerivePullback() { + const clang::FunctionDecl* FD = m_DiffReq.Function; // FIXME: Duplication of external source here is a workaround // for the two 'Derive's being different functions. if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); - // FIXME: We should not use const_cast to get the decl request here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(m_DiffReq) = request; assert(m_DiffReq.Mode == DiffMode::experimental_pullback); assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; - if (!request.DVI.empty()) - for (const auto& dParam : request.DVI) + if (!m_DiffReq.DVI.empty()) + for (const auto& dParam : m_DiffReq.DVI) args.push_back(dParam.param); else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); @@ -510,7 +503,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, #endif if (m_ExternalSource) - m_ExternalSource->ActAfterParsingDiffArgs(request, args); + m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback"; @@ -557,7 +550,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Derivative->setParams(params); m_Derivative->setBody(nullptr); - if (!request.DeclarationOnly) { + if (!m_DiffReq.DeclarationOnly) { if (m_ExternalSource) m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); @@ -566,7 +559,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, beginBlock(); if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); + m_ExternalSource->ActOnStartOfDerivedFnBody(m_DiffReq); StmtDiff bodyDiff = Visit(m_DiffReq->getBody()); Stmt* forward = bodyDiff.getStmt(); @@ -595,9 +588,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Size >= current derivative order means that there exists a declaration // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + if (m_DiffReq.DerivedFDPrototypes.size() >= + m_DiffReq.CurrentDerivativeOrder) m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); + m_DiffReq + .DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext();