diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 4772a2445..0d96a6c27 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -117,7 +117,8 @@ struct DiffRequest { CurrentDerivativeOrder == other.CurrentDerivativeOrder && RequestedDerivativeOrder == other.RequestedDerivativeOrder && CallContext == other.CallContext && Args == other.Args && - Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && EnableActivityAnalysis == other.EnableActivityAnalysis && + Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && + EnableActivityAnalysis == other.EnableActivityAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && DeclarationOnly == other.DeclarationOnly; } diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index 3d12e2283..b4f347fdd 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -2,35 +2,33 @@ using namespace clang; -namespace clad{ - -void VariedAnalyzer::Analyze(const FunctionDecl* FD) { - // Build the CFG (control-flow graph) of FD. - clang::CFG::BuildOptions Options; - m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); - - - m_BlockData.resize(m_CFG->size()); - m_BlockPassCounter.resize(m_CFG->size(), 0); - - // Set current block ID to the ID of entry the block. - auto* entry = &m_CFG->getEntry(); - m_CurBlockID = entry->getBlockID(); - m_BlockData[m_CurBlockID] = new VarsData(); - for(const auto& i: m_VariedDecls){ - m_BlockData[m_CurBlockID]->insert(i); - } - // Add the entry block to the queue. - m_CFGQueue.insert(m_CurBlockID); - - // Visit CFG blocks in the queue until it's empty. - while (!m_CFGQueue.empty()) { - auto IDIter = std::prev(m_CFGQueue.end()); - m_CurBlockID = *IDIter; - m_CFGQueue.erase(IDIter); - CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); - VisitCFGBlock(nextBlock); - } +namespace clad { + +void VariedAnalyzer::Analyze(const FunctionDecl* FD) { + // Build the CFG (control-flow graph) of FD. + clang::CFG::BuildOptions Options; + m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); + + m_BlockData.resize(m_CFG->size()); + m_BlockPassCounter.resize(m_CFG->size(), 0); + + // Set current block ID to the ID of entry the block. + auto* entry = &m_CFG->getEntry(); + m_CurBlockID = entry->getBlockID(); + m_BlockData[m_CurBlockID] = new VarsData(); + for (const auto& i : m_VariedDecls) + m_BlockData[m_CurBlockID]->insert(i); + // Add the entry block to the queue. + m_CFGQueue.insert(m_CurBlockID); + + // Visit CFG blocks in the queue until it's empty. + while (!m_CFGQueue.empty()) { + auto IDIter = std::prev(m_CFGQueue.end()); + m_CurBlockID = *IDIter; + m_CFGQueue.erase(IDIter); + CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); + VisitCFGBlock(nextBlock); + } } CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { @@ -40,75 +38,71 @@ CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { void VariedAnalyzer::VisitCFGBlock(const CFGBlock& block) { // Visit all the statements inside the block. for (const clang::CFGElement& Element : block) { - if (Element.getKind() == clang::CFGElement::Statement) { + if (Element.getKind() == clang::CFGElement::Statement) { const clang::Stmt* S = Element.castAs().getStmt(); TraverseStmt(const_cast(S)); - } + } } - for(const auto succ: block.succs()){ + for (const auto succ : block.succs()) { if (!succ) continue; auto& succData = m_BlockData[succ->getBlockID()]; - if(!succData){ + if (!succData) { succData = new VarsData(*m_BlockData[block.getBlockID()]); succData->m_Prev = m_BlockData[block.getBlockID()]; } - - if(succ->getBlockID() > block.getBlockID()){ - if(m_LoopMem == *m_BlockData[block.getBlockID()]) + if (succ->getBlockID() > block.getBlockID()) { + if (m_LoopMem == *m_BlockData[block.getBlockID()]) m_shouldPushSucc = false; - + // has to be changed - for(const auto& i : *m_BlockData[block.getBlockID()]) + for (const auto& i : *m_BlockData[block.getBlockID()]) m_LoopMem.insert(i); } - if(m_shouldPushSucc){ + if (m_shouldPushSucc) m_CFGQueue.insert(succ->getBlockID()); - } m_shouldPushSucc = true; merge(succData, m_BlockData[block.getBlockID()]); } // has to be changed - for(const auto& i: *m_BlockData[block.getBlockID()]) + for (const auto& i : *m_BlockData[block.getBlockID()]) m_VariedDecls.insert(i); } -bool VariedAnalyzer::isVaried(const VarDecl* VD){ - auto& curBranch = getCurBlockVarsData(); - return curBranch.find(VD) != curBranch.end(); +bool VariedAnalyzer::isVaried(const VarDecl* VD) { + auto& curBranch = getCurBlockVarsData(); + return curBranch.find(VD) != curBranch.end(); } void VariedAnalyzer::merge(VarsData* targetData, VarsData* mergeData) { - for(const auto& i: *mergeData){ + for (const auto& i : *mergeData) targetData->insert(i); - } - for(const auto& i: *targetData){ + for (const auto& i : *targetData) mergeData->insert(i); - } } void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { - auto& curBranch = getCurBlockVarsData(); - curBranch.insert(VD); + auto& curBranch = getCurBlockVarsData(); + curBranch.insert(VD); } bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { Expr* L = BinOp->getLHS(); Expr* R = BinOp->getRHS(); - if(BinOp->isAssignmentOp()){ + if (BinOp->isAssignmentOp()) { m_Varied = false; TraverseStmt(R); m_Marking = m_Varied; TraverseStmt(L); m_Marking = false; - }else{ + } else { TraverseStmt(L); TraverseStmt(R); } @@ -119,21 +113,18 @@ bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) { return true; } -bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { - return true; -} +bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { return true; } bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { - for (auto* D : DS->decls()){ + for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { - if(Expr* init = VD->getInit()){ + if (Expr* init = VD->getInit()) { m_Varied = false; TraverseStmt(init); m_Marking = true; auto& curBranch = getCurBlockVarsData(); - if(curBranch.find(VD) == curBranch.end() && m_Varied){ + if (curBranch.find(VD) == curBranch.end() && m_Varied) copyVarToCurBlock(VD); - } m_Marking = false; } } @@ -141,23 +132,21 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { return true; } - -bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp){ +bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { Expr* E = UnOp->getSubExpr(); TraverseStmt(E); return true; } bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { - if(isVaried(dyn_cast(DRE->getDecl()))){ - m_Varied = true; - } + if (isVaried(dyn_cast(DRE->getDecl()))) + m_Varied = true; - if (const auto* VD = dyn_cast(DRE->getDecl())) { - auto& curBranch = getCurBlockVarsData(); - if (curBranch.find(VD) == curBranch.end() && m_Varied && m_Marking) - copyVarToCurBlock(VD); - } - return true; -} + if (const auto* VD = dyn_cast(DRE->getDecl())) { + auto& curBranch = getCurBlockVarsData(); + if (curBranch.find(VD) == curBranch.end() && m_Varied && m_Marking) + copyVarToCurBlock(VD); + } + return true; } +} // namespace clad diff --git a/lib/Differentiator/ActivityAnalyzer.h b/lib/Differentiator/ActivityAnalyzer.h index b6df42268..94a3e3f86 100644 --- a/lib/Differentiator/ActivityAnalyzer.h +++ b/lib/Differentiator/ActivityAnalyzer.h @@ -4,109 +4,107 @@ #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/Compatibility.h" -#include -#include #include -#include -#include #include +#include +#include +#include +#include using namespace clang; -namespace clad{ -class VariedAnalyzer : public clang::RecursiveASTVisitor{ - - - bool m_Varied = false; - bool m_Marking = false; - bool m_shouldPushSucc = true; - - std::set& m_VariedDecls; - - struct VarsData { - std::set m_Data; - VarsData* m_Prev = nullptr; - - VarsData() = default; - VarsData(const VarsData& other) = default; - ~VarsData() = default; - VarsData(VarsData&& other) noexcept - : m_Data(std::move(other.m_Data)), m_Prev(other.m_Prev) {} - VarsData& operator=(const VarsData& other) = delete; - VarsData& operator=(VarsData&& other) noexcept { - if (&m_Data == &other.m_Data) { - m_Data = std::move(other.m_Data); - m_Prev = other.m_Prev; - } - return *this; - } - - bool operator==(VarsData other) noexcept{ - std::vector diff; - if(m_Data == other.m_Data) - return true; - return false; - } - - using iterator = - std::set::iterator; - int size(){return m_Data.size();} - iterator begin() { return m_Data.begin(); } - iterator end() { return m_Data.end(); } - iterator find(const clang::VarDecl* VD) { return m_Data.find(VD); } - void insert(const clang::VarDecl* VD){m_Data.insert(VD);} - void clear() { m_Data.clear(); } - std::set updateLoopMem(){return m_Data;} - }; - - VarsData m_LoopMem; - clang::CFGBlock* getCFGBlockByID(unsigned ID); - // VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false); - - std::set static collectDataFromPredecessors(VarsData* varsData, - VarsData* limit = nullptr); - - static VarsData* findLowestCommonAncestor(VarsData* varsData1, +namespace clad { +class VariedAnalyzer : public clang::RecursiveASTVisitor { + + bool m_Varied = false; + bool m_Marking = false; + bool m_shouldPushSucc = true; + + std::set& m_VariedDecls; + + struct VarsData { + std::set m_Data; + VarsData* m_Prev = nullptr; + + VarsData() = default; + VarsData(const VarsData& other) = default; + ~VarsData() = default; + VarsData(VarsData&& other) noexcept + : m_Data(std::move(other.m_Data)), m_Prev(other.m_Prev) {} + VarsData& operator=(const VarsData& other) = delete; + VarsData& operator=(VarsData&& other) noexcept { + if (&m_Data == &other.m_Data) { + m_Data = std::move(other.m_Data); + m_Prev = other.m_Prev; + } + return *this; + } + + bool operator==(VarsData other) noexcept { + std::vector diff; + if (m_Data == other.m_Data) + return true; + return false; + } + + using iterator = std::set::iterator; + int size() { return m_Data.size(); } + iterator begin() { return m_Data.begin(); } + iterator end() { return m_Data.end(); } + iterator find(const clang::VarDecl* VD) { return m_Data.find(VD); } + void insert(const clang::VarDecl* VD) { m_Data.insert(VD); } + void clear() { m_Data.clear(); } + std::set updateLoopMem() { return m_Data; } + }; + + VarsData m_LoopMem; + clang::CFGBlock* getCFGBlockByID(unsigned ID); + // VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false); + + std::set static collectDataFromPredecessors( + VarsData* varsData, VarsData* limit = nullptr); + + static VarsData* findLowestCommonAncestor(VarsData* varsData1, VarsData* varsData2); - void merge(VarsData* targetData, VarsData* mergeData); - ASTContext& m_Context; - std::unique_ptr m_CFG; - std::vector m_BlockData; - std::vector m_BlockPassCounter; - unsigned m_CurBlockID{}; - std::set m_CFGQueue; + void merge(VarsData* targetData, VarsData* mergeData); + ASTContext& m_Context; + std::unique_ptr m_CFG; + std::vector m_BlockData; + std::vector m_BlockPassCounter; + unsigned m_CurBlockID{}; + std::set m_CFGQueue; + + void addToVaried(const clang::VarDecl* VD); + bool isVaried(const clang::VarDecl* VD); - void addToVaried(const clang::VarDecl* VD); - bool isVaried(const clang::VarDecl* VD); - - void copyVarToCurBlock(const clang::VarDecl* VD); - VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } + void copyVarToCurBlock(const clang::VarDecl* VD); + VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } public: - /// Constructor - VariedAnalyzer(ASTContext& Context, std::set& Decls) - : m_VariedDecls(Decls), m_Context(Context) {} - - /// Destructor - ~VariedAnalyzer() = default; - - /// Delete copy/move operators and constructors. - VariedAnalyzer(const VariedAnalyzer&) = delete; - VariedAnalyzer& operator=(const VariedAnalyzer&) = delete; - VariedAnalyzer(const VariedAnalyzer&&) = delete; - VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; - - /// Visitors - void Analyze(const clang::FunctionDecl* FD); - - void VisitCFGBlock(const clang::CFGBlock& block); - - bool VisitBinaryOperator(clang::BinaryOperator* BinOp); - bool VisitCallExpr(clang::CallExpr* CE); - bool VisitConditionalOperator(clang::ConditionalOperator* CO); - bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); - bool VisitDeclStmt(clang::DeclStmt* DS); - bool VisitUnaryOperator(clang::UnaryOperator* UnOp); + /// Constructor + VariedAnalyzer(ASTContext& Context, std::set& Decls) + : m_VariedDecls(Decls), m_Context(Context) {} + + /// Destructor + ~VariedAnalyzer() = default; + + /// Delete copy/move operators and constructors. + VariedAnalyzer(const VariedAnalyzer&) = delete; + VariedAnalyzer& operator=(const VariedAnalyzer&) = delete; + VariedAnalyzer(const VariedAnalyzer&&) = delete; + VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; + + /// Visitors + void Analyze(const clang::FunctionDecl* FD); + + void VisitCFGBlock(const clang::CFGBlock& block); + + bool VisitBinaryOperator(clang::BinaryOperator* BinOp); + bool VisitCallExpr(clang::CallExpr* CE); + bool VisitConditionalOperator(clang::ConditionalOperator* CO); + bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); + bool VisitDeclStmt(clang::DeclStmt* DS); + bool VisitUnaryOperator(clang::UnaryOperator* UnOp); }; -} \ No newline at end of file +} // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index a3c5b39f6..4f2dd8826 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -1,7 +1,7 @@ #include "clad/Differentiator/DiffPlanner.h" -#include "TBRAnalyzer.h" #include "ActivityAnalyzer.h" +#include "TBRAnalyzer.h" #include "clang/AST/ASTContext.h" #include "clang/AST/RecursiveASTVisitor.h" @@ -619,20 +619,22 @@ namespace clad { bool DiffRequest::shouldHaveAdjoint(VarDecl* VD) const { if (!EnableActivityAnalysis) return true; - - if(VD->getType()->isPointerType()) + + if (VD->getType()->isPointerType()) return true; if (!m_ActivityRunInfo.HasAnalysisRun) { if (Args) { for (const auto& dParam : DVI) m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); - }else{ - std::copy(Function->param_begin(), Function->param_end(), std::inserter(m_ActivityRunInfo.ToBeRecorded, m_ActivityRunInfo.ToBeRecorded.end())); + } else { + std::copy(Function->param_begin(), Function->param_end(), + std::inserter(m_ActivityRunInfo.ToBeRecorded, + m_ActivityRunInfo.ToBeRecorded.end())); } - + VariedAnalyzer analyzer(Function->getASTContext(), - m_ActivityRunInfo.ToBeRecorded); + m_ActivityRunInfo.ToBeRecorded); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; } @@ -711,7 +713,8 @@ namespace clad { } if (enable_aa_in_req || disable_aa_in_req) { // override the default value of TBR analysis. - request.EnableActivityAnalysis = enable_aa_in_req && !disable_aa_in_req; + request.EnableActivityAnalysis = + enable_aa_in_req && !disable_aa_in_req; } else { request.EnableActivityAnalysis = m_Options.EnableActivityAnalysis; } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 832222769..5878849fa 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2534,7 +2534,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (!ResultRef) return Clone(BinOp); - + // We need to store values of derivative pointer variables in forward pass // and restore them in reverse pass. if (isPointerOp) { @@ -3002,7 +3002,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (derivedVDE) m_Variables.emplace(VDClone, derivedVDE); - // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. // This can happen in rare cases, e.g. when the original function @@ -3058,7 +3057,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // }else{ // adjInsert = true; // } - + if (m_ExternalSource) m_ExternalSource->ActOnStartOfDifferentiateSingleStmt(); beginBlock(direction::reverse); @@ -3089,7 +3088,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // if(!adjInsert) // ReverseResult = nullptr; - return StmtDiff(SDiff.getStmt(), ReverseResult); } diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index a5398b9dd..32ebb897b 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -106,7 +106,7 @@ double f3(double x){ //CHECK-NEXT: clad::tape _t5 = {}; //CHECK-NEXT: double _d_x1 = 0., _d_x2 = 0., _d_x3 = 0., _d_x4 = 0., _d_x5 = 0.; //CHECK-NEXT: double x1, x2, x3, x4, x5 = 0; -//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: unsigned long _t0 = 0UL; //CHECK-NEXT: while (!x3) //CHECK-NEXT: { //CHECK-NEXT: _t0++; @@ -171,4 +171,5 @@ int main(){ double result[3] = {}; TEST(f1, 3);// CHECK-EXEC: {6.00} TEST(f2, 3);// CHECK-EXEC: {6.00} + TEST(f3, 3);// CHECK-EXEC: {0.00} } \ No newline at end of file diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index d9a3097a6..19a5e7d1b 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -399,10 +399,11 @@ namespace clad { } static void SetActivityAnalysisOptions(const DifferentiationOptions& DO, - RequestOptions& opts) { + RequestOptions& opts) { // If user has explicitly specified the mode for AA, use it. if (DO.EnableActivityAnalysis || DO.DisableActivityAnalysis) - opts.EnableActivityAnalysis = DO.EnableActivityAnalysis && !DO.DisableActivityAnalysis; + opts.EnableActivityAnalysis = + DO.EnableActivityAnalysis && !DO.DisableActivityAnalysis; else opts.EnableActivityAnalysis = false; // Default mode. } diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 1e5d98992..43750828e 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -51,26 +51,27 @@ class CladTimerGroup { namespace plugin { struct DifferentiationOptions { - DifferentiationOptions() - : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), - DumpDerivedAST(false), GenerateSourceFile(false), - ValidateClangVersion(true), EnableTBRAnalysis(false), - DisableTBRAnalysis(false), EnableActivityAnalysis(false), DisableActivityAnalysis(false), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false) {} - - bool DumpSourceFn : 1; - bool DumpSourceFnAST : 1; - bool DumpDerivedFn : 1; - bool DumpDerivedAST : 1; - bool GenerateSourceFile : 1; - bool ValidateClangVersion : 1; - bool EnableTBRAnalysis : 1; - bool DisableTBRAnalysis : 1; - bool EnableActivityAnalysis : 1; - bool DisableActivityAnalysis : 1; - bool CustomEstimationModel : 1; - bool PrintNumDiffErrorInfo : 1; - std::string CustomModelName; + DifferentiationOptions() + : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), + DumpDerivedAST(false), GenerateSourceFile(false), + ValidateClangVersion(true), EnableTBRAnalysis(false), + DisableTBRAnalysis(false), EnableActivityAnalysis(false), + DisableActivityAnalysis(false), CustomEstimationModel(false), + PrintNumDiffErrorInfo(false) {} + + bool DumpSourceFn : 1; + bool DumpSourceFnAST : 1; + bool DumpDerivedFn : 1; + bool DumpDerivedAST : 1; + bool GenerateSourceFile : 1; + bool ValidateClangVersion : 1; + bool EnableTBRAnalysis : 1; + bool DisableTBRAnalysis : 1; + bool EnableActivityAnalysis : 1; + bool DisableActivityAnalysis : 1; + bool CustomEstimationModel : 1; + bool PrintNumDiffErrorInfo : 1; + std::string CustomModelName; }; class CladExternalSource : public clang::ExternalSemaSource { @@ -316,11 +317,11 @@ class CladTimerGroup { m_DO.EnableTBRAnalysis = true; } else if (args[i] == "-disable-tbr") { m_DO.DisableTBRAnalysis = true; - }else if(args[i] == "-enable-aa"){ + } else if (args[i] == "-enable-aa") { m_DO.EnableActivityAnalysis = true; - }else if(args[i] == "-disable-aa"){ + } else if (args[i] == "-disable-aa") { m_DO.DisableActivityAnalysis = true; - }else if (args[i] == "-fcustom-estimation-model") { + } else if (args[i] == "-fcustom-estimation-model") { m_DO.CustomEstimationModel = true; if (++i == e) { llvm::errs() << "No shared object was specified.";