diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index 4ad838a31..f2c06bec2 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -21,23 +21,28 @@ enum order { third = 3, }; // enum order -enum opts { +enum opts : unsigned { use_enzyme = 1 << ORDER_BITS, vector_mode = 1 << (ORDER_BITS + 1), + + // Storing two bits for tbr analysis. + // 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid + enable_tbr = 1 << (ORDER_BITS + 2), + disable_tbr = 1 << (ORDER_BITS + 3), }; // enum opts -constexpr unsigned GetDerivativeOrder(unsigned const bitmasked_opts) { +constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) { return bitmasked_opts & ORDER_MASK; } -constexpr bool HasOption(unsigned const bitmasked_opts, unsigned const option) { +constexpr bool HasOption(const unsigned bitmasked_opts, const unsigned option) { return (bitmasked_opts & option) == option; } constexpr unsigned GetBitmaskedOpts() { return 0; } -constexpr unsigned GetBitmaskedOpts(unsigned const first) { return first; } +constexpr unsigned GetBitmaskedOpts(const unsigned first) { return first; } template -constexpr unsigned GetBitmaskedOpts(unsigned const first, Opts... opts) { +constexpr unsigned GetBitmaskedOpts(const unsigned first, Opts... opts) { return first | GetBitmaskedOpts(opts...); } diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 2705cc447..7c108fa23 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -86,6 +86,12 @@ namespace clad { using DiffSchedule = llvm::SmallVector; using DiffInterval = std::vector; + struct RequestOptions { + /// This is a flag to indicate the default behaviour to enable/disable + /// TBR analysis during reverse-mode differentiation. + bool EnableTBRAnalysis = false; + }; + class DiffCollector: public clang::RecursiveASTVisitor { /// The source interval where clad was activated. /// @@ -101,9 +107,11 @@ namespace clad { const clang::FunctionDecl* m_TopMostFD = nullptr; clang::Sema& m_Sema; + RequestOptions& m_Options; + public: DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S); + DiffSchedule& plans, clang::Sema& S, RequestOptions& opts); bool VisitCallExpr(clang::CallExpr* E); private: diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index a686e35ef..b3e708b18 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -358,9 +358,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// \param[in] args independent parameters information /// \returns `CladFunction` object to access the corresponding derived /// function. - template , + template , typename = typename std::enable_if< !std::is_class>::value>::type> CladFunction, true> __attribute__(( @@ -376,9 +375,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// Specialization for differentiating functors. /// The specialization is needed because objects have to be passed /// by reference whereas functions have to be passed by value. - template , + template , typename = typename std::enable_if< std::is_class>::value>::type> CladFunction, true> __attribute__(( @@ -397,8 +395,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// \param[in] args independent parameters information /// \returns `CladFunction` object to access the corresponding derived /// function. - template , + template , typename = typename std::enable_if< !std::is_class>::value>::type> CladFunction> __attribute__(( @@ -406,18 +404,16 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { hessian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by hessian*/, - code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction>( + derivedFn /* will be replaced by hessian*/, code); } /// Specialization for differentiating functors. /// The specialization is needed because objects have to be passed /// by reference whereas functions have to be passed by value. - template , + template , typename = typename std::enable_if< std::is_class>::value>::type> CladFunction> __attribute__(( @@ -425,10 +421,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { hessian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by hessian*/, - code, f); + return CladFunction>( + derivedFn /* will be replaced by hessian*/, code, f); } /// Generates function which computes jacobian matrix of the given function @@ -438,8 +432,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// \param[in] args independent parameters information /// \returns `CladFunction` object to access the corresponding derived /// function. - template , + template , typename = typename std::enable_if< !std::is_class>::value>::type> CladFunction> __attribute__(( @@ -447,18 +441,16 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { jacobian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by Jacobian*/, - code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction>( + derivedFn /* will be replaced by Jacobian*/, code); } /// Specialization for differentiating functors. /// The specialization is needed because objects have to be passed /// by reference whereas functions have to be passed by value. - template , + template , typename = typename std::enable_if< std::is_class>::value>::type> CladFunction> __attribute__(( @@ -466,10 +458,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { jacobian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by Jacobian*/, - code, f); + return CladFunction>( + derivedFn /* will be replaced by Jacobian*/, code, f); } template getAnnotation().equals("D")) { - request.Mode = DiffMode::forward; - - // bitmask_opts is a template pack of unsigned integers, so we need to - // do bitwise or of all the values to get the final value. - unsigned bitmasked_opts_value = 0; + // bitmask_opts is a template pack of unsigned integers, so we need to + // do bitwise or of all the values to get the final value. + unsigned bitmasked_opts_value = 0; + bool enable_tbr_in_req = false; + bool disable_tbr_in_req = false; + if (!A->getAnnotation().equals("E") && + FD->getTemplateSpecializationArgs()) { const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); if (template_arg.getKind() == TemplateArgument::Pack) for (const auto& arg : @@ -569,14 +571,39 @@ namespace clad { bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); else bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); + + // Set option for TBR analysis. + enable_tbr_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr); + disable_tbr_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr); + if (enable_tbr_in_req && disable_tbr_in_req) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "Both enable and disable TBR options are specified."); + return true; + } + if (enable_tbr_in_req || disable_tbr_in_req) { + // override the default value of TBR analysis. + request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req; + } else { + request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; + } + } + + if (A->getAnnotation().equals("D")) { + request.Mode = DiffMode::forward; unsigned derivative_order = clad::GetDerivativeOrder(bitmasked_opts_value); if (derivative_order == 0) { derivative_order = 1; // default to first order derivative. } request.RequestedDerivativeOrder = derivative_order; - if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) { + if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) request.use_enzyme = true; + if (enable_tbr_in_req) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "TBR analysis is not meant for forward mode AD."); + return true; } if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { request.Mode = DiffMode::vector_forward_mode; @@ -601,17 +628,6 @@ namespace clad { request.Mode = DiffMode::jacobian; } else if (A->getAnnotation().equals("G")) { request.Mode = DiffMode::reverse; - - // bitmask_opts is a template pack of unsigned integers, so we need to - // do bitwise or of all the values to get the final value. - unsigned bitmasked_opts_value = 0; - const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); - if (template_arg.getKind() == TemplateArgument::Pack) - for (const auto& arg : - FD->getTemplateSpecializationArgs()->get(0).pack_elements()) - bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); - else - bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) request.use_enzyme = true; // reverse vector mode is not yet supported. diff --git a/test/Analyses/TBR.cpp b/test/Analyses/TBR.cpp index 17be557b1..c87dfb93c 100644 --- a/test/Analyses/TBR.cpp +++ b/test/Analyses/TBR.cpp @@ -1,4 +1,4 @@ -// RUN: %cladclang -mllvm -debug-only=clad-tbr -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s +// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s // REQUIRES: asserts //CHECK-NOT: {{.*error|warning|note:.*}} @@ -13,7 +13,7 @@ double f1(double x) { #define TEST(F, x) { \ result[0] = 0; \ - auto F##grad = clad::gradient(F);\ + auto F##grad = clad::gradient(F);\ F##grad.execute(x, result);\ printf("{%.2f}\n", result[0]); \ } diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index 953ccc6cf..a3654f503 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -175,6 +175,7 @@ int main () { clad::differentiate(test_6, "x"); clad::differentiate(test_7, "i"); clad::differentiate(test_8, "x"); - + clad::differentiate(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}} + clad::differentiate(test_8); // expected-error {{Both enable and disable TBR options are specified.}} return 0; } diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index cf5fd322b..294325ac6 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -617,7 +617,7 @@ int main() { d_structPointer.execute(5, &d_x); printf("%.2f\n", d_x); // CHECK-EXEC: 1.00 - auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x"); + auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x"); d_x = 0; d_cStyleMemoryAlloc.execute(5, 7, &d_x); printf("%.2f\n", d_x); // CHECK-EXEC: 4.00 diff --git a/test/Misc/Args.C b/test/Misc/Args.C index 58e44c751..35b7c3e5f 100644 --- a/test/Misc/Args.C +++ b/test/Misc/Args.C @@ -5,6 +5,9 @@ // CHECK_HELP-NEXT: -fdump-derived-fn // CHECK_HELP-NEXT: -fdump-derived-fn-ast // CHECK_HELP-NEXT: -fgenerate-source-file +// CHECK_HELP-NEXT: -fno-validate-clang-version +// CHECK_HELP-NEXT: -enable-tbr +// CHECK_HELP-NEXT: -disable-tbr // CHECK_HELP-NEXT: -fcustom-estimation-model // CHECK_HELP-NEXT: -fprint-num-diff-errors // CHECK_HELP-NEXT: -help @@ -23,3 +26,7 @@ // RUN: -Xclang %t.so %S/../../demos/ErrorEstimation/CustomModel/test.cpp \ // RUN: -I%S/../../include 2>&1 | FileCheck --check-prefix=CHECK_SO_INVALID %s // CHECK_SO_INVALID: Failed to load '{{.*.so}}', {{.*}}. Aborting. + +// RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-tbr \ +// RUN: -Xclang -plugin-arg-clad -Xclang -disable-tbr %s 2>&1 | FileCheck --check-prefix=CHECK_TBR %s +// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together \ No newline at end of file diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0f55cf065..6ce14e591 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -107,18 +107,15 @@ namespace clad { if (m_HandleTopLevelDeclInternal) return true; + RequestOptions opts{}; + SetRequestOptions(opts); DiffSchedule requests{}; - DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); + DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema(), + opts); if (requests.empty()) return true; - // FIXME: flags have to be set manually since DiffCollector's constructor - // does not have access to m_DO. - if (m_DO.EnableTBRAnalysis) - for (DiffRequest& request : requests) - request.EnableTBRAnalysis = true; - // FIXME: Remove the PerformPendingInstantiations altogether. We should // somehow make the relevant functions referenced. // Instantiate all pending for instantiations templates, because we will @@ -318,6 +315,19 @@ namespace clad { m_HasRuntime = !R.empty(); return m_HasRuntime; } + + void CladPlugin::SetRequestOptions(RequestOptions& opts) const { + SetTBRAnalysisOptions(m_DO, opts); + } + + void CladPlugin::SetTBRAnalysisOptions(const DifferentiationOptions& DO, + RequestOptions& opts) { + // If user has explicitly specified the mode for TBR analysis, use it. + if (DO.EnableTBRAnalysis || DO.DisableTBRAnalysis) + opts.EnableTBRAnalysis = DO.EnableTBRAnalysis && !DO.DisableTBRAnalysis; + else + opts.EnableTBRAnalysis = false; // Default mode. + } } // end namespace plugin clad::CladTimerGroup::CladTimerGroup() diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 808443e49..47ae51046 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -74,22 +74,24 @@ namespace clad { namespace plugin { struct DifferentiationOptions { - DifferentiationOptions() - : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), - DumpDerivedAST(false), GenerateSourceFile(false), - ValidateClangVersion(true), EnableTBRAnalysis(false), - CustomEstimationModel(false), PrintNumDiffErrorInfo(false) {} - - bool DumpSourceFn : 1; - bool DumpSourceFnAST : 1; - bool DumpDerivedFn : 1; - bool DumpDerivedAST : 1; - bool GenerateSourceFile : 1; - bool ValidateClangVersion : 1; - bool EnableTBRAnalysis : 1; - bool CustomEstimationModel : 1; - bool PrintNumDiffErrorInfo : 1; - std::string CustomModelName; + DifferentiationOptions() + : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), + DumpDerivedAST(false), GenerateSourceFile(false), + ValidateClangVersion(true), EnableTBRAnalysis(false), + DisableTBRAnalysis(false), CustomEstimationModel(false), + PrintNumDiffErrorInfo(false) {} + + bool DumpSourceFn : 1; + bool DumpSourceFnAST : 1; + bool DumpDerivedFn : 1; + bool DumpDerivedAST : 1; + bool GenerateSourceFile : 1; + bool ValidateClangVersion : 1; + bool EnableTBRAnalysis : 1; + bool DisableTBRAnalysis : 1; + bool CustomEstimationModel : 1; + bool PrintNumDiffErrorInfo : 1; + std::string CustomModelName; }; class CladPlugin : public clang::ASTConsumer { @@ -109,6 +111,9 @@ namespace clad { private: bool CheckBuiltins(); + void SetRequestOptions(RequestOptions& opts) const; + static void SetTBRAnalysisOptions(const DifferentiationOptions& DO, + RequestOptions& opts); void ProcessTopLevelDecl(clang::Decl* D); }; @@ -146,6 +151,8 @@ namespace clad { m_DO.ValidateClangVersion = false; } else if (args[i] == "-enable-tbr") { m_DO.EnableTBRAnalysis = true; + } else if (args[i] == "-disable-tbr") { + m_DO.DisableTBRAnalysis = true; } else if (args[i] == "-fcustom-estimation-model") { m_DO.CustomEstimationModel = true; if (++i == e) { @@ -170,6 +177,14 @@ namespace clad { "derivative.\n" << "-fgenerate-source-file - Produces a file containing the " "derivatives.\n" + << "-fno-validate-clang-version - Disables the validation of " + "the clang version.\n" + << "-enable-tbr - Ensures that TBR analysis is enabled during " + "reverse-mode differentiation unless explicitly specified " + "in an individual request.\n" + << "-disable-tbr - Ensures that TBR analysis is disabled " + "during reverse-mode differentiation unless explicitly " + "specified in an individual request.\n" << "-fcustom-estimation-model - allows user to send in a " "shared object to use as the custom estimation model.\n" << "-fprint-num-diff-errors - allows users to print the " @@ -186,6 +201,11 @@ namespace clad { if (!checkClangVersion()) return false; } + if (m_DO.EnableTBRAnalysis && m_DO.DisableTBRAnalysis) { + llvm::errs() << "clad: Error: -enable-tbr and -disable-tbr cannot " + "be used together.\n"; + return false; + } return true; }