-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Varied analysis to the reverse mode #1084
base: master
Are you sure you want to change the base?
Changes from all commits
baf78b3
fedcdaa
f86259a
fec3b70
bd1807f
a915ca5
5f9b6b1
80e6799
5d8416e
945bb85
6fbbb2c
5924874
431c490
59ea27e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
#include "ActivityAnalyzer.h" | ||
|
||
using namespace clang; | ||
|
||
namespace clad { | ||
|
||
void VariedAnalyzer::Analyze(const FunctionDecl* FD) { | ||
// Build the CFG (control-flow graph) of FD. | ||
clang::CFG::BuildOptions Options; | ||
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); | ||
|
||
m_BlockData.resize(m_CFG->size()); | ||
// Set current block ID to the ID of entry the block. | ||
CFGBlock* entry = &m_CFG->getEntry(); | ||
m_CurBlockID = entry->getBlockID(); | ||
m_BlockData[m_CurBlockID] = createNewVarsData({}); | ||
for (const VarDecl* i : m_VariedDecls) | ||
m_BlockData[m_CurBlockID]->insert(i); | ||
// Add the entry block to the queue. | ||
m_CFGQueue.insert(m_CurBlockID); | ||
|
||
// Visit CFG blocks in the queue until it's empty. | ||
while (!m_CFGQueue.empty()) { | ||
auto IDIter = std::prev(m_CFGQueue.end()); | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m_CurBlockID = *IDIter; | ||
m_CFGQueue.erase(IDIter); | ||
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); | ||
AnalyzeCFGBlock(nextBlock); | ||
} | ||
} | ||
|
||
void mergeVarsData(VarsData* targetData, VarsData* mergeData) { | ||
for (const clang::VarDecl* i : *mergeData) | ||
targetData->insert(i); | ||
for (const clang::VarDecl* i : *targetData) | ||
mergeData->insert(i); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this function will make |
||
|
||
CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) { | ||
return *(m_CFG->begin() + ID); | ||
} | ||
|
||
void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) { | ||
// Visit all the statements inside the block. | ||
for (const clang::CFGElement& Element : block) { | ||
if (Element.getKind() == clang::CFGElement::Statement) { | ||
const clang::Stmt* S = Element.castAs<clang::CFGStmt>().getStmt(); | ||
// The const_cast is inevitable, since there is no | ||
// ConstRecusiveASTVisitor. | ||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||
TraverseStmt(const_cast<clang::Stmt*>(S)); | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
for (const clang::CFGBlock::AdjacentBlock succ : block.succs()) { | ||
if (!succ) | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
continue; | ||
auto& succData = m_BlockData[succ->getBlockID()]; | ||
|
||
if (!succData) | ||
succData = createNewVarsData(*m_BlockData[block.getBlockID()]); | ||
|
||
bool shouldPushSucc = true; | ||
if (succ->getBlockID() > block.getBlockID()) { | ||
if (m_LoopMem == *m_BlockData[block.getBlockID()]) | ||
shouldPushSucc = false; | ||
|
||
for (const VarDecl* i : *m_BlockData[block.getBlockID()]) | ||
m_LoopMem.insert(i); | ||
} | ||
|
||
if (shouldPushSucc) | ||
m_CFGQueue.insert(succ->getBlockID()); | ||
|
||
mergeVarsData(succData.get(), m_BlockData[block.getBlockID()].get()); | ||
} | ||
// FIXME: Information about the varied variables is stored in the last block, | ||
// so we should be able to get it form there | ||
for (const VarDecl* i : *m_BlockData[block.getBlockID()]) | ||
m_VariedDecls.insert(i); | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
bool VariedAnalyzer::isVaried(const VarDecl* VD) const { | ||
const VarsData& curBranch = getCurBlockVarsData(); | ||
return curBranch.find(VD) != curBranch.end(); | ||
} | ||
|
||
void VariedAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { | ||
VarsData& curBranch = getCurBlockVarsData(); | ||
curBranch.insert(VD); | ||
} | ||
|
||
bool VariedAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { | ||
Expr* L = BinOp->getLHS(); | ||
Expr* R = BinOp->getRHS(); | ||
const auto opCode = BinOp->getOpcode(); | ||
if (BinOp->isAssignmentOp()) { | ||
m_Varied = false; | ||
TraverseStmt(R); | ||
m_Marking = m_Varied; | ||
TraverseStmt(L); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: cannot initialize object parameter of type 'clang::RecursiveASTVisitorclad::VariedAnalyzer' with an expression of type 'clad::VariedAnalyzer' [clang-diagnostic-error] TraverseStmt(L);
^ |
||
m_Marking = false; | ||
} else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul || | ||
opCode == BO_Div) { | ||
for (auto* subexpr : BinOp->children()) | ||
if (!isa<BinaryOperator>(subexpr)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we visit binary operators recursively? Doesn't this mean we will ignore a part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #1106 |
||
TraverseStmt(subexpr); | ||
} | ||
return true; | ||
} | ||
|
||
// add branching merging | ||
bool VariedAnalyzer::VisitConditionalOperator(ConditionalOperator* CO) { | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TraverseStmt(CO->getCond()); | ||
TraverseStmt(CO->getTrueExpr()); | ||
ovdiiuv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TraverseStmt(CO->getFalseExpr()); | ||
ovdiiuv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { | ||
FunctionDecl* FD = CE->getDirectCallee(); | ||
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); | ||
if (noHiddenParam) { | ||
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters(); | ||
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { | ||
clang::Expr* par = CE->getArg(i); | ||
TraverseStmt(par); | ||
m_VariedDecls.insert(FDparam[i]); | ||
} | ||
} | ||
m_Varied = true; | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { | ||
for (Decl* D : DS->decls()) { | ||
if (!isa<VarDecl>(D)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A test here could be |
||
continue; | ||
if (Expr* init = cast<VarDecl>(D)->getInit()) { | ||
m_Varied = false; | ||
TraverseStmt(init); | ||
m_Marking = true; | ||
if (m_Varied) | ||
copyVarToCurBlock(cast<VarDecl>(D)); | ||
m_Marking = false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { | ||
const auto opCode = UnOp->getOpcode(); | ||
Expr* E = UnOp->getSubExpr(); | ||
if (opCode == UO_AddrOf || opCode == UO_Deref) { | ||
m_Varied = true; | ||
m_Marking = true; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to set these modes to |
||
TraverseStmt(E); | ||
return true; | ||
} | ||
|
||
bool VariedAnalyzer::VisitInitListExpr(InitListExpr* ILE) { return true; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: method 'VisitInitListExpr' can be made static [readability-convert-member-functions-to-static] lib/Differentiator/ActivityAnalyzer.h:81: - bool VisitInitListExpr(clang::InitListExpr* ILE);
+ static bool VisitInitListExpr(clang::InitListExpr* ILE); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: method 'VisitInitListExpr' can be made static [readability-convert-member-functions-to-static] lib/Differentiator/ActivityAnalyzer.h:79: - bool VisitInitListExpr(clang::InitListExpr* ILE);
+ static bool VisitInitListExpr(clang::InitListExpr* ILE); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ping. |
||
|
||
bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { | ||
if (isVaried(dyn_cast<VarDecl>(DRE->getDecl()))) | ||
m_Varied = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this portion of the code repeat what comes next? Perhaps we should keep one of these. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
|
||
auto* VD = dyn_cast<VarDecl>(DRE->getDecl()); | ||
if (!VD) | ||
return true; | ||
|
||
if (isVaried(VD)) | ||
m_Varied = true; | ||
|
||
if (m_Varied && m_Marking) | ||
copyVarToCurBlock(VD); | ||
return true; | ||
} | ||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} // namespace clad |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,83 @@ | ||||||||
#ifndef CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H | ||||||||
#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H | ||||||||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: header guard does not follow preferred style [llvm-header-guard]
Suggested change
lib/Differentiator/ActivityAnalyzer.h:84: - #endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
+ #endif // GITHUB_WORKSPACE_LIB_DIFFERENTIATOR_ACTIVITYANALYZER_H
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
#include "clang/AST/RecursiveASTVisitor.h" | ||||||||
ovdiiuv marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
#include "clang/Analysis/CFG.h" | ||||||||
|
||||||||
#include "clad/Differentiator/CladUtils.h" | ||||||||
#include "clad/Differentiator/Compatibility.h" | ||||||||
|
||||||||
#include <algorithm> | ||||||||
#include <iterator> | ||||||||
#include <memory> | ||||||||
#include <set> | ||||||||
#include <utility> | ||||||||
/// Class that implemets Varied part of the Activity analysis. | ||||||||
/// By performing static data-flow analysis, so called Varied variables | ||||||||
/// are determined, meaning variables that depend on input parameters | ||||||||
/// in a differentiable way. That result enables us to remove redundant | ||||||||
/// statements in the reverse mode, improving generated codes efficiency. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The documentation should go to the class definition. |
||||||||
namespace clad { | ||||||||
using VarsData = std::set<const clang::VarDecl*>; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can go in the class definition. |
||||||||
class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> { | ||||||||
ovdiiuv marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
bool m_Varied = false; | ||||||||
bool m_Marking = false; | ||||||||
|
||||||||
std::set<const clang::VarDecl*>& m_VariedDecls; | ||||||||
ovdiiuv marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
/// A helper method to allocate VarsData | ||||||||
/// \param[in] toAssign - Parameter to initialize new VarsData with. | ||||||||
/// \return Unique pointer to a new object of type Varsdata. | ||||||||
static std::unique_ptr<VarsData> createNewVarsData(VarsData toAssign) { | ||||||||
return std::unique_ptr<VarsData>(new VarsData(std::move(toAssign))); | ||||||||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
} | ||||||||
VarsData m_LoopMem; | ||||||||
|
||||||||
clang::CFGBlock* getCFGBlockByID(unsigned ID); | ||||||||
|
||||||||
clang::ASTContext& m_Context; | ||||||||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
std::unique_ptr<clang::CFG> m_CFG; | ||||||||
std::vector<std::unique_ptr<VarsData>> m_BlockData; | ||||||||
unsigned m_CurBlockID{}; | ||||||||
std::set<unsigned> m_CFGQueue; | ||||||||
/// Checks if a variable is on the current branch. | ||||||||
/// \param[in] VD - Variable declaration. | ||||||||
/// @return Whether a variable is on the current branch. | ||||||||
bool isVaried(const clang::VarDecl* VD) const; | ||||||||
/// Adds varied variable to current branch. | ||||||||
/// \param[in] VD - Variable declaration. | ||||||||
void copyVarToCurBlock(const clang::VarDecl* VD); | ||||||||
VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } | ||||||||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
[[nodiscard]] const VarsData& getCurBlockVarsData() const { | ||||||||
return const_cast<VariedAnalyzer*>(this)->getCurBlockVarsData(); | ||||||||
vgvassilev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
} | ||||||||
void AnalyzeCFGBlock(const clang::CFGBlock& block); | ||||||||
|
||||||||
public: | ||||||||
/// Constructor | ||||||||
VariedAnalyzer(clang::ASTContext& Context, | ||||||||
std::set<const clang::VarDecl*>& Decls) | ||||||||
: m_VariedDecls(Decls), m_Context(Context) {} | ||||||||
|
||||||||
/// Destructor | ||||||||
~VariedAnalyzer() = default; | ||||||||
|
||||||||
/// Delete copy/move operators and constructors. | ||||||||
VariedAnalyzer(const VariedAnalyzer&) = delete; | ||||||||
VariedAnalyzer& operator=(const VariedAnalyzer&) = delete; | ||||||||
VariedAnalyzer(const VariedAnalyzer&&) = delete; | ||||||||
VariedAnalyzer& operator=(const VariedAnalyzer&&) = delete; | ||||||||
|
||||||||
/// Runs Varied analysis. | ||||||||
/// \param[in] FD Function to run the analysis on. | ||||||||
void Analyze(const clang::FunctionDecl* FD); | ||||||||
bool VisitBinaryOperator(clang::BinaryOperator* BinOp); | ||||||||
bool VisitCallExpr(clang::CallExpr* CE); | ||||||||
bool VisitConditionalOperator(clang::ConditionalOperator* CO); | ||||||||
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); | ||||||||
bool VisitDeclStmt(clang::DeclStmt* DS); | ||||||||
bool VisitUnaryOperator(clang::UnaryOperator* UnOp); | ||||||||
bool VisitInitListExpr(clang::InitListExpr* ILE); | ||||||||
}; | ||||||||
} // namespace clad | ||||||||
#endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: #endif without #if [clang-diagnostic-error] #endif // CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
^ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: member variable 'EnableVariedAnalysis' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]